shelly-data 0.6.0

Data-layer primitives for Shelly LiveView (schemas, changesets, repo, migrations).
Documentation
use crate::{
    adapter::AdapterKind,
    error::{DataError, DataResult},
    query::Query,
    repo::{MemoryRepo, Repo, Row, StoredRow},
};
use std::{
    future::Future,
    pin::Pin,
    sync::{
        atomic::{AtomicBool, Ordering},
        Arc, Mutex,
    },
    time::{Duration, Instant},
};

pub type AsyncRepoFuture<'a, T> = Pin<Box<dyn Future<Output = DataResult<T>> + Send + 'a>>;

#[derive(Debug, Clone, Default)]
pub struct AsyncCancellationToken {
    cancelled: Arc<AtomicBool>,
}

impl AsyncCancellationToken {
    pub fn cancel(&self) {
        self.cancelled.store(true, Ordering::SeqCst);
    }

    pub fn reset(&self) {
        self.cancelled.store(false, Ordering::SeqCst);
    }

    pub fn is_cancelled(&self) -> bool {
        self.cancelled.load(Ordering::SeqCst)
    }
}

#[derive(Debug, Clone, Default)]
pub struct AsyncQueryContext {
    pub deadline: Option<Instant>,
    pub cancellation: AsyncCancellationToken,
}

impl AsyncQueryContext {
    pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
        self.deadline = Some(Instant::now() + Duration::from_millis(timeout_ms.max(1)));
        self
    }

    pub fn with_deadline(mut self, deadline: Instant) -> Self {
        self.deadline = Some(deadline);
        self
    }

    pub fn with_cancellation_token(mut self, token: AsyncCancellationToken) -> Self {
        self.cancellation = token;
        self
    }

    pub fn is_cancelled(&self) -> bool {
        self.cancellation.is_cancelled()
    }

    pub fn is_deadline_exceeded(&self) -> bool {
        self.deadline
            .map(|deadline| Instant::now() > deadline)
            .unwrap_or(false)
    }

    pub fn ensure_active(&self) -> DataResult<()> {
        if self.is_cancelled() {
            return Err(DataError::Query(
                "async query cancelled by cancellation token".to_string(),
            ));
        }
        if self.is_deadline_exceeded() {
            return Err(DataError::Query(
                "async query deadline exceeded before completion".to_string(),
            ));
        }
        Ok(())
    }
}

pub trait AsyncRepo: Send + Sync {
    fn adapter_kind(&self) -> AdapterKind;

    fn insert<'a>(
        &'a self,
        context: &'a AsyncQueryContext,
        table: &'a str,
        data: Row,
    ) -> AsyncRepoFuture<'a, StoredRow>;

    fn update<'a>(
        &'a self,
        context: &'a AsyncQueryContext,
        table: &'a str,
        id: u64,
        data: Row,
    ) -> AsyncRepoFuture<'a, StoredRow>;

    fn delete<'a>(
        &'a self,
        context: &'a AsyncQueryContext,
        table: &'a str,
        id: u64,
    ) -> AsyncRepoFuture<'a, ()>;

    fn find<'a>(
        &'a self,
        context: &'a AsyncQueryContext,
        table: &'a str,
        id: u64,
    ) -> AsyncRepoFuture<'a, Option<StoredRow>>;

    fn list<'a>(
        &'a self,
        context: &'a AsyncQueryContext,
        table: &'a str,
        query: &'a Query,
    ) -> AsyncRepoFuture<'a, Vec<StoredRow>>;
}

#[derive(Clone)]
pub struct AsyncMemoryRepo {
    inner: Arc<Mutex<MemoryRepo>>,
    adapter_kind: AdapterKind,
}

impl AsyncMemoryRepo {
    pub fn new(repo: MemoryRepo) -> Self {
        let adapter_kind = repo.adapter_kind();
        Self {
            inner: Arc::new(Mutex::new(repo)),
            adapter_kind,
        }
    }
}

impl AsyncRepo for AsyncMemoryRepo {
    fn adapter_kind(&self) -> AdapterKind {
        self.adapter_kind
    }

    fn insert<'a>(
        &'a self,
        context: &'a AsyncQueryContext,
        table: &'a str,
        data: Row,
    ) -> AsyncRepoFuture<'a, StoredRow> {
        Box::pin(async move {
            context.ensure_active()?;
            let result = self
                .inner
                .lock()
                .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
                .insert(table, data);
            context.ensure_active()?;
            result
        })
    }

    fn update<'a>(
        &'a self,
        context: &'a AsyncQueryContext,
        table: &'a str,
        id: u64,
        data: Row,
    ) -> AsyncRepoFuture<'a, StoredRow> {
        Box::pin(async move {
            context.ensure_active()?;
            let result = self
                .inner
                .lock()
                .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
                .update(table, id, data);
            context.ensure_active()?;
            result
        })
    }

    fn delete<'a>(
        &'a self,
        context: &'a AsyncQueryContext,
        table: &'a str,
        id: u64,
    ) -> AsyncRepoFuture<'a, ()> {
        Box::pin(async move {
            context.ensure_active()?;
            let result = self
                .inner
                .lock()
                .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
                .delete(table, id);
            context.ensure_active()?;
            result
        })
    }

    fn find<'a>(
        &'a self,
        context: &'a AsyncQueryContext,
        table: &'a str,
        id: u64,
    ) -> AsyncRepoFuture<'a, Option<StoredRow>> {
        Box::pin(async move {
            context.ensure_active()?;
            let result = self
                .inner
                .lock()
                .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
                .find(table, id);
            context.ensure_active()?;
            result
        })
    }

    fn list<'a>(
        &'a self,
        context: &'a AsyncQueryContext,
        table: &'a str,
        query: &'a Query,
    ) -> AsyncRepoFuture<'a, Vec<StoredRow>> {
        Box::pin(async move {
            context.ensure_active()?;
            let result = self
                .inner
                .lock()
                .map_err(|_| DataError::Query("async memory repo lock poisoned".to_string()))?
                .list(table, query);
            context.ensure_active()?;
            result
        })
    }
}

