use crate::flow::state::{InMemoryState, SessionState};
use async_trait::async_trait;
use moka::future::Cache;
use std::fmt::Debug;
use std::sync::Arc;
use std::time::Duration;
use tracing::info;
pub type SessionStore = Arc<dyn SessionStoreType>;
#[async_trait]
pub trait SessionStoreType: Send + Sync + Debug {
async fn get(&self, session_id: &str) -> Option<SessionState>;
async fn get_or_create(&self, session_id: &str) -> SessionState;
async fn get_channel(&self, channel_key: &str) -> Option<String>;
async fn get_or_create_channel(&self, channel_key: &str) -> String;
async fn remove(&self, session_id: &str);
fn clear(&self);
}
#[derive(Clone, Debug)]
pub struct InMemorySessionStore {
cache: Cache<String, Arc<InMemoryState>>, by_channel: Cache<String, String>, reverse_map: Cache<String, String>, }
impl InMemorySessionStore {
pub fn new(ttl_secs: u64) -> Arc<Self> {
let cache = Cache::builder()
.time_to_idle(Duration::from_secs(ttl_secs))
.eviction_listener(|key: Arc<String>, _value: Arc<InMemoryState>, cause| {
info!("Session expired: key={}, cause={:?}", key, cause,);
})
.build();
let by_channel = Cache::builder()
.time_to_idle(Duration::from_secs(ttl_secs))
.build();
let reverse_map = Cache::builder()
.time_to_idle(Duration::from_secs(ttl_secs))
.build();
Arc::new(Self {
cache,
by_channel,
reverse_map,
})
}
}
#[async_trait]
impl SessionStoreType for InMemorySessionStore {
async fn remove(&self, session_id: &str) {
self.cache.invalidate(session_id).await;
if let Some(channle_pair) = self.reverse_map.get(session_id).await {
self.by_channel.invalidate(&channle_pair).await;
self.reverse_map.invalidate(session_id).await;
}
}
fn clear(&self) {
self.cache.invalidate_all();
self.by_channel.invalidate_all();
}
async fn get_channel(&self, channel_key: &str) -> Option<String> {
self.by_channel.get(channel_key).await
}
async fn get_or_create_channel(&self, channel_key: &str) -> String {
if let Some(entry) = self.get_channel(channel_key).await {
return entry;
} else {
let session_id = uuid::Uuid::new_v4().to_string();
self.by_channel
.insert(channel_key.to_string(), session_id.clone())
.await;
self.reverse_map
.insert(session_id.clone(), channel_key.to_string())
.await;
session_id
}
}
async fn get(&self, session_id: &str) -> Option<SessionState> {
let key = session_id.to_string();
let result = match self.cache.get(&key).await {
Some(state) => {
let state: Option<SessionState> = Some(state);
state
}
None => None,
};
result
}
async fn get_or_create(&self, session_id: &str) -> SessionState {
let key = session_id.to_string();
let result = match self.cache.get(&key).await {
Some(state) => state as SessionState,
None => {
let new_state = InMemoryState::new();
self.cache.insert(key, new_state.clone()).await;
new_state
}
};
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::flow::state::StateValue;
#[tokio::test]
async fn test_session_store_create_and_retrieve() {
let store = InMemorySessionStore::new(60);
let session_id = "abc123";
let session = store.get_or_create(session_id).await;
session.set("foo".to_string(), StateValue::String("bar".into()));
let session2 = store.get_or_create(session_id).await;
let val = session2.get("foo");
assert_eq!(val, Some(StateValue::String("bar".into())));
}
#[tokio::test]
async fn test_session_store_removal() {
let store = InMemorySessionStore::new(60);
let session_id = "abc123";
let session = store.get_or_create(session_id).await;
session.set("foo".to_string(), StateValue::String("bar".into()));
store.remove(session_id).await;
let session2 = store.get_or_create(session_id).await;
let val = session2.get("foo");
assert_eq!(val, None); }
#[tokio::test]
async fn test_get_channel_none() {
let store = InMemorySessionStore::new(60);
let result = store.get_channel("telegram|chat_123").await;
assert!(result.is_none(), "No session should exist yet");
}
#[tokio::test]
async fn test_get_or_create_channel_creates() {
let store = InMemorySessionStore::new(60);
let sid1 = store.get_or_create_channel("telegram|chat_123").await;
let sid2 = store.get_channel("telegram|chat_123").await;
assert_eq!(Some(sid1.clone()), sid2);
store.remove(&sid1).await;
assert!(store.get_channel("telegram|chat_123").await.is_none());
}
#[tokio::test]
async fn test_clear_sessions() {
let store = InMemorySessionStore::new(60);
let session1 = store.get_or_create("session1").await;
session1.set("foo".to_string(), StateValue::String("bar".into()));
store.clear();
let new1 = store.get_or_create("session1").await;
assert_eq!(new1.get("foo"), None); }
#[tokio::test]
async fn test_reverse_index_cleanup() {
let store = InMemorySessionStore::new(60);
let sid = store.get_or_create_channel("telegram|chat_999").await;
store.remove(&sid).await;
assert!(store.get_channel("telegram|chat_999").await.is_none());
let expected_none = store.reverse_map.get::<str>(sid.as_ref()).await;
assert!(expected_none.is_none());
}
}