use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::types::{AgentId, AgentState, Message, MessageRole, SessionId, ToolCall, ToolResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionConfig {
pub model: String,
pub temperature: f32,
pub max_iterations: i32,
pub system_prompt: Option<String>,
#[serde(default = "default_max_messages")]
pub max_messages: usize,
#[serde(default = "default_max_tools")]
pub max_tools: usize,
}
fn default_max_messages() -> usize {
1000
}
fn default_max_tools() -> usize {
100
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
model: "claude-sonnet-4-6".to_string(),
temperature: 0.7,
max_iterations: 100,
system_prompt: None,
max_messages: 1000,
max_tools: 100,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub session_id: SessionId,
pub agent_id: AgentId,
pub state: AgentState,
pub iteration: i32,
pub max_iterations: i32,
pub messages: Vec<Message>,
pub tools_registered: Vec<String>,
pub tool_calls_pending: Vec<ToolCall>,
pub tool_results_cache: Vec<ToolResult>,
pub model: String,
pub temperature: f32,
pub system_prompt: String,
pub tokens_total: i64,
pub tokens_prompt: i64,
pub tokens_completion: i64,
pub cost_estimate: f64,
pub created_at: DateTime<Utc>,
pub last_updated: DateTime<Utc>,
pub checkpoint_count: i32,
#[serde(default = "default_max_messages")]
pub max_messages: usize,
#[serde(default = "default_max_tools")]
pub max_tools: usize,
}
impl Session {
pub fn new(config: &SessionConfig) -> Self {
let now = Utc::now();
Self {
session_id: SessionId::new(),
agent_id: AgentId::new(),
state: AgentState::Idle,
iteration: 0,
max_iterations: config.max_iterations,
messages: Vec::new(),
tools_registered: Vec::new(),
tool_calls_pending: Vec::new(),
tool_results_cache: Vec::new(),
model: config.model.clone(),
temperature: config.temperature,
system_prompt: config.system_prompt.clone().unwrap_or_default(),
tokens_total: 0,
tokens_prompt: 0,
tokens_completion: 0,
cost_estimate: 0.0,
created_at: now,
last_updated: now,
checkpoint_count: 0,
max_messages: config.max_messages,
max_tools: config.max_tools,
}
}
pub fn add_user_message(&mut self, content: &str) {
self.messages.push(Message::user(content));
self.trim_messages();
self.iteration += 1;
self.touch();
}
pub fn add_assistant_message(&mut self, content: &str) {
self.messages.push(Message::assistant(content));
self.trim_messages();
self.touch();
}
pub fn add_system_message(&mut self, content: &str) {
self.messages.push(Message::system(content));
self.trim_messages();
self.touch();
}
fn trim_messages(&mut self) {
if self.messages.len() > self.max_messages {
let excess = self.messages.len() - self.max_messages;
let first_is_system = self
.messages
.first()
.map(|m| m.role == MessageRole::System)
.unwrap_or(false);
if first_is_system && excess > 0 {
self.messages.drain(1..=excess.min(self.messages.len() - 1));
} else {
self.messages.drain(0..excess);
}
}
}
pub fn register_tool(&mut self, tool_name: &str) {
if !self.tools_registered.contains(&tool_name.to_string()) {
if self.tools_registered.len() >= self.max_tools {
self.tools_registered.remove(0);
}
self.tools_registered.push(tool_name.to_string());
self.touch();
}
}
pub fn touch(&mut self) {
self.last_updated = Utc::now();
}
pub fn can_continue(&self) -> bool {
self.iteration < self.max_iterations
&& matches!(self.state, AgentState::Running | AgentState::Idle)
}
pub fn to_json(&self) -> serde_json::Result<String> {
serde_json::to_string_pretty(self)
}
pub fn from_json(json: &str) -> serde_json::Result<Self> {
serde_json::from_str(json)
}
pub fn to_dict(&self) -> serde_json::Value {
serde_json::to_value(self).unwrap_or(serde_json::Value::Null)
}
}
impl Default for Session {
fn default() -> Self {
Self::new(&SessionConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_creation() {
let config = SessionConfig::default();
let session = Session::new(&config);
assert!(session.messages.is_empty());
assert_eq!(session.state, AgentState::Idle);
assert_eq!(session.iteration, 0);
}
#[test]
fn test_session_messages() {
let config = SessionConfig::default();
let mut session = Session::new(&config);
session.add_user_message("Hello");
assert_eq!(session.messages.len(), 1);
assert_eq!(session.iteration, 1);
session.add_assistant_message("Hi there!");
assert_eq!(session.messages.len(), 2);
}
#[test]
fn test_session_can_continue() {
let config = SessionConfig {
max_iterations: 5,
..Default::default()
};
let mut session = Session::new(&config);
assert!(session.can_continue());
session.state = AgentState::Running;
assert!(session.can_continue());
session.state = AgentState::Stopped;
assert!(!session.can_continue());
}
#[test]
fn test_session_serialization() {
let config = SessionConfig::default();
let session = Session::new(&config);
let json = session.to_json().unwrap();
let restored = Session::from_json(&json).unwrap();
assert_eq!(session.session_id, restored.session_id);
assert_eq!(session.state, restored.state);
}
#[test]
fn test_session_max_messages_limit() {
let config = SessionConfig {
max_messages: 5,
..Default::default()
};
let mut session = Session::new(&config);
for i in 0..10 {
session.add_user_message(&format!("Message {}", i));
}
assert_eq!(session.messages.len(), 5);
}
#[test]
fn test_session_preserves_system_message() {
let config = SessionConfig {
max_messages: 3,
system_prompt: Some("System prompt".to_string()),
..Default::default()
};
let mut session = Session::new(&config);
session.add_system_message("System prompt");
for i in 0..5 {
session.add_user_message(&format!("User {}", i));
}
assert_eq!(session.messages.len(), 3);
assert!(session
.messages
.first()
.map(|m| m.role == MessageRole::System)
.unwrap_or(false));
}
#[test]
fn test_session_max_tools_limit() {
let config = SessionConfig {
max_tools: 3,
..Default::default()
};
let mut session = Session::new(&config);
for i in 0..5 {
session.register_tool(&format!("tool_{}", i));
}
assert_eq!(session.tools_registered.len(), 3);
assert!(!session.tools_registered.contains(&"tool_0".to_string()));
assert!(!session.tools_registered.contains(&"tool_1".to_string()));
}
#[test]
fn test_session_no_duplicate_tools() {
let config = SessionConfig::default();
let mut session = Session::new(&config);
session.register_tool("tool_a");
session.register_tool("tool_a");
session.register_tool("tool_a");
assert_eq!(session.tools_registered.len(), 1);
}
}