use crate::{
adapter::AdapterKind,
error::{DataError, DataResult},
query::Query,
repo::{MemoryRepo, Repo, Row, StoredRow},
};
use std::{
future::Future,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
},
time::{Duration, Instant},
};
pub type AsyncRepoFuture<'a, T> = Pin<Box<dyn Future<Output = DataResult<T>> + Send + 'a>>;
#[derive(Debug, Clone, Default)]
pub struct AsyncCancellationToken {
cancelled: Arc<AtomicBool>,
}
impl AsyncCancellationToken {
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::SeqCst);
}
pub fn reset(&self) {
self.cancelled.store(false, Ordering::SeqCst);
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::SeqCst)
}
}
#[derive(Debug, Clone, Default)]
pub struct AsyncQueryContext {
pub deadline: Option<Instant>,
pub cancellation: AsyncCancellationToken,
}
impl AsyncQueryContext {
pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
self.deadline = Some(Instant::now() + Duration::from_millis(timeout_ms.max(1)));
self
}
pub fn with_deadline(mut self, deadline: Instant) -> Self {
self.deadline = Some(deadline);
self
}
pub fn with_cancellation_token(mut self, token: AsyncCancellationToken) -> Self {
self.cancellation = token;
self
}
pub fn is_cancelled(&self) -> bool {
self.cancellation.is_cancelled()
}
pub fn is_deadline_exceeded(&self) -> bool {
self.deadline
.map(|deadline| Instant::now() > deadline)
.unwrap_or(false)
}
pub fn ensure_active(&self) -> DataResult<()> {
if self.is_cancelled() {
return Err(DataError::Query(
"async query cancelled by cancellation token".to_string(),
));
}
if self.is_deadline_exceeded() {
return Err(DataError::Query(
"async query deadline exceeded before completion".to_string(),
));
}
Ok(())
}
}
pub trait AsyncRepo: Send + Sync {
fn adapter_kind(&self) -> AdapterKind;
fn insert<'a>(
&'a self,
context: &'a AsyncQueryContext,
table: &'a str,
data: Row,
) -> AsyncRepoFuture<'a, StoredRow>;
fn update<'a>(
&'a self,
context: &'a AsyncQueryContext,
table: &'a str,
id: u64,
data: Row,
) -> AsyncRepoFuture<'a, StoredRow>;
fn delete<'a>(
&'a self,
context: &'a AsyncQueryContext,
table: &'a str,
id: u64,
) -> AsyncRepoFuture<'a, ()>;
fn find<'a>(
&'a self,
context: &'a AsyncQueryContext,
table: &'a str,
id: u64,
) -> AsyncRepoFuture<'a, Option<StoredRow>>;
fn list<'a>(
&'a self,
context: &'a AsyncQueryContext,
table: &'a str,
query: &'a Query,
) -> AsyncRepoFuture<'a, Vec<StoredRow>>;
}
#[derive(Clone)]
pub struct AsyncMemoryRepo {
inner: Arc<Mutex<MemoryRepo>>,
adapter_kind: AdapterKind,
}
impl AsyncMemoryRepo {
pub fn new(repo: MemoryRepo) -> Self {
let adapter_kind = repo.adapter_kind();
Self {
inner: Arc::new(Mutex::new(repo)),
adapter_kind,
}
}
}
impl AsyncRepo for AsyncMemoryRepo {
fn adapter_kind(&self) -> AdapterKind {
self.adapter_kind
}
fn insert<'a>(
&'a self,
context: &'a AsyncQueryContext,
table: &'a str,
data: Row,
) -> AsyncRepoFuture<'a, StoredRow> {
Box::pin(async move {
context.ensure_active()?;
let result = self
.inner
.lock()
.map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
.insert(table, data);
context.ensure_active()?;
result
})
}
fn update<'a>(
&'a self,
context: &'a AsyncQueryContext,
table: &'a str,
id: u64,
data: Row,
) -> AsyncRepoFuture<'a, StoredRow> {
Box::pin(async move {
context.ensure_active()?;
let result = self
.inner
.lock()
.map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
.update(table, id, data);
context.ensure_active()?;
result
})
}
fn delete<'a>(
&'a self,
context: &'a AsyncQueryContext,
table: &'a str,
id: u64,
) -> AsyncRepoFuture<'a, ()> {
Box::pin(async move {
context.ensure_active()?;
let result = self
.inner
.lock()
.map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
.delete(table, id);
context.ensure_active()?;
result
})
}
fn find<'a>(
&'a self,
context: &'a AsyncQueryContext,
table: &'a str,
id: u64,
) -> AsyncRepoFuture<'a, Option<StoredRow>> {
Box::pin(async move {
context.ensure_active()?;
let result = self
.inner
.lock()
.map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
.find(table, id);
context.ensure_active()?;
result
})
}
fn list<'a>(
&'a self,
context: &'a AsyncQueryContext,
table: &'a str,
query: &'a Query,
) -> AsyncRepoFuture<'a, Vec<StoredRow>> {
Box::pin(async move {
context.ensure_active()?;
let result = self
.inner
.lock()
.map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
.list(table, query);
context.ensure_active()?;
result
})
}
}
#[cfg(test)]
mod tests {
use super::{AsyncCancellationToken, AsyncMemoryRepo, AsyncQueryContext, AsyncRepo};
use crate::{adapter_for, AdapterKind, DatabaseConfig, MemoryRepo, Query, Row};
use serde_json::json;
use std::{
panic::{catch_unwind, AssertUnwindSafe},
time::{Duration, Instant},
};
#[tokio::test]
async fn async_memory_repo_supports_insert_and_list() {
let repo = MemoryRepo::new(
adapter_for(&DatabaseConfig {
adapter: AdapterKind::SingleStore,
url: None,
url_env: None,
})
.expect("driver"),
);
let async_repo = AsyncMemoryRepo::new(repo);
let mut row = Row::new();
row.insert("account".to_string(), json!("Acme"));
let context = AsyncQueryContext::default().with_timeout_ms(500);
async_repo
.insert(&context, "accounts", row)
.await
.expect("insert");
let rows = async_repo
.list(&context, "accounts", &Query::new())
.await
.expect("list");
assert_eq!(rows.len(), 1);
}
#[tokio::test]
async fn async_query_context_cancellation_stops_query() {
let repo = MemoryRepo::new(
adapter_for(&DatabaseConfig {
adapter: AdapterKind::Postgres,
url: None,
url_env: None,
})
.expect("driver"),
);
let async_repo = AsyncMemoryRepo::new(repo);
let token = AsyncCancellationToken::default();
token.cancel();
let context = AsyncQueryContext::default().with_cancellation_token(token);
let result = async_repo.list(&context, "accounts", &Query::new()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn async_query_context_deadline_exceeded_stops_query() {
let repo = MemoryRepo::new(
adapter_for(&DatabaseConfig {
adapter: AdapterKind::MySql,
url: None,
url_env: None,
})
.expect("driver"),
);
let async_repo = AsyncMemoryRepo::new(repo);
let context = AsyncQueryContext::default().with_timeout_ms(1);
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
let result = async_repo.find(&context, "accounts", 1).await;
assert!(result.is_err());
}
#[tokio::test]
async fn async_query_context_reset_and_explicit_deadline_paths_are_exercised() {
let token = AsyncCancellationToken::default();
token.cancel();
assert!(token.is_cancelled());
token.reset();
assert!(!token.is_cancelled());
let past_deadline = Instant::now()
.checked_sub(Duration::from_millis(1))
.expect("past instant");
let context = AsyncQueryContext::default()
.with_deadline(past_deadline)
.with_cancellation_token(token);
let result = context.ensure_active();
assert!(result.is_err());
}
#[tokio::test]
async fn async_memory_repo_supports_update_and_delete() {
let repo = MemoryRepo::new(
adapter_for(&DatabaseConfig {
adapter: AdapterKind::ClickHouse,
url: None,
url_env: None,
})
.expect("driver"),
);
let async_repo = AsyncMemoryRepo::new(repo);
let context = AsyncQueryContext::default().with_timeout_ms(500);
let mut initial = Row::new();
initial.insert("name".to_string(), json!("Draft"));
let inserted = async_repo
.insert(&context, "accounts", initial)
.await
.expect("insert");
let mut updated = Row::new();
updated.insert("name".to_string(), json!("Published"));
let updated_row = async_repo
.update(&context, "accounts", inserted.id, updated)
.await
.expect("update");
assert_eq!(updated_row.data.get("name"), Some(&json!("Published")));
async_repo
.delete(&context, "accounts", inserted.id)
.await
.expect("delete");
let found = async_repo
.find(&context, "accounts", inserted.id)
.await
.expect("find after delete");
assert!(found.is_none());
}
#[tokio::test]
async fn async_memory_repo_reports_lock_poisoned_errors_and_adapter_kind() {
let repo = MemoryRepo::new(
adapter_for(&DatabaseConfig {
adapter: AdapterKind::OpenSearch,
url: None,
url_env: None,
})
.expect("driver"),
);
let async_repo = AsyncMemoryRepo::new(repo);
assert_eq!(async_repo.adapter_kind(), AdapterKind::OpenSearch);
let _ = catch_unwind(AssertUnwindSafe(|| {
let _guard = async_repo.inner.lock().expect("repo lock");
panic!("poison async memory repo lock");
}));
let context = AsyncQueryContext::default().with_timeout_ms(100);
let err = async_repo.list(&context, "accounts", &Query::new()).await;
assert!(err
.unwrap_err()
.to_string()
.contains("async memory repo lock poisoned"));
}
}