1#![allow(dead_code)]
8
9use crate::agency::error::{AgencyError, AgencyResult};
10use crate::agency::models::{AgencyMessage, MessageRole, TokenUsage};
11use chrono::{DateTime, Utc};
12use rusqlite::{params, Connection, OptionalExtension};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::path::Path;
16use std::sync::{Arc, Mutex};
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize)]
20pub struct SessionState {
21 pub data: HashMap<String, serde_json::Value>,
23}
24
25impl SessionState {
26 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
31 self.data
32 .get(key)
33 .and_then(|v| serde_json::from_value(v.clone()).ok())
34 }
35
36 pub fn set<T: Serialize>(&mut self, key: impl Into<String>, value: T) {
37 if let Ok(v) = serde_json::to_value(value) {
38 self.data.insert(key.into(), v);
39 }
40 }
41
42 pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
43 self.data.remove(key)
44 }
45
46 pub fn contains(&self, key: &str) -> bool {
47 self.data.contains_key(key)
48 }
49
50 pub fn clear(&mut self) {
51 self.data.clear();
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct Session {
58 pub id: String,
60 pub agent_name: String,
62 #[serde(default)]
64 pub user_id: Option<String>,
65 #[serde(default)]
67 pub title: Option<String>,
68 pub messages: Vec<AgencyMessage>,
70 #[serde(default)]
72 pub state: SessionState,
73 #[serde(default)]
75 pub token_usage: TokenUsage,
76 pub created_at: DateTime<Utc>,
78 pub updated_at: DateTime<Utc>,
80 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
82 pub metadata: HashMap<String, serde_json::Value>,
83}
84
85impl Session {
86 pub fn new(agent_name: impl Into<String>, user_id: Option<String>) -> Self {
88 let now = Utc::now();
89 Self {
90 id: generate_session_id(),
91 agent_name: agent_name.into(),
92 user_id,
93 title: None,
94 messages: Vec::new(),
95 state: SessionState::new(),
96 token_usage: TokenUsage::default(),
97 created_at: now,
98 updated_at: now,
99 metadata: HashMap::new(),
100 }
101 }
102
103 pub fn add_message(&mut self, message: AgencyMessage) {
105 if let Some(tokens) = message.tokens {
106 self.token_usage.total_tokens += tokens;
107 match message.role {
108 MessageRole::User | MessageRole::System => {
109 self.token_usage.prompt_tokens += tokens;
110 }
111 MessageRole::Assistant | MessageRole::Tool => {
112 self.token_usage.completion_tokens += tokens;
113 }
114 }
115 }
116 self.messages.push(message);
117 self.updated_at = Utc::now();
118 }
119
120 pub fn to_api_messages(&self) -> Vec<serde_json::Value> {
122 self.messages
123 .iter()
124 .map(|m| {
125 serde_json::json!({
126 "role": m.role.to_string(),
127 "content": m.content
128 })
129 })
130 .collect()
131 }
132
133 pub fn last_messages(&self, n: usize) -> &[AgencyMessage] {
135 let start = self.messages.len().saturating_sub(n);
136 &self.messages[start..]
137 }
138
139 pub fn clear_messages(&mut self) {
141 self.messages.clear();
142 self.token_usage = TokenUsage::default();
143 self.updated_at = Utc::now();
144 }
145
146 pub fn rewind_to(&mut self, message_id: &str) -> Option<Vec<AgencyMessage>> {
148 if let Some(pos) = self.messages.iter().position(|m| m.id == message_id) {
149 let removed: Vec<_> = self.messages.drain(pos..).collect();
150 self.updated_at = Utc::now();
151 self.recalculate_tokens();
153 Some(removed)
154 } else {
155 None
156 }
157 }
158
159 fn recalculate_tokens(&mut self) {
160 let mut usage = TokenUsage::default();
161 for m in &self.messages {
162 if let Some(tokens) = m.tokens {
163 usage.total_tokens += tokens;
164 match m.role {
165 MessageRole::User | MessageRole::System => {
166 usage.prompt_tokens += tokens;
167 }
168 MessageRole::Assistant | MessageRole::Tool => {
169 usage.completion_tokens += tokens;
170 }
171 }
172 }
173 }
174 self.token_usage = usage;
175 }
176}
177
178fn generate_session_id() -> String {
180 format!(
181 "session-{}-{}",
182 Utc::now().timestamp_millis(),
183 &uuid::Uuid::new_v4().to_string()[..8]
184 )
185}
186
187pub fn generate_message_id() -> String {
189 format!(
190 "msg-{}-{}",
191 Utc::now().timestamp_millis(),
192 &uuid::Uuid::new_v4().to_string()[..8]
193 )
194}
195
196pub struct SessionManager {
198 conn: Arc<Mutex<Connection>>,
199}
200
201impl SessionManager {
202 pub fn new(db_path: impl AsRef<Path>) -> AgencyResult<Self> {
204 let conn = Connection::open(db_path)?;
205 let manager = Self {
206 conn: Arc::new(Mutex::new(conn)),
207 };
208 manager.init_schema()?;
209 Ok(manager)
210 }
211
212 pub fn in_memory() -> AgencyResult<Self> {
214 let conn = Connection::open_in_memory()?;
215 let manager = Self {
216 conn: Arc::new(Mutex::new(conn)),
217 };
218 manager.init_schema()?;
219 Ok(manager)
220 }
221
222 fn init_schema(&self) -> AgencyResult<()> {
224 let conn = self
225 .conn
226 .lock()
227 .map_err(|e| AgencyError::DatabaseError(e.to_string()))?;
228 conn.execute_batch(
229 r#"
230 CREATE TABLE IF NOT EXISTS Agency_sessions (
231 id TEXT PRIMARY KEY,
232 agent_name TEXT NOT NULL,
233 user_id TEXT,
234 title TEXT,
235 messages TEXT NOT NULL,
236 state TEXT NOT NULL,
237 token_usage TEXT NOT NULL,
238 metadata TEXT,
239 created_at TEXT NOT NULL,
240 updated_at TEXT NOT NULL
241 );
242
243 CREATE INDEX IF NOT EXISTS idx_Agency_sessions_agent ON Agency_sessions(agent_name);
244 CREATE INDEX IF NOT EXISTS idx_Agency_sessions_user ON Agency_sessions(user_id);
245 CREATE INDEX IF NOT EXISTS idx_Agency_sessions_updated ON Agency_sessions(updated_at DESC);
246 "#,
247 )?;
248 Ok(())
249 }
250
251 pub fn create(
253 &self,
254 agent_name: impl Into<String>,
255 user_id: Option<String>,
256 ) -> AgencyResult<Session> {
257 let session = Session::new(agent_name, user_id);
258 self.save(&session)?;
259 Ok(session)
260 }
261
262 pub fn save(&self, session: &Session) -> AgencyResult<()> {
264 let conn = self
265 .conn
266 .lock()
267 .map_err(|e| AgencyError::DatabaseError(e.to_string()))?;
268 conn.execute(
269 r#"
270 INSERT OR REPLACE INTO Agency_sessions
271 (id, agent_name, user_id, title, messages, state, token_usage, metadata, created_at, updated_at)
272 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
273 "#,
274 params![
275 session.id,
276 session.agent_name,
277 session.user_id,
278 session.title,
279 serde_json::to_string(&session.messages)?,
280 serde_json::to_string(&session.state)?,
281 serde_json::to_string(&session.token_usage)?,
282 serde_json::to_string(&session.metadata)?,
283 session.created_at.to_rfc3339(),
284 session.updated_at.to_rfc3339(),
285 ],
286 )?;
287 Ok(())
288 }
289
290 pub fn get(&self, id: &str) -> AgencyResult<Option<Session>> {
292 let conn = self
293 .conn
294 .lock()
295 .map_err(|e| AgencyError::DatabaseError(e.to_string()))?;
296 let session = conn
297 .query_row(
298 "SELECT * FROM Agency_sessions WHERE id = ?1",
299 params![id],
300 |row| {
301 Ok(Session {
302 id: row.get(0)?,
303 agent_name: row.get(1)?,
304 user_id: row.get(2)?,
305 title: row.get(3)?,
306 messages: serde_json::from_str(&row.get::<_, String>(4)?)
307 .unwrap_or_default(),
308 state: serde_json::from_str(&row.get::<_, String>(5)?).unwrap_or_default(),
309 token_usage: serde_json::from_str(&row.get::<_, String>(6)?)
310 .unwrap_or_default(),
311 metadata: serde_json::from_str(&row.get::<_, String>(7)?)
312 .unwrap_or_default(),
313 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?)
314 .map(|dt| dt.with_timezone(&Utc))
315 .unwrap_or_else(|_| Utc::now()),
316 updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
317 .map(|dt| dt.with_timezone(&Utc))
318 .unwrap_or_else(|_| Utc::now()),
319 })
320 },
321 )
322 .optional()?;
323 Ok(session)
324 }
325
326 pub fn list_by_agent(
328 &self,
329 agent_name: &str,
330 limit: Option<u32>,
331 ) -> AgencyResult<Vec<Session>> {
332 let conn = self
333 .conn
334 .lock()
335 .map_err(|e| AgencyError::DatabaseError(e.to_string()))?;
336 let limit = limit.unwrap_or(100);
337 let mut stmt = conn.prepare(
338 "SELECT * FROM Agency_sessions WHERE agent_name = ?1 ORDER BY updated_at DESC LIMIT ?2",
339 )?;
340 let sessions = stmt
341 .query_map(params![agent_name, limit], |row| {
342 Ok(Session {
343 id: row.get(0)?,
344 agent_name: row.get(1)?,
345 user_id: row.get(2)?,
346 title: row.get(3)?,
347 messages: serde_json::from_str(&row.get::<_, String>(4)?).unwrap_or_default(),
348 state: serde_json::from_str(&row.get::<_, String>(5)?).unwrap_or_default(),
349 token_usage: serde_json::from_str(&row.get::<_, String>(6)?)
350 .unwrap_or_default(),
351 metadata: serde_json::from_str(&row.get::<_, String>(7)?).unwrap_or_default(),
352 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?)
353 .map(|dt| dt.with_timezone(&Utc))
354 .unwrap_or_else(|_| Utc::now()),
355 updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
356 .map(|dt| dt.with_timezone(&Utc))
357 .unwrap_or_else(|_| Utc::now()),
358 })
359 })?
360 .filter_map(|r| r.ok())
361 .collect();
362 Ok(sessions)
363 }
364
365 pub fn list_by_user(&self, user_id: &str, limit: Option<u32>) -> AgencyResult<Vec<Session>> {
367 let conn = self
368 .conn
369 .lock()
370 .map_err(|e| AgencyError::DatabaseError(e.to_string()))?;
371 let limit = limit.unwrap_or(100);
372 let mut stmt = conn.prepare(
373 "SELECT * FROM Agency_sessions WHERE user_id = ?1 ORDER BY updated_at DESC LIMIT ?2",
374 )?;
375 let sessions = stmt
376 .query_map(params![user_id, limit], |row| {
377 Ok(Session {
378 id: row.get(0)?,
379 agent_name: row.get(1)?,
380 user_id: row.get(2)?,
381 title: row.get(3)?,
382 messages: serde_json::from_str(&row.get::<_, String>(4)?).unwrap_or_default(),
383 state: serde_json::from_str(&row.get::<_, String>(5)?).unwrap_or_default(),
384 token_usage: serde_json::from_str(&row.get::<_, String>(6)?)
385 .unwrap_or_default(),
386 metadata: serde_json::from_str(&row.get::<_, String>(7)?).unwrap_or_default(),
387 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?)
388 .map(|dt| dt.with_timezone(&Utc))
389 .unwrap_or_else(|_| Utc::now()),
390 updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
391 .map(|dt| dt.with_timezone(&Utc))
392 .unwrap_or_else(|_| Utc::now()),
393 })
394 })?
395 .filter_map(|r| r.ok())
396 .collect();
397 Ok(sessions)
398 }
399
400 pub fn delete(&self, id: &str) -> AgencyResult<bool> {
402 let conn = self
403 .conn
404 .lock()
405 .map_err(|e| AgencyError::DatabaseError(e.to_string()))?;
406 let rows = conn.execute("DELETE FROM Agency_sessions WHERE id = ?1", params![id])?;
407 Ok(rows > 0)
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_session_state() {
417 let mut state = SessionState::new();
418 state.set("count", 42);
419 state.set("name", "test");
420
421 assert_eq!(state.get::<i32>("count"), Some(42));
422 assert_eq!(state.get::<String>("name"), Some("test".to_string()));
423 assert!(state.contains("count"));
424 assert!(!state.contains("missing"));
425 }
426
427 #[test]
428 fn test_session_messages() {
429 let mut session = Session::new("test_agent", None);
430 session.add_message(AgencyMessage {
431 id: "msg1".to_string(),
432 role: MessageRole::User,
433 content: "Hello".to_string(),
434 tool_calls: vec![],
435 tool_result: None,
436 timestamp: Utc::now(),
437 tokens: Some(5),
438 agent_name: None,
439 metadata: HashMap::new(),
440 });
441
442 assert_eq!(session.messages.len(), 1);
443 assert_eq!(session.token_usage.prompt_tokens, 5);
444 }
445
446 #[test]
447 fn test_session_manager() -> AgencyResult<()> {
448 let manager = SessionManager::in_memory()?;
449 let session = manager.create("test_agent", Some("user1".to_string()))?;
450
451 let loaded = manager.get(&session.id)?;
452 assert!(loaded.is_some());
453 assert_eq!(loaded.unwrap().agent_name, "test_agent");
454
455 let sessions = manager.list_by_agent("test_agent", None)?;
456 assert_eq!(sessions.len(), 1);
457
458 manager.delete(&session.id)?;
459 let deleted = manager.get(&session.id)?;
460 assert!(deleted.is_none());
461
462 Ok(())
463 }
464}