1use std::collections::HashMap;
29use std::sync::Arc;
30use std::time::{Duration, SystemTime, UNIX_EPOCH};
31
32use async_trait::async_trait;
33use parking_lot::RwLock;
34use serde::{Deserialize, Serialize};
35use tracing::{debug, info};
36
37pub type ConversationId = String;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
42pub enum ChatRole {
43 System,
45 User,
47 Assistant,
49 Tool,
51}
52
53impl ChatRole {
54 pub fn as_str(&self) -> &'static str {
55 match self {
56 ChatRole::System => "system",
57 ChatRole::User => "user",
58 ChatRole::Assistant => "assistant",
59 ChatRole::Tool => "tool",
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ChatMessage {
67 pub role: ChatRole,
69 pub content: String,
71 pub timestamp: u64,
73 pub name: Option<String>,
75 pub tool_call_id: Option<String>,
77 pub metadata: HashMap<String, String>,
79}
80
81impl ChatMessage {
82 pub fn new(role: ChatRole, content: impl Into<String>) -> Self {
83 Self {
84 role,
85 content: content.into(),
86 timestamp: SystemTime::now()
87 .duration_since(UNIX_EPOCH)
88 .unwrap_or_default()
89 .as_secs(),
90 name: None,
91 tool_call_id: None,
92 metadata: HashMap::new(),
93 }
94 }
95
96 pub fn system(content: impl Into<String>) -> Self {
97 Self::new(ChatRole::System, content)
98 }
99
100 pub fn user(content: impl Into<String>) -> Self {
101 Self::new(ChatRole::User, content)
102 }
103
104 pub fn assistant(content: impl Into<String>) -> Self {
105 Self::new(ChatRole::Assistant, content)
106 }
107
108 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
109 let mut msg = Self::new(ChatRole::Tool, content);
110 msg.tool_call_id = Some(tool_call_id.into());
111 msg
112 }
113
114 pub fn with_name(mut self, name: impl Into<String>) -> Self {
115 self.name = Some(name.into());
116 self
117 }
118
119 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
120 self.metadata.insert(key.into(), value.into());
121 self
122 }
123
124 pub fn estimated_tokens(&self) -> usize {
126 self.content.len() / 4 + 1
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct Turn {
133 pub number: u32,
135 pub user_message: ChatMessage,
137 pub assistant_response: Option<ChatMessage>,
139 pub tool_calls: Vec<ChatMessage>,
141 pub started_at: u64,
143 pub completed_at: Option<u64>,
145}
146
147impl Turn {
148 pub fn new(number: u32, user_message: ChatMessage) -> Self {
149 Self {
150 number,
151 started_at: user_message.timestamp,
152 user_message,
153 assistant_response: None,
154 tool_calls: Vec::new(),
155 completed_at: None,
156 }
157 }
158
159 pub fn complete(&mut self, response: ChatMessage) {
160 self.assistant_response = Some(response);
161 self.completed_at = Some(
162 SystemTime::now()
163 .duration_since(UNIX_EPOCH)
164 .unwrap_or_default()
165 .as_secs(),
166 );
167 }
168
169 pub fn add_tool_call(&mut self, tool_message: ChatMessage) {
170 self.tool_calls.push(tool_message);
171 }
172
173 pub fn is_complete(&self) -> bool {
174 self.assistant_response.is_some()
175 }
176
177 pub fn all_messages(&self) -> Vec<&ChatMessage> {
178 let mut messages = vec![&self.user_message];
179 messages.extend(self.tool_calls.iter());
180 if let Some(ref response) = self.assistant_response {
181 messages.push(response);
182 }
183 messages
184 }
185}
186
187#[derive(Debug, Clone, Default, Serialize, Deserialize)]
189pub struct SessionState {
190 pub variables: HashMap<String, String>,
192 pub data: HashMap<String, String>,
194}
195
196impl SessionState {
197 pub fn new() -> Self {
198 Self::default()
199 }
200
201 pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
202 self.variables.insert(key.into(), value.into());
203 }
204
205 pub fn get(&self, key: &str) -> Option<&String> {
206 self.variables.get(key)
207 }
208
209 pub fn set_data<T: Serialize>(
210 &mut self,
211 key: impl Into<String>,
212 value: &T,
213 ) -> Result<(), serde_json::Error> {
214 let json = serde_json::to_string(value)?;
215 self.data.insert(key.into(), json);
216 Ok(())
217 }
218
219 pub fn get_data<T: for<'de> Deserialize<'de>>(
220 &self,
221 key: &str,
222 ) -> Option<Result<T, serde_json::Error>> {
223 self.data.get(key).map(|json| serde_json::from_str(json))
224 }
225
226 pub fn remove(&mut self, key: &str) -> Option<String> {
227 self.variables.remove(key)
228 }
229
230 pub fn clear(&mut self) {
231 self.variables.clear();
232 self.data.clear();
233 }
234}
235
236#[derive(Debug, Clone)]
238pub struct SessionConfig {
239 pub max_messages: usize,
241 pub max_tokens: usize,
243 pub timeout: Option<Duration>,
245 pub persist_system_prompt: bool,
247}
248
249impl Default for SessionConfig {
250 fn default() -> Self {
251 Self {
252 max_messages: 100,
253 max_tokens: 8192,
254 timeout: Some(Duration::from_secs(3600)), persist_system_prompt: true,
256 }
257 }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct ConversationSession {
263 pub id: ConversationId,
265 pub user_id: Option<String>,
267 pub system_prompt: Option<String>,
269 pub messages: Vec<ChatMessage>,
271 pub turns: Vec<Turn>,
273 pub state: SessionState,
275 pub created_at: u64,
277 pub updated_at: u64,
279 pub metadata: HashMap<String, String>,
281 #[serde(skip)]
283 config: SessionConfig,
284}
285
286impl ConversationSession {
287 pub fn new(id: impl Into<String>) -> Self {
288 let now = SystemTime::now()
289 .duration_since(UNIX_EPOCH)
290 .unwrap_or_default()
291 .as_secs();
292
293 Self {
294 id: id.into(),
295 user_id: None,
296 system_prompt: None,
297 messages: Vec::new(),
298 turns: Vec::new(),
299 state: SessionState::new(),
300 created_at: now,
301 updated_at: now,
302 metadata: HashMap::new(),
303 config: SessionConfig::default(),
304 }
305 }
306
307 pub fn with_config(mut self, config: SessionConfig) -> Self {
308 self.config = config;
309 self
310 }
311
312 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
313 self.user_id = Some(user_id.into());
314 self
315 }
316
317 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
318 let prompt = prompt.into();
319 self.system_prompt = Some(prompt.clone());
320 if self.config.persist_system_prompt {
321 self.messages.insert(0, ChatMessage::system(prompt));
322 }
323 self
324 }
325
326 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
327 self.metadata.insert(key.into(), value.into());
328 self
329 }
330
331 pub fn add_message(&mut self, role: ChatRole, content: impl Into<String>) {
333 let message = ChatMessage::new(role, content);
334 self.add_message_obj(message);
335 }
336
337 pub fn add_message_obj(&mut self, message: ChatMessage) {
339 match message.role {
341 ChatRole::User => {
342 let turn_num = self.turns.len() as u32 + 1;
343 self.turns.push(Turn::new(turn_num, message.clone()));
344 }
345 ChatRole::Assistant => {
346 if let Some(turn) = self.turns.last_mut() {
347 turn.complete(message.clone());
348 }
349 }
350 ChatRole::Tool => {
351 if let Some(turn) = self.turns.last_mut() {
352 turn.add_tool_call(message.clone());
353 }
354 }
355 ChatRole::System => {
356 }
358 }
359
360 self.messages.push(message);
361 self.touch();
362 self.enforce_limits();
363 }
364
365 pub fn get_context(&self, max_tokens: Option<usize>) -> Vec<&ChatMessage> {
367 let max_tokens = max_tokens.unwrap_or(self.config.max_tokens);
368 let mut result = Vec::new();
369 let mut token_count = 0;
370
371 if let Some(system_msg) = self.messages.first() {
373 if system_msg.role == ChatRole::System {
374 token_count += system_msg.estimated_tokens();
375 result.push(system_msg);
376 }
377 }
378
379 for message in self.messages.iter().rev() {
381 if message.role == ChatRole::System {
382 continue; }
384
385 let msg_tokens = message.estimated_tokens();
386 if token_count + msg_tokens > max_tokens {
387 break;
388 }
389
390 token_count += msg_tokens;
391 result.push(message);
392 }
393
394 let system_msg = if result.first().map(|m| m.role) == Some(ChatRole::System) {
396 Some(result.remove(0))
397 } else {
398 None
399 };
400
401 result.reverse();
402
403 if let Some(sys) = system_msg {
404 result.insert(0, sys);
405 }
406
407 result
408 }
409
410 pub fn get_all_messages(&self) -> &[ChatMessage] {
412 &self.messages
413 }
414
415 pub fn get_recent(&self, n: usize) -> Vec<&ChatMessage> {
417 self.messages.iter().rev().take(n).rev().collect()
418 }
419
420 pub fn get_by_role(&self, role: ChatRole) -> Vec<&ChatMessage> {
422 self.messages.iter().filter(|m| m.role == role).collect()
423 }
424
425 pub fn current_turn(&self) -> u32 {
427 self.turns.len() as u32
428 }
429
430 pub fn last_turn(&self) -> Option<&Turn> {
432 self.turns.last()
433 }
434
435 pub fn get_turn(&self, number: u32) -> Option<&Turn> {
437 if number == 0 || number as usize > self.turns.len() {
438 return None;
439 }
440 self.turns.get(number as usize - 1)
441 }
442
443 pub fn is_expired(&self) -> bool {
445 if let Some(timeout) = self.config.timeout {
446 let now = SystemTime::now()
447 .duration_since(UNIX_EPOCH)
448 .unwrap_or_default()
449 .as_secs();
450 return now - self.updated_at > timeout.as_secs();
451 }
452 false
453 }
454
455 pub fn age(&self) -> Duration {
457 let now = SystemTime::now()
458 .duration_since(UNIX_EPOCH)
459 .unwrap_or_default()
460 .as_secs();
461 Duration::from_secs(now - self.created_at)
462 }
463
464 pub fn clear_messages(&mut self) {
466 let system = if self.config.persist_system_prompt {
467 self.messages
468 .first()
469 .filter(|m| m.role == ChatRole::System)
470 .cloned()
471 } else {
472 None
473 };
474
475 self.messages.clear();
476 self.turns.clear();
477
478 if let Some(sys) = system {
479 self.messages.push(sys);
480 }
481
482 self.touch();
483 }
484
485 pub fn total_tokens(&self) -> usize {
487 self.messages.iter().map(|m| m.estimated_tokens()).sum()
488 }
489
490 fn touch(&mut self) {
491 self.updated_at = SystemTime::now()
492 .duration_since(UNIX_EPOCH)
493 .unwrap_or_default()
494 .as_secs();
495 }
496
497 fn enforce_limits(&mut self) {
498 while self.messages.len() > self.config.max_messages {
500 let remove_idx = if self.messages.first().map(|m| m.role) == Some(ChatRole::System) {
502 1
503 } else {
504 0
505 };
506 if remove_idx < self.messages.len() {
507 self.messages.remove(remove_idx);
508 }
509 }
510 }
511}
512
513#[async_trait]
515pub trait SessionStore: Send + Sync {
516 async fn save(&self, session: &ConversationSession) -> Result<(), SessionError>;
518
519 async fn load(&self, session_id: &str) -> Result<Option<ConversationSession>, SessionError>;
521
522 async fn delete(&self, session_id: &str) -> Result<(), SessionError>;
524
525 async fn list(&self) -> Result<Vec<ConversationId>, SessionError>;
527
528 async fn list_for_user(&self, user_id: &str) -> Result<Vec<ConversationId>, SessionError>;
530}
531
532#[derive(Debug, thiserror::Error)]
534pub enum SessionError {
535 #[error("Session not found: {0}")]
536 NotFound(ConversationId),
537
538 #[error("Session expired: {0}")]
539 Expired(ConversationId),
540
541 #[error("Storage error: {0}")]
542 StorageError(String),
543
544 #[error("Serialization error: {0}")]
545 SerializationError(String),
546}
547
548#[derive(Default)]
550pub struct MemorySessionStore {
551 sessions: RwLock<HashMap<ConversationId, ConversationSession>>,
552}
553
554impl MemorySessionStore {
555 pub fn new() -> Self {
556 Self::default()
557 }
558}
559
560#[async_trait]
561impl SessionStore for MemorySessionStore {
562 async fn save(&self, session: &ConversationSession) -> Result<(), SessionError> {
563 self.sessions
564 .write()
565 .insert(session.id.clone(), session.clone());
566 Ok(())
567 }
568
569 async fn load(&self, session_id: &str) -> Result<Option<ConversationSession>, SessionError> {
570 Ok(self.sessions.read().get(session_id).cloned())
571 }
572
573 async fn delete(&self, session_id: &str) -> Result<(), SessionError> {
574 self.sessions.write().remove(session_id);
575 Ok(())
576 }
577
578 async fn list(&self) -> Result<Vec<ConversationId>, SessionError> {
579 Ok(self.sessions.read().keys().cloned().collect())
580 }
581
582 async fn list_for_user(&self, user_id: &str) -> Result<Vec<ConversationId>, SessionError> {
583 Ok(self
584 .sessions
585 .read()
586 .values()
587 .filter(|s| s.user_id.as_deref() == Some(user_id))
588 .map(|s| s.id.clone())
589 .collect())
590 }
591}
592
593pub struct ConversationManager<S: SessionStore> {
595 store: Arc<S>,
596 config: SessionConfig,
597}
598
599impl ConversationManager<MemorySessionStore> {
600 pub fn in_memory() -> Self {
602 Self::new(Arc::new(MemorySessionStore::new()))
603 }
604}
605
606impl<S: SessionStore> ConversationManager<S> {
607 pub fn new(store: Arc<S>) -> Self {
608 Self {
609 store,
610 config: SessionConfig::default(),
611 }
612 }
613
614 pub fn with_config(mut self, config: SessionConfig) -> Self {
615 self.config = config;
616 self
617 }
618
619 pub async fn create(
621 &self,
622 session_id: impl Into<String>,
623 ) -> Result<ConversationSession, SessionError> {
624 let session = ConversationSession::new(session_id).with_config(self.config.clone());
625 self.store.save(&session).await?;
626 info!(session_id = %session.id, "Session created");
627 Ok(session)
628 }
629
630 pub async fn get_or_create(
632 &self,
633 session_id: impl Into<String>,
634 ) -> Result<ConversationSession, SessionError> {
635 let session_id = session_id.into();
636 match self.store.load(&session_id).await? {
637 Some(session) => {
638 if session.is_expired() {
639 debug!(session_id = %session_id, "Session expired, creating new");
640 self.store.delete(&session_id).await?;
641 self.create(session_id).await
642 } else {
643 Ok(session)
644 }
645 }
646 None => self.create(session_id).await,
647 }
648 }
649
650 pub async fn update(&self, session: &ConversationSession) -> Result<(), SessionError> {
652 self.store.save(session).await
653 }
654
655 pub async fn delete(&self, session_id: &str) -> Result<(), SessionError> {
657 self.store.delete(session_id).await?;
658 info!(session_id, "Session deleted");
659 Ok(())
660 }
661
662 pub async fn list(&self) -> Result<Vec<ConversationId>, SessionError> {
664 self.store.list().await
665 }
666
667 pub async fn list_for_user(&self, user_id: &str) -> Result<Vec<ConversationId>, SessionError> {
669 self.store.list_for_user(user_id).await
670 }
671
672 pub async fn cleanup_expired(&self) -> Result<usize, SessionError> {
674 let mut cleaned = 0;
675 let session_ids = self.store.list().await?;
676
677 for id in session_ids {
678 if let Some(session) = self.store.load(&id).await? {
679 if session.is_expired() {
680 self.store.delete(&id).await?;
681 cleaned += 1;
682 }
683 }
684 }
685
686 if cleaned > 0 {
687 info!(count = cleaned, "Cleaned up expired sessions");
688 }
689
690 Ok(cleaned)
691 }
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697
698 #[test]
699 fn test_message_creation() {
700 let msg = ChatMessage::user("Hello!");
701 assert_eq!(msg.role, ChatRole::User);
702 assert_eq!(msg.content, "Hello!");
703 assert!(msg.timestamp > 0);
704 }
705
706 #[test]
707 fn test_message_builders() {
708 let msg = ChatMessage::assistant("Hi there!")
709 .with_name("Assistant")
710 .with_metadata("source", "test");
711
712 assert_eq!(msg.role, ChatRole::Assistant);
713 assert_eq!(msg.name, Some("Assistant".to_string()));
714 assert_eq!(msg.metadata.get("source").unwrap(), "test");
715 }
716
717 #[test]
718 fn test_session_basic() {
719 let mut session =
720 ConversationSession::new("test_session").with_system_prompt("You are helpful");
721
722 session.add_message(ChatRole::User, "Hello!");
723 session.add_message(ChatRole::Assistant, "Hi there!");
724
725 assert_eq!(session.messages.len(), 3); assert_eq!(session.current_turn(), 1);
727 }
728
729 #[test]
730 fn test_session_turns() {
731 let mut session = ConversationSession::new("test");
732
733 session.add_message(ChatRole::User, "First question");
734 session.add_message(ChatRole::Assistant, "First answer");
735 session.add_message(ChatRole::User, "Second question");
736 session.add_message(ChatRole::Assistant, "Second answer");
737
738 assert_eq!(session.current_turn(), 2);
739
740 let turn1 = session.get_turn(1).unwrap();
741 assert_eq!(turn1.user_message.content, "First question");
742 assert!(turn1.is_complete());
743
744 let turn2 = session.get_turn(2).unwrap();
745 assert_eq!(turn2.user_message.content, "Second question");
746 }
747
748 #[test]
749 fn test_get_context_with_limit() {
750 let mut session = ConversationSession::new("test").with_system_prompt("System");
751
752 for i in 0..20 {
754 session.add_message(ChatRole::User, format!("Message {}", i));
755 session.add_message(ChatRole::Assistant, format!("Response {}", i));
756 }
757
758 let context = session.get_context(Some(100));
760
761 assert!(!context.is_empty());
763 assert_eq!(context[0].role, ChatRole::System);
764 }
765
766 #[test]
767 fn test_session_state() {
768 let mut session = ConversationSession::new("test");
769
770 session.state.set("user_name", "Alice");
771 session.state.set("preference", "dark_mode");
772
773 assert_eq!(session.state.get("user_name").unwrap(), "Alice");
774 assert_eq!(session.state.get("preference").unwrap(), "dark_mode");
775
776 #[derive(Serialize, Deserialize, PartialEq, Debug)]
778 struct UserPrefs {
779 theme: String,
780 language: String,
781 }
782
783 let prefs = UserPrefs {
784 theme: "dark".to_string(),
785 language: "en".to_string(),
786 };
787
788 session.state.set_data("prefs", &prefs).unwrap();
789 let loaded: UserPrefs = session.state.get_data("prefs").unwrap().unwrap();
790 assert_eq!(loaded, prefs);
791 }
792
793 #[test]
794 fn test_message_limit() {
795 let config = SessionConfig {
796 max_messages: 5,
797 ..Default::default()
798 };
799
800 let mut session = ConversationSession::new("test").with_config(config);
801
802 for i in 0..10 {
803 session.add_message(ChatRole::User, format!("Message {}", i));
804 }
805
806 assert_eq!(session.messages.len(), 5);
807 }
808
809 #[test]
810 fn test_clear_messages() {
811 let mut session = ConversationSession::new("test").with_system_prompt("System prompt");
812
813 session.add_message(ChatRole::User, "Hello");
814 session.add_message(ChatRole::Assistant, "Hi");
815
816 assert_eq!(session.messages.len(), 3);
817
818 session.clear_messages();
819
820 assert_eq!(session.messages.len(), 1);
822 assert_eq!(session.messages[0].role, ChatRole::System);
823 assert_eq!(session.turns.len(), 0);
824 }
825
826 #[test]
827 fn test_get_by_role() {
828 let mut session = ConversationSession::new("test");
829
830 session.add_message(ChatRole::User, "Q1");
831 session.add_message(ChatRole::Assistant, "A1");
832 session.add_message(ChatRole::User, "Q2");
833 session.add_message(ChatRole::Assistant, "A2");
834
835 let user_messages = session.get_by_role(ChatRole::User);
836 assert_eq!(user_messages.len(), 2);
837
838 let assistant_messages = session.get_by_role(ChatRole::Assistant);
839 assert_eq!(assistant_messages.len(), 2);
840 }
841
842 #[tokio::test]
843 async fn test_session_manager() {
844 let manager = ConversationManager::in_memory();
845
846 let session = manager.create("session1").await.unwrap();
847 assert_eq!(session.id, "session1");
848
849 let loaded = manager.get_or_create("session1").await.unwrap();
850 assert_eq!(loaded.id, "session1");
851
852 let sessions = manager.list().await.unwrap();
853 assert_eq!(sessions.len(), 1);
854
855 manager.delete("session1").await.unwrap();
856 let sessions = manager.list().await.unwrap();
857 assert!(sessions.is_empty());
858 }
859
860 #[tokio::test]
861 async fn test_session_store() {
862 let store = MemorySessionStore::new();
863
864 let mut session = ConversationSession::new("test").with_user_id("user1");
865 session.add_message(ChatRole::User, "Hello");
866
867 store.save(&session).await.unwrap();
868
869 let loaded = store.load("test").await.unwrap().unwrap();
870 assert_eq!(loaded.messages.len(), 1);
871
872 let user_sessions = store.list_for_user("user1").await.unwrap();
873 assert_eq!(user_sessions.len(), 1);
874 }
875
876 #[test]
877 fn test_tool_message() {
878 let mut session = ConversationSession::new("test");
879
880 session.add_message(ChatRole::User, "Calculate 2+2");
881 session.add_message_obj(ChatMessage::tool("4", "call_123"));
882 session.add_message(ChatRole::Assistant, "The result is 4");
883
884 let turn = session.get_turn(1).unwrap();
885 assert_eq!(turn.tool_calls.len(), 1);
886 assert_eq!(
887 turn.tool_calls[0].tool_call_id,
888 Some("call_123".to_string())
889 );
890 }
891
892 #[test]
893 fn test_token_estimation() {
894 let msg = ChatMessage::user("Hello world"); assert_eq!(msg.estimated_tokens(), 3); }
897
898 #[test]
899 fn test_total_tokens() {
900 let mut session = ConversationSession::new("test");
901 session.add_message(ChatRole::User, "Hello world"); session.add_message(ChatRole::Assistant, "Hi there"); assert!(session.total_tokens() >= 4);
905 }
906}