1use crate::compiler;
15use crate::db::Database;
16use crate::index::VectorIndex;
17use crate::types::{
18 CompilerConfig, Message, MessageRole, Result, Session, WorkingSet,
19};
20use serde::{Deserialize, Serialize};
21
22pub struct SessionManager {
24 db: Database,
25}
26
27impl SessionManager {
28 pub fn new(db: Database) -> Self {
38 Self { db }
39 }
40
41 pub fn start_session(&self, user_id: Option<&str>) -> Result<Session> {
51 self.db.create_session(user_id, None)
52 }
53
54 pub async fn add_user_message(
74 &self,
75 session_id: &str,
76 query: &str,
77 config: CompilerConfig,
78 index: &VectorIndex,
79 api_key: Option<&str>,
80 ) -> Result<(Message, WorkingSet)> {
81 let message = self
83 .db
84 .add_message(session_id, MessageRole::User, query, None)?;
85
86 let working_set = compiler::compile(query, config.clone(), &self.db, index, api_key).await?;
88
89 self.db.associate_working_set(
91 session_id,
92 Some(&message.id),
93 &working_set,
94 query,
95 &config,
96 )?;
97
98 Ok((message, working_set))
99 }
100
101 pub fn add_assistant_message(
113 &self,
114 session_id: &str,
115 content: &str,
116 metadata: Option<&serde_json::Value>,
117 ) -> Result<Message> {
118 self.db
119 .add_message(session_id, MessageRole::Assistant, content, metadata)
120 }
121
122 pub fn get_conversation_history(
146 &self,
147 session_id: &str,
148 max_tokens: Option<usize>,
149 ) -> Result<String> {
150 let messages = self.db.get_messages(session_id, None)?;
151
152 if messages.is_empty() {
153 return Ok(String::new());
154 }
155
156 let formatted_messages: Vec<String> = messages
158 .iter()
159 .map(|msg| {
160 let role = match msg.role {
161 MessageRole::User => "User",
162 MessageRole::Assistant => "Assistant",
163 MessageRole::System => "System",
164 MessageRole::Tool => "Tool",
165 };
166 format!("{}: {}", role, msg.content)
167 })
168 .collect();
169
170 if max_tokens.is_none() {
172 return Ok(formatted_messages.join("\n\n"));
173 }
174
175 let max_tokens = max_tokens.unwrap();
176
177 let mut selected_messages = Vec::new();
180 let mut total_tokens = 0;
181
182 for msg in formatted_messages.iter().rev() {
184 let msg_tokens = estimate_tokens(msg);
185
186 if total_tokens + msg_tokens <= max_tokens {
187 selected_messages.push(msg.clone());
188 total_tokens += msg_tokens;
189 } else {
190 break;
192 }
193 }
194
195 selected_messages.reverse();
197
198 Ok(selected_messages.join("\n\n"))
199 }
200
201 pub fn replay_session(&self, session_id: &str) -> Result<SessionReplay> {
214 let session_data = self.db.get_session_full(session_id)?;
215
216 if session_data.is_none() {
217 return Err(crate::types::Error::NotFound(format!(
218 "Session not found: {}",
219 session_id
220 )));
221 }
222
223 let session_data = session_data.unwrap();
224 let session = session_data.session;
225 let messages = session_data.messages;
226 let working_sets = session_data.working_sets;
227
228 let mut working_set_map = std::collections::HashMap::new();
230 for ws in working_sets {
231 if let Some(msg_id) = &ws.message_id {
232 working_set_map.insert(msg_id.clone(), ws.working_set);
233 }
234 }
235
236 let mut turns = Vec::new();
238 let mut i = 0;
239
240 while i < messages.len() {
241 let msg = &messages[i];
242
243 if matches!(msg.role, MessageRole::User) {
245 let user_message = msg.clone();
246 let working_set = working_set_map.get(&user_message.id).cloned();
247
248 let assistant_message = if i + 1 < messages.len()
250 && matches!(messages[i + 1].role, MessageRole::Assistant)
251 {
252 i += 1; Some(messages[i].clone())
254 } else {
255 None
256 };
257
258 turns.push(SessionTurn {
259 user_message,
260 working_set,
261 assistant_message,
262 });
263 }
264
265 i += 1;
266 }
267
268 Ok(SessionReplay { session, turns })
269 }
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct SessionReplay {
275 pub session: Session,
277 pub turns: Vec<SessionTurn>,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct SessionTurn {
284 pub user_message: Message,
286 pub working_set: Option<WorkingSet>,
288 pub assistant_message: Option<Message>,
290}
291
292fn estimate_tokens(text: &str) -> usize {
298 (text.len() + 3) / 4
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::types::Artifact;
305 use crate::types::Span;
306 use uuid::Uuid;
307
308 #[test]
309 fn test_session_manager_new() {
310 let db = Database::new(":memory:").unwrap();
311 let _manager = SessionManager::new(db);
312 }
313
314 #[test]
315 fn test_start_session() {
316 let db = Database::new(":memory:").unwrap();
317 let manager = SessionManager::new(db);
318
319 let session = manager.start_session(Some("test_user")).unwrap();
320
321 assert!(!session.id.is_empty());
322 assert_eq!(session.user_id, Some("test_user".to_string()));
323 }
324
325 #[tokio::test]
326 async fn test_add_user_message() {
327 let db = Database::new(":memory:").unwrap();
328 let manager = SessionManager::new(db.clone());
329
330 let session = manager.start_session(Some("user1")).unwrap();
332
333 let artifact = Artifact {
335 id: Uuid::new_v4().to_string(),
336 path: "test.txt".to_string(),
337 content: "This is a test document about Rust programming.".to_string(),
338 content_hash: "hash123".to_string(),
339 metadata: None,
340 created_at: chrono::Utc::now(),
341 };
342
343 db.insert_artifact(&artifact).unwrap();
344
345 let span = Span {
346 id: Uuid::new_v4().to_string(),
347 artifact_id: artifact.id.clone(),
348 start_line: 1,
349 end_line: 1,
350 text: "This is a test document about Rust programming.".to_string(),
351 embedding: Some(vec![0.1; 384]), embedding_model: Some("test".to_string()),
353 token_count: 10,
354 metadata: None,
355 };
356
357 db.insert_spans(&[span]).unwrap();
358
359 let index = db.get_vector_index().unwrap();
361
362 let config = CompilerConfig::default();
364 let (message, working_set) = manager
365 .add_user_message(&session.id, "What is Rust?", config, &index, None)
366 .await
367 .unwrap();
368
369 assert_eq!(message.content, "What is Rust?");
370 assert_eq!(message.role.as_str(), "user");
371 assert!(!working_set.text.is_empty());
372 }
373
374 #[test]
375 fn test_add_assistant_message() {
376 let db = Database::new(":memory:").unwrap();
377 let manager = SessionManager::new(db);
378
379 let session = manager.start_session(Some("user1")).unwrap();
380
381 let message = manager
382 .add_assistant_message(&session.id, "Rust is a systems programming language.", None)
383 .unwrap();
384
385 assert_eq!(message.content, "Rust is a systems programming language.");
386 assert_eq!(message.role.as_str(), "assistant");
387 }
388
389 #[test]
390 fn test_get_conversation_history() {
391 let db = Database::new(":memory:").unwrap();
392 let manager = SessionManager::new(db.clone());
393
394 let session = manager.start_session(Some("user1")).unwrap();
395
396 db.add_message(&session.id, MessageRole::User, "Hello", None)
398 .unwrap();
399 db.add_message(&session.id, MessageRole::Assistant, "Hi there!", None)
400 .unwrap();
401 db.add_message(&session.id, MessageRole::User, "How are you?", None)
402 .unwrap();
403
404 let history = manager
405 .get_conversation_history(&session.id, None)
406 .unwrap();
407
408 assert!(history.contains("User: Hello"));
409 assert!(history.contains("Assistant: Hi there!"));
410 assert!(history.contains("User: How are you?"));
411
412 let lines: Vec<&str> = history.split("\n\n").collect();
414 assert_eq!(lines.len(), 3);
415 assert_eq!(lines[0], "User: Hello");
416 assert_eq!(lines[1], "Assistant: Hi there!");
417 assert_eq!(lines[2], "User: How are you?");
418 }
419
420 #[test]
421 fn test_get_conversation_history_with_token_limit() {
422 let db = Database::new(":memory:").unwrap();
423 let manager = SessionManager::new(db.clone());
424
425 let session = manager.start_session(Some("user1")).unwrap();
426
427 db.add_message(&session.id, MessageRole::User, "Message 1", None)
429 .unwrap();
430 db.add_message(&session.id, MessageRole::Assistant, "Response 1", None)
431 .unwrap();
432 db.add_message(&session.id, MessageRole::User, "Message 2", None)
433 .unwrap();
434 db.add_message(&session.id, MessageRole::Assistant, "Response 2", None)
435 .unwrap();
436
437 let history = manager
440 .get_conversation_history(&session.id, Some(20))
441 .unwrap();
442
443 assert!(history.contains("Message 2"));
445 assert!(history.contains("Response 2"));
446
447 let message_count = history.split("\n\n").count();
450 assert!(message_count <= 4); }
452
453 #[test]
454 fn test_get_conversation_history_empty() {
455 let db = Database::new(":memory:").unwrap();
456 let manager = SessionManager::new(db);
457
458 let session = manager.start_session(Some("user1")).unwrap();
459
460 let history = manager
461 .get_conversation_history(&session.id, None)
462 .unwrap();
463
464 assert_eq!(history, "");
465 }
466
467 #[tokio::test]
468 async fn test_replay_session() {
469 let db = Database::new(":memory:").unwrap();
470 let manager = SessionManager::new(db.clone());
471
472 let session = manager.start_session(Some("user1")).unwrap();
474
475 let artifact = Artifact {
477 id: Uuid::new_v4().to_string(),
478 path: "test.txt".to_string(),
479 content: "Test content for replay.".to_string(),
480 content_hash: "hash123".to_string(),
481 metadata: None,
482 created_at: chrono::Utc::now(),
483 };
484
485 db.insert_artifact(&artifact).unwrap();
486
487 let span = Span {
488 id: Uuid::new_v4().to_string(),
489 artifact_id: artifact.id.clone(),
490 start_line: 1,
491 end_line: 1,
492 text: "Test content for replay.".to_string(),
493 embedding: Some(vec![0.1; 384]),
494 embedding_model: Some("test".to_string()),
495 token_count: 5,
496 metadata: None,
497 };
498
499 db.insert_spans(&[span]).unwrap();
500
501 let index = db.get_vector_index().unwrap();
502
503 let config = CompilerConfig::default();
505 manager
506 .add_user_message(&session.id, "First query", config.clone(), &index, None)
507 .await
508 .unwrap();
509 manager
510 .add_assistant_message(&session.id, "First response", None)
511 .unwrap();
512 manager
513 .add_user_message(&session.id, "Second query", config, &index, None)
514 .await
515 .unwrap();
516 manager
517 .add_assistant_message(&session.id, "Second response", None)
518 .unwrap();
519
520 let replay = manager.replay_session(&session.id).unwrap();
522
523 assert_eq!(replay.session.id, session.id);
524 assert_eq!(replay.turns.len(), 2);
525
526 let turn1 = &replay.turns[0];
528 assert_eq!(turn1.user_message.content, "First query");
529 assert!(turn1.working_set.is_some());
530 assert!(turn1.assistant_message.is_some());
531 assert_eq!(
532 turn1.assistant_message.as_ref().unwrap().content,
533 "First response"
534 );
535
536 let turn2 = &replay.turns[1];
538 assert_eq!(turn2.user_message.content, "Second query");
539 assert!(turn2.working_set.is_some());
540 assert!(turn2.assistant_message.is_some());
541 assert_eq!(
542 turn2.assistant_message.as_ref().unwrap().content,
543 "Second response"
544 );
545 }
546
547 #[test]
548 fn test_replay_session_not_found() {
549 let db = Database::new(":memory:").unwrap();
550 let manager = SessionManager::new(db);
551
552 let result = manager.replay_session("nonexistent-id");
553 assert!(result.is_err());
554 }
555
556 #[tokio::test]
557 async fn test_replay_session_incomplete_turns() {
558 let db = Database::new(":memory:").unwrap();
559 let manager = SessionManager::new(db.clone());
560
561 let session = manager.start_session(Some("user1")).unwrap();
562
563 let artifact = Artifact {
565 id: Uuid::new_v4().to_string(),
566 path: "test.txt".to_string(),
567 content: "Test content.".to_string(),
568 content_hash: "hash123".to_string(),
569 metadata: None,
570 created_at: chrono::Utc::now(),
571 };
572
573 db.insert_artifact(&artifact).unwrap();
574
575 let span = Span {
576 id: Uuid::new_v4().to_string(),
577 artifact_id: artifact.id.clone(),
578 start_line: 1,
579 end_line: 1,
580 text: "Test content.".to_string(),
581 embedding: Some(vec![0.1; 384]),
582 embedding_model: Some("test".to_string()),
583 token_count: 3,
584 metadata: None,
585 };
586
587 db.insert_spans(&[span]).unwrap();
588
589 let index = db.get_vector_index().unwrap();
590
591 let config = CompilerConfig::default();
593 manager
594 .add_user_message(&session.id, "Query without response", config, &index, None)
595 .await
596 .unwrap();
597
598 let replay = manager.replay_session(&session.id).unwrap();
600
601 assert_eq!(replay.turns.len(), 1);
602 let turn = &replay.turns[0];
603 assert_eq!(turn.user_message.content, "Query without response");
604 assert!(turn.working_set.is_some());
605 assert!(turn.assistant_message.is_none());
606 }
607
608 #[test]
609 fn test_estimate_tokens() {
610 let text = "Hello world";
612 let tokens = estimate_tokens(text);
613 assert_eq!(tokens, 3);
615
616 let longer_text = "This is a longer piece of text for testing token estimation.";
617 let tokens = estimate_tokens(longer_text);
618 assert!(tokens > 10);
620 assert!(tokens < 20);
621 }
622
623 #[tokio::test]
625 async fn test_full_session_workflow() {
626 let db = Database::new(":memory:").unwrap();
628 let manager = SessionManager::new(db.clone());
629
630 let docs = vec![
632 ("rust_basics.md", "Rust is a systems programming language that runs blazingly fast, prevents segfaults, and guarantees thread safety."),
633 ("rust_ownership.md", "Ownership is Rust's most unique feature. It enables Rust to make memory safety guarantees without needing a garbage collector."),
634 ("rust_concurrency.md", "Rust's type system and ownership model guarantee thread safety. You can't have data races in safe Rust code."),
635 ];
636
637 for (path, content) in &docs {
638 let artifact = Artifact {
639 id: Uuid::new_v4().to_string(),
640 path: path.to_string(),
641 content: content.to_string(),
642 content_hash: format!("hash_{}", path),
643 metadata: None,
644 created_at: chrono::Utc::now(),
645 };
646
647 db.insert_artifact(&artifact).unwrap();
648
649 let span = Span {
651 id: Uuid::new_v4().to_string(),
652 artifact_id: artifact.id.clone(),
653 start_line: 1,
654 end_line: 1,
655 text: content.to_string(),
656 embedding: Some(vec![0.1; 384]), embedding_model: Some("test".to_string()),
658 token_count: content.split_whitespace().count(),
659 metadata: None,
660 };
661
662 db.insert_spans(&[span]).unwrap();
663 }
664
665 let index = db.get_vector_index().unwrap();
667
668 let session = manager.start_session(Some("alice")).unwrap();
670 assert_eq!(session.user_id, Some("alice".to_string()));
671
672 let config = CompilerConfig::default();
674 let (msg1, ws1) = manager
675 .add_user_message(&session.id, "What is Rust?", config.clone(), &index, None)
676 .await
677 .unwrap();
678
679 assert_eq!(msg1.content, "What is Rust?");
680 assert!(!ws1.text.is_empty());
681 assert!(!ws1.citations.is_empty());
682
683 let resp1 = manager
685 .add_assistant_message(
686 &session.id,
687 "Rust is a systems programming language known for memory safety.",
688 None,
689 )
690 .unwrap();
691
692 assert!(resp1.content.contains("memory safety"));
693
694 let (msg2, ws2) = manager
696 .add_user_message(
697 &session.id,
698 "Tell me about ownership",
699 config.clone(),
700 &index,
701 None,
702 )
703 .await
704 .unwrap();
705
706 assert_eq!(msg2.content, "Tell me about ownership");
707 assert!(!ws2.text.is_empty());
708
709 let resp2 = manager
711 .add_assistant_message(
712 &session.id,
713 "Ownership is Rust's unique feature for memory management.",
714 None,
715 )
716 .unwrap();
717
718 assert!(resp2.content.contains("Ownership"));
719
720 let history = manager
722 .get_conversation_history(&session.id, None)
723 .unwrap();
724
725 assert!(history.contains("What is Rust?"));
727 assert!(history.contains("memory safety"));
728 assert!(history.contains("Tell me about ownership"));
729 assert!(history.contains("Ownership is Rust's unique feature"));
730
731 let limited_history = manager
733 .get_conversation_history(&session.id, Some(100))
734 .unwrap();
735
736 assert!(!limited_history.is_empty());
738 assert!(limited_history.contains("Ownership"));
740
741 let replay = manager.replay_session(&session.id).unwrap();
743
744 assert_eq!(replay.session.id, session.id);
745 assert_eq!(replay.turns.len(), 2);
746
747 let turn1 = &replay.turns[0];
749 assert_eq!(turn1.user_message.content, "What is Rust?");
750 assert!(turn1.working_set.is_some());
751 assert!(turn1.assistant_message.is_some());
752
753 let turn2 = &replay.turns[1];
759 assert_eq!(turn2.user_message.content, "Tell me about ownership");
760 assert!(turn2.working_set.is_some());
761 assert!(turn2.assistant_message.is_some());
762
763 assert!(turn1.user_message.sequence_number < turn2.user_message.sequence_number);
765 }
766}