use super::{Driver, Error};
use async_trait::async_trait;
use std::{
	collections::HashMap,
	sync::Mutex,
	time::{Duration, SystemTime},
};

#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Default)]
/// A driver that stores values in memory.
pub struct MemoryDriver {
	cache: Mutex<HashMap<String, (Vec<u8>, Option<SystemTime>)>>,
}

impl MemoryDriver {
	#[must_use]
	pub fn new() -> Self {
		Self::default()
	}
}

#[async_trait]
impl Driver for MemoryDriver {
	async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
		let cache = self.cache.lock().unwrap();

		let Some((data, expires_at)) = cache.get(key) else {
			return Ok(None);
		};

		if let Some(expires_at) = expires_at {
			if expires_at < &SystemTime::now() {
				// We would ideally clean up expired values here, but that would require write access,
				// and trading a little memory for not blocking concurrent readers is worth it.
				return Ok(None);
			}
		}

		Ok(Some(data.clone()))
	}

	async fn has(&self, key: &str) -> Result<bool, Error> {
		Ok(self.cache.lock().unwrap().contains_key(key))
	}

	async fn put(
		&self,
		key: &str,
		value: Vec<u8>,
		duration: Option<Duration>,
	) -> Result<(), Error> {
		let expires_at = duration.map(|duration| SystemTime::now() + duration);

		self.cache
			.lock()
			.unwrap()
			.insert(key.to_owned(), (value, expires_at));

		Ok(())
	}

	async fn forget(&self, key: &str) -> Result<(), Error> {
		self.cache.lock().unwrap().remove(key);

		Ok(())
	}

	async fn flush(&self) -> Result<(), Error> {
		self.cache.lock().unwrap().clear();

		Ok(())
	}
}

#[cfg(test)]
mod tests {
	use super::*;
	use crate::Cache;

	#[tokio::test]
	async fn test_memory_driver() {
		let cache = Cache::new(MemoryDriver::new());

		assert_eq!(cache.get::<String>("foo").await.unwrap(), None);
		assert!(!cache.has("foo").await.unwrap());

		cache
			.put("foo", &"bar".to_string(), Duration::from_secs(10))
			.await
			.unwrap();

		assert_eq!(cache.get("foo").await.unwrap(), Some("bar".to_string()));
		assert!(cache.has("foo").await.unwrap());

		cache.forget("foo").await.unwrap();

		assert_eq!(cache.get::<String>("foo").await.unwrap(), None);
		assert!(!cache.has("foo").await.unwrap());
	}
}