use anyhow::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use std::sync::RwLock;
use crate::case::Case;
use crate::session::Session;
#[async_trait]
pub trait CaseStore: Send + Sync {
async fn upsert(&self, case: &Case) -> Result<()>;
async fn get_by_key(&self, case_key: &str) -> Result<Option<Case>>;
async fn get_by_session(&self, session_id: &str) -> Result<Vec<Case>>;
async fn setup(&self) -> Result<()> {
Ok(())
}
}
#[async_trait]
pub trait SessionStore: Send + Sync {
async fn upsert(&self, session: &Session) -> Result<()>;
async fn get(&self, session_id: &str) -> Result<Option<Session>>;
async fn delete(&self, session_id: &str) -> Result<()>;
async fn setup(&self) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct StateEntry {
pub case_key: String,
pub step: String,
pub data: String,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct SessionStateEntry {
pub session_id: String,
pub step: String,
pub data: String,
pub updated_at: DateTime<Utc>,
}
#[async_trait]
pub trait StateStore: Send + Sync {
async fn save(&self, case_key: &str, step: &str, data: &str) -> Result<()>;
async fn get(&self, case_key: &str, step: &str) -> Result<Option<StateEntry>>;
async fn get_all(&self, case_key: &str) -> Result<Vec<StateEntry>>;
async fn delete_by_case(&self, case_key: &str) -> Result<()>;
async fn save_session(&self, _session_id: &str, _step: &str, _data: &str) -> Result<()> {
Err(anyhow::anyhow!(
"session-scoped state not supported by this store"
))
}
async fn get_session(
&self,
_session_id: &str,
_step: &str,
) -> Result<Option<SessionStateEntry>> {
Err(anyhow::anyhow!(
"session-scoped state not supported by this store"
))
}
async fn get_all_session(&self, _session_id: &str) -> Result<Vec<SessionStateEntry>> {
Err(anyhow::anyhow!(
"session-scoped state not supported by this store"
))
}
async fn delete_by_session(&self, _session_id: &str) -> Result<()> {
Err(anyhow::anyhow!(
"session-scoped state not supported by this store"
))
}
async fn setup(&self) -> Result<()> {
Ok(())
}
}
#[derive(Default)]
pub struct InMemoryCaseStore {
cases: RwLock<HashMap<String, Case>>,
}
#[async_trait]
impl CaseStore for InMemoryCaseStore {
async fn upsert(&self, case: &Case) -> Result<()> {
let mut guard = self.cases.write().unwrap();
guard.insert(case.case_key.clone(), case.clone());
Ok(())
}
async fn get_by_key(&self, case_key: &str) -> Result<Option<Case>> {
let guard = self.cases.read().unwrap();
Ok(guard.get(case_key).cloned())
}
async fn get_by_session(&self, session_id: &str) -> Result<Vec<Case>> {
let guard = self.cases.read().unwrap();
let result = guard
.values()
.filter(|c| c.session_id == session_id)
.cloned()
.collect();
Ok(result)
}
}
#[derive(Default)]
pub struct InMemorySessionStore {
sessions: RwLock<HashMap<String, Session>>,
}
#[async_trait]
impl SessionStore for InMemorySessionStore {
async fn upsert(&self, session: &Session) -> Result<()> {
let mut guard = self.sessions.write().unwrap();
guard.insert(session.session_id.clone(), session.clone());
Ok(())
}
async fn get(&self, session_id: &str) -> Result<Option<Session>> {
let guard = self.sessions.read().unwrap();
Ok(guard.get(session_id).cloned())
}
async fn delete(&self, session_id: &str) -> Result<()> {
let mut guard = self.sessions.write().unwrap();
guard.remove(session_id);
Ok(())
}
}
#[derive(Default)]
pub struct InMemoryStateStore {
entries: RwLock<HashMap<(String, String), StateEntry>>,
session_entries: RwLock<HashMap<(String, String), SessionStateEntry>>,
}
#[async_trait]
impl StateStore for InMemoryStateStore {
async fn save(&self, case_key: &str, step: &str, data: &str) -> Result<()> {
let mut guard = self.entries.write().unwrap();
guard.insert(
(case_key.to_string(), step.to_string()),
StateEntry {
case_key: case_key.to_string(),
step: step.to_string(),
data: data.to_string(),
updated_at: Utc::now(),
},
);
Ok(())
}
async fn get(&self, case_key: &str, step: &str) -> Result<Option<StateEntry>> {
let guard = self.entries.read().unwrap();
Ok(guard
.get(&(case_key.to_string(), step.to_string()))
.cloned())
}
async fn get_all(&self, case_key: &str) -> Result<Vec<StateEntry>> {
let guard = self.entries.read().unwrap();
let result = guard
.values()
.filter(|e| e.case_key == case_key)
.cloned()
.collect();
Ok(result)
}
async fn delete_by_case(&self, case_key: &str) -> Result<()> {
let mut guard = self.entries.write().unwrap();
guard.retain(|(ck, _), _| ck != case_key);
Ok(())
}
async fn save_session(&self, session_id: &str, step: &str, data: &str) -> Result<()> {
let mut guard = self.session_entries.write().unwrap();
guard.insert(
(session_id.to_string(), step.to_string()),
SessionStateEntry {
session_id: session_id.to_string(),
step: step.to_string(),
data: data.to_string(),
updated_at: Utc::now(),
},
);
Ok(())
}
async fn get_session(&self, session_id: &str, step: &str) -> Result<Option<SessionStateEntry>> {
let guard = self.session_entries.read().unwrap();
Ok(guard
.get(&(session_id.to_string(), step.to_string()))
.cloned())
}
async fn get_all_session(&self, session_id: &str) -> Result<Vec<SessionStateEntry>> {
let guard = self.session_entries.read().unwrap();
let result = guard
.values()
.filter(|e| e.session_id == session_id)
.cloned()
.collect();
Ok(result)
}
async fn delete_by_session(&self, session_id: &str) -> Result<()> {
let mut guard = self.session_entries.write().unwrap();
guard.retain(|(sid, _), _| sid != session_id);
Ok(())
}
}