graph_flow/
context.rs

1//! Context and state management for workflows.
2//!
3//! This module provides thread-safe state management across workflow tasks,
4//! including regular data storage and specialized chat history management.
5//!
6//! # Examples
7//!
8//! ## Basic Context Usage
9//!
10//! ```rust
11//! use graph_flow::Context;
12//!
13//! # #[tokio::main]
14//! # async fn main() {
15//! let context = Context::new();
16//!
17//! // Store different types of data
18//! context.set("user_id", 12345).await;
19//! context.set("name", "Alice".to_string()).await;
20//! context.set("active", true).await;
21//!
22//! // Retrieve data with type safety
23//! let user_id: Option<i32> = context.get("user_id").await;
24//! let name: Option<String> = context.get("name").await;
25//! let active: Option<bool> = context.get("active").await;
26//!
27//! // Synchronous access (useful in edge conditions)
28//! let name_sync: Option<String> = context.get_sync("name");
29//! # }
30//! ```
31//!
32//! ## Chat History Management
33//!
34//! ```rust
35//! use graph_flow::Context;
36//!
37//! # #[tokio::main]
38//! # async fn main() {
39//! let context = Context::new();
40//!
41//! // Add messages to chat history
42//! context.add_user_message("Hello, assistant!".to_string()).await;
43//! context.add_assistant_message("Hello! How can I help you?".to_string()).await;
44//! context.add_system_message("User session started".to_string()).await;
45//!
46//! // Get chat history
47//! let history = context.get_chat_history().await;
48//! let all_messages = context.get_all_messages().await;
49//! let last_5 = context.get_last_messages(5).await;
50//!
51//! // Check history status
52//! let count = context.chat_history_len().await;
53//! let is_empty = context.is_chat_history_empty().await;
54//! # }
55//! ```
56//!
57//! ## Context with Message Limits
58//!
59//! ```rust
60//! use graph_flow::Context;
61//!
62//! # #[tokio::main]
63//! # async fn main() {
64//! // Create context with maximum 100 messages
65//! let context = Context::with_max_chat_messages(100);
66//!
67//! // Messages will be automatically pruned when limit is exceeded
68//! for i in 0..150 {
69//!     context.add_user_message(format!("Message {}", i)).await;
70//! }
71//!
72//! // Only the last 100 messages are kept
73//! assert_eq!(context.chat_history_len().await, 100);
74//! # }
75//! ```
76//!
77//! ## LLM Integration (with `rig` feature)
78//!
79//! ```rust
80//! # #[cfg(feature = "rig")]
81//! # {
82//! use graph_flow::Context;
83//!
84//! # #[tokio::main]
85//! # async fn main() {
86//! let context = Context::new();
87//!
88//! context.add_user_message("What is the capital of France?".to_string()).await;
89//! context.add_assistant_message("The capital of France is Paris.".to_string()).await;
90//!
91//! // Get messages in rig format for LLM calls
92//! let rig_messages = context.get_rig_messages().await;
93//! let recent_messages = context.get_last_rig_messages(10).await;
94//!
95//! // Use with rig's completion API
96//! // let response = agent.completion(&rig_messages).await?;
97//! # }
98//! # }
99//! ```
100
101use chrono::{DateTime, Utc};
102use dashmap::DashMap;
103use serde::{Deserialize, Serialize};
104use serde_json::Value;
105use std::sync::{Arc, RwLock};
106
107#[cfg(feature = "rig")]
108use rig::completion::Message;
109
110/// Represents the role of a message in a conversation.
111///
112/// Used in chat history to distinguish between different types of messages.
113#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
114pub enum MessageRole {
115    /// Message from a user/human
116    User,
117    /// Message from an assistant/AI
118    Assistant,
119    /// System message (instructions, status updates, etc.)
120    System,
121}
122
123/// A serializable message that can be converted to/from rig::completion::Message.
124///
125/// This struct provides a unified message format that can be stored, serialized,
126/// and optionally converted to other formats like rig's Message type.
127///
128/// # Examples
129///
130/// ```rust
131/// use graph_flow::{SerializableMessage, MessageRole};
132///
133/// // Create different types of messages
134/// let user_msg = SerializableMessage::user("Hello!".to_string());
135/// let assistant_msg = SerializableMessage::assistant("Hi there!".to_string());
136/// let system_msg = SerializableMessage::system("Session started".to_string());
137///
138/// // Access message properties
139/// assert_eq!(user_msg.role, MessageRole::User);
140/// assert_eq!(user_msg.content, "Hello!");
141/// ```
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct SerializableMessage {
144    /// The role of the message sender
145    pub role: MessageRole,
146    /// The content of the message
147    pub content: String,
148    /// When the message was created
149    pub timestamp: DateTime<Utc>,
150}
151
152impl SerializableMessage {
153    /// Create a new message with the specified role and content.
154    ///
155    /// The timestamp is automatically set to the current UTC time.
156    pub fn new(role: MessageRole, content: String) -> Self {
157        Self {
158            role,
159            content,
160            timestamp: Utc::now(),
161        }
162    }
163
164    /// Create a new user message.
165    ///
166    /// # Examples
167    ///
168    /// ```rust
169    /// use graph_flow::SerializableMessage;
170    ///
171    /// let msg = SerializableMessage::user("Hello, world!".to_string());
172    /// ```
173    pub fn user(content: String) -> Self {
174        Self::new(MessageRole::User, content)
175    }
176
177    /// Create a new assistant message.
178    ///
179    /// # Examples
180    ///
181    /// ```rust
182    /// use graph_flow::SerializableMessage;
183    ///
184    /// let msg = SerializableMessage::assistant("Hello! How can I help?".to_string());
185    /// ```
186    pub fn assistant(content: String) -> Self {
187        Self::new(MessageRole::Assistant, content)
188    }
189
190    /// Create a new system message.
191    ///
192    /// # Examples
193    ///
194    /// ```rust
195    /// use graph_flow::SerializableMessage;
196    ///
197    /// let msg = SerializableMessage::system("User logged in".to_string());
198    /// ```
199    pub fn system(content: String) -> Self {
200        Self::new(MessageRole::System, content)
201    }
202}
203
204/// Container for managing chat history with serialization support.
205///
206/// Provides automatic message limit management and convenient methods
207/// for adding and retrieving messages.
208///
209/// # Examples
210///
211/// ```rust
212/// use graph_flow::ChatHistory;
213///
214/// let mut history = ChatHistory::new();
215/// history.add_user_message("Hello".to_string());
216/// history.add_assistant_message("Hi there!".to_string());
217///
218/// assert_eq!(history.len(), 2);
219/// assert!(!history.is_empty());
220/// ```
221#[derive(Debug, Clone, Serialize, Deserialize, Default)]
222pub struct ChatHistory {
223    messages: Vec<SerializableMessage>,
224    max_messages: Option<usize>,
225}
226
227impl ChatHistory {
228    /// Create a new empty chat history with a default limit of 1000 messages.
229    pub fn new() -> Self {
230        Self {
231            messages: Vec::new(),
232            max_messages: Some(1000), // Default limit to prevent unbounded growth
233        }
234    }
235
236    /// Create a new chat history with a maximum message limit.
237    ///
238    /// When the limit is exceeded, older messages are automatically removed.
239    ///
240    /// # Examples
241    ///
242    /// ```rust
243    /// use graph_flow::ChatHistory;
244    ///
245    /// let mut history = ChatHistory::with_max_messages(10);
246    ///
247    /// // Add 15 messages
248    /// for i in 0..15 {
249    ///     history.add_user_message(format!("Message {}", i));
250    /// }
251    ///
252    /// // Only the last 10 are kept
253    /// assert_eq!(history.len(), 10);
254    /// ```
255    pub fn with_max_messages(max: usize) -> Self {
256        Self {
257            messages: Vec::new(),
258            max_messages: Some(max),
259        }
260    }
261
262    /// Add a user message to the chat history.
263    pub fn add_user_message(&mut self, content: String) {
264        self.add_message(SerializableMessage::user(content));
265    }
266
267    /// Add an assistant message to the chat history.
268    pub fn add_assistant_message(&mut self, content: String) {
269        self.add_message(SerializableMessage::assistant(content));
270    }
271
272    /// Add a system message to the chat history.
273    pub fn add_system_message(&mut self, content: String) {
274        self.add_message(SerializableMessage::system(content));
275    }
276
277    /// Add a message to the chat history, respecting max_messages limit.
278    fn add_message(&mut self, message: SerializableMessage) {
279        self.messages.push(message);
280
281        if let Some(max) = self.max_messages {
282            if self.messages.len() > max {
283                self.messages.drain(0..(self.messages.len() - max));
284            }
285        }
286    }
287
288    /// Clear all messages from the chat history.
289    pub fn clear(&mut self) {
290        self.messages.clear();
291    }
292
293    /// Get the number of messages in the chat history.
294    pub fn len(&self) -> usize {
295        self.messages.len()
296    }
297
298    /// Check if the chat history is empty.
299    pub fn is_empty(&self) -> bool {
300        self.messages.is_empty()
301    }
302
303    /// Get a reference to all messages.
304    pub fn messages(&self) -> &[SerializableMessage] {
305        &self.messages
306    }
307
308    /// Get the last N messages.
309    ///
310    /// If N is greater than the total number of messages, all messages are returned.
311    ///
312    /// # Examples
313    ///
314    /// ```rust
315    /// use graph_flow::ChatHistory;
316    ///
317    /// let mut history = ChatHistory::new();
318    /// history.add_user_message("Message 1".to_string());
319    /// history.add_user_message("Message 2".to_string());
320    /// history.add_user_message("Message 3".to_string());
321    ///
322    /// let last_two = history.last_messages(2);
323    /// assert_eq!(last_two.len(), 2);
324    /// assert_eq!(last_two[0].content, "Message 2");
325    /// assert_eq!(last_two[1].content, "Message 3");
326    /// ```
327    pub fn last_messages(&self, n: usize) -> &[SerializableMessage] {
328        let start = if self.messages.len() > n {
329            self.messages.len() - n
330        } else {
331            0
332        };
333        &self.messages[start..]
334    }
335}
336
337/// Helper struct for serializing/deserializing Context
338#[derive(Serialize, Deserialize)]
339struct ContextData {
340    data: std::collections::HashMap<String, Value>,
341    chat_history: ChatHistory,
342}
343
344/// Context for sharing data between tasks in a graph execution.
345///
346/// Provides thread-safe storage for workflow state and dedicated chat history
347/// management. The context is shared across all tasks in a workflow execution.
348///
349/// # Examples
350///
351/// ## Basic Usage
352///
353/// ```rust
354/// use graph_flow::Context;
355///
356/// # #[tokio::main]
357/// # async fn main() {
358/// let context = Context::new();
359///
360/// // Store different types of data
361/// context.set("user_id", 12345).await;
362/// context.set("name", "Alice".to_string()).await;
363/// context.set("settings", vec!["opt1", "opt2"]).await;
364///
365/// // Retrieve data
366/// let user_id: Option<i32> = context.get("user_id").await;
367/// let name: Option<String> = context.get("name").await;
368/// let settings: Option<Vec<String>> = context.get("settings").await;
369/// # }
370/// ```
371///
372/// ## Chat History
373///
374/// ```rust
375/// use graph_flow::Context;
376///
377/// # #[tokio::main]
378/// # async fn main() {
379/// let context = Context::new();
380///
381/// // Add messages
382/// context.add_user_message("Hello".to_string()).await;
383/// context.add_assistant_message("Hi there!".to_string()).await;
384///
385/// // Get message history
386/// let history = context.get_chat_history().await;
387/// let last_5 = context.get_last_messages(5).await;
388/// # }
389/// ```
390#[derive(Clone, Debug)]
391pub struct Context {
392    data: Arc<DashMap<String, Value>>,
393    chat_history: Arc<RwLock<ChatHistory>>,
394}
395
396impl Context {
397    /// Create a new empty context.
398    pub fn new() -> Self {
399        Self {
400            data: Arc::new(DashMap::new()),
401            chat_history: Arc::new(RwLock::new(ChatHistory::new())),
402        }
403    }
404
405    /// Create a new context with a maximum chat history size.
406    ///
407    /// When the chat history exceeds this size, older messages are automatically removed.
408    ///
409    /// # Examples
410    ///
411    /// ```rust
412    /// use graph_flow::Context;
413    ///
414    /// # #[tokio::main]
415    /// # async fn main() {
416    /// let context = Context::with_max_chat_messages(50);
417    ///
418    /// // Chat history will be limited to 50 messages
419    /// for i in 0..100 {
420    ///     context.add_user_message(format!("Message {}", i)).await;
421    /// }
422    ///
423    /// assert_eq!(context.chat_history_len().await, 50);
424    /// # }
425    /// ```
426    pub fn with_max_chat_messages(max: usize) -> Self {
427        Self {
428            data: Arc::new(DashMap::new()),
429            chat_history: Arc::new(RwLock::new(ChatHistory::with_max_messages(max))),
430        }
431    }
432
433    // Regular context methods (unchanged API)
434
435    /// Set a value in the context.
436    ///
437    /// The value must be serializable. Most common Rust types are supported.
438    ///
439    /// # Examples
440    ///
441    /// ```rust
442    /// use graph_flow::Context;
443    /// use serde::{Serialize, Deserialize};
444    ///
445    /// #[derive(Serialize, Deserialize)]
446    /// struct UserData {
447    ///     id: u32,
448    ///     name: String,
449    /// }
450    ///
451    /// # #[tokio::main]
452    /// # async fn main() {
453    /// let context = Context::new();
454    ///
455    /// // Store primitive types
456    /// context.set("count", 42).await;
457    /// context.set("name", "Alice".to_string()).await;
458    /// context.set("active", true).await;
459    ///
460    /// // Store complex types
461    /// let user = UserData { id: 1, name: "Bob".to_string() };
462    /// context.set("user", user).await;
463    /// # }
464    /// ```
465    pub async fn set(&self, key: impl Into<String>, value: impl serde::Serialize) {
466        let value = serde_json::to_value(value).expect("Failed to serialize value");
467        self.data.insert(key.into(), value);
468    }
469
470    /// Get a value from the context.
471    ///
472    /// Returns `None` if the key doesn't exist or if deserialization fails.
473    ///
474    /// # Examples
475    ///
476    /// ```rust
477    /// use graph_flow::Context;
478    ///
479    /// # #[tokio::main]
480    /// # async fn main() {
481    /// let context = Context::new();
482    /// context.set("count", 42).await;
483    ///
484    /// let count: Option<i32> = context.get("count").await;
485    /// assert_eq!(count, Some(42));
486    ///
487    /// let missing: Option<String> = context.get("missing").await;
488    /// assert_eq!(missing, None);
489    /// # }
490    /// ```
491    pub async fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
492        self.data
493            .get(key)
494            .and_then(|v| serde_json::from_value(v.clone()).ok())
495    }
496
497    /// Remove a value from the context.
498    ///
499    /// Returns the removed value if it existed.
500    ///
501    /// # Examples
502    ///
503    /// ```rust
504    /// use graph_flow::Context;
505    ///
506    /// # #[tokio::main]
507    /// # async fn main() {
508    /// let context = Context::new();
509    /// context.set("temp", "value".to_string()).await;
510    ///
511    /// let removed = context.remove("temp").await;
512    /// assert!(removed.is_some());
513    ///
514    /// let value: Option<String> = context.get("temp").await;
515    /// assert_eq!(value, None);
516    /// # }
517    /// ```
518    pub async fn remove(&self, key: &str) -> Option<Value> {
519        self.data.remove(key).map(|(_, v)| v)
520    }
521
522    /// Clear all regular context data (does not affect chat history).
523    ///
524    /// # Examples
525    ///
526    /// ```rust
527    /// use graph_flow::Context;
528    ///
529    /// # #[tokio::main]
530    /// # async fn main() {
531    /// let context = Context::new();
532    /// context.set("key1", "value1".to_string()).await;
533    /// context.set("key2", "value2".to_string()).await;
534    /// context.add_user_message("Hello".to_string()).await;
535    ///
536    /// context.clear().await;
537    ///
538    /// // Regular data is cleared
539    /// let value: Option<String> = context.get("key1").await;
540    /// assert_eq!(value, None);
541    ///
542    /// // Chat history is preserved
543    /// assert_eq!(context.chat_history_len().await, 1);
544    /// # }
545    /// ```
546    pub async fn clear(&self) {
547        self.data.clear();
548    }
549
550    /// Synchronous version of get for use in edge conditions.
551    ///
552    /// This method should only be used when you're certain the data exists
553    /// and when async is not available (e.g., in edge condition closures).
554    ///
555    /// # Examples
556    ///
557    /// ```rust
558    /// use graph_flow::{Context, GraphBuilder};
559    ///
560    /// # #[tokio::main]
561    /// # async fn main() {
562    /// let context = Context::new();
563    /// context.set("condition", true).await;
564    ///
565    /// // Used in edge conditions
566    /// let graph = GraphBuilder::new("test")
567    ///     .add_conditional_edge(
568    ///         "task1",
569    ///         |ctx| ctx.get_sync::<bool>("condition").unwrap_or(false),
570    ///         "task2",
571    ///         "task3"
572    ///     );
573    /// # }
574    /// ```
575    pub fn get_sync<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
576        self.data
577            .get(key)
578            .and_then(|v| serde_json::from_value(v.clone()).ok())
579    }
580
581    /// Synchronous version of set for use when async is not available.
582    ///
583    /// # Examples
584    ///
585    /// ```rust
586    /// use graph_flow::Context;
587    ///
588    /// let context = Context::new();
589    /// context.set_sync("key", "value".to_string());
590    ///
591    /// let value: Option<String> = context.get_sync("key");
592    /// assert_eq!(value, Some("value".to_string()));
593    /// ```
594    pub fn set_sync(&self, key: impl Into<String>, value: impl serde::Serialize) {
595        let value = serde_json::to_value(value).expect("Failed to serialize value");
596        self.data.insert(key.into(), value);
597    }
598
599    // Chat history methods
600
601    /// Add a user message to the chat history.
602    ///
603    /// # Examples
604    ///
605    /// ```rust
606    /// use graph_flow::Context;
607    ///
608    /// # #[tokio::main]
609    /// # async fn main() {
610    /// let context = Context::new();
611    /// context.add_user_message("Hello, assistant!".to_string()).await;
612    /// # }
613    /// ```
614    pub async fn add_user_message(&self, content: String) {
615        if let Ok(mut history) = self.chat_history.write() {
616            history.add_user_message(content);
617        }
618    }
619
620    /// Add an assistant message to the chat history.
621    ///
622    /// # Examples
623    ///
624    /// ```rust
625    /// use graph_flow::Context;
626    ///
627    /// # #[tokio::main]
628    /// # async fn main() {
629    /// let context = Context::new();
630    /// context.add_assistant_message("Hello! How can I help you?".to_string()).await;
631    /// # }
632    /// ```
633    pub async fn add_assistant_message(&self, content: String) {
634        if let Ok(mut history) = self.chat_history.write() {
635            history.add_assistant_message(content);
636        }
637    }
638
639    /// Add a system message to the chat history.
640    ///
641    /// # Examples
642    ///
643    /// ```rust
644    /// use graph_flow::Context;
645    ///
646    /// # #[tokio::main]
647    /// # async fn main() {
648    /// let context = Context::new();
649    /// context.add_system_message("Session started".to_string()).await;
650    /// # }
651    /// ```
652    pub async fn add_system_message(&self, content: String) {
653        if let Ok(mut history) = self.chat_history.write() {
654            history.add_system_message(content);
655        }
656    }
657
658    /// Get a clone of the current chat history.
659    ///
660    /// # Examples
661    ///
662    /// ```rust
663    /// use graph_flow::Context;
664    ///
665    /// # #[tokio::main]
666    /// # async fn main() {
667    /// let context = Context::new();
668    /// context.add_user_message("Hello".to_string()).await;
669    ///
670    /// let history = context.get_chat_history().await;
671    /// assert_eq!(history.len(), 1);
672    /// # }
673    /// ```
674    pub async fn get_chat_history(&self) -> ChatHistory {
675        if let Ok(history) = self.chat_history.read() {
676            history.clone()
677        } else {
678            ChatHistory::new()
679        }
680    }
681
682    /// Clear the chat history.
683    ///
684    /// # Examples
685    ///
686    /// ```rust
687    /// use graph_flow::Context;
688    ///
689    /// # #[tokio::main]
690    /// # async fn main() {
691    /// let context = Context::new();
692    /// context.add_user_message("Hello".to_string()).await;
693    /// assert_eq!(context.chat_history_len().await, 1);
694    ///
695    /// context.clear_chat_history().await;
696    /// assert_eq!(context.chat_history_len().await, 0);
697    /// # }
698    /// ```
699    pub async fn clear_chat_history(&self) {
700        if let Ok(mut history) = self.chat_history.write() {
701            history.clear();
702        }
703    }
704
705    /// Get the number of messages in the chat history.
706    pub async fn chat_history_len(&self) -> usize {
707        if let Ok(history) = self.chat_history.read() {
708            history.len()
709        } else {
710            0
711        }
712    }
713
714    /// Check if the chat history is empty.
715    pub async fn is_chat_history_empty(&self) -> bool {
716        if let Ok(history) = self.chat_history.read() {
717            history.is_empty()
718        } else {
719            true
720        }
721    }
722
723    /// Get the last N messages from chat history.
724    ///
725    /// # Examples
726    ///
727    /// ```rust
728    /// use graph_flow::Context;
729    ///
730    /// # #[tokio::main]
731    /// # async fn main() {
732    /// let context = Context::new();
733    /// context.add_user_message("Message 1".to_string()).await;
734    /// context.add_user_message("Message 2".to_string()).await;
735    /// context.add_user_message("Message 3".to_string()).await;
736    ///
737    /// let last_two = context.get_last_messages(2).await;
738    /// assert_eq!(last_two.len(), 2);
739    /// assert_eq!(last_two[0].content, "Message 2");
740    /// assert_eq!(last_two[1].content, "Message 3");
741    /// # }
742    /// ```
743    pub async fn get_last_messages(&self, n: usize) -> Vec<SerializableMessage> {
744        if let Ok(history) = self.chat_history.read() {
745            history.last_messages(n).to_vec()
746        } else {
747            Vec::new()
748        }
749    }
750
751    /// Get all messages from chat history as SerializableMessage.
752    ///
753    /// # Examples
754    ///
755    /// ```rust
756    /// use graph_flow::Context;
757    ///
758    /// # #[tokio::main]
759    /// # async fn main() {
760    /// let context = Context::new();
761    /// context.add_user_message("Hello".to_string()).await;
762    /// context.add_assistant_message("Hi there!".to_string()).await;
763    ///
764    /// let all_messages = context.get_all_messages().await;
765    /// assert_eq!(all_messages.len(), 2);
766    /// # }
767    /// ```
768    pub async fn get_all_messages(&self) -> Vec<SerializableMessage> {
769        if let Ok(history) = self.chat_history.read() {
770            history.messages().to_vec()
771        } else {
772            Vec::new()
773        }
774    }
775
776    // Rig integration methods (only available when rig feature is enabled)
777
778    #[cfg(feature = "rig")]
779    /// Get all chat history messages converted to rig::completion::Message format.
780    ///
781    /// This method is only available when the "rig" feature is enabled.
782    ///
783    /// # Examples
784    ///
785    /// ```rust
786    /// # #[cfg(feature = "rig")]
787    /// # {
788    /// use graph_flow::Context;
789    ///
790    /// # #[tokio::main]
791    /// # async fn main() {
792    /// let context = Context::new();
793    /// context.add_user_message("Hello".to_string()).await;
794    /// context.add_assistant_message("Hi there!".to_string()).await;
795    ///
796    /// let rig_messages = context.get_rig_messages().await;
797    /// assert_eq!(rig_messages.len(), 2);
798    /// # }
799    /// # }
800    /// ```
801    pub async fn get_rig_messages(&self) -> Vec<Message> {
802        let messages = self.get_all_messages().await;
803        messages
804            .iter()
805            .map(|msg| self.to_rig_message(msg))
806            .collect()
807    }
808
809    #[cfg(feature = "rig")]
810    /// Get the last N messages converted to rig::completion::Message format.
811    ///
812    /// This method is only available when the "rig" feature is enabled.
813    ///
814    /// # Examples
815    ///
816    /// ```rust
817    /// # #[cfg(feature = "rig")]
818    /// # {
819    /// use graph_flow::Context;
820    ///
821    /// # #[tokio::main]
822    /// # async fn main() {
823    /// let context = Context::new();
824    /// for i in 0..10 {
825    ///     context.add_user_message(format!("Message {}", i)).await;
826    /// }
827    ///
828    /// let last_5 = context.get_last_rig_messages(5).await;
829    /// assert_eq!(last_5.len(), 5);
830    /// # }
831    /// # }
832    /// ```
833    pub async fn get_last_rig_messages(&self, n: usize) -> Vec<Message> {
834        let messages = self.get_last_messages(n).await;
835        messages
836            .iter()
837            .map(|msg| self.to_rig_message(msg))
838            .collect()
839    }
840
841    #[cfg(feature = "rig")]
842    /// Convert a SerializableMessage to a rig::completion::Message.
843    ///
844    /// This method is only available when the "rig" feature is enabled.
845    fn to_rig_message(&self, msg: &SerializableMessage) -> Message {
846        match msg.role {
847            MessageRole::User => Message::user(msg.content.clone()),
848            MessageRole::Assistant => Message::assistant(msg.content.clone()),
849            // rig doesn't have a system message type, so we'll treat it as a user message
850            // with a system prefix
851            MessageRole::System => Message::user(format!("[SYSTEM] {}", msg.content)),
852        }
853    }
854}
855
856impl Default for Context {
857    fn default() -> Self {
858        Self::new()
859    }
860}
861
862// Serialization support for Context
863impl Serialize for Context {
864    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
865    where
866        S: serde::Serializer,
867    {
868        // Convert DashMap to HashMap for serialization
869        let data: std::collections::HashMap<String, Value> = self
870            .data
871            .iter()
872            .map(|entry| (entry.key().clone(), entry.value().clone()))
873            .collect();
874
875        let chat_history = if let Ok(history) = self.chat_history.read() {
876            history.clone()
877        } else {
878            ChatHistory::new()
879        };
880
881        let context_data = ContextData { data, chat_history };
882        context_data.serialize(serializer)
883    }
884}
885
886impl<'de> Deserialize<'de> for Context {
887    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
888    where
889        D: serde::Deserializer<'de>,
890    {
891        let context_data = ContextData::deserialize(deserializer)?;
892
893        let data = Arc::new(DashMap::new());
894        for (key, value) in context_data.data {
895            data.insert(key, value);
896        }
897
898        let chat_history = Arc::new(RwLock::new(context_data.chat_history));
899
900        Ok(Context { data, chat_history })
901    }
902}
903
904#[cfg(test)]
905mod tests {
906    use super::*;
907
908    #[tokio::test]
909    async fn test_basic_context_operations() {
910        let context = Context::new();
911
912        context.set("key", "value").await;
913        let value: Option<String> = context.get("key").await;
914        assert_eq!(value, Some("value".to_string()));
915    }
916
917    #[tokio::test]
918    async fn test_chat_history_operations() {
919        let context = Context::new();
920
921        assert!(context.is_chat_history_empty().await);
922        assert_eq!(context.chat_history_len().await, 0);
923
924        context.add_user_message("Hello".to_string()).await;
925        context.add_assistant_message("Hi there!".to_string()).await;
926
927        assert!(!context.is_chat_history_empty().await);
928        assert_eq!(context.chat_history_len().await, 2);
929
930        let history = context.get_chat_history().await;
931        assert_eq!(history.len(), 2);
932        assert_eq!(history.messages()[0].content, "Hello");
933        assert_eq!(history.messages()[0].role, MessageRole::User);
934        assert_eq!(history.messages()[1].content, "Hi there!");
935        assert_eq!(history.messages()[1].role, MessageRole::Assistant);
936    }
937
938    #[tokio::test]
939    async fn test_chat_history_max_messages() {
940        let context = Context::with_max_chat_messages(2);
941
942        context.add_user_message("Message 1".to_string()).await;
943        context
944            .add_assistant_message("Response 1".to_string())
945            .await;
946        context.add_user_message("Message 2".to_string()).await;
947
948        let history = context.get_chat_history().await;
949        assert_eq!(history.len(), 2);
950        assert_eq!(history.messages()[0].content, "Response 1");
951        assert_eq!(history.messages()[1].content, "Message 2");
952    }
953
954    #[tokio::test]
955    async fn test_last_messages() {
956        let context = Context::new();
957
958        context.add_user_message("Message 1".to_string()).await;
959        context
960            .add_assistant_message("Response 1".to_string())
961            .await;
962        context.add_user_message("Message 2".to_string()).await;
963        context
964            .add_assistant_message("Response 2".to_string())
965            .await;
966
967        let last_two = context.get_last_messages(2).await;
968        assert_eq!(last_two.len(), 2);
969        assert_eq!(last_two[0].content, "Message 2");
970        assert_eq!(last_two[1].content, "Response 2");
971    }
972
973    #[tokio::test]
974    async fn test_context_serialization() {
975        let context = Context::new();
976        context.set("key", "value").await;
977        context.add_user_message("test message".to_string()).await;
978
979        let serialized = serde_json::to_string(&context).unwrap();
980        let deserialized: Context = serde_json::from_str(&serialized).unwrap();
981
982        let value: Option<String> = deserialized.get("key").await;
983        assert_eq!(value, Some("value".to_string()));
984
985        assert_eq!(deserialized.chat_history_len().await, 1);
986        let history = deserialized.get_chat_history().await;
987        assert_eq!(history.messages()[0].content, "test message");
988        assert_eq!(history.messages()[0].role, MessageRole::User);
989    }
990
991    #[test]
992    fn test_serializable_message() {
993        let msg = SerializableMessage::user("test content".to_string());
994        assert_eq!(msg.role, MessageRole::User);
995        assert_eq!(msg.content, "test content");
996
997        let serialized = serde_json::to_string(&msg).unwrap();
998        let deserialized: SerializableMessage = serde_json::from_str(&serialized).unwrap();
999
1000        assert_eq!(msg.role, deserialized.role);
1001        assert_eq!(msg.content, deserialized.content);
1002    }
1003
1004    #[test]
1005    fn test_chat_history_serialization() {
1006        let mut history = ChatHistory::new();
1007        history.add_user_message("Hello".to_string());
1008        history.add_assistant_message("Hi!".to_string());
1009
1010        let serialized = serde_json::to_string(&history).unwrap();
1011        let deserialized: ChatHistory = serde_json::from_str(&serialized).unwrap();
1012
1013        assert_eq!(deserialized.len(), 2);
1014        assert_eq!(deserialized.messages()[0].content, "Hello");
1015        assert_eq!(deserialized.messages()[1].content, "Hi!");
1016    }
1017
1018    #[cfg(feature = "rig")]
1019    #[tokio::test]
1020    async fn test_rig_integration() {
1021        let context = Context::new();
1022
1023        context.add_user_message("Hello".to_string()).await;
1024        context.add_assistant_message("Hi there!".to_string()).await;
1025        context
1026            .add_system_message("System message".to_string())
1027            .await;
1028
1029        let rig_messages = context.get_rig_messages().await;
1030        assert_eq!(rig_messages.len(), 3);
1031
1032        let last_two = context.get_last_rig_messages(2).await;
1033        assert_eq!(last_two.len(), 2);
1034
1035        // Test that the conversion works without panicking
1036        // We can't easily verify the content since rig::Message doesn't expose it directly
1037        // but we can verify the conversion completes without error
1038        let _debug_output = format!("{:?}", rig_messages);
1039        // Test passes if we reach this point without panicking
1040    }
1041}