Documentation
// Copyright 2024 The Drasi Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::{
    hash::{Hash, Hasher},
    sync::Arc,
};

use async_trait::async_trait;
use caches::{lru::CacheError, Cache, DefaultHashBuilder, LRUCache};
use hashers::builtin::DefaultHasher;
use ordered_float::OrderedFloat;
use tokio::sync::RwLock;

use crate::{
    evaluation::functions::aggregation::ValueAccumulator,
    interface::{
        AccumulatorIndex, IndexError, LazySortedSetStore, ResultIndex, ResultKey, ResultOwner,
        ResultSequence, ResultSequenceCounter,
    },
};

pub struct CachedResultIndex {
    inner: Arc<dyn ResultIndex>,

    value_cache: Arc<RwLock<LRUCache<u64, ValueAccumulator, DefaultHashBuilder>>>,
    set_count_cache: Arc<RwLock<LRUCache<(u64, OrderedFloat<f64>), isize, DefaultHashBuilder>>>,
}

impl CachedResultIndex {
    pub fn new(inner: Arc<dyn ResultIndex>, cache_size: usize) -> Result<Self, CacheError> {
        log::info!("using cached result index with cache size {cache_size}");

        let value_cache = LRUCache::new(cache_size)?;
        let set_count_cache = LRUCache::new(cache_size)?;

        Ok(CachedResultIndex {
            inner,
            value_cache: Arc::new(RwLock::new(value_cache)),
            set_count_cache: Arc::new(RwLock::new(set_count_cache)),
        })
    }
}

#[async_trait]
impl AccumulatorIndex for CachedResultIndex {
    async fn clear(&self) -> Result<(), IndexError> {
        self.inner.clear().await?;

        let mut value_cache = self.value_cache.write().await;
        value_cache.purge();

        let mut set_count_cache = self.set_count_cache.write().await;
        set_count_cache.purge();

        Ok(())
    }

    async fn get(
        &self,
        key: &ResultKey,
        owner: &ResultOwner,
    ) -> Result<Option<ValueAccumulator>, IndexError> {
        let cache_key = get_hash_key(owner, key);

        let mut cache = self.value_cache.write().await;
        match cache.get(&cache_key) {
            None => {
                let value = self.inner.get(key, owner).await?;
                match value {
                    None => Ok(None),
                    Some(v) => {
                        _ = cache.put(cache_key, v.clone());
                        Ok(Some(v))
                    }
                }
            }
            Some(v) => Ok(Some(v.clone())),
        }
    }

    async fn set(
        &self,
        key: ResultKey,
        owner: ResultOwner,
        value: Option<ValueAccumulator>,
    ) -> Result<(), IndexError> {
        let cache_key = get_hash_key(&owner, &key);

        self.inner.set(key, owner, value.clone()).await?;

        let mut cache = self.value_cache.write().await;
        match value {
            None => _ = cache.remove(&cache_key),
            Some(v) => _ = cache.put(cache_key, v),
        };

        Ok(())
    }
}

#[async_trait]
impl LazySortedSetStore for CachedResultIndex {
    async fn get_next(
        &self,
        set_id: u64,
        value: Option<OrderedFloat<f64>>,
    ) -> Result<Option<(OrderedFloat<f64>, isize)>, IndexError> {
        self.inner.get_next(set_id, value).await
    }

    async fn get_value_count(
        &self,
        set_id: u64,
        value: OrderedFloat<f64>,
    ) -> Result<isize, IndexError> {
        let cache_key = (set_id, value);

        let mut cache = self.set_count_cache.write().await;
        match cache.get(&cache_key) {
            None => {
                let value = self.inner.get_value_count(set_id, value).await?;
                _ = cache.put(cache_key, value);
                Ok(value)
            }
            Some(v) => Ok(*v),
        }
    }

    async fn increment_value_count(
        &self,
        set_id: u64,
        value: OrderedFloat<f64>,
        delta: isize,
    ) -> Result<(), IndexError> {
        self.inner
            .increment_value_count(set_id, value, delta)
            .await?;

        let cache_key = (set_id, value);
        let mut cache = self.set_count_cache.write().await;

        match cache.get_mut(&cache_key) {
            None => _ = cache.put(cache_key, delta),
            Some(v) => *v += delta,
        }

        Ok(())
    }
}

#[async_trait]
impl ResultSequenceCounter for CachedResultIndex {
    async fn apply_sequence(
        &self,
        sequence: u64,
        source_change_id: &str,
    ) -> Result<(), IndexError> {
        self.inner.apply_sequence(sequence, source_change_id).await
    }

    async fn get_sequence(&self) -> Result<ResultSequence, IndexError> {
        self.inner.get_sequence().await
    }
}

impl ResultIndex for CachedResultIndex {}

fn get_hash_key(owner: &ResultOwner, key: &ResultKey) -> u64 {
    let mut hasher = DefaultHasher::new();
    owner.hash(&mut hasher);
    key.hash(&mut hasher);
    hasher.finish()
}