use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use axum::extract::{Path, Query, State};
use axum::http::StatusCode;
use axum::response::sse::{Event, Sse};
use axum::response::IntoResponse;
use axum::Json;
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use tokio::sync::{broadcast, RwLock};
use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::StreamExt;
use crate::error_response::{api_error, ErrorCode};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SessionStatus {
Running,
AwaitingApproval,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EventCounts {
pub thoughts: u32,
pub tool_calls: u32,
pub tool_results: u32,
pub errors: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AgentSession {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
pub objective: String,
pub status: SessionStatus,
pub tools: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub working_dir: Option<String>,
pub created_at: u64,
pub updated_at: u64,
pub iteration: u32,
pub max_iterations: u32,
pub event_counts: EventCounts,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
}
impl AgentSession {
pub fn new(id: String, objective: String) -> Self {
let now = now_ms();
Self {
id,
name: None,
objective,
status: SessionStatus::Running,
tools: Vec::new(),
working_dir: None,
created_at: now,
updated_at: now,
iteration: 0,
max_iterations: 10,
event_counts: EventCounts::default(),
client_id: None,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_tools(mut self, tools: Vec<String>) -> Self {
self.tools = tools;
self
}
pub fn with_max_iterations(mut self, max: u32) -> Self {
self.max_iterations = max;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AgentEventData {
Thought {
content: String,
},
ToolCall {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
id: String,
output: serde_json::Value,
success: bool,
},
Error {
message: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "event", rename_all = "snake_case")]
pub enum SessionEvent {
SessionStarted {
session: AgentSession,
},
SessionUpdated {
session_id: String,
status: SessionStatus,
iteration: u32,
event_counts: EventCounts,
},
AgentEvent {
session_id: String,
#[serde(flatten)]
data: AgentEventData,
},
SessionEnded {
session_id: String,
status: SessionStatus,
duration_ms: u64,
final_answer: Option<String>,
},
}
pub struct SessionRegistry {
sessions: RwLock<HashMap<String, AgentSession>>,
event_tx: broadcast::Sender<SessionEvent>,
}
impl Default for SessionRegistry {
fn default() -> Self {
Self::new()
}
}
impl SessionRegistry {
pub fn new() -> Self {
let (event_tx, _) = broadcast::channel(1024);
Self {
sessions: RwLock::new(HashMap::new()),
event_tx,
}
}
pub async fn register_session(&self, session: AgentSession) {
let id = session.id.clone();
self.sessions.write().await.insert(id, session.clone());
let _ = self.event_tx.send(SessionEvent::SessionStarted { session });
}
pub async fn emit_event(&self, session_id: &str, event: AgentEventData) {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
match &event {
AgentEventData::Thought { .. } => session.event_counts.thoughts += 1,
AgentEventData::ToolCall { .. } => session.event_counts.tool_calls += 1,
AgentEventData::ToolResult { .. } => session.event_counts.tool_results += 1,
AgentEventData::Error { .. } => session.event_counts.errors += 1,
}
session.updated_at = now_ms();
let _ = self.event_tx.send(SessionEvent::SessionUpdated {
session_id: session_id.to_string(),
status: session.status,
iteration: session.iteration,
event_counts: session.event_counts.clone(),
});
}
let _ = self.event_tx.send(SessionEvent::AgentEvent {
session_id: session_id.to_string(),
data: event,
});
}
pub async fn set_iteration(&self, session_id: &str, iteration: u32) {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
session.iteration = iteration;
session.updated_at = now_ms();
}
}
pub async fn set_status(&self, session_id: &str, status: SessionStatus) {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
session.status = status;
session.updated_at = now_ms();
}
}
pub async fn end_session(
&self,
session_id: &str,
status: SessionStatus,
final_answer: Option<String>,
) {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
session.status = status;
session.updated_at = now_ms();
let duration_ms = session.updated_at.saturating_sub(session.created_at);
let _ = self.event_tx.send(SessionEvent::SessionEnded {
session_id: session_id.to_string(),
status,
duration_ms,
final_answer,
});
}
}
pub async fn get_session(&self, session_id: &str) -> Option<AgentSession> {
self.sessions.read().await.get(session_id).cloned()
}
pub async fn list_sessions(&self) -> Vec<AgentSession> {
self.sessions.read().await.values().cloned().collect()
}
pub fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
self.event_tx.subscribe()
}
pub async fn cancel_session(&self, session_id: &str) -> Result<(), &'static str> {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
if session.status != SessionStatus::Running
&& session.status != SessionStatus::AwaitingApproval
{
return Err("Session is not running");
}
session.status = SessionStatus::Cancelled;
session.updated_at = now_ms();
let duration_ms = session.updated_at.saturating_sub(session.created_at);
let _ = self.event_tx.send(SessionEvent::SessionEnded {
session_id: session_id.to_string(),
status: SessionStatus::Cancelled,
duration_ms,
final_answer: None,
});
Ok(())
} else {
Err("Session not found")
}
}
pub fn generate_id() -> String {
format!(
"sess_{}",
uuid::Uuid::new_v4().simple().to_string()[..12].to_string()
)
}
}
#[derive(Debug, Deserialize)]
pub struct ListSessionsQuery {
pub status: Option<String>,
#[serde(default = "default_limit")]
pub limit: usize,
#[serde(default)]
pub offset: usize,
}
fn default_limit() -> usize {
50
}
#[derive(Debug, Serialize)]
pub struct ListSessionsResponse {
pub sessions: Vec<AgentSession>,
pub total: usize,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GetSessionResponse {
pub session: AgentSession,
pub recent_events: Vec<serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct CancelResponse {
pub status: String,
pub session_id: String,
}
#[derive(Debug, Deserialize)]
pub struct StreamQuery {
pub session_ids: Option<String>,
pub event_types: Option<String>,
}
pub async fn list_sessions(
State(registry): State<Arc<SessionRegistry>>,
Query(query): Query<ListSessionsQuery>,
) -> impl IntoResponse {
let mut sessions = registry.list_sessions().await;
if let Some(status_str) = &query.status {
let filter_status = match status_str.as_str() {
"running" => Some(SessionStatus::Running),
"awaiting_approval" => Some(SessionStatus::AwaitingApproval),
"completed" => Some(SessionStatus::Completed),
"failed" => Some(SessionStatus::Failed),
"cancelled" => Some(SessionStatus::Cancelled),
_ => None,
};
if let Some(status) = filter_status {
sessions.retain(|s| s.status == status);
}
}
let total = sessions.len();
let sessions: Vec<_> = sessions
.into_iter()
.skip(query.offset)
.take(query.limit)
.collect();
Json(ListSessionsResponse { sessions, total })
}
pub async fn get_session(
State(registry): State<Arc<SessionRegistry>>,
Path(session_id): Path<String>,
) -> impl IntoResponse {
match registry.get_session(&session_id).await {
Some(session) => Json(GetSessionResponse {
session,
recent_events: Vec::new(), })
.into_response(),
None => (
StatusCode::NOT_FOUND,
Json(api_error(ErrorCode::NotFound, "Session not found")),
)
.into_response(),
}
}
pub async fn sessions_stream(
State(registry): State<Arc<SessionRegistry>>,
Query(query): Query<StreamQuery>,
) -> Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>> {
let rx = registry.subscribe();
let stream = BroadcastStream::new(rx);
let session_filter: Option<Vec<String>> = query
.session_ids
.map(|s| s.split(',').map(|s| s.trim().to_string()).collect());
let event_filter: Option<Vec<String>> = query
.event_types
.map(|s| s.split(',').map(|s| s.trim().to_string()).collect());
let sse_stream = stream.filter_map(move |result| {
let session_filter = session_filter.clone();
let event_filter = event_filter.clone();
match result {
Ok(event) => {
let session_id = match &event {
SessionEvent::SessionStarted { session } => session.id.clone(),
SessionEvent::SessionUpdated { session_id, .. } => session_id.clone(),
SessionEvent::AgentEvent { session_id, .. } => session_id.clone(),
SessionEvent::SessionEnded { session_id, .. } => session_id.clone(),
};
let event_type = match &event {
SessionEvent::SessionStarted { .. } => "session_started",
SessionEvent::SessionUpdated { .. } => "session_updated",
SessionEvent::AgentEvent { .. } => "agent_event",
SessionEvent::SessionEnded { .. } => "session_ended",
};
let session_ok = session_filter
.as_ref()
.map(|f| f.contains(&session_id))
.unwrap_or(true);
let event_ok = event_filter
.as_ref()
.map(|f| f.iter().any(|t| t == event_type))
.unwrap_or(true);
if session_ok && event_ok {
serde_json::to_string(&event).ok().map(|data| {
Ok::<_, std::convert::Infallible>(
Event::default().event(event_type).data(data),
)
})
} else {
None
}
},
Err(_) => None,
}
});
Sse::new(sse_stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(15))
.text("ping"),
)
}
pub async fn session_stream(
State(registry): State<Arc<SessionRegistry>>,
Path(session_id): Path<String>,
) -> impl IntoResponse {
if registry.get_session(&session_id).await.is_none() {
return (
StatusCode::NOT_FOUND,
Json(api_error(ErrorCode::NotFound, "Session not found")),
)
.into_response();
}
let rx = registry.subscribe();
let stream = BroadcastStream::new(rx);
let filter_id = session_id.clone();
let sse_stream = stream.filter_map(move |result| {
let filter_id = filter_id.clone();
match result {
Ok(event) => {
let event_session_id = match &event {
SessionEvent::SessionStarted { session } => session.id.clone(),
SessionEvent::SessionUpdated { session_id, .. } => session_id.clone(),
SessionEvent::AgentEvent { session_id, .. } => session_id.clone(),
SessionEvent::SessionEnded { session_id, .. } => session_id.clone(),
};
if event_session_id != filter_id {
None
} else {
let event_type = match &event {
SessionEvent::SessionStarted { .. } => "session_started",
SessionEvent::SessionUpdated { .. } => "session_updated",
SessionEvent::AgentEvent { .. } => "agent_event",
SessionEvent::SessionEnded { .. } => "session_ended",
};
serde_json::to_string(&event).ok().map(|data| {
Ok::<_, std::convert::Infallible>(
Event::default().event(event_type).data(data),
)
})
}
},
Err(_) => None,
}
});
Sse::new(sse_stream)
.keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(15))
.text("ping"),
)
.into_response()
}
pub async fn cancel_session(
State(registry): State<Arc<SessionRegistry>>,
Path(session_id): Path<String>,
) -> impl IntoResponse {
match registry.cancel_session(&session_id).await {
Ok(()) => Json(CancelResponse {
status: "cancelled".to_string(),
session_id,
})
.into_response(),
Err("Session not found") => (
StatusCode::NOT_FOUND,
Json(api_error(ErrorCode::NotFound, "Session not found")),
)
.into_response(),
Err("Session is not running") => (
StatusCode::BAD_REQUEST,
Json(api_error(
ErrorCode::InvalidRequest,
"Session is not running",
)),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(api_error(ErrorCode::InternalError, e)),
)
.into_response(),
}
}
fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_status_serialization() {
assert_eq!(
serde_json::to_string(&SessionStatus::Running).unwrap(),
"\"running\""
);
assert_eq!(
serde_json::to_string(&SessionStatus::AwaitingApproval).unwrap(),
"\"awaiting_approval\""
);
}
#[test]
fn test_agent_session_new() {
let session = AgentSession::new("sess_test".to_string(), "Test objective".to_string());
assert_eq!(session.id, "sess_test");
assert_eq!(session.objective, "Test objective");
assert_eq!(session.status, SessionStatus::Running);
assert_eq!(session.iteration, 0);
}
#[test]
fn test_agent_session_builder() {
let session = AgentSession::new("sess_test".to_string(), "Test".to_string())
.with_name("Test Session")
.with_tools(vec!["file_read".to_string(), "grep".to_string()])
.with_max_iterations(20);
assert_eq!(session.name, Some("Test Session".to_string()));
assert_eq!(session.tools.len(), 2);
assert_eq!(session.max_iterations, 20);
}
#[test]
fn test_event_counts_default() {
let counts = EventCounts::default();
assert_eq!(counts.thoughts, 0);
assert_eq!(counts.tool_calls, 0);
assert_eq!(counts.tool_results, 0);
assert_eq!(counts.errors, 0);
}
#[test]
fn test_session_registry_generate_id() {
let id1 = SessionRegistry::generate_id();
let id2 = SessionRegistry::generate_id();
assert!(id1.starts_with("sess_"));
assert!(id2.starts_with("sess_"));
assert_ne!(id1, id2);
}
#[tokio::test]
async fn test_session_registry_register() {
let registry = SessionRegistry::new();
let session = AgentSession::new("sess_test".to_string(), "Test objective".to_string());
registry.register_session(session.clone()).await;
let retrieved = registry.get_session("sess_test").await;
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().objective, "Test objective");
}
#[tokio::test]
async fn test_session_registry_emit_event() {
let registry = SessionRegistry::new();
let session = AgentSession::new("sess_test".to_string(), "Test".to_string());
registry.register_session(session).await;
registry
.emit_event(
"sess_test",
AgentEventData::Thought {
content: "Thinking...".to_string(),
},
)
.await;
let session = registry.get_session("sess_test").await.unwrap();
assert_eq!(session.event_counts.thoughts, 1);
}
#[tokio::test]
async fn test_session_registry_cancel() {
let registry = SessionRegistry::new();
let session = AgentSession::new("sess_test".to_string(), "Test".to_string());
registry.register_session(session).await;
let result = registry.cancel_session("sess_test").await;
assert!(result.is_ok());
let session = registry.get_session("sess_test").await.unwrap();
assert_eq!(session.status, SessionStatus::Cancelled);
}
#[tokio::test]
async fn test_session_registry_cancel_not_found() {
let registry = SessionRegistry::new();
let result = registry.cancel_session("nonexistent").await;
assert_eq!(result, Err("Session not found"));
}
#[tokio::test]
async fn test_session_registry_list() {
let registry = SessionRegistry::new();
registry
.register_session(AgentSession::new(
"sess_1".to_string(),
"Task 1".to_string(),
))
.await;
registry
.register_session(AgentSession::new(
"sess_2".to_string(),
"Task 2".to_string(),
))
.await;
let sessions = registry.list_sessions().await;
assert_eq!(sessions.len(), 2);
}
}