Skip to main content

cortexai_agents/
session.rs

1//! # Structured Sessions
2//!
3//! Conversation memory and session management for agents.
4//!
5//! Inspired by OpenAI Agents SDK session patterns.
6//!
7//! ## Features
8//!
9//! - **Conversation History**: Store and retrieve message history
10//! - **Session State**: Persist arbitrary session state
11//! - **Turn Management**: Track conversation turns
12//! - **Context Windows**: Manage context length limits
13//!
14//! ## Example
15//!
16//! ```rust,ignore
17//! use cortex::session::{ConversationSession, ChatMessage, ChatRole};
18//!
19//! let session = ConversationSession::new("user_123")
20//!     .with_system_prompt("You are a helpful assistant");
21//!
22//! session.add_message(ChatRole::User, "Hello!");
23//! session.add_message(ChatRole::Assistant, "Hi there! How can I help?");
24//!
25//! let context = session.get_context(4096)?; // Get messages within token limit
26//! ```
27
28use std::collections::HashMap;
29use std::sync::Arc;
30use std::time::{Duration, SystemTime, UNIX_EPOCH};
31
32use async_trait::async_trait;
33use parking_lot::RwLock;
34use serde::{Deserialize, Serialize};
35use tracing::{debug, info};
36
37/// Unique identifier for a conversation session
38pub type ConversationId = String;
39
40/// Role of a message sender in a conversation
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
42pub enum ChatRole {
43    /// System instructions
44    System,
45    /// User message
46    User,
47    /// Assistant response
48    Assistant,
49    /// Tool/function call
50    Tool,
51}
52
53impl ChatRole {
54    pub fn as_str(&self) -> &'static str {
55        match self {
56            ChatRole::System => "system",
57            ChatRole::User => "user",
58            ChatRole::Assistant => "assistant",
59            ChatRole::Tool => "tool",
60        }
61    }
62}
63
64/// A message in the conversation
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ChatMessage {
67    /// Role of the sender
68    pub role: ChatRole,
69    /// Message content
70    pub content: String,
71    /// When the message was created
72    pub timestamp: u64,
73    /// Optional name (for multi-agent scenarios)
74    pub name: Option<String>,
75    /// Optional tool call ID
76    pub tool_call_id: Option<String>,
77    /// Optional metadata
78    pub metadata: HashMap<String, String>,
79}
80
81impl ChatMessage {
82    pub fn new(role: ChatRole, content: impl Into<String>) -> Self {
83        Self {
84            role,
85            content: content.into(),
86            timestamp: SystemTime::now()
87                .duration_since(UNIX_EPOCH)
88                .unwrap_or_default()
89                .as_secs(),
90            name: None,
91            tool_call_id: None,
92            metadata: HashMap::new(),
93        }
94    }
95
96    pub fn system(content: impl Into<String>) -> Self {
97        Self::new(ChatRole::System, content)
98    }
99
100    pub fn user(content: impl Into<String>) -> Self {
101        Self::new(ChatRole::User, content)
102    }
103
104    pub fn assistant(content: impl Into<String>) -> Self {
105        Self::new(ChatRole::Assistant, content)
106    }
107
108    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
109        let mut msg = Self::new(ChatRole::Tool, content);
110        msg.tool_call_id = Some(tool_call_id.into());
111        msg
112    }
113
114    pub fn with_name(mut self, name: impl Into<String>) -> Self {
115        self.name = Some(name.into());
116        self
117    }
118
119    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
120        self.metadata.insert(key.into(), value.into());
121        self
122    }
123
124    /// Estimate token count (rough approximation: 4 chars per token)
125    pub fn estimated_tokens(&self) -> usize {
126        self.content.len() / 4 + 1
127    }
128}
129
130/// A conversation turn (user message + assistant response)
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct Turn {
133    /// Turn number (1-indexed)
134    pub number: u32,
135    /// User message
136    pub user_message: ChatMessage,
137    /// Assistant response (if available)
138    pub assistant_response: Option<ChatMessage>,
139    /// Tool calls made during this turn
140    pub tool_calls: Vec<ChatMessage>,
141    /// When the turn started
142    pub started_at: u64,
143    /// When the turn completed
144    pub completed_at: Option<u64>,
145}
146
147impl Turn {
148    pub fn new(number: u32, user_message: ChatMessage) -> Self {
149        Self {
150            number,
151            started_at: user_message.timestamp,
152            user_message,
153            assistant_response: None,
154            tool_calls: Vec::new(),
155            completed_at: None,
156        }
157    }
158
159    pub fn complete(&mut self, response: ChatMessage) {
160        self.assistant_response = Some(response);
161        self.completed_at = Some(
162            SystemTime::now()
163                .duration_since(UNIX_EPOCH)
164                .unwrap_or_default()
165                .as_secs(),
166        );
167    }
168
169    pub fn add_tool_call(&mut self, tool_message: ChatMessage) {
170        self.tool_calls.push(tool_message);
171    }
172
173    pub fn is_complete(&self) -> bool {
174        self.assistant_response.is_some()
175    }
176
177    pub fn all_messages(&self) -> Vec<&ChatMessage> {
178        let mut messages = vec![&self.user_message];
179        messages.extend(self.tool_calls.iter());
180        if let Some(ref response) = self.assistant_response {
181            messages.push(response);
182        }
183        messages
184    }
185}
186
187/// Session state for arbitrary data
188#[derive(Debug, Clone, Default, Serialize, Deserialize)]
189pub struct SessionState {
190    /// Key-value store for session variables
191    pub variables: HashMap<String, String>,
192    /// Structured data (JSON serialized)
193    pub data: HashMap<String, String>,
194}
195
196impl SessionState {
197    pub fn new() -> Self {
198        Self::default()
199    }
200
201    pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
202        self.variables.insert(key.into(), value.into());
203    }
204
205    pub fn get(&self, key: &str) -> Option<&String> {
206        self.variables.get(key)
207    }
208
209    pub fn set_data<T: Serialize>(
210        &mut self,
211        key: impl Into<String>,
212        value: &T,
213    ) -> Result<(), serde_json::Error> {
214        let json = serde_json::to_string(value)?;
215        self.data.insert(key.into(), json);
216        Ok(())
217    }
218
219    pub fn get_data<T: for<'de> Deserialize<'de>>(
220        &self,
221        key: &str,
222    ) -> Option<Result<T, serde_json::Error>> {
223        self.data.get(key).map(|json| serde_json::from_str(json))
224    }
225
226    pub fn remove(&mut self, key: &str) -> Option<String> {
227        self.variables.remove(key)
228    }
229
230    pub fn clear(&mut self) {
231        self.variables.clear();
232        self.data.clear();
233    }
234}
235
236/// Configuration for a session
237#[derive(Debug, Clone)]
238pub struct SessionConfig {
239    /// Maximum messages to keep in history
240    pub max_messages: usize,
241    /// Maximum tokens in context
242    pub max_tokens: usize,
243    /// Session timeout (auto-expire)
244    pub timeout: Option<Duration>,
245    /// Whether to persist system prompt in history
246    pub persist_system_prompt: bool,
247}
248
249impl Default for SessionConfig {
250    fn default() -> Self {
251        Self {
252            max_messages: 100,
253            max_tokens: 8192,
254            timeout: Some(Duration::from_secs(3600)), // 1 hour
255            persist_system_prompt: true,
256        }
257    }
258}
259
260/// A session representing a conversation
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct ConversationSession {
263    /// Session ID
264    pub id: ConversationId,
265    /// User ID (if applicable)
266    pub user_id: Option<String>,
267    /// System prompt
268    pub system_prompt: Option<String>,
269    /// Message history
270    pub messages: Vec<ChatMessage>,
271    /// Conversation turns
272    pub turns: Vec<Turn>,
273    /// Session state
274    pub state: SessionState,
275    /// When session was created
276    pub created_at: u64,
277    /// When session was last updated
278    pub updated_at: u64,
279    /// Session metadata
280    pub metadata: HashMap<String, String>,
281    /// Configuration (not serialized)
282    #[serde(skip)]
283    config: SessionConfig,
284}
285
286impl ConversationSession {
287    pub fn new(id: impl Into<String>) -> Self {
288        let now = SystemTime::now()
289            .duration_since(UNIX_EPOCH)
290            .unwrap_or_default()
291            .as_secs();
292
293        Self {
294            id: id.into(),
295            user_id: None,
296            system_prompt: None,
297            messages: Vec::new(),
298            turns: Vec::new(),
299            state: SessionState::new(),
300            created_at: now,
301            updated_at: now,
302            metadata: HashMap::new(),
303            config: SessionConfig::default(),
304        }
305    }
306
307    pub fn with_config(mut self, config: SessionConfig) -> Self {
308        self.config = config;
309        self
310    }
311
312    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
313        self.user_id = Some(user_id.into());
314        self
315    }
316
317    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
318        let prompt = prompt.into();
319        self.system_prompt = Some(prompt.clone());
320        if self.config.persist_system_prompt {
321            self.messages.insert(0, ChatMessage::system(prompt));
322        }
323        self
324    }
325
326    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
327        self.metadata.insert(key.into(), value.into());
328        self
329    }
330
331    /// Add a message to the session
332    pub fn add_message(&mut self, role: ChatRole, content: impl Into<String>) {
333        let message = ChatMessage::new(role, content);
334        self.add_message_obj(message);
335    }
336
337    /// Add a pre-constructed message
338    pub fn add_message_obj(&mut self, message: ChatMessage) {
339        // Handle turn management
340        match message.role {
341            ChatRole::User => {
342                let turn_num = self.turns.len() as u32 + 1;
343                self.turns.push(Turn::new(turn_num, message.clone()));
344            }
345            ChatRole::Assistant => {
346                if let Some(turn) = self.turns.last_mut() {
347                    turn.complete(message.clone());
348                }
349            }
350            ChatRole::Tool => {
351                if let Some(turn) = self.turns.last_mut() {
352                    turn.add_tool_call(message.clone());
353                }
354            }
355            ChatRole::System => {
356                // System messages don't affect turns
357            }
358        }
359
360        self.messages.push(message);
361        self.touch();
362        self.enforce_limits();
363    }
364
365    /// Get messages for context (within token limit)
366    pub fn get_context(&self, max_tokens: Option<usize>) -> Vec<&ChatMessage> {
367        let max_tokens = max_tokens.unwrap_or(self.config.max_tokens);
368        let mut result = Vec::new();
369        let mut token_count = 0;
370
371        // Always include system prompt if present
372        if let Some(system_msg) = self.messages.first() {
373            if system_msg.role == ChatRole::System {
374                token_count += system_msg.estimated_tokens();
375                result.push(system_msg);
376            }
377        }
378
379        // Add messages from most recent, respecting token limit
380        for message in self.messages.iter().rev() {
381            if message.role == ChatRole::System {
382                continue; // Already handled
383            }
384
385            let msg_tokens = message.estimated_tokens();
386            if token_count + msg_tokens > max_tokens {
387                break;
388            }
389
390            token_count += msg_tokens;
391            result.push(message);
392        }
393
394        // Reverse to get chronological order (except system which is first)
395        let system_msg = if result.first().map(|m| m.role) == Some(ChatRole::System) {
396            Some(result.remove(0))
397        } else {
398            None
399        };
400
401        result.reverse();
402
403        if let Some(sys) = system_msg {
404            result.insert(0, sys);
405        }
406
407        result
408    }
409
410    /// Get all messages
411    pub fn get_all_messages(&self) -> &[ChatMessage] {
412        &self.messages
413    }
414
415    /// Get recent N messages
416    pub fn get_recent(&self, n: usize) -> Vec<&ChatMessage> {
417        self.messages.iter().rev().take(n).rev().collect()
418    }
419
420    /// Get messages by role
421    pub fn get_by_role(&self, role: ChatRole) -> Vec<&ChatMessage> {
422        self.messages.iter().filter(|m| m.role == role).collect()
423    }
424
425    /// Get current turn number
426    pub fn current_turn(&self) -> u32 {
427        self.turns.len() as u32
428    }
429
430    /// Get the last turn
431    pub fn last_turn(&self) -> Option<&Turn> {
432        self.turns.last()
433    }
434
435    /// Get a specific turn
436    pub fn get_turn(&self, number: u32) -> Option<&Turn> {
437        if number == 0 || number as usize > self.turns.len() {
438            return None;
439        }
440        self.turns.get(number as usize - 1)
441    }
442
443    /// Check if session is expired
444    pub fn is_expired(&self) -> bool {
445        if let Some(timeout) = self.config.timeout {
446            let now = SystemTime::now()
447                .duration_since(UNIX_EPOCH)
448                .unwrap_or_default()
449                .as_secs();
450            return now - self.updated_at > timeout.as_secs();
451        }
452        false
453    }
454
455    /// Get session age
456    pub fn age(&self) -> Duration {
457        let now = SystemTime::now()
458            .duration_since(UNIX_EPOCH)
459            .unwrap_or_default()
460            .as_secs();
461        Duration::from_secs(now - self.created_at)
462    }
463
464    /// Clear all messages (keep system prompt)
465    pub fn clear_messages(&mut self) {
466        let system = if self.config.persist_system_prompt {
467            self.messages
468                .first()
469                .filter(|m| m.role == ChatRole::System)
470                .cloned()
471        } else {
472            None
473        };
474
475        self.messages.clear();
476        self.turns.clear();
477
478        if let Some(sys) = system {
479            self.messages.push(sys);
480        }
481
482        self.touch();
483    }
484
485    /// Get total estimated tokens
486    pub fn total_tokens(&self) -> usize {
487        self.messages.iter().map(|m| m.estimated_tokens()).sum()
488    }
489
490    fn touch(&mut self) {
491        self.updated_at = SystemTime::now()
492            .duration_since(UNIX_EPOCH)
493            .unwrap_or_default()
494            .as_secs();
495    }
496
497    fn enforce_limits(&mut self) {
498        // Enforce max messages
499        while self.messages.len() > self.config.max_messages {
500            // Keep system prompt at index 0 if present
501            let remove_idx = if self.messages.first().map(|m| m.role) == Some(ChatRole::System) {
502                1
503            } else {
504                0
505            };
506            if remove_idx < self.messages.len() {
507                self.messages.remove(remove_idx);
508            }
509        }
510    }
511}
512
513/// Trait for session storage
514#[async_trait]
515pub trait SessionStore: Send + Sync {
516    /// Save a session
517    async fn save(&self, session: &ConversationSession) -> Result<(), SessionError>;
518
519    /// Load a session
520    async fn load(&self, session_id: &str) -> Result<Option<ConversationSession>, SessionError>;
521
522    /// Delete a session
523    async fn delete(&self, session_id: &str) -> Result<(), SessionError>;
524
525    /// List all session IDs
526    async fn list(&self) -> Result<Vec<ConversationId>, SessionError>;
527
528    /// List sessions for a user
529    async fn list_for_user(&self, user_id: &str) -> Result<Vec<ConversationId>, SessionError>;
530}
531
532/// Error type for session operations
533#[derive(Debug, thiserror::Error)]
534pub enum SessionError {
535    #[error("Session not found: {0}")]
536    NotFound(ConversationId),
537
538    #[error("Session expired: {0}")]
539    Expired(ConversationId),
540
541    #[error("Storage error: {0}")]
542    StorageError(String),
543
544    #[error("Serialization error: {0}")]
545    SerializationError(String),
546}
547
548/// In-memory session store
549#[derive(Default)]
550pub struct MemorySessionStore {
551    sessions: RwLock<HashMap<ConversationId, ConversationSession>>,
552}
553
554impl MemorySessionStore {
555    pub fn new() -> Self {
556        Self::default()
557    }
558}
559
560#[async_trait]
561impl SessionStore for MemorySessionStore {
562    async fn save(&self, session: &ConversationSession) -> Result<(), SessionError> {
563        self.sessions
564            .write()
565            .insert(session.id.clone(), session.clone());
566        Ok(())
567    }
568
569    async fn load(&self, session_id: &str) -> Result<Option<ConversationSession>, SessionError> {
570        Ok(self.sessions.read().get(session_id).cloned())
571    }
572
573    async fn delete(&self, session_id: &str) -> Result<(), SessionError> {
574        self.sessions.write().remove(session_id);
575        Ok(())
576    }
577
578    async fn list(&self) -> Result<Vec<ConversationId>, SessionError> {
579        Ok(self.sessions.read().keys().cloned().collect())
580    }
581
582    async fn list_for_user(&self, user_id: &str) -> Result<Vec<ConversationId>, SessionError> {
583        Ok(self
584            .sessions
585            .read()
586            .values()
587            .filter(|s| s.user_id.as_deref() == Some(user_id))
588            .map(|s| s.id.clone())
589            .collect())
590    }
591}
592
593/// Session manager for handling multiple sessions
594pub struct ConversationManager<S: SessionStore> {
595    store: Arc<S>,
596    config: SessionConfig,
597}
598
599impl ConversationManager<MemorySessionStore> {
600    /// Create with in-memory store
601    pub fn in_memory() -> Self {
602        Self::new(Arc::new(MemorySessionStore::new()))
603    }
604}
605
606impl<S: SessionStore> ConversationManager<S> {
607    pub fn new(store: Arc<S>) -> Self {
608        Self {
609            store,
610            config: SessionConfig::default(),
611        }
612    }
613
614    pub fn with_config(mut self, config: SessionConfig) -> Self {
615        self.config = config;
616        self
617    }
618
619    /// Create a new session
620    pub async fn create(
621        &self,
622        session_id: impl Into<String>,
623    ) -> Result<ConversationSession, SessionError> {
624        let session = ConversationSession::new(session_id).with_config(self.config.clone());
625        self.store.save(&session).await?;
626        info!(session_id = %session.id, "Session created");
627        Ok(session)
628    }
629
630    /// Get or create a session
631    pub async fn get_or_create(
632        &self,
633        session_id: impl Into<String>,
634    ) -> Result<ConversationSession, SessionError> {
635        let session_id = session_id.into();
636        match self.store.load(&session_id).await? {
637            Some(session) => {
638                if session.is_expired() {
639                    debug!(session_id = %session_id, "Session expired, creating new");
640                    self.store.delete(&session_id).await?;
641                    self.create(session_id).await
642                } else {
643                    Ok(session)
644                }
645            }
646            None => self.create(session_id).await,
647        }
648    }
649
650    /// Update a session
651    pub async fn update(&self, session: &ConversationSession) -> Result<(), SessionError> {
652        self.store.save(session).await
653    }
654
655    /// Delete a session
656    pub async fn delete(&self, session_id: &str) -> Result<(), SessionError> {
657        self.store.delete(session_id).await?;
658        info!(session_id, "Session deleted");
659        Ok(())
660    }
661
662    /// List all sessions
663    pub async fn list(&self) -> Result<Vec<ConversationId>, SessionError> {
664        self.store.list().await
665    }
666
667    /// List sessions for a user
668    pub async fn list_for_user(&self, user_id: &str) -> Result<Vec<ConversationId>, SessionError> {
669        self.store.list_for_user(user_id).await
670    }
671
672    /// Cleanup expired sessions
673    pub async fn cleanup_expired(&self) -> Result<usize, SessionError> {
674        let mut cleaned = 0;
675        let session_ids = self.store.list().await?;
676
677        for id in session_ids {
678            if let Some(session) = self.store.load(&id).await? {
679                if session.is_expired() {
680                    self.store.delete(&id).await?;
681                    cleaned += 1;
682                }
683            }
684        }
685
686        if cleaned > 0 {
687            info!(count = cleaned, "Cleaned up expired sessions");
688        }
689
690        Ok(cleaned)
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697
698    #[test]
699    fn test_message_creation() {
700        let msg = ChatMessage::user("Hello!");
701        assert_eq!(msg.role, ChatRole::User);
702        assert_eq!(msg.content, "Hello!");
703        assert!(msg.timestamp > 0);
704    }
705
706    #[test]
707    fn test_message_builders() {
708        let msg = ChatMessage::assistant("Hi there!")
709            .with_name("Assistant")
710            .with_metadata("source", "test");
711
712        assert_eq!(msg.role, ChatRole::Assistant);
713        assert_eq!(msg.name, Some("Assistant".to_string()));
714        assert_eq!(msg.metadata.get("source").unwrap(), "test");
715    }
716
717    #[test]
718    fn test_session_basic() {
719        let mut session =
720            ConversationSession::new("test_session").with_system_prompt("You are helpful");
721
722        session.add_message(ChatRole::User, "Hello!");
723        session.add_message(ChatRole::Assistant, "Hi there!");
724
725        assert_eq!(session.messages.len(), 3); // system + user + assistant
726        assert_eq!(session.current_turn(), 1);
727    }
728
729    #[test]
730    fn test_session_turns() {
731        let mut session = ConversationSession::new("test");
732
733        session.add_message(ChatRole::User, "First question");
734        session.add_message(ChatRole::Assistant, "First answer");
735        session.add_message(ChatRole::User, "Second question");
736        session.add_message(ChatRole::Assistant, "Second answer");
737
738        assert_eq!(session.current_turn(), 2);
739
740        let turn1 = session.get_turn(1).unwrap();
741        assert_eq!(turn1.user_message.content, "First question");
742        assert!(turn1.is_complete());
743
744        let turn2 = session.get_turn(2).unwrap();
745        assert_eq!(turn2.user_message.content, "Second question");
746    }
747
748    #[test]
749    fn test_get_context_with_limit() {
750        let mut session = ConversationSession::new("test").with_system_prompt("System");
751
752        // Add many messages
753        for i in 0..20 {
754            session.add_message(ChatRole::User, format!("Message {}", i));
755            session.add_message(ChatRole::Assistant, format!("Response {}", i));
756        }
757
758        // Get with token limit
759        let context = session.get_context(Some(100));
760
761        // Should include system prompt and recent messages within limit
762        assert!(!context.is_empty());
763        assert_eq!(context[0].role, ChatRole::System);
764    }
765
766    #[test]
767    fn test_session_state() {
768        let mut session = ConversationSession::new("test");
769
770        session.state.set("user_name", "Alice");
771        session.state.set("preference", "dark_mode");
772
773        assert_eq!(session.state.get("user_name").unwrap(), "Alice");
774        assert_eq!(session.state.get("preference").unwrap(), "dark_mode");
775
776        // Test structured data
777        #[derive(Serialize, Deserialize, PartialEq, Debug)]
778        struct UserPrefs {
779            theme: String,
780            language: String,
781        }
782
783        let prefs = UserPrefs {
784            theme: "dark".to_string(),
785            language: "en".to_string(),
786        };
787
788        session.state.set_data("prefs", &prefs).unwrap();
789        let loaded: UserPrefs = session.state.get_data("prefs").unwrap().unwrap();
790        assert_eq!(loaded, prefs);
791    }
792
793    #[test]
794    fn test_message_limit() {
795        let config = SessionConfig {
796            max_messages: 5,
797            ..Default::default()
798        };
799
800        let mut session = ConversationSession::new("test").with_config(config);
801
802        for i in 0..10 {
803            session.add_message(ChatRole::User, format!("Message {}", i));
804        }
805
806        assert_eq!(session.messages.len(), 5);
807    }
808
809    #[test]
810    fn test_clear_messages() {
811        let mut session = ConversationSession::new("test").with_system_prompt("System prompt");
812
813        session.add_message(ChatRole::User, "Hello");
814        session.add_message(ChatRole::Assistant, "Hi");
815
816        assert_eq!(session.messages.len(), 3);
817
818        session.clear_messages();
819
820        // Should keep system prompt
821        assert_eq!(session.messages.len(), 1);
822        assert_eq!(session.messages[0].role, ChatRole::System);
823        assert_eq!(session.turns.len(), 0);
824    }
825
826    #[test]
827    fn test_get_by_role() {
828        let mut session = ConversationSession::new("test");
829
830        session.add_message(ChatRole::User, "Q1");
831        session.add_message(ChatRole::Assistant, "A1");
832        session.add_message(ChatRole::User, "Q2");
833        session.add_message(ChatRole::Assistant, "A2");
834
835        let user_messages = session.get_by_role(ChatRole::User);
836        assert_eq!(user_messages.len(), 2);
837
838        let assistant_messages = session.get_by_role(ChatRole::Assistant);
839        assert_eq!(assistant_messages.len(), 2);
840    }
841
842    #[tokio::test]
843    async fn test_session_manager() {
844        let manager = ConversationManager::in_memory();
845
846        let session = manager.create("session1").await.unwrap();
847        assert_eq!(session.id, "session1");
848
849        let loaded = manager.get_or_create("session1").await.unwrap();
850        assert_eq!(loaded.id, "session1");
851
852        let sessions = manager.list().await.unwrap();
853        assert_eq!(sessions.len(), 1);
854
855        manager.delete("session1").await.unwrap();
856        let sessions = manager.list().await.unwrap();
857        assert!(sessions.is_empty());
858    }
859
860    #[tokio::test]
861    async fn test_session_store() {
862        let store = MemorySessionStore::new();
863
864        let mut session = ConversationSession::new("test").with_user_id("user1");
865        session.add_message(ChatRole::User, "Hello");
866
867        store.save(&session).await.unwrap();
868
869        let loaded = store.load("test").await.unwrap().unwrap();
870        assert_eq!(loaded.messages.len(), 1);
871
872        let user_sessions = store.list_for_user("user1").await.unwrap();
873        assert_eq!(user_sessions.len(), 1);
874    }
875
876    #[test]
877    fn test_tool_message() {
878        let mut session = ConversationSession::new("test");
879
880        session.add_message(ChatRole::User, "Calculate 2+2");
881        session.add_message_obj(ChatMessage::tool("4", "call_123"));
882        session.add_message(ChatRole::Assistant, "The result is 4");
883
884        let turn = session.get_turn(1).unwrap();
885        assert_eq!(turn.tool_calls.len(), 1);
886        assert_eq!(
887            turn.tool_calls[0].tool_call_id,
888            Some("call_123".to_string())
889        );
890    }
891
892    #[test]
893    fn test_token_estimation() {
894        let msg = ChatMessage::user("Hello world"); // 11 chars
895        assert_eq!(msg.estimated_tokens(), 3); // 11/4 + 1 = 3
896    }
897
898    #[test]
899    fn test_total_tokens() {
900        let mut session = ConversationSession::new("test");
901        session.add_message(ChatRole::User, "Hello world"); // ~3 tokens
902        session.add_message(ChatRole::Assistant, "Hi there"); // ~3 tokens
903
904        assert!(session.total_tokens() >= 4);
905    }
906}