use chrono::Utc;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::error::SessionError;
use crate::model::{DataSensitivity, SessionId, SessionStatus, TaskSession};
pub struct CreateSessionRequest {
pub agent_id: Uuid,
pub delegation_chain_snapshot: Vec<String>,
pub declared_intent: String,
pub authorized_tools: Vec<String>,
#[allow(dead_code)]
pub authorized_credentials: Vec<String>,
pub time_limit: chrono::Duration,
pub call_budget: u64,
pub rate_limit_per_minute: Option<u64>,
pub rate_limit_window_secs: u64,
pub data_sensitivity_ceiling: DataSensitivity,
}
#[derive(Clone)]
pub struct SessionStore {
sessions: Arc<RwLock<HashMap<SessionId, TaskSession>>>,
}
impl SessionStore {
pub fn new() -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn create(&self, req: CreateSessionRequest) -> TaskSession {
let time_limit = if req.time_limit < chrono::Duration::seconds(1) {
tracing::warn!(
requested = ?req.time_limit,
"session time_limit below minimum, clamping to 1 second"
);
chrono::Duration::seconds(1)
} else {
req.time_limit
};
let session = TaskSession {
session_id: Uuid::new_v4(),
agent_id: req.agent_id,
delegation_chain_snapshot: req.delegation_chain_snapshot,
declared_intent: req.declared_intent,
authorized_tools: req.authorized_tools,
authorized_credentials: req.authorized_credentials,
time_limit,
call_budget: req.call_budget,
calls_made: 0,
rate_limit_per_minute: req.rate_limit_per_minute,
rate_window_start: Utc::now(),
rate_window_calls: 0,
rate_limit_window_secs: req.rate_limit_window_secs,
data_sensitivity_ceiling: req.data_sensitivity_ceiling,
created_at: Utc::now(),
status: SessionStatus::Active,
};
tracing::info!(
session_id = %session.session_id,
agent_id = %session.agent_id,
intent = %session.declared_intent,
budget = session.call_budget,
"created task session"
);
let mut sessions = self.sessions.write().await;
sessions.insert(session.session_id, session.clone());
session
}
pub async fn create_if_under_cap(
&self,
req: CreateSessionRequest,
max_sessions: u64,
) -> Result<TaskSession, SessionError> {
let mut sessions = self.sessions.write().await;
let active_count = sessions
.values()
.filter(|s| s.agent_id == req.agent_id && s.status == SessionStatus::Active)
.count() as u64;
if active_count >= max_sessions {
return Err(SessionError::TooManySessions {
agent_id: req.agent_id.to_string(),
max: max_sessions,
current: active_count,
});
}
let session = TaskSession {
session_id: Uuid::new_v4(),
agent_id: req.agent_id,
delegation_chain_snapshot: req.delegation_chain_snapshot,
declared_intent: req.declared_intent,
authorized_tools: req.authorized_tools,
authorized_credentials: req.authorized_credentials,
time_limit: req.time_limit,
call_budget: req.call_budget,
calls_made: 0,
rate_limit_per_minute: req.rate_limit_per_minute,
rate_window_start: Utc::now(),
rate_window_calls: 0,
rate_limit_window_secs: req.rate_limit_window_secs,
data_sensitivity_ceiling: req.data_sensitivity_ceiling,
created_at: Utc::now(),
status: SessionStatus::Active,
};
sessions.insert(session.session_id, session.clone());
Ok(session)
}
pub async fn use_session(
&self,
session_id: SessionId,
tool_name: &str,
requesting_agent_id: Option<Uuid>,
) -> Result<TaskSession, SessionError> {
let mut sessions = self.sessions.write().await;
let session = sessions
.get_mut(&session_id)
.ok_or(SessionError::NotFound(session_id))?;
if let Some(agent_id) = requesting_agent_id
&& agent_id != session.agent_id
{
return Err(SessionError::AgentMismatch {
session_id,
expected: session.agent_id,
actual: agent_id,
});
}
if session.status == SessionStatus::Closed {
return Err(SessionError::AlreadyClosed(session_id));
}
if session.is_expired() {
session.status = SessionStatus::Expired;
return Err(SessionError::Expired(session_id));
}
if session.is_budget_exceeded() {
return Err(SessionError::BudgetExceeded {
session_id,
limit: session.call_budget,
used: session.calls_made,
});
}
if !session.is_tool_authorized(tool_name) {
return Err(SessionError::ToolNotAuthorized {
session_id,
tool: tool_name.into(),
});
}
if session.check_rate_limit() {
return Err(SessionError::RateLimited {
session_id,
limit_per_minute: session.rate_limit_per_minute.unwrap_or(0),
});
}
session.calls_made += 1;
tracing::debug!(
session_id = %session_id,
tool = tool_name,
calls = session.calls_made,
budget = session.call_budget,
"session tool call recorded"
);
Ok(session.clone())
}
pub async fn use_session_batch(
&self,
session_id: SessionId,
tool_names: &[&str],
requesting_agent_id: Option<Uuid>,
) -> Result<TaskSession, SessionError> {
let mut sessions = self.sessions.write().await;
let session = sessions
.get_mut(&session_id)
.ok_or(SessionError::NotFound(session_id))?;
if let Some(agent_id) = requesting_agent_id
&& agent_id != session.agent_id
{
return Err(SessionError::AgentMismatch {
session_id,
expected: session.agent_id,
actual: agent_id,
});
}
if session.status == SessionStatus::Closed {
return Err(SessionError::AlreadyClosed(session_id));
}
if session.is_expired() {
session.status = SessionStatus::Expired;
return Err(SessionError::Expired(session_id));
}
let batch_size = tool_names.len() as u64;
if session.calls_made + batch_size > session.call_budget {
return Err(SessionError::BudgetExceeded {
session_id,
limit: session.call_budget,
used: session.calls_made,
});
}
for tool_name in tool_names {
if !session.is_tool_authorized(tool_name) {
return Err(SessionError::ToolNotAuthorized {
session_id,
tool: (*tool_name).into(),
});
}
}
if let Some(limit) = session.rate_limit_per_minute {
let now = chrono::Utc::now();
let elapsed = now - session.rate_window_start;
if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
} else if session.rate_window_calls + batch_size > limit {
return Err(SessionError::RateLimited {
session_id,
limit_per_minute: limit,
});
}
}
if let Some(_limit) = session.rate_limit_per_minute {
let now = chrono::Utc::now();
let elapsed = now - session.rate_window_start;
if elapsed >= chrono::Duration::seconds(session.rate_limit_window_secs as i64) {
session.rate_window_start = now;
session.rate_window_calls = batch_size;
} else {
session.rate_window_calls += batch_size;
}
}
session.calls_made += batch_size;
tracing::debug!(
session_id = %session_id,
batch_size = batch_size,
calls = session.calls_made,
budget = session.call_budget,
"session batch tool calls recorded"
);
Ok(session.clone())
}
pub async fn close(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
let mut sessions = self.sessions.write().await;
let session = sessions
.get_mut(&session_id)
.ok_or(SessionError::NotFound(session_id))?;
if session.status == SessionStatus::Closed {
return Err(SessionError::AlreadyClosed(session_id));
}
session.status = SessionStatus::Closed;
tracing::info!(session_id = %session_id, "session closed");
Ok(session.clone())
}
pub async fn get(&self, session_id: SessionId) -> Result<TaskSession, SessionError> {
let sessions = self.sessions.read().await;
sessions
.get(&session_id)
.cloned()
.ok_or(SessionError::NotFound(session_id))
}
pub async fn list_all(&self) -> Vec<TaskSession> {
let sessions = self.sessions.read().await;
sessions.values().cloned().collect()
}
pub async fn list_for_agent(&self, agent_id: Uuid) -> Vec<TaskSession> {
let sessions = self.sessions.read().await;
sessions
.values()
.filter(|s| s.agent_id == agent_id)
.cloned()
.collect()
}
pub async fn count_active_for_agent(&self, agent_id: uuid::Uuid) -> u64 {
let sessions = self.sessions.read().await;
sessions
.values()
.filter(|s| s.agent_id == agent_id && s.status == SessionStatus::Active)
.count() as u64
}
pub async fn close_sessions_for_agent(&self, agent_id: uuid::Uuid) -> usize {
let mut sessions = self.sessions.write().await;
let mut closed = 0usize;
for session in sessions.values_mut() {
if session.agent_id == agent_id && session.status == SessionStatus::Active {
session.status = SessionStatus::Closed;
closed += 1;
tracing::info!(
session_id = %session.session_id,
agent_id = %agent_id,
"closed session due to agent deactivation"
);
}
}
closed
}
pub async fn cleanup_expired(&self) -> usize {
let mut sessions = self.sessions.write().await;
let before = sessions.len();
sessions.retain(|_, s| {
if s.is_expired() {
tracing::debug!(session_id = %s.session_id, "cleaning up expired session");
false
} else if s.status == SessionStatus::Closed {
tracing::debug!(session_id = %s.session_id, "cleaning up closed session");
false
} else {
true
}
});
let removed = before - sessions.len();
if removed > 0 {
tracing::info!(removed, "cleaned up expired/closed sessions");
}
removed
}
}
impl Default for SessionStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_create_request() -> CreateSessionRequest {
CreateSessionRequest {
agent_id: Uuid::new_v4(),
delegation_chain_snapshot: vec![],
declared_intent: "read and analyze files".into(),
authorized_tools: vec!["read_file".into(), "list_dir".into()],
authorized_credentials: vec![],
time_limit: chrono::Duration::hours(1),
call_budget: 5,
rate_limit_per_minute: None,
rate_limit_window_secs: 60,
data_sensitivity_ceiling: DataSensitivity::Internal,
}
}
#[tokio::test]
async fn create_and_use_session() {
let store = SessionStore::new();
let session = store.create(test_create_request()).await;
assert_eq!(session.calls_made, 0);
assert!(session.is_active());
let updated = store
.use_session(session.session_id, "read_file", None)
.await
.unwrap();
assert_eq!(updated.calls_made, 1);
}
#[tokio::test]
async fn budget_enforcement() {
let store = SessionStore::new();
let mut req = test_create_request();
req.call_budget = 2;
let session = store.create(req).await;
store
.use_session(session.session_id, "read_file", None)
.await
.unwrap();
store
.use_session(session.session_id, "read_file", None)
.await
.unwrap();
let result = store
.use_session(session.session_id, "read_file", None)
.await;
assert!(matches!(result, Err(SessionError::BudgetExceeded { .. })));
}
#[tokio::test]
async fn tool_whitelist_enforcement() {
let store = SessionStore::new();
let session = store.create(test_create_request()).await;
store
.use_session(session.session_id, "read_file", None)
.await
.unwrap();
let result = store
.use_session(session.session_id, "delete_file", None)
.await;
assert!(matches!(
result,
Err(SessionError::ToolNotAuthorized { .. })
));
}
#[tokio::test]
async fn session_expiry() {
let store = SessionStore::new();
let mut req = test_create_request();
req.time_limit = chrono::Duration::seconds(1);
let session = store.create(req).await;
tokio::time::sleep(std::time::Duration::from_millis(1100)).await;
let result = store
.use_session(session.session_id, "read_file", None)
.await;
assert!(matches!(result, Err(SessionError::Expired(_))));
}
#[tokio::test]
async fn close_and_reuse() {
let store = SessionStore::new();
let session = store.create(test_create_request()).await;
store.close(session.session_id).await.unwrap();
let result = store
.use_session(session.session_id, "read_file", None)
.await;
assert!(matches!(result, Err(SessionError::AlreadyClosed(_))));
}
#[tokio::test]
async fn cleanup_expired_sessions() {
let store = SessionStore::new();
let mut req = test_create_request();
req.time_limit = chrono::Duration::seconds(1);
store.create(req).await;
let valid_req = test_create_request();
store.create(valid_req).await;
tokio::time::sleep(std::time::Duration::from_millis(1100)).await;
let removed = store.cleanup_expired().await;
assert_eq!(removed, 1);
}
#[tokio::test]
async fn session_not_found() {
let store = SessionStore::new();
let fake_id = Uuid::new_v4();
let result = store.use_session(fake_id, "anything", None).await;
assert!(matches!(result, Err(SessionError::NotFound(_))));
}
#[tokio::test]
async fn rate_limit_enforcement() {
let store = SessionStore::new();
let mut req = test_create_request();
req.rate_limit_per_minute = Some(3);
req.call_budget = 100; let session = store.create(req).await;
store
.use_session(session.session_id, "read_file", None)
.await
.unwrap();
store
.use_session(session.session_id, "read_file", None)
.await
.unwrap();
store
.use_session(session.session_id, "read_file", None)
.await
.unwrap();
let result = store
.use_session(session.session_id, "read_file", None)
.await;
assert!(
matches!(result, Err(SessionError::RateLimited { .. })),
"expected RateLimited, got {result:?}"
);
}
#[tokio::test]
async fn no_rate_limit_when_unset() {
let store = SessionStore::new();
let mut req = test_create_request();
req.rate_limit_per_minute = None;
req.call_budget = 100;
let session = store.create(req).await;
for _ in 0..10 {
store
.use_session(session.session_id, "read_file", None)
.await
.unwrap();
}
}
#[tokio::test]
async fn batch_validation_atomicity() {
let store = SessionStore::new();
let mut req = test_create_request();
req.call_budget = 10;
req.authorized_tools = vec!["read_file".into(), "list_dir".into()];
let session = store.create(req).await;
let result = store
.use_session_batch(session.session_id, &["read_file", "delete_file"], None)
.await;
assert!(
matches!(result, Err(SessionError::ToolNotAuthorized { .. })),
"expected ToolNotAuthorized, got {result:?}"
);
let s = store.get(session.session_id).await.unwrap();
assert_eq!(
s.calls_made, 0,
"no budget should be consumed on batch failure"
);
}
#[tokio::test]
async fn batch_budget_enforcement() {
let store = SessionStore::new();
let mut req = test_create_request();
req.call_budget = 3;
req.authorized_tools = vec!["read_file".into()];
let session = store.create(req).await;
let result = store
.use_session_batch(
session.session_id,
&["read_file", "read_file", "read_file", "read_file"],
None,
)
.await;
assert!(
matches!(result, Err(SessionError::BudgetExceeded { .. })),
"expected BudgetExceeded, got {result:?}"
);
let s = store.get(session.session_id).await.unwrap();
assert_eq!(
s.calls_made, 0,
"no budget should be consumed on batch failure"
);
}
#[tokio::test]
async fn batch_rate_limit_enforcement() {
let store = SessionStore::new();
let mut req = test_create_request();
req.call_budget = 100;
req.rate_limit_per_minute = Some(3);
req.authorized_tools = vec!["read_file".into()];
let session = store.create(req).await;
let result = store
.use_session_batch(
session.session_id,
&["read_file", "read_file", "read_file", "read_file"],
None,
)
.await;
assert!(
matches!(result, Err(SessionError::RateLimited { .. })),
"expected RateLimited, got {result:?}"
);
}
#[tokio::test]
async fn empty_batch_succeeds() {
let store = SessionStore::new();
let session = store.create(test_create_request()).await;
let result = store
.use_session_batch(session.session_id, &[], None)
.await
.unwrap();
assert_eq!(result.calls_made, 0, "empty batch must not consume budget");
}
#[tokio::test]
async fn cleanup_also_removes_closed() {
let store = SessionStore::new();
let session = store.create(test_create_request()).await;
store.close(session.session_id).await.unwrap();
let removed = store.cleanup_expired().await;
assert_eq!(removed, 1, "closed session should be cleaned up");
let result = store.get(session.session_id).await;
assert!(
matches!(result, Err(SessionError::NotFound(_))),
"closed session should be removed after cleanup"
);
}
#[tokio::test]
async fn zero_budget_session() {
let store = SessionStore::new();
let mut req = test_create_request();
req.call_budget = 0;
let session = store.create(req).await;
let result = store
.use_session(session.session_id, "read_file", None)
.await;
assert!(
matches!(result, Err(SessionError::BudgetExceeded { .. })),
"zero-budget session must reject the first call, got {result:?}"
);
}
#[tokio::test]
async fn deactivation_closes_agent_sessions() {
let store = SessionStore::new();
let agent_id = Uuid::new_v4();
let other_agent = Uuid::new_v4();
for _ in 0..3 {
let mut req = test_create_request();
req.agent_id = agent_id;
store.create(req).await;
}
let mut other_req = test_create_request();
other_req.agent_id = other_agent;
let other_session = store.create(other_req).await;
let closed = store.close_sessions_for_agent(agent_id).await;
assert_eq!(closed, 3);
let all = store.list_all().await;
for s in &all {
if s.agent_id == agent_id {
assert_eq!(s.status, SessionStatus::Closed);
}
}
let other = store.get(other_session.session_id).await.unwrap();
assert_eq!(other.status, SessionStatus::Active);
}
#[tokio::test]
async fn concurrent_budget_enforcement() {
let store = SessionStore::new();
let mut req = test_create_request();
req.call_budget = 5;
req.authorized_tools = vec!["read_file".into()];
let session = store.create(req).await;
let successes = Arc::new(std::sync::atomic::AtomicU64::new(0));
let failures = Arc::new(std::sync::atomic::AtomicU64::new(0));
let mut handles = Vec::new();
for _ in 0..10 {
let store = store.clone();
let sid = session.session_id;
let s = successes.clone();
let f = failures.clone();
handles.push(tokio::spawn(async move {
match store.use_session(sid, "read_file", None).await {
Ok(_) => {
s.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
Err(SessionError::BudgetExceeded { .. }) => {
f.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
Err(e) => panic!("unexpected error: {e:?}"),
}
}));
}
for h in handles {
h.await.unwrap();
}
assert_eq!(
successes.load(std::sync::atomic::Ordering::Relaxed),
5,
"exactly 5 calls should succeed"
);
assert_eq!(
failures.load(std::sync::atomic::Ordering::Relaxed),
5,
"exactly 5 calls should fail with BudgetExceeded"
);
}
#[tokio::test]
async fn agent_mismatch_rejected() {
let store = SessionStore::new();
let session = store.create(test_create_request()).await;
let attacker_id = Uuid::new_v4();
let result = store
.use_session(session.session_id, "read_file", Some(attacker_id))
.await;
assert!(
matches!(result, Err(SessionError::AgentMismatch { .. })),
"different agent must be rejected, got {result:?}"
);
let result = store
.use_session(session.session_id, "read_file", Some(session.agent_id))
.await;
assert!(result.is_ok(), "session owner should succeed");
}
#[tokio::test]
async fn batch_agent_mismatch_rejected() {
let store = SessionStore::new();
let session = store.create(test_create_request()).await;
let attacker_id = Uuid::new_v4();
let result = store
.use_session_batch(session.session_id, &["read_file"], Some(attacker_id))
.await;
assert!(
matches!(result, Err(SessionError::AgentMismatch { .. })),
"batch with wrong agent must be rejected, got {result:?}"
);
}
}