#[cfg(test)]
mod tests {
    use super::{AsyncCancellationToken, AsyncMemoryRepo, AsyncQueryContext, AsyncRepo};
    use crate::{adapter_for, AdapterKind, DatabaseConfig, MemoryRepo, Query, Row};
    use serde_json::json;
    use std::{
        panic::{catch_unwind, AssertUnwindSafe},
        time::{Duration, Instant},
    };

    #[tokio::test]
    async fn async_memory_repo_supports_insert_and_list() {
        let repo = MemoryRepo::new(
            adapter_for(&DatabaseConfig {
                adapter: AdapterKind::SingleStore,
                url: None,
                url_env: None,
            })
            .expect("driver"),
        );
        let async_repo = AsyncMemoryRepo::new(repo);
        let mut row = Row::new();
        row.insert("account".to_string(), json!("Acme"));
        let context = AsyncQueryContext::default().with_timeout_ms(500);
        async_repo
            .insert(&context, "accounts", row)
            .await
            .expect("insert");
        let rows = async_repo
            .list(&context, "accounts", &Query::new())
            .await
            .expect("list");
        assert_eq!(rows.len(), 1);
    }

    #[tokio::test]
    async fn async_query_context_cancellation_stops_query() {
        let repo = MemoryRepo::new(
            adapter_for(&DatabaseConfig {
                adapter: AdapterKind::Postgres,
                url: None,
                url_env: None,
            })
            .expect("driver"),
        );
        let async_repo = AsyncMemoryRepo::new(repo);
        let token = AsyncCancellationToken::default();
        token.cancel();
        let context = AsyncQueryContext::default().with_cancellation_token(token);
        let result = async_repo.list(&context, "accounts", &Query::new()).await;
        assert!(result.is_err());
    }

    #[tokio::test]
    async fn async_query_context_deadline_exceeded_stops_query() {
        let repo = MemoryRepo::new(
            adapter_for(&DatabaseConfig {
                adapter: AdapterKind::MySql,
                url: None,
                url_env: None,
            })
            .expect("driver"),
        );
        let async_repo = AsyncMemoryRepo::new(repo);
        let context = AsyncQueryContext::default().with_timeout_ms(1);
        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
        let result = async_repo.find(&context, "accounts", 1).await;
        assert!(result.is_err());
    }

    #[tokio::test]
    async fn async_query_context_reset_and_explicit_deadline_paths_are_exercised() {
        let token = AsyncCancellationToken::default();
        token.cancel();
        assert!(token.is_cancelled());
        token.reset();
        assert!(!token.is_cancelled());

        let past_deadline = Instant::now()
            .checked_sub(Duration::from_millis(1))
            .expect("past instant");
        let context = AsyncQueryContext::default()
            .with_deadline(past_deadline)
            .with_cancellation_token(token);
        let result = context.ensure_active();
        assert!(result.is_err());
    }

    #[tokio::test]
    async fn async_memory_repo_supports_update_and_delete() {
        let repo = MemoryRepo::new(
            adapter_for(&DatabaseConfig {
                adapter: AdapterKind::ClickHouse,
                url: None,
                url_env: None,
            })
            .expect("driver"),
        );
        let async_repo = AsyncMemoryRepo::new(repo);
        let context = AsyncQueryContext::default().with_timeout_ms(500);

        let mut initial = Row::new();
        initial.insert("name".to_string(), json!("Draft"));
        let inserted = async_repo
            .insert(&context, "accounts", initial)
            .await
            .expect("insert");

        let mut updated = Row::new();
        updated.insert("name".to_string(), json!("Published"));
        let updated_row = async_repo
            .update(&context, "accounts", inserted.id, updated)
            .await
            .expect("update");
        assert_eq!(updated_row.data.get("name"), Some(&json!("Published")));

        async_repo
            .delete(&context, "accounts", inserted.id)
            .await
            .expect("delete");
        let found = async_repo
            .find(&context, "accounts", inserted.id)
            .await
            .expect("find after delete");
        assert!(found.is_none());
    }

    #[tokio::test]
    async fn async_memory_repo_reports_lock_poisoned_errors_and_adapter_kind() {
        let repo = MemoryRepo::new(
            adapter_for(&DatabaseConfig {
                adapter: AdapterKind::OpenSearch,
                url: None,
                url_env: None,
            })
            .expect("driver"),
        );
        let async_repo = AsyncMemoryRepo::new(repo);
        assert_eq!(async_repo.adapter_kind(), AdapterKind::OpenSearch);

        let _ = catch_unwind(AssertUnwindSafe(|| {
            let _guard = async_repo.inner.lock().expect("repo lock");
            panic!("poison async memory repo lock");
        }));

        let context = AsyncQueryContext::default().with_timeout_ms(100);
        let err = async_repo.list(&context, "accounts", &Query::new()).await;
        assert!(err
            .unwrap_err()
            .to_string()
            .contains("async memory repo lock poisoned"));
    }
}