potato-agent 0.23.0

Potato brands
Documentation
use super::{user_state_store::UserStateStore, validate_db_path, StoreError};
use crate::agents::session::SessionSnapshot;
use async_trait::async_trait;
use sqlx::{Pool, Sqlite, SqlitePool};
use std::sync::Arc;

#[derive(Debug, Clone)]
pub struct SqliteUserStateStore {
    pool: Arc<Pool<Sqlite>>,
}

impl SqliteUserStateStore {
    pub async fn new(path: &str) -> Result<Self, StoreError> {
        let url = validate_db_path(path)?;
        let pool = SqlitePool::connect(&url)
            .await
            .map_err(|e| StoreError::Connection(e.to_string()))?;
        let store = Self {
            pool: Arc::new(pool),
        };
        store.init_tables().await?;
        Ok(store)
    }

    pub async fn in_memory() -> Result<Self, StoreError> {
        let pool = SqlitePool::connect("sqlite::memory:")
            .await
            .map_err(|e| StoreError::Connection(e.to_string()))?;
        let store = Self {
            pool: Arc::new(pool),
        };
        store.init_tables().await?;
        Ok(store)
    }

    async fn init_tables(&self) -> Result<(), StoreError> {
        sqlx::query(
            "CREATE TABLE IF NOT EXISTS user_state (
                app_name TEXT NOT NULL,
                user_id TEXT NOT NULL,
                state_json TEXT NOT NULL,
                updated_at TEXT NOT NULL,
                PRIMARY KEY (app_name, user_id)
            )",
        )
        .execute(self.pool.as_ref())
        .await
        .map_err(|e| StoreError::Backend(e.to_string()))?;

        Ok(())
    }
}

#[async_trait]
impl UserStateStore for SqliteUserStateStore {
    async fn load(
        &self,
        app_name: &str,
        user_id: &str,
    ) -> Result<Option<SessionSnapshot>, StoreError> {
        let result: Option<(String,)> =
            sqlx::query_as("SELECT state_json FROM user_state WHERE app_name = ? AND user_id = ?")
                .bind(app_name)
                .bind(user_id)
                .fetch_optional(self.pool.as_ref())
                .await
                .map_err(|e| StoreError::Backend(e.to_string()))?;

        match result {
            Some((json,)) => {
                let snapshot: SessionSnapshot = serde_json::from_str(&json)?;
                Ok(Some(snapshot))
            }
            None => Ok(None),
        }
    }

    async fn save(
        &self,
        app_name: &str,
        user_id: &str,
        snapshot: &SessionSnapshot,
    ) -> Result<(), StoreError> {
        let json = serde_json::to_string(snapshot)?;
        let now = chrono::Utc::now().to_rfc3339();

        sqlx::query(
            "INSERT OR REPLACE INTO user_state (app_name, user_id, state_json, updated_at)
             VALUES (?, ?, ?, ?)",
        )
        .bind(app_name)
        .bind(user_id)
        .bind(&json)
        .bind(&now)
        .execute(self.pool.as_ref())
        .await
        .map_err(|e| StoreError::Backend(e.to_string()))?;

        Ok(())
    }

    async fn delete(&self, app_name: &str, user_id: &str) -> Result<(), StoreError> {
        sqlx::query("DELETE FROM user_state WHERE app_name = ? AND user_id = ?")
            .bind(app_name)
            .bind(user_id)
            .execute(self.pool.as_ref())
            .await
            .map_err(|e| StoreError::Backend(e.to_string()))?;

        Ok(())
    }
}