#[cfg(feature = "dynamodb")]
use std::sync::Arc;
#[cfg(feature = "dynamodb")]
use aws_sdk_dynamodb::types::AttributeValue;
#[cfg(feature = "dynamodb")]
use aws_sdk_dynamodb::Client as DynamoDbClient;
#[cfg(feature = "dynamodb")]
use tracing::{debug, info};
#[cfg(feature = "dynamodb")]
use crate::error::AgentKitError;
#[cfg(feature = "dynamodb")]
use super::types::SessionRecord;
use std::collections::HashMap;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::error::Result;
use super::types::SessionItem;
#[derive(Default)]
pub struct InMemorySessionService {
sessions: Mutex<HashMap<String, SessionItem>>,
}
impl InMemorySessionService {
pub fn new() -> Self {
Self::default()
}
pub async fn create_session(
&self,
app_name: &str,
user_id: &str,
session_id: Option<String>,
) -> Result<SessionItem> {
let id = session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
let item = SessionItem {
id: id.clone(),
app_name: app_name.to_string(),
user_id: user_id.to_string(),
state: HashMap::new(),
event: None,
last_update_time: now,
};
self.sessions.lock().await.insert(id, item.clone());
Ok(item)
}
pub async fn get_session(&self, session_id: &str) -> Option<SessionItem> {
self.sessions.lock().await.get(session_id).cloned()
}
pub async fn upsert_session(
&self,
app_name: &str,
user_id: &str,
session_id: &str,
) -> Result<SessionItem> {
if let Some(s) = self.get_session(session_id).await {
return Ok(s);
}
self.create_session(app_name, user_id, Some(session_id.to_string())).await
}
pub async fn append_event(
&self,
session_id: &str,
event: serde_json::Value,
state: HashMap<String, serde_json::Value>,
) -> Result<()> {
let mut sessions = self.sessions.lock().await;
if let Some(s) = sessions.get_mut(session_id) {
s.event = Some(event);
s.state = state;
s.last_update_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
}
Ok(())
}
}
#[cfg(feature = "dynamodb")]
pub struct DynamoDbSessionService {
client: DynamoDbClient,
table_name: String,
cache: Mutex<HashMap<String, SessionItem>>,
session_ttl_secs: i64,
}
#[cfg(feature = "dynamodb")]
impl DynamoDbSessionService {
pub async fn new(
table_name: impl Into<String>,
endpoint_url: Option<String>,
region: Option<String>,
) -> Result<Arc<Self>> {
let mut builder = aws_config::defaults(aws_config::BehaviorVersion::latest());
if let Some(region) = region {
builder = builder.region(aws_config::Region::new(region));
}
let sdk_config = builder.load().await;
let mut dynamo_builder = aws_sdk_dynamodb::config::Builder::from(&sdk_config);
if let Some(ep) = endpoint_url {
dynamo_builder = dynamo_builder.endpoint_url(ep);
}
let client = DynamoDbClient::from_conf(dynamo_builder.build());
let table_name: String = table_name.into();
let ttl = std::env::var("SESSION_EXPIRES_IN")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(3600i64);
info!("DynamoDB session service initialised (table: {})", table_name);
Ok(Arc::new(Self {
client,
table_name,
cache: Mutex::new(HashMap::new()),
session_ttl_secs: ttl,
}))
}
pub async fn create_session(
&self,
app_name: &str,
user_id: &str,
session_id: Option<String>,
) -> Result<SessionItem> {
let id = session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
let item = SessionItem {
id: id.clone(),
app_name: app_name.to_string(),
user_id: user_id.to_string(),
state: HashMap::new(),
event: None,
last_update_time: now,
};
self.cache.lock().await.insert(id, item.clone());
debug!("Session {} cached in memory (not yet persisted)", &item.id);
Ok(item)
}
pub async fn get_session(&self, app_name: &str, session_id: &str) -> Result<Option<SessionItem>> {
if let Some(s) = self.cache.lock().await.get(session_id) {
debug!("Session {} found in memory cache", session_id);
return Ok(Some(s.clone()));
}
let pk = format!("{app_name}#{session_id}");
let result = self
.client
.query()
.table_name(&self.table_name)
.key_condition_expression("pk = :pk")
.expression_attribute_values(":pk", AttributeValue::S(pk))
.send()
.await
.map_err(|e| AgentKitError::DynamoDB(e.to_string()))?;
let items = result.items.unwrap_or_default();
if items.is_empty() {
return Ok(None);
}
let mut last_update = 0.0f64;
let mut session_state: HashMap<String, serde_json::Value> = HashMap::new();
let mut last_item: Option<SessionItem> = None;
for row in &items {
let record_json = serde_json::to_value(
row.iter()
.map(|(k, v)| (k.clone(), attribute_value_to_json(v)))
.collect::<HashMap<_, _>>(),
)
.unwrap_or_default();
if let Ok(record) = serde_json::from_value::<SessionRecord>(record_json) {
let ev = &record.event;
if ev.last_update_time > last_update {
last_update = ev.last_update_time;
session_state = ev.state.clone();
}
last_item = Some(ev.clone());
}
}
let mut session = last_item.unwrap_or_else(|| SessionItem {
id: session_id.to_string(),
app_name: app_name.to_string(),
user_id: String::new(),
state: HashMap::new(),
event: None,
last_update_time: last_update,
});
session.state = session_state;
Ok(Some(session))
}
pub async fn upsert_session(
&self,
app_name: &str,
user_id: &str,
session_id: &str,
) -> Result<SessionItem> {
match self.get_session(app_name, session_id).await? {
Some(s) => Ok(s),
None => self.create_session(app_name, user_id, Some(session_id.to_string())).await,
}
}
pub async fn append_event(
&self,
session: &SessionItem,
event: serde_json::Value,
state: HashMap<String, serde_json::Value>,
) -> Result<()> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
let timestamp = now.as_secs_f64();
let expires_at = now.as_secs() as i64 + self.session_ttl_secs;
let record = SessionRecord {
pk: format!("{}#{}", session.app_name, session.id),
sk: format!("{}", (timestamp * 1000.0) as u64),
event: SessionItem {
event: Some(event),
state,
last_update_time: timestamp,
..session.clone()
},
user_id: session.user_id.clone(),
expires_at,
};
let item_map = dynamo_item_from_record(&record)?;
self.client
.put_item()
.table_name(&self.table_name)
.set_item(Some(item_map))
.send()
.await
.map_err(|e| AgentKitError::DynamoDB(e.to_string()))?;
self.cache.lock().await.remove(&session.id);
debug!("Session {} persisted to DynamoDB and evicted from cache", session.id);
Ok(())
}
pub async fn delete_session(&self, app_name: &str, session_id: &str) -> Result<()> {
let pk = format!("{app_name}#{session_id}");
self.client
.delete_item()
.table_name(&self.table_name)
.key("pk", AttributeValue::S(pk))
.send()
.await
.map_err(|e| AgentKitError::DynamoDB(e.to_string()))?;
self.cache.lock().await.remove(session_id);
Ok(())
}
}
#[cfg(feature = "dynamodb")]
fn attribute_value_to_json(av: &AttributeValue) -> serde_json::Value {
match av {
AttributeValue::S(s) => serde_json::Value::String(s.clone()),
AttributeValue::N(n) => {
if let Ok(i) = n.parse::<i64>() {
serde_json::json!(i)
} else if let Ok(f) = n.parse::<f64>() {
serde_json::json!(f)
} else {
serde_json::Value::String(n.clone())
}
}
AttributeValue::Bool(b) => serde_json::Value::Bool(*b),
AttributeValue::Null(_) => serde_json::Value::Null,
AttributeValue::L(list) => {
serde_json::Value::Array(list.iter().map(attribute_value_to_json).collect())
}
AttributeValue::M(map) => serde_json::Value::Object(
map.iter()
.map(|(k, v)| (k.clone(), attribute_value_to_json(v)))
.collect(),
),
_ => serde_json::Value::Null,
}
}
#[cfg(feature = "dynamodb")]
fn dynamo_item_from_record(
record: &SessionRecord,
) -> Result<HashMap<String, AttributeValue>> {
let json = serde_json::to_value(record)?;
let obj = json.as_object().ok_or_else(|| AgentKitError::DynamoDB("Serialization error".into()))?;
Ok(obj.iter().map(|(k, v)| (k.clone(), json_to_attribute_value(v))).collect())
}
#[cfg(feature = "dynamodb")]
fn json_to_attribute_value(v: &serde_json::Value) -> AttributeValue {
match v {
serde_json::Value::String(s) => AttributeValue::S(s.clone()),
serde_json::Value::Number(n) => AttributeValue::N(n.to_string()),
serde_json::Value::Bool(b) => AttributeValue::Bool(*b),
serde_json::Value::Null => AttributeValue::Null(true),
serde_json::Value::Array(arr) => {
AttributeValue::L(arr.iter().map(json_to_attribute_value).collect())
}
serde_json::Value::Object(map) => AttributeValue::M(
map.iter()
.map(|(k, v)| (k.clone(), json_to_attribute_value(v)))
.collect(),
),
}
}