avocado_core/
session.rs

1//! High-level session management API
2//!
3//! This module provides the SessionManager - a high-level API for conversation
4//! management that integrates with the compiler and database layers.
5//!
6//! # Features
7//!
8//! - Create and manage conversation sessions
9//! - Add user and assistant messages
10//! - Automatic context compilation for user queries
11//! - Format conversation history for LLM consumption
12//! - Debug and replay conversations
13
14use crate::compiler;
15use crate::db::Database;
16use crate::index::VectorIndex;
17use crate::types::{
18    CompilerConfig, Message, MessageRole, Result, Session, WorkingSet,
19};
20use serde::{Deserialize, Serialize};
21
22/// High-level session management
23pub struct SessionManager {
24    db: Database,
25}
26
27impl SessionManager {
28    /// Create a new SessionManager
29    ///
30    /// # Arguments
31    ///
32    /// * `db` - Database instance
33    ///
34    /// # Returns
35    ///
36    /// A new SessionManager instance
37    pub fn new(db: Database) -> Self {
38        Self { db }
39    }
40
41    /// Start a new session
42    ///
43    /// # Arguments
44    ///
45    /// * `user_id` - Optional user identifier
46    ///
47    /// # Returns
48    ///
49    /// The newly created session
50    pub fn start_session(&self, user_id: Option<&str>) -> Result<Session> {
51        self.db.create_session(user_id, None)
52    }
53
54    /// Add a user message and compile context
55    ///
56    /// This method:
57    /// 1. Adds the user message to the database
58    /// 2. Calls the compiler to generate a WorkingSet from the query
59    /// 3. Associates the WorkingSet with the session
60    /// 4. Returns both the Message and WorkingSet
61    ///
62    /// # Arguments
63    ///
64    /// * `session_id` - The session ID
65    /// * `query` - The user's query
66    /// * `config` - Compiler configuration
67    /// * `index` - Vector index for search
68    /// * `api_key` - Optional OpenAI API key (for embeddings)
69    ///
70    /// # Returns
71    ///
72    /// Tuple of (Message, WorkingSet)
73    pub async fn add_user_message(
74        &self,
75        session_id: &str,
76        query: &str,
77        config: CompilerConfig,
78        index: &VectorIndex,
79        api_key: Option<&str>,
80    ) -> Result<(Message, WorkingSet)> {
81        // Add the message to the database
82        let message = self
83            .db
84            .add_message(session_id, MessageRole::User, query, None)?;
85
86        // Compile the context
87        let working_set = compiler::compile(query, config.clone(), &self.db, index, api_key).await?;
88
89        // Associate the working set with the session
90        self.db.associate_working_set(
91            session_id,
92            Some(&message.id),
93            &working_set,
94            query,
95            &config,
96        )?;
97
98        Ok((message, working_set))
99    }
100
101    /// Add an assistant response
102    ///
103    /// # Arguments
104    ///
105    /// * `session_id` - The session ID
106    /// * `content` - The assistant's response
107    /// * `metadata` - Optional metadata (e.g., model info, citations)
108    ///
109    /// # Returns
110    ///
111    /// The newly created message
112    pub fn add_assistant_message(
113        &self,
114        session_id: &str,
115        content: &str,
116        metadata: Option<&serde_json::Value>,
117    ) -> Result<Message> {
118        self.db
119            .add_message(session_id, MessageRole::Assistant, content, metadata)
120    }
121
122    /// Get conversation history formatted for LLM consumption
123    ///
124    /// Formats messages as:
125    /// ```text
126    /// User: <message>
127    ///
128    /// Assistant: <message>
129    ///
130    /// User: <message>
131    /// ...
132    /// ```
133    ///
134    /// If `max_tokens` is specified, older messages are truncated to stay within
135    /// the token budget. Most recent messages are always kept (they're most relevant).
136    ///
137    /// # Arguments
138    ///
139    /// * `session_id` - The session ID
140    /// * `max_tokens` - Optional token limit
141    ///
142    /// # Returns
143    ///
144    /// Formatted conversation history as a string
145    pub fn get_conversation_history(
146        &self,
147        session_id: &str,
148        max_tokens: Option<usize>,
149    ) -> Result<String> {
150        let messages = self.db.get_messages(session_id, None)?;
151
152        if messages.is_empty() {
153            return Ok(String::new());
154        }
155
156        // Format all messages first
157        let formatted_messages: Vec<String> = messages
158            .iter()
159            .map(|msg| {
160                let role = match msg.role {
161                    MessageRole::User => "User",
162                    MessageRole::Assistant => "Assistant",
163                    MessageRole::System => "System",
164                    MessageRole::Tool => "Tool",
165                };
166                format!("{}: {}", role, msg.content)
167            })
168            .collect();
169
170        // If no token limit, return all messages
171        if max_tokens.is_none() {
172            return Ok(formatted_messages.join("\n\n"));
173        }
174
175        let max_tokens = max_tokens.unwrap();
176
177        // Apply token limiting - keep most recent messages
178        // Token counting: simple approximation (chars / 4)
179        let mut selected_messages = Vec::new();
180        let mut total_tokens = 0;
181
182        // Iterate from most recent to oldest
183        for msg in formatted_messages.iter().rev() {
184            let msg_tokens = estimate_tokens(msg);
185
186            if total_tokens + msg_tokens <= max_tokens {
187                selected_messages.push(msg.clone());
188                total_tokens += msg_tokens;
189            } else {
190                // Can't fit any more messages
191                break;
192            }
193        }
194
195        // Reverse to restore chronological order
196        selected_messages.reverse();
197
198        Ok(selected_messages.join("\n\n"))
199    }
200
201    /// Replay a session for debugging
202    ///
203    /// Groups messages into conversation turns (user + assistant pairs)
204    /// and includes associated working sets for analysis.
205    ///
206    /// # Arguments
207    ///
208    /// * `session_id` - The session ID
209    ///
210    /// # Returns
211    ///
212    /// SessionReplay with structured debug data
213    pub fn replay_session(&self, session_id: &str) -> Result<SessionReplay> {
214        let session_data = self.db.get_session_full(session_id)?;
215
216        if session_data.is_none() {
217            return Err(crate::types::Error::NotFound(format!(
218                "Session not found: {}",
219                session_id
220            )));
221        }
222
223        let session_data = session_data.unwrap();
224        let session = session_data.session;
225        let messages = session_data.messages;
226        let working_sets = session_data.working_sets;
227
228        // Build a map of message_id -> working_set for quick lookup
229        let mut working_set_map = std::collections::HashMap::new();
230        for ws in working_sets {
231            if let Some(msg_id) = &ws.message_id {
232                working_set_map.insert(msg_id.clone(), ws.working_set);
233            }
234        }
235
236        // Group messages into turns
237        let mut turns = Vec::new();
238        let mut i = 0;
239
240        while i < messages.len() {
241            let msg = &messages[i];
242
243            // Only create turns for user messages
244            if matches!(msg.role, MessageRole::User) {
245                let user_message = msg.clone();
246                let working_set = working_set_map.get(&user_message.id).cloned();
247
248                // Look for the next assistant message (if any)
249                let assistant_message = if i + 1 < messages.len()
250                    && matches!(messages[i + 1].role, MessageRole::Assistant)
251                {
252                    i += 1; // Skip the assistant message in the next iteration
253                    Some(messages[i].clone())
254                } else {
255                    None
256                };
257
258                turns.push(SessionTurn {
259                    user_message,
260                    working_set,
261                    assistant_message,
262                });
263            }
264
265            i += 1;
266        }
267
268        Ok(SessionReplay { session, turns })
269    }
270}
271
272/// Replay data for debugging
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct SessionReplay {
275    /// The session
276    pub session: Session,
277    /// Conversation turns (user + assistant pairs)
278    pub turns: Vec<SessionTurn>,
279}
280
281/// A conversation turn (user query + optional assistant response)
282#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct SessionTurn {
284    /// User message
285    pub user_message: Message,
286    /// Working set compiled for this user message (if any)
287    pub working_set: Option<WorkingSet>,
288    /// Assistant response (if any)
289    pub assistant_message: Option<Message>,
290}
291
292/// Estimate token count using simple approximation
293///
294/// Simple heuristic: chars / 4 (roughly matches GPT tokenization)
295///
296/// For production, consider using tiktoken-rs for accurate counting.
297fn estimate_tokens(text: &str) -> usize {
298    (text.len() + 3) / 4
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use crate::types::Artifact;
305    use crate::types::Span;
306    use uuid::Uuid;
307
308    #[test]
309    fn test_session_manager_new() {
310        let db = Database::new(":memory:").unwrap();
311        let _manager = SessionManager::new(db);
312    }
313
314    #[test]
315    fn test_start_session() {
316        let db = Database::new(":memory:").unwrap();
317        let manager = SessionManager::new(db);
318
319        let session = manager.start_session(Some("test_user")).unwrap();
320
321        assert!(!session.id.is_empty());
322        assert_eq!(session.user_id, Some("test_user".to_string()));
323    }
324
325    #[tokio::test]
326    async fn test_add_user_message() {
327        let db = Database::new(":memory:").unwrap();
328        let manager = SessionManager::new(db.clone());
329
330        // Create a session
331        let session = manager.start_session(Some("user1")).unwrap();
332
333        // Add some test data
334        let artifact = Artifact {
335            id: Uuid::new_v4().to_string(),
336            path: "test.txt".to_string(),
337            content: "This is a test document about Rust programming.".to_string(),
338            content_hash: "hash123".to_string(),
339            metadata: None,
340            created_at: chrono::Utc::now(),
341        };
342
343        db.insert_artifact(&artifact).unwrap();
344
345        let span = Span {
346            id: Uuid::new_v4().to_string(),
347            artifact_id: artifact.id.clone(),
348            start_line: 1,
349            end_line: 1,
350            text: "This is a test document about Rust programming.".to_string(),
351            embedding: Some(vec![0.1; 384]), // Fake embedding
352            embedding_model: Some("test".to_string()),
353            token_count: 10,
354            metadata: None,
355        };
356
357        db.insert_spans(&[span]).unwrap();
358
359        // Build index
360        let index = db.get_vector_index().unwrap();
361
362        // Add user message
363        let config = CompilerConfig::default();
364        let (message, working_set) = manager
365            .add_user_message(&session.id, "What is Rust?", config, &index, None)
366            .await
367            .unwrap();
368
369        assert_eq!(message.content, "What is Rust?");
370        assert_eq!(message.role.as_str(), "user");
371        assert!(!working_set.text.is_empty());
372    }
373
374    #[test]
375    fn test_add_assistant_message() {
376        let db = Database::new(":memory:").unwrap();
377        let manager = SessionManager::new(db);
378
379        let session = manager.start_session(Some("user1")).unwrap();
380
381        let message = manager
382            .add_assistant_message(&session.id, "Rust is a systems programming language.", None)
383            .unwrap();
384
385        assert_eq!(message.content, "Rust is a systems programming language.");
386        assert_eq!(message.role.as_str(), "assistant");
387    }
388
389    #[test]
390    fn test_get_conversation_history() {
391        let db = Database::new(":memory:").unwrap();
392        let manager = SessionManager::new(db.clone());
393
394        let session = manager.start_session(Some("user1")).unwrap();
395
396        // Add messages
397        db.add_message(&session.id, MessageRole::User, "Hello", None)
398            .unwrap();
399        db.add_message(&session.id, MessageRole::Assistant, "Hi there!", None)
400            .unwrap();
401        db.add_message(&session.id, MessageRole::User, "How are you?", None)
402            .unwrap();
403
404        let history = manager
405            .get_conversation_history(&session.id, None)
406            .unwrap();
407
408        assert!(history.contains("User: Hello"));
409        assert!(history.contains("Assistant: Hi there!"));
410        assert!(history.contains("User: How are you?"));
411
412        // Verify formatting
413        let lines: Vec<&str> = history.split("\n\n").collect();
414        assert_eq!(lines.len(), 3);
415        assert_eq!(lines[0], "User: Hello");
416        assert_eq!(lines[1], "Assistant: Hi there!");
417        assert_eq!(lines[2], "User: How are you?");
418    }
419
420    #[test]
421    fn test_get_conversation_history_with_token_limit() {
422        let db = Database::new(":memory:").unwrap();
423        let manager = SessionManager::new(db.clone());
424
425        let session = manager.start_session(Some("user1")).unwrap();
426
427        // Add messages
428        db.add_message(&session.id, MessageRole::User, "Message 1", None)
429            .unwrap();
430        db.add_message(&session.id, MessageRole::Assistant, "Response 1", None)
431            .unwrap();
432        db.add_message(&session.id, MessageRole::User, "Message 2", None)
433            .unwrap();
434        db.add_message(&session.id, MessageRole::Assistant, "Response 2", None)
435            .unwrap();
436
437        // Set a tight token limit that should only allow the last 2 messages
438        // Each message is about 5-7 tokens, so limit to 20 tokens
439        let history = manager
440            .get_conversation_history(&session.id, Some(20))
441            .unwrap();
442
443        // Should only contain the most recent messages
444        assert!(history.contains("Message 2"));
445        assert!(history.contains("Response 2"));
446
447        // Should NOT contain older messages (if limit is tight enough)
448        // Note: This is approximate due to simple token counting
449        let message_count = history.split("\n\n").count();
450        assert!(message_count <= 4); // All 4 messages fit in 20 tokens with our simple counting
451    }
452
453    #[test]
454    fn test_get_conversation_history_empty() {
455        let db = Database::new(":memory:").unwrap();
456        let manager = SessionManager::new(db);
457
458        let session = manager.start_session(Some("user1")).unwrap();
459
460        let history = manager
461            .get_conversation_history(&session.id, None)
462            .unwrap();
463
464        assert_eq!(history, "");
465    }
466
467    #[tokio::test]
468    async fn test_replay_session() {
469        let db = Database::new(":memory:").unwrap();
470        let manager = SessionManager::new(db.clone());
471
472        // Create session
473        let session = manager.start_session(Some("user1")).unwrap();
474
475        // Add test data for compilation
476        let artifact = Artifact {
477            id: Uuid::new_v4().to_string(),
478            path: "test.txt".to_string(),
479            content: "Test content for replay.".to_string(),
480            content_hash: "hash123".to_string(),
481            metadata: None,
482            created_at: chrono::Utc::now(),
483        };
484
485        db.insert_artifact(&artifact).unwrap();
486
487        let span = Span {
488            id: Uuid::new_v4().to_string(),
489            artifact_id: artifact.id.clone(),
490            start_line: 1,
491            end_line: 1,
492            text: "Test content for replay.".to_string(),
493            embedding: Some(vec![0.1; 384]),
494            embedding_model: Some("test".to_string()),
495            token_count: 5,
496            metadata: None,
497        };
498
499        db.insert_spans(&[span]).unwrap();
500
501        let index = db.get_vector_index().unwrap();
502
503        // Add conversation
504        let config = CompilerConfig::default();
505        manager
506            .add_user_message(&session.id, "First query", config.clone(), &index, None)
507            .await
508            .unwrap();
509        manager
510            .add_assistant_message(&session.id, "First response", None)
511            .unwrap();
512        manager
513            .add_user_message(&session.id, "Second query", config, &index, None)
514            .await
515            .unwrap();
516        manager
517            .add_assistant_message(&session.id, "Second response", None)
518            .unwrap();
519
520        // Replay session
521        let replay = manager.replay_session(&session.id).unwrap();
522
523        assert_eq!(replay.session.id, session.id);
524        assert_eq!(replay.turns.len(), 2);
525
526        // Verify first turn
527        let turn1 = &replay.turns[0];
528        assert_eq!(turn1.user_message.content, "First query");
529        assert!(turn1.working_set.is_some());
530        assert!(turn1.assistant_message.is_some());
531        assert_eq!(
532            turn1.assistant_message.as_ref().unwrap().content,
533            "First response"
534        );
535
536        // Verify second turn
537        let turn2 = &replay.turns[1];
538        assert_eq!(turn2.user_message.content, "Second query");
539        assert!(turn2.working_set.is_some());
540        assert!(turn2.assistant_message.is_some());
541        assert_eq!(
542            turn2.assistant_message.as_ref().unwrap().content,
543            "Second response"
544        );
545    }
546
547    #[test]
548    fn test_replay_session_not_found() {
549        let db = Database::new(":memory:").unwrap();
550        let manager = SessionManager::new(db);
551
552        let result = manager.replay_session("nonexistent-id");
553        assert!(result.is_err());
554    }
555
556    #[tokio::test]
557    async fn test_replay_session_incomplete_turns() {
558        let db = Database::new(":memory:").unwrap();
559        let manager = SessionManager::new(db.clone());
560
561        let session = manager.start_session(Some("user1")).unwrap();
562
563        // Add test data
564        let artifact = Artifact {
565            id: Uuid::new_v4().to_string(),
566            path: "test.txt".to_string(),
567            content: "Test content.".to_string(),
568            content_hash: "hash123".to_string(),
569            metadata: None,
570            created_at: chrono::Utc::now(),
571        };
572
573        db.insert_artifact(&artifact).unwrap();
574
575        let span = Span {
576            id: Uuid::new_v4().to_string(),
577            artifact_id: artifact.id.clone(),
578            start_line: 1,
579            end_line: 1,
580            text: "Test content.".to_string(),
581            embedding: Some(vec![0.1; 384]),
582            embedding_model: Some("test".to_string()),
583            token_count: 3,
584            metadata: None,
585        };
586
587        db.insert_spans(&[span]).unwrap();
588
589        let index = db.get_vector_index().unwrap();
590
591        // Add user message without assistant response
592        let config = CompilerConfig::default();
593        manager
594            .add_user_message(&session.id, "Query without response", config, &index, None)
595            .await
596            .unwrap();
597
598        // Replay should still work
599        let replay = manager.replay_session(&session.id).unwrap();
600
601        assert_eq!(replay.turns.len(), 1);
602        let turn = &replay.turns[0];
603        assert_eq!(turn.user_message.content, "Query without response");
604        assert!(turn.working_set.is_some());
605        assert!(turn.assistant_message.is_none());
606    }
607
608    #[test]
609    fn test_estimate_tokens() {
610        // Test simple token estimation
611        let text = "Hello world";
612        let tokens = estimate_tokens(text);
613        // "Hello world" = 11 chars, so (11 + 3) / 4 = 3 tokens
614        assert_eq!(tokens, 3);
615
616        let longer_text = "This is a longer piece of text for testing token estimation.";
617        let tokens = estimate_tokens(longer_text);
618        // Should be roughly chars/4
619        assert!(tokens > 10);
620        assert!(tokens < 20);
621    }
622
623    /// Integration test demonstrating the full SessionManager workflow
624    #[tokio::test]
625    async fn test_full_session_workflow() {
626        // Setup database and manager
627        let db = Database::new(":memory:").unwrap();
628        let manager = SessionManager::new(db.clone());
629
630        // Ingest some test documents
631        let docs = vec![
632            ("rust_basics.md", "Rust is a systems programming language that runs blazingly fast, prevents segfaults, and guarantees thread safety."),
633            ("rust_ownership.md", "Ownership is Rust's most unique feature. It enables Rust to make memory safety guarantees without needing a garbage collector."),
634            ("rust_concurrency.md", "Rust's type system and ownership model guarantee thread safety. You can't have data races in safe Rust code."),
635        ];
636
637        for (path, content) in &docs {
638            let artifact = Artifact {
639                id: Uuid::new_v4().to_string(),
640                path: path.to_string(),
641                content: content.to_string(),
642                content_hash: format!("hash_{}", path),
643                metadata: None,
644                created_at: chrono::Utc::now(),
645            };
646
647            db.insert_artifact(&artifact).unwrap();
648
649            // Create span for the document
650            let span = Span {
651                id: Uuid::new_v4().to_string(),
652                artifact_id: artifact.id.clone(),
653                start_line: 1,
654                end_line: 1,
655                text: content.to_string(),
656                embedding: Some(vec![0.1; 384]), // Fake embedding
657                embedding_model: Some("test".to_string()),
658                token_count: content.split_whitespace().count(),
659                metadata: None,
660            };
661
662            db.insert_spans(&[span]).unwrap();
663        }
664
665        // Build index
666        let index = db.get_vector_index().unwrap();
667
668        // Start a new session
669        let session = manager.start_session(Some("alice")).unwrap();
670        assert_eq!(session.user_id, Some("alice".to_string()));
671
672        // First turn: User asks about Rust
673        let config = CompilerConfig::default();
674        let (msg1, ws1) = manager
675            .add_user_message(&session.id, "What is Rust?", config.clone(), &index, None)
676            .await
677            .unwrap();
678
679        assert_eq!(msg1.content, "What is Rust?");
680        assert!(!ws1.text.is_empty());
681        assert!(!ws1.citations.is_empty());
682
683        // Assistant responds
684        let resp1 = manager
685            .add_assistant_message(
686                &session.id,
687                "Rust is a systems programming language known for memory safety.",
688                None,
689            )
690            .unwrap();
691
692        assert!(resp1.content.contains("memory safety"));
693
694        // Second turn: User asks follow-up
695        let (msg2, ws2) = manager
696            .add_user_message(
697                &session.id,
698                "Tell me about ownership",
699                config.clone(),
700                &index,
701                None,
702            )
703            .await
704            .unwrap();
705
706        assert_eq!(msg2.content, "Tell me about ownership");
707        assert!(!ws2.text.is_empty());
708
709        // Assistant responds
710        let resp2 = manager
711            .add_assistant_message(
712                &session.id,
713                "Ownership is Rust's unique feature for memory management.",
714                None,
715            )
716            .unwrap();
717
718        assert!(resp2.content.contains("Ownership"));
719
720        // Get conversation history
721        let history = manager
722            .get_conversation_history(&session.id, None)
723            .unwrap();
724
725        // Verify all messages are in history
726        assert!(history.contains("What is Rust?"));
727        assert!(history.contains("memory safety"));
728        assert!(history.contains("Tell me about ownership"));
729        assert!(history.contains("Ownership is Rust's unique feature"));
730
731        // Test token limiting - limit to about 2 messages worth
732        let limited_history = manager
733            .get_conversation_history(&session.id, Some(100))
734            .unwrap();
735
736        // Should work without errors and contain at least some messages
737        assert!(!limited_history.is_empty());
738        // Most recent messages should be present
739        assert!(limited_history.contains("Ownership"));
740
741        // Replay the session
742        let replay = manager.replay_session(&session.id).unwrap();
743
744        assert_eq!(replay.session.id, session.id);
745        assert_eq!(replay.turns.len(), 2);
746
747        // Verify first turn
748        let turn1 = &replay.turns[0];
749        assert_eq!(turn1.user_message.content, "What is Rust?");
750        assert!(turn1.working_set.is_some());
751        assert!(turn1.assistant_message.is_some());
752
753        // Note: Working sets from replay are placeholders in Phase 1
754        // Phase 2 database doesn't store full WorkingSet data yet
755        // This is expected and documented in db.rs
756
757        // Verify second turn
758        let turn2 = &replay.turns[1];
759        assert_eq!(turn2.user_message.content, "Tell me about ownership");
760        assert!(turn2.working_set.is_some());
761        assert!(turn2.assistant_message.is_some());
762
763        // Verify messages are in order
764        assert!(turn1.user_message.sequence_number < turn2.user_message.sequence_number);
765    }
766}