use async_trait::async_trait;
use indexmap::IndexSet;
#[allow(unused_imports)] use meerkat_core::Session;
use meerkat_core::service::{
CreateSessionRequest, SessionError, SessionInfo, SessionQuery, SessionService, SessionSummary,
SessionUsage, SessionView, StartTurnRequest,
};
use meerkat_core::types::{RunResult, SessionId};
use meerkat_store::SessionStore;
use std::sync::Arc;
use crate::ephemeral::{EphemeralSessionService, SessionAgentBuilder};
pub struct PersistentSessionService<B: SessionAgentBuilder> {
inner: EphemeralSessionService<B>,
store: Arc<dyn SessionStore>,
}
impl<B: SessionAgentBuilder + 'static> PersistentSessionService<B> {
pub fn new(builder: B, max_sessions: usize, store: Arc<dyn SessionStore>) -> Self {
Self {
inner: EphemeralSessionService::new(builder, max_sessions),
store,
}
}
}
#[async_trait]
impl<B: SessionAgentBuilder + 'static> SessionService for PersistentSessionService<B> {
async fn create_session(&self, req: CreateSessionRequest) -> Result<RunResult, SessionError> {
let result = self.inner.create_session(req).await?;
self.persist_full_session(&result.session_id).await?;
Ok(result)
}
async fn start_turn(
&self,
id: &SessionId,
req: StartTurnRequest,
) -> Result<RunResult, SessionError> {
let result = self.inner.start_turn(id, req).await?;
self.persist_full_session(id).await?;
Ok(result)
}
async fn interrupt(&self, id: &SessionId) -> Result<(), SessionError> {
self.inner.interrupt(id).await
}
async fn read(&self, id: &SessionId) -> Result<SessionView, SessionError> {
match self.inner.read(id).await {
Ok(view) => Ok(view),
Err(SessionError::NotFound { .. }) => {
let session = self
.store
.load(id)
.await
.map_err(|e| SessionError::Store(Box::new(e)))?
.ok_or_else(|| SessionError::NotFound { id: id.clone() })?;
Ok(SessionView {
state: SessionInfo {
session_id: session.id().clone(),
created_at: session.created_at(),
updated_at: session.updated_at(),
message_count: session.messages().len(),
is_active: false,
last_assistant_text: session.last_assistant_text(),
},
billing: SessionUsage {
total_tokens: session.total_tokens(),
usage: session.total_usage(),
},
})
}
Err(e) => Err(e),
}
}
async fn list(&self, query: SessionQuery) -> Result<Vec<SessionSummary>, SessionError> {
let mut summaries = self.inner.list(SessionQuery::default()).await?;
let live_ids: IndexSet<_> = summaries.iter().map(|s| s.session_id.clone()).collect();
let stored = self
.store
.list(meerkat_store::SessionFilter::default())
.await
.map_err(|e| SessionError::Store(Box::new(e)))?;
for meta in stored {
if !live_ids.contains(&meta.id) {
summaries.push(SessionSummary {
session_id: meta.id,
created_at: meta.created_at,
updated_at: meta.updated_at,
message_count: meta.message_count,
total_tokens: meta.total_tokens,
is_active: false,
});
}
}
if let Some(offset) = query.offset {
if offset < summaries.len() {
summaries = summaries.split_off(offset);
} else {
summaries.clear();
}
}
if let Some(limit) = query.limit {
summaries.truncate(limit);
}
Ok(summaries)
}
async fn archive(&self, id: &SessionId) -> Result<(), SessionError> {
let live_result = self.inner.archive(id).await;
let in_store = self
.store
.exists(id)
.await
.map_err(|e| SessionError::Store(Box::new(e)))?;
if in_store {
self.store
.delete(id)
.await
.map_err(|e| SessionError::Store(Box::new(e)))?;
}
match (&live_result, in_store) {
(Ok(()), _) | (_, true) => Ok(()),
_ => live_result,
}
}
}
impl<B: SessionAgentBuilder + 'static> PersistentSessionService<B> {
pub async fn event_injector(
&self,
session_id: &SessionId,
) -> Option<std::sync::Arc<dyn meerkat_core::SubscribableInjector>> {
self.inner.event_injector(session_id).await
}
pub async fn comms_runtime(
&self,
session_id: &SessionId,
) -> Option<std::sync::Arc<dyn meerkat_core::agent::CommsRuntime>> {
self.inner.comms_runtime(session_id).await
}
pub async fn wait_session_registered(&self) {
self.inner.wait_session_registered().await;
}
pub async fn shutdown(&self) {
self.inner.shutdown().await;
}
pub async fn load_persisted(&self, id: &SessionId) -> Result<Option<Session>, SessionError> {
self.store
.load(id)
.await
.map_err(|e| SessionError::Store(Box::new(e)))
}
async fn persist_full_session(&self, id: &SessionId) -> Result<(), SessionError> {
let session = self.inner.export_session(id).await?;
self.store
.save(&session)
.await
.map_err(|e| SessionError::Store(Box::new(e)))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use meerkat_store::MemoryStore;
#[tokio::test]
async fn test_persistent_load_persisted_returns_stored_session() {
let store: Arc<dyn SessionStore> = Arc::new(MemoryStore::new());
let session = Session::new();
let id = session.id().clone();
store.save(&session).await.unwrap();
let loaded = store.load(&id).await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().id(), &id);
}
#[tokio::test]
async fn test_persistent_load_persisted_returns_none_for_unknown() {
let store: Arc<dyn SessionStore> = Arc::new(MemoryStore::new());
let unknown = SessionId::new();
let loaded = store.load(&unknown).await.unwrap();
assert!(loaded.is_none());
}
#[tokio::test]
async fn test_persistent_archive_deletes_from_store() {
let store: Arc<dyn SessionStore> = Arc::new(MemoryStore::new());
let session = Session::new();
let id = session.id().clone();
store.save(&session).await.unwrap();
assert!(store.load(&id).await.unwrap().is_some());
store.delete(&id).await.unwrap();
assert!(store.load(&id).await.unwrap().is_none());
}
#[tokio::test]
async fn test_persistent_archive_store_only_session_succeeds() {
let store: Arc<dyn SessionStore> = Arc::new(MemoryStore::new());
let session = Session::new();
let id = session.id().clone();
store.save(&session).await.unwrap();
assert!(store.load(&id).await.unwrap().is_some());
store.delete(&id).await.unwrap();
assert!(store.load(&id).await.unwrap().is_none());
}
}