1use crate::compiler;
17use crate::db::Database;
18use crate::index::VectorIndex;
19use crate::storage::StorageBackend;
20use crate::types::{
21 CompilerConfig, Message, MessageRole, Result, Session, WorkingSet,
22};
23use serde::{Deserialize, Serialize};
24use std::sync::Arc;
25
26pub struct SessionManager {
28 db: Database,
29}
30
31impl SessionManager {
32 pub fn new(db: Database) -> Self {
42 Self { db }
43 }
44
45 pub fn start_session(&self, user_id: Option<&str>) -> Result<Session> {
55 self.db.create_session(user_id, None)
56 }
57
58 pub async fn add_user_message(
78 &self,
79 session_id: &str,
80 query: &str,
81 config: CompilerConfig,
82 index: &VectorIndex,
83 api_key: Option<&str>,
84 ) -> Result<(Message, WorkingSet)> {
85 let message = self
87 .db
88 .add_message(session_id, MessageRole::User, query, None)?;
89
90 let working_set = compiler::compile(query, config.clone(), &self.db, index, api_key).await?;
92
93 self.db.associate_working_set(
95 session_id,
96 Some(&message.id),
97 &working_set,
98 query,
99 &config,
100 )?;
101
102 Ok((message, working_set))
103 }
104
105 pub fn add_assistant_message(
117 &self,
118 session_id: &str,
119 content: &str,
120 metadata: Option<&serde_json::Value>,
121 ) -> Result<Message> {
122 self.db
123 .add_message(session_id, MessageRole::Assistant, content, metadata)
124 }
125
126 pub fn get_conversation_history(
150 &self,
151 session_id: &str,
152 max_tokens: Option<usize>,
153 ) -> Result<String> {
154 let messages = self.db.get_messages(session_id, None)?;
155
156 if messages.is_empty() {
157 return Ok(String::new());
158 }
159
160 let formatted_messages: Vec<String> = messages
162 .iter()
163 .map(|msg| {
164 let role = match msg.role {
165 MessageRole::User => "User",
166 MessageRole::Assistant => "Assistant",
167 MessageRole::System => "System",
168 MessageRole::Tool => "Tool",
169 };
170 format!("{}: {}", role, msg.content)
171 })
172 .collect();
173
174 if max_tokens.is_none() {
176 return Ok(formatted_messages.join("\n\n"));
177 }
178
179 let max_tokens = max_tokens.unwrap();
180
181 let mut selected_messages = Vec::new();
184 let mut total_tokens = 0;
185
186 for msg in formatted_messages.iter().rev() {
188 let msg_tokens = estimate_tokens(msg);
189
190 if total_tokens + msg_tokens <= max_tokens {
191 selected_messages.push(msg.clone());
192 total_tokens += msg_tokens;
193 } else {
194 break;
196 }
197 }
198
199 selected_messages.reverse();
201
202 Ok(selected_messages.join("\n\n"))
203 }
204
205 pub fn replay_session(&self, session_id: &str) -> Result<SessionReplay> {
218 let session_data = self.db.get_session_full(session_id)?;
219
220 if session_data.is_none() {
221 return Err(crate::types::Error::NotFound(format!(
222 "Session not found: {}",
223 session_id
224 )));
225 }
226
227 let session_data = session_data.unwrap();
228 let session = session_data.session;
229 let messages = session_data.messages;
230 let working_sets = session_data.working_sets;
231
232 let mut working_set_map = std::collections::HashMap::new();
234 for ws in working_sets {
235 if let Some(msg_id) = &ws.message_id {
236 working_set_map.insert(msg_id.clone(), ws.working_set);
237 }
238 }
239
240 let mut turns = Vec::new();
242 let mut i = 0;
243
244 while i < messages.len() {
245 let msg = &messages[i];
246
247 if matches!(msg.role, MessageRole::User) {
249 let user_message = msg.clone();
250 let working_set = working_set_map.get(&user_message.id).cloned();
251
252 let assistant_message = if i + 1 < messages.len()
254 && matches!(messages[i + 1].role, MessageRole::Assistant)
255 {
256 i += 1; Some(messages[i].clone())
258 } else {
259 None
260 };
261
262 turns.push(SessionTurn {
263 user_message,
264 working_set,
265 assistant_message,
266 });
267 }
268
269 i += 1;
270 }
271
272 Ok(SessionReplay { session, turns })
273 }
274}
275
276pub struct SessionManagerGeneric<B: StorageBackend> {
296 backend: Arc<B>,
297}
298
299impl<B: StorageBackend> SessionManagerGeneric<B> {
300 pub fn new(backend: B) -> Self {
310 Self {
311 backend: Arc::new(backend),
312 }
313 }
314
315 pub fn from_arc(backend: Arc<B>) -> Self {
317 Self { backend }
318 }
319
320 pub fn backend(&self) -> &B {
322 &self.backend
323 }
324
325 pub async fn start_session(&self, user_id: Option<&str>) -> Result<Session> {
335 self.backend.create_session(user_id, None).await
336 }
337
338 pub async fn add_user_message(
357 &self,
358 session_id: &str,
359 query: &str,
360 config: CompilerConfig,
361 api_key: Option<&str>,
362 ) -> Result<(Message, WorkingSet)> {
363 let message = self
365 .backend
366 .add_message(session_id, MessageRole::User, query, None)
367 .await?;
368
369 let working_set = compiler::compile_with_backend(
371 query,
372 config.clone(),
373 self.backend.as_ref(),
374 api_key,
375 )
376 .await?;
377
378 self.backend
380 .associate_working_set(session_id, Some(&message.id), &working_set, query, &config)
381 .await?;
382
383 Ok((message, working_set))
384 }
385
386 pub async fn add_user_message_with_explain(
388 &self,
389 session_id: &str,
390 query: &str,
391 config: CompilerConfig,
392 api_key: Option<&str>,
393 explain: bool,
394 ) -> Result<(Message, WorkingSet)> {
395 let message = self
396 .backend
397 .add_message(session_id, MessageRole::User, query, None)
398 .await?;
399
400 let working_set = compiler::compile_with_backend_options(
401 query,
402 config.clone(),
403 self.backend.as_ref(),
404 api_key,
405 explain,
406 )
407 .await?;
408
409 self.backend
410 .associate_working_set(session_id, Some(&message.id), &working_set, query, &config)
411 .await?;
412
413 Ok((message, working_set))
414 }
415
416 pub async fn add_assistant_message(
428 &self,
429 session_id: &str,
430 content: &str,
431 metadata: Option<&serde_json::Value>,
432 ) -> Result<Message> {
433 self.backend
434 .add_message(session_id, MessageRole::Assistant, content, metadata)
435 .await
436 }
437
438 pub async fn get_conversation_history(
449 &self,
450 session_id: &str,
451 max_tokens: Option<usize>,
452 ) -> Result<String> {
453 let messages = self.backend.get_messages(session_id, None).await?;
454
455 if messages.is_empty() {
456 return Ok(String::new());
457 }
458
459 let formatted_messages: Vec<String> = messages
461 .iter()
462 .map(|msg| {
463 let role = match msg.role {
464 MessageRole::User => "User",
465 MessageRole::Assistant => "Assistant",
466 MessageRole::System => "System",
467 MessageRole::Tool => "Tool",
468 };
469 format!("{}: {}", role, msg.content)
470 })
471 .collect();
472
473 if max_tokens.is_none() {
475 return Ok(formatted_messages.join("\n\n"));
476 }
477
478 let max_tokens = max_tokens.unwrap();
479
480 let mut selected_messages = Vec::new();
482 let mut total_tokens = 0;
483
484 for msg in formatted_messages.iter().rev() {
485 let msg_tokens = estimate_tokens(msg);
486
487 if total_tokens + msg_tokens <= max_tokens {
488 selected_messages.push(msg.clone());
489 total_tokens += msg_tokens;
490 } else {
491 break;
492 }
493 }
494
495 selected_messages.reverse();
496 Ok(selected_messages.join("\n\n"))
497 }
498
499 pub async fn replay_session(&self, session_id: &str) -> Result<SessionReplay> {
509 let session_data = self.backend.get_session_full(session_id).await?;
510
511 if session_data.is_none() {
512 return Err(crate::types::Error::NotFound(format!(
513 "Session not found: {}",
514 session_id
515 )));
516 }
517
518 let session_data = session_data.unwrap();
519 let session = session_data.session;
520 let messages = session_data.messages;
521 let working_sets = session_data.working_sets;
522
523 let mut working_set_map = std::collections::HashMap::new();
525 for ws in working_sets {
526 if let Some(msg_id) = &ws.message_id {
527 working_set_map.insert(msg_id.clone(), ws.working_set);
528 }
529 }
530
531 let mut turns = Vec::new();
533 let mut i = 0;
534
535 while i < messages.len() {
536 let msg = &messages[i];
537
538 if matches!(msg.role, MessageRole::User) {
539 let user_message = msg.clone();
540 let working_set = working_set_map.get(&user_message.id).cloned();
541
542 let assistant_message = if i + 1 < messages.len()
543 && matches!(messages[i + 1].role, MessageRole::Assistant)
544 {
545 i += 1;
546 Some(messages[i].clone())
547 } else {
548 None
549 };
550
551 turns.push(SessionTurn {
552 user_message,
553 working_set,
554 assistant_message,
555 });
556 }
557
558 i += 1;
559 }
560
561 Ok(SessionReplay { session, turns })
562 }
563
564 pub async fn get_session(&self, session_id: &str) -> Result<Option<Session>> {
566 self.backend.get_session(session_id).await
567 }
568
569 pub async fn list_sessions(
571 &self,
572 user_id: Option<&str>,
573 limit: Option<usize>,
574 ) -> Result<Vec<Session>> {
575 self.backend.list_sessions(user_id, limit).await
576 }
577
578 pub async fn delete_session(&self, session_id: &str) -> Result<()> {
580 self.backend.delete_session(session_id).await
581 }
582}
583
584#[derive(Debug, Clone, Serialize, Deserialize)]
586pub struct SessionReplay {
587 pub session: Session,
589 pub turns: Vec<SessionTurn>,
591}
592
593#[derive(Debug, Clone, Serialize, Deserialize)]
595pub struct SessionTurn {
596 pub user_message: Message,
598 pub working_set: Option<WorkingSet>,
600 pub assistant_message: Option<Message>,
602}
603
604fn estimate_tokens(text: &str) -> usize {
610 (text.len() + 3) / 4
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616 use crate::types::Artifact;
617 use crate::types::Span;
618 use uuid::Uuid;
619
620 #[test]
621 fn test_session_manager_new() {
622 let db = Database::new(":memory:").unwrap();
623 let _manager = SessionManager::new(db);
624 }
625
626 #[test]
627 fn test_start_session() {
628 let db = Database::new(":memory:").unwrap();
629 let manager = SessionManager::new(db);
630
631 let session = manager.start_session(Some("test_user")).unwrap();
632
633 assert!(!session.id.is_empty());
634 assert_eq!(session.user_id, Some("test_user".to_string()));
635 }
636
637 #[tokio::test]
638 async fn test_add_user_message() {
639 let db = Database::new(":memory:").unwrap();
640 let manager = SessionManager::new(db.clone());
641
642 let session = manager.start_session(Some("user1")).unwrap();
644
645 let artifact = Artifact {
647 id: Uuid::new_v4().to_string(),
648 path: "test.txt".to_string(),
649 content: "This is a test document about Rust programming.".to_string(),
650 content_hash: "hash123".to_string(),
651 metadata: None,
652 created_at: chrono::Utc::now(),
653 };
654
655 db.insert_artifact(&artifact).unwrap();
656
657 let span = Span {
658 id: Uuid::new_v4().to_string(),
659 artifact_id: artifact.id.clone(),
660 start_line: 1,
661 end_line: 1,
662 text: "This is a test document about Rust programming.".to_string(),
663 embedding: Some(vec![0.1; 384]), embedding_model: Some("test".to_string()),
665 token_count: 10,
666 metadata: None,
667 };
668
669 db.insert_spans(&[span]).unwrap();
670
671 let index = db.get_vector_index().unwrap();
673
674 let config = CompilerConfig::default();
676 let (message, working_set) = manager
677 .add_user_message(&session.id, "What is Rust?", config, &index, None)
678 .await
679 .unwrap();
680
681 assert_eq!(message.content, "What is Rust?");
682 assert_eq!(message.role.as_str(), "user");
683 assert!(!working_set.text.is_empty());
684 }
685
686 #[test]
687 fn test_add_assistant_message() {
688 let db = Database::new(":memory:").unwrap();
689 let manager = SessionManager::new(db);
690
691 let session = manager.start_session(Some("user1")).unwrap();
692
693 let message = manager
694 .add_assistant_message(&session.id, "Rust is a systems programming language.", None)
695 .unwrap();
696
697 assert_eq!(message.content, "Rust is a systems programming language.");
698 assert_eq!(message.role.as_str(), "assistant");
699 }
700
701 #[test]
702 fn test_get_conversation_history() {
703 let db = Database::new(":memory:").unwrap();
704 let manager = SessionManager::new(db.clone());
705
706 let session = manager.start_session(Some("user1")).unwrap();
707
708 db.add_message(&session.id, MessageRole::User, "Hello", None)
710 .unwrap();
711 db.add_message(&session.id, MessageRole::Assistant, "Hi there!", None)
712 .unwrap();
713 db.add_message(&session.id, MessageRole::User, "How are you?", None)
714 .unwrap();
715
716 let history = manager
717 .get_conversation_history(&session.id, None)
718 .unwrap();
719
720 assert!(history.contains("User: Hello"));
721 assert!(history.contains("Assistant: Hi there!"));
722 assert!(history.contains("User: How are you?"));
723
724 let lines: Vec<&str> = history.split("\n\n").collect();
726 assert_eq!(lines.len(), 3);
727 assert_eq!(lines[0], "User: Hello");
728 assert_eq!(lines[1], "Assistant: Hi there!");
729 assert_eq!(lines[2], "User: How are you?");
730 }
731
732 #[test]
733 fn test_get_conversation_history_with_token_limit() {
734 let db = Database::new(":memory:").unwrap();
735 let manager = SessionManager::new(db.clone());
736
737 let session = manager.start_session(Some("user1")).unwrap();
738
739 db.add_message(&session.id, MessageRole::User, "Message 1", None)
741 .unwrap();
742 db.add_message(&session.id, MessageRole::Assistant, "Response 1", None)
743 .unwrap();
744 db.add_message(&session.id, MessageRole::User, "Message 2", None)
745 .unwrap();
746 db.add_message(&session.id, MessageRole::Assistant, "Response 2", None)
747 .unwrap();
748
749 let history = manager
752 .get_conversation_history(&session.id, Some(20))
753 .unwrap();
754
755 assert!(history.contains("Message 2"));
757 assert!(history.contains("Response 2"));
758
759 let message_count = history.split("\n\n").count();
762 assert!(message_count <= 4); }
764
765 #[test]
766 fn test_get_conversation_history_empty() {
767 let db = Database::new(":memory:").unwrap();
768 let manager = SessionManager::new(db);
769
770 let session = manager.start_session(Some("user1")).unwrap();
771
772 let history = manager
773 .get_conversation_history(&session.id, None)
774 .unwrap();
775
776 assert_eq!(history, "");
777 }
778
779 #[tokio::test]
780 async fn test_replay_session() {
781 let db = Database::new(":memory:").unwrap();
782 let manager = SessionManager::new(db.clone());
783
784 let session = manager.start_session(Some("user1")).unwrap();
786
787 let artifact = Artifact {
789 id: Uuid::new_v4().to_string(),
790 path: "test.txt".to_string(),
791 content: "Test content for replay.".to_string(),
792 content_hash: "hash123".to_string(),
793 metadata: None,
794 created_at: chrono::Utc::now(),
795 };
796
797 db.insert_artifact(&artifact).unwrap();
798
799 let span = Span {
800 id: Uuid::new_v4().to_string(),
801 artifact_id: artifact.id.clone(),
802 start_line: 1,
803 end_line: 1,
804 text: "Test content for replay.".to_string(),
805 embedding: Some(vec![0.1; 384]),
806 embedding_model: Some("test".to_string()),
807 token_count: 5,
808 metadata: None,
809 };
810
811 db.insert_spans(&[span]).unwrap();
812
813 let index = db.get_vector_index().unwrap();
814
815 let config = CompilerConfig::default();
817 manager
818 .add_user_message(&session.id, "First query", config.clone(), &index, None)
819 .await
820 .unwrap();
821 manager
822 .add_assistant_message(&session.id, "First response", None)
823 .unwrap();
824 manager
825 .add_user_message(&session.id, "Second query", config, &index, None)
826 .await
827 .unwrap();
828 manager
829 .add_assistant_message(&session.id, "Second response", None)
830 .unwrap();
831
832 let replay = manager.replay_session(&session.id).unwrap();
834
835 assert_eq!(replay.session.id, session.id);
836 assert_eq!(replay.turns.len(), 2);
837
838 let turn1 = &replay.turns[0];
840 assert_eq!(turn1.user_message.content, "First query");
841 assert!(turn1.working_set.is_some());
842 assert!(turn1.assistant_message.is_some());
843 assert_eq!(
844 turn1.assistant_message.as_ref().unwrap().content,
845 "First response"
846 );
847
848 let turn2 = &replay.turns[1];
850 assert_eq!(turn2.user_message.content, "Second query");
851 assert!(turn2.working_set.is_some());
852 assert!(turn2.assistant_message.is_some());
853 assert_eq!(
854 turn2.assistant_message.as_ref().unwrap().content,
855 "Second response"
856 );
857 }
858
859 #[test]
860 fn test_replay_session_not_found() {
861 let db = Database::new(":memory:").unwrap();
862 let manager = SessionManager::new(db);
863
864 let result = manager.replay_session("nonexistent-id");
865 assert!(result.is_err());
866 }
867
868 #[tokio::test]
869 async fn test_replay_session_incomplete_turns() {
870 let db = Database::new(":memory:").unwrap();
871 let manager = SessionManager::new(db.clone());
872
873 let session = manager.start_session(Some("user1")).unwrap();
874
875 let artifact = Artifact {
877 id: Uuid::new_v4().to_string(),
878 path: "test.txt".to_string(),
879 content: "Test content.".to_string(),
880 content_hash: "hash123".to_string(),
881 metadata: None,
882 created_at: chrono::Utc::now(),
883 };
884
885 db.insert_artifact(&artifact).unwrap();
886
887 let span = Span {
888 id: Uuid::new_v4().to_string(),
889 artifact_id: artifact.id.clone(),
890 start_line: 1,
891 end_line: 1,
892 text: "Test content.".to_string(),
893 embedding: Some(vec![0.1; 384]),
894 embedding_model: Some("test".to_string()),
895 token_count: 3,
896 metadata: None,
897 };
898
899 db.insert_spans(&[span]).unwrap();
900
901 let index = db.get_vector_index().unwrap();
902
903 let config = CompilerConfig::default();
905 manager
906 .add_user_message(&session.id, "Query without response", config, &index, None)
907 .await
908 .unwrap();
909
910 let replay = manager.replay_session(&session.id).unwrap();
912
913 assert_eq!(replay.turns.len(), 1);
914 let turn = &replay.turns[0];
915 assert_eq!(turn.user_message.content, "Query without response");
916 assert!(turn.working_set.is_some());
917 assert!(turn.assistant_message.is_none());
918 }
919
920 #[test]
921 fn test_estimate_tokens() {
922 let text = "Hello world";
924 let tokens = estimate_tokens(text);
925 assert_eq!(tokens, 3);
927
928 let longer_text = "This is a longer piece of text for testing token estimation.";
929 let tokens = estimate_tokens(longer_text);
930 assert!(tokens > 10);
932 assert!(tokens < 20);
933 }
934
935 #[tokio::test]
937 async fn test_full_session_workflow() {
938 let db = Database::new(":memory:").unwrap();
940 let manager = SessionManager::new(db.clone());
941
942 let docs = vec![
944 ("rust_basics.md", "Rust is a systems programming language that runs blazingly fast, prevents segfaults, and guarantees thread safety."),
945 ("rust_ownership.md", "Ownership is Rust's most unique feature. It enables Rust to make memory safety guarantees without needing a garbage collector."),
946 ("rust_concurrency.md", "Rust's type system and ownership model guarantee thread safety. You can't have data races in safe Rust code."),
947 ];
948
949 for (path, content) in &docs {
950 let artifact = Artifact {
951 id: Uuid::new_v4().to_string(),
952 path: path.to_string(),
953 content: content.to_string(),
954 content_hash: format!("hash_{}", path),
955 metadata: None,
956 created_at: chrono::Utc::now(),
957 };
958
959 db.insert_artifact(&artifact).unwrap();
960
961 let span = Span {
963 id: Uuid::new_v4().to_string(),
964 artifact_id: artifact.id.clone(),
965 start_line: 1,
966 end_line: 1,
967 text: content.to_string(),
968 embedding: Some(vec![0.1; 384]), embedding_model: Some("test".to_string()),
970 token_count: content.split_whitespace().count(),
971 metadata: None,
972 };
973
974 db.insert_spans(&[span]).unwrap();
975 }
976
977 let index = db.get_vector_index().unwrap();
979
980 let session = manager.start_session(Some("alice")).unwrap();
982 assert_eq!(session.user_id, Some("alice".to_string()));
983
984 let config = CompilerConfig::default();
986 let (msg1, ws1) = manager
987 .add_user_message(&session.id, "What is Rust?", config.clone(), &index, None)
988 .await
989 .unwrap();
990
991 assert_eq!(msg1.content, "What is Rust?");
992 assert!(!ws1.text.is_empty());
993 assert!(!ws1.citations.is_empty());
994
995 let resp1 = manager
997 .add_assistant_message(
998 &session.id,
999 "Rust is a systems programming language known for memory safety.",
1000 None,
1001 )
1002 .unwrap();
1003
1004 assert!(resp1.content.contains("memory safety"));
1005
1006 let (msg2, ws2) = manager
1008 .add_user_message(
1009 &session.id,
1010 "Tell me about ownership",
1011 config.clone(),
1012 &index,
1013 None,
1014 )
1015 .await
1016 .unwrap();
1017
1018 assert_eq!(msg2.content, "Tell me about ownership");
1019 assert!(!ws2.text.is_empty());
1020
1021 let resp2 = manager
1023 .add_assistant_message(
1024 &session.id,
1025 "Ownership is Rust's unique feature for memory management.",
1026 None,
1027 )
1028 .unwrap();
1029
1030 assert!(resp2.content.contains("Ownership"));
1031
1032 let history = manager
1034 .get_conversation_history(&session.id, None)
1035 .unwrap();
1036
1037 assert!(history.contains("What is Rust?"));
1039 assert!(history.contains("memory safety"));
1040 assert!(history.contains("Tell me about ownership"));
1041 assert!(history.contains("Ownership is Rust's unique feature"));
1042
1043 let limited_history = manager
1045 .get_conversation_history(&session.id, Some(100))
1046 .unwrap();
1047
1048 assert!(!limited_history.is_empty());
1050 assert!(limited_history.contains("Ownership"));
1052
1053 let replay = manager.replay_session(&session.id).unwrap();
1055
1056 assert_eq!(replay.session.id, session.id);
1057 assert_eq!(replay.turns.len(), 2);
1058
1059 let turn1 = &replay.turns[0];
1061 assert_eq!(turn1.user_message.content, "What is Rust?");
1062 assert!(turn1.working_set.is_some());
1063 assert!(turn1.assistant_message.is_some());
1064
1065 let turn2 = &replay.turns[1];
1071 assert_eq!(turn2.user_message.content, "Tell me about ownership");
1072 assert!(turn2.working_set.is_some());
1073 assert!(turn2.assistant_message.is_some());
1074
1075 assert!(turn1.user_message.sequence_number < turn2.user_message.sequence_number);
1077 }
1078}