use std::any::Any;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use dashmap::DashMap;
use futures::StreamExt;
use futures::future::BoxFuture;
use futures::stream::{self, BoxStream};
use thiserror::Error;
use tokio::sync::broadcast;
use crate::Command;
use crate::subscription::{SubscriptionId, SubscriptionSource};
use super::cache::CacheEntry;
use super::config::QueryConfig;
#[derive(Error, Debug, Clone)]
pub enum QueryError {
#[error("Fetch failed: {0}")]
FetchError(String),
#[error("Network error: {0}")]
NetworkError(String),
}
#[derive(Debug, Clone)]
pub enum QueryState<T> {
Loading,
Success {
data: T,
is_stale: bool,
},
Error(String),
}
#[derive(Debug, Clone)]
pub struct QueryResult<T> {
pub state: QueryState<T>,
}
impl<T> QueryResult<T> {
pub const fn data(&self) -> Option<&T> {
match &self.state {
QueryState::Success { data, .. } => Some(data),
_ => None,
}
}
pub const fn is_loading(&self) -> bool {
matches!(self.state, QueryState::Loading)
}
pub const fn is_success(&self) -> bool {
matches!(self.state, QueryState::Success { .. })
}
pub const fn is_error(&self) -> bool {
matches!(self.state, QueryState::Error(_))
}
pub const fn is_stale(&self) -> bool {
matches!(self.state, QueryState::Success { is_stale: true, .. })
}
}
#[derive(Debug, Clone)]
pub struct QueryClient {
cache: Arc<DashMap<String, Box<dyn Any + Send + Sync>>>,
invalidation_tx: broadcast::Sender<String>,
config: QueryConfig,
}
impl QueryClient {
#[must_use]
pub fn new() -> Self {
Self::with_config(QueryConfig::default())
}
#[must_use]
pub fn with_config(config: QueryConfig) -> Self {
let (invalidation_tx, _) = broadcast::channel(100);
Self {
cache: Arc::new(DashMap::new()),
invalidation_tx,
config,
}
}
pub fn invalidate<Msg>(&self, key: &impl ToString) -> Command<Msg>
where
Msg: Send + 'static,
{
let tx = self.invalidation_tx.clone();
let key_string = key.to_string();
Command {
stream: Some(
futures::stream::once(async move {
let _ = tx.send(key_string);
})
.filter_map(|()| async { None })
.boxed(),
),
}
}
fn subscribe_invalidation(&self) -> broadcast::Receiver<String> {
self.invalidation_tx.subscribe()
}
fn get_cache<T: Clone + Send + Sync + 'static>(&self, key: &str) -> Option<CacheEntry<T>> {
self.cache
.get(key)
.and_then(|entry| entry.downcast_ref::<CacheEntry<T>>().cloned())
}
fn set_cache<T: Clone + Send + Sync + 'static>(&self, key: String, entry: CacheEntry<T>) {
self.cache.insert(key, Box::new(entry));
}
const fn config(&self) -> &QueryConfig {
&self.config
}
}
impl Default for QueryClient {
fn default() -> Self {
Self::new()
}
}
pub struct Query<V> {
key: String,
fetcher: Arc<dyn Fn() -> BoxFuture<'static, Result<V, QueryError>> + Send + Sync>,
client: Arc<QueryClient>,
}
impl<V> Query<V>
where
V: Clone + Send + Sync + 'static,
{
pub fn new<F>(key: &impl ToString, fetcher: F, client: Arc<QueryClient>) -> Self
where
F: Fn() -> BoxFuture<'static, Result<V, QueryError>> + Send + Sync + 'static,
{
Self {
key: key.to_string(),
fetcher: Arc::new(fetcher),
client,
}
}
}
impl<V> SubscriptionSource for Query<V>
where
V: Clone + Send + Sync + 'static,
{
type Output = QueryResult<V>;
fn stream(&self) -> BoxStream<'static, Self::Output> {
let key = self.key.clone();
let fetcher = self.fetcher.clone();
let client = self.client.clone();
stream::unfold(State::Initial, move |state| {
let key = key.clone();
let fetcher = fetcher.clone();
let client = client.clone();
async move {
match state {
State::Initial => {
if let Some(mut cached) = client.get_cache::<V>(&key) {
let is_stale = cached.check_staleness(client.config().stale_time);
let result = QueryResult {
state: QueryState::Success {
data: cached.data.clone(),
is_stale,
},
};
if is_stale {
Some((result, State::Refetching))
} else {
let rx = client.subscribe_invalidation();
Some((result, State::Watching { rx }))
}
} else {
let result = QueryResult {
state: QueryState::Loading,
};
Some((result, State::Fetching))
}
}
State::Fetching => {
match fetcher().await {
Ok(data) => {
let entry = CacheEntry::new(data.clone());
client.set_cache(key.clone(), entry);
let result = QueryResult {
state: QueryState::Success {
data,
is_stale: false,
},
};
let rx = client.subscribe_invalidation();
Some((result, State::Watching { rx }))
}
Err(e) => {
let result = QueryResult {
state: QueryState::Error(e.to_string()),
};
let rx = client.subscribe_invalidation();
Some((result, State::Watching { rx }))
}
}
}
State::Refetching => {
match fetcher().await {
Ok(data) => {
let entry = CacheEntry::new(data.clone());
client.set_cache(key.clone(), entry);
let result = QueryResult {
state: QueryState::Success {
data,
is_stale: false,
},
};
let rx = client.subscribe_invalidation();
Some((result, State::Watching { rx }))
}
Err(e) => {
let result = QueryResult {
state: QueryState::Error(e.to_string()),
};
let rx = client.subscribe_invalidation();
Some((result, State::Watching { rx }))
}
}
}
State::Watching { mut rx } => {
loop {
match rx.recv().await {
Ok(invalidated_key) if invalidated_key == key => {
let result = QueryResult {
state: QueryState::Loading,
};
return Some((result, State::Fetching));
}
Ok(_) => {
}
Err(_) => {
return None;
}
}
}
}
}
}
})
.boxed()
}
fn id(&self) -> SubscriptionId {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
SubscriptionId::of::<Self>(hasher.finish())
}
}
impl<V> Hash for Query<V> {
fn hash<H>(&self, hasher: &mut H)
where
H: std::hash::Hasher,
{
self.key.hash(hasher);
}
}
enum State {
Initial,
Fetching,
Refetching,
Watching { rx: broadcast::Receiver<String> },
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_query_result_data() {
let result = QueryResult {
state: QueryState::Success {
data: 42,
is_stale: false,
},
};
assert_eq!(result.data(), Some(&42));
let result: QueryResult<i32> = QueryResult {
state: QueryState::Loading,
};
assert_eq!(result.data(), None);
let result: QueryResult<i32> = QueryResult {
state: QueryState::Error("error".to_string()),
};
assert_eq!(result.data(), None);
}
#[test]
fn test_query_result_predicates() {
let loading: QueryResult<i32> = QueryResult {
state: QueryState::Loading,
};
assert!(loading.is_loading());
assert!(!loading.is_success());
assert!(!loading.is_error());
assert!(!loading.is_stale());
let success = QueryResult {
state: QueryState::Success {
data: 42,
is_stale: false,
},
};
assert!(!success.is_loading());
assert!(success.is_success());
assert!(!success.is_error());
assert!(!success.is_stale());
let stale = QueryResult {
state: QueryState::Success {
data: 42,
is_stale: true,
},
};
assert!(!stale.is_loading());
assert!(stale.is_success());
assert!(!stale.is_error());
assert!(stale.is_stale());
let error: QueryResult<i32> = QueryResult {
state: QueryState::Error("error".to_string()),
};
assert!(!error.is_loading());
assert!(!error.is_success());
assert!(error.is_error());
assert!(!error.is_stale());
}
#[test]
fn test_query_client_new() {
let client = QueryClient::new();
assert_eq!(client.cache.len(), 0);
assert_eq!(client.config.stale_time, Duration::from_secs(0));
}
#[test]
fn test_query_client_with_config() {
let config = QueryConfig::new(Duration::from_secs(30), Duration::from_secs(300));
let client = QueryClient::with_config(config);
assert_eq!(client.config.stale_time, Duration::from_secs(30));
assert_eq!(client.config.cache_time, Duration::from_secs(300));
}
#[test]
fn test_query_client_cache_operations() {
let client = QueryClient::new();
assert!(client.get_cache::<i32>("key1").is_none());
let entry = CacheEntry::new(42);
client.set_cache("key1".to_string(), entry);
let cached = client.get_cache::<i32>("key1");
assert!(cached.is_some());
if let Some(entry) = cached {
assert_eq!(entry.data, 42);
}
}
#[test]
fn test_query_error_display() {
let err = QueryError::FetchError("test error".to_string());
assert_eq!(err.to_string(), "Fetch failed: test error");
let err = QueryError::NetworkError("network error".to_string());
assert_eq!(err.to_string(), "Network error: network error");
}
#[tokio::test]
async fn test_invalidate_command_execution() {
use futures::StreamExt;
let client = QueryClient::new();
let cmd: Command<()> = client.invalidate(&"test-key");
let stream = cmd
.stream
.expect("invalidate should produce a command with a stream");
let actions: Vec<_> = stream.collect().await;
assert!(
actions.is_empty(),
"invalidate should not produce any messages"
);
}
#[tokio::test]
async fn test_invalidate_broadcasts_notification() {
use futures::StreamExt;
let client = QueryClient::new();
let mut rx = client.subscribe_invalidation();
let cmd: Command<()> = client.invalidate(&"test-key");
if let Some(stream) = cmd.stream {
let _: Vec<_> = stream.collect().await;
}
let result = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv()).await;
let key = result
.expect("Should receive notification within timeout")
.expect("Channel should not be closed");
assert_eq!(key, "test-key");
}
#[tokio::test]
async fn test_invalidate_nonexistent_key() {
use futures::StreamExt;
let client = QueryClient::new();
let cmd: Command<()> = client.invalidate(&"nonexistent");
if let Some(stream) = cmd.stream {
let actions: Vec<_> = stream.collect().await;
assert!(actions.is_empty());
}
let mut rx = client.subscribe_invalidation();
let cmd: Command<()> = client.invalidate(&"nonexistent");
if let Some(stream) = cmd.stream {
let _: Vec<_> = stream.collect().await;
}
let result = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv()).await;
let key = result
.expect("Should receive notification within timeout")
.expect("Channel should not be closed");
assert_eq!(key, "nonexistent");
}
#[test]
fn test_query_id_consistency() {
let client = Arc::new(QueryClient::new());
let query1 = Query::new(
&"user-123",
|| Box::pin(async { Ok::<i32, QueryError>(42) }),
client.clone(),
);
let query2 = Query::new(
&"user-123",
|| Box::pin(async { Ok::<i32, QueryError>(42) }),
client,
);
assert_eq!(query1.id(), query2.id());
}
#[test]
fn test_query_id_different_keys() {
let client = Arc::new(QueryClient::new());
let query1 = Query::new(
&"user-123",
|| Box::pin(async { Ok::<i32, QueryError>(42) }),
client.clone(),
);
let query2 = Query::new(
&"user-456",
|| Box::pin(async { Ok::<i32, QueryError>(42) }),
client,
);
assert_ne!(query1.id(), query2.id());
}
#[test]
fn test_query_id_same_key_different_type() {
let client = Arc::new(QueryClient::new());
let query1 = Query::new(
&"data",
|| Box::pin(async { Ok::<i32, QueryError>(42) }),
client.clone(),
);
let query2 = Query::new(
&"data",
|| Box::pin(async { Ok::<String, QueryError>("test".to_string()) }),
client,
);
assert_ne!(query1.id(), query2.id());
}
}