use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use uuid::Uuid;
use crate::error::{ClawError, ClawResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextRecord {
pub id: Uuid,
pub session_id: String,
pub key: String,
pub value: serde_json::Value,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
#[derive(Debug)]
pub struct ContextStore<'a> {
pool: &'a SqlitePool,
}
impl<'a> ContextStore<'a> {
pub fn new(pool: &'a SqlitePool) -> Self {
ContextStore { pool }
}
pub async fn insert(&self, record: &ContextRecord) -> ClawResult<()> {
sqlx::query(
r#"
INSERT INTO context (id, session_id, key, value, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?)
"#,
)
.bind(record.id.to_string())
.bind(&record.session_id)
.bind(&record.key)
.bind(serde_json::to_string(&record.value)?)
.bind(record.created_at.to_rfc3339())
.bind(record.expires_at.to_rfc3339())
.execute(self.pool)
.await?;
Ok(())
}
pub async fn purge_expired(&self) -> ClawResult<u64> {
let now = Utc::now().to_rfc3339();
let result = sqlx::query("DELETE FROM context WHERE expires_at <= ?")
.bind(now)
.execute(self.pool)
.await?;
Ok(result.rows_affected())
}
pub async fn get_active(&self, session_id: &str) -> ClawResult<Vec<ContextRecord>> {
let now = Utc::now().to_rfc3339();
let rows = sqlx::query_as::<_, (String, String, String, String, String, String)>(
"SELECT id, session_id, key, value, created_at, expires_at \
FROM context WHERE session_id = ? AND expires_at > ? ORDER BY created_at ASC",
)
.bind(session_id)
.bind(now)
.fetch_all(self.pool)
.await?;
rows.into_iter()
.map(|(id, session_id, key, value, created_at, expires_at)| {
Ok(ContextRecord {
id: Uuid::parse_str(&id).map_err(|e| ClawError::Store(e.to_string()))?,
session_id,
key,
value: serde_json::from_str(&value)?,
created_at: DateTime::parse_from_rfc3339(&created_at)
.map_err(|e| ClawError::Store(e.to_string()))?
.with_timezone(&Utc),
expires_at: DateTime::parse_from_rfc3339(&expires_at)
.map_err(|e| ClawError::Store(e.to_string()))?
.with_timezone(&Utc),
})
})
.collect()
}
}