graph_flow/context.rs
1//! Context and state management for workflows.
2//!
3//! This module provides thread-safe state management across workflow tasks,
4//! including regular data storage and specialized chat history management.
5//!
6//! # Examples
7//!
8//! ## Basic Context Usage
9//!
10//! ```rust
11//! use graph_flow::Context;
12//!
13//! # #[tokio::main]
14//! # async fn main() {
15//! let context = Context::new();
16//!
17//! // Store different types of data
18//! context.set("user_id", 12345).await;
19//! context.set("name", "Alice".to_string()).await;
20//! context.set("active", true).await;
21//!
22//! // Retrieve data with type safety
23//! let user_id: Option<i32> = context.get("user_id").await;
24//! let name: Option<String> = context.get("name").await;
25//! let active: Option<bool> = context.get("active").await;
26//!
27//! // Synchronous access (useful in edge conditions)
28//! let name_sync: Option<String> = context.get_sync("name");
29//! # }
30//! ```
31//!
32//! ## Chat History Management
33//!
34//! ```rust
35//! use graph_flow::Context;
36//!
37//! # #[tokio::main]
38//! # async fn main() {
39//! let context = Context::new();
40//!
41//! // Add messages to chat history
42//! context.add_user_message("Hello, assistant!".to_string()).await;
43//! context.add_assistant_message("Hello! How can I help you?".to_string()).await;
44//! context.add_system_message("User session started".to_string()).await;
45//!
46//! // Get chat history
47//! let history = context.get_chat_history().await;
48//! let all_messages = context.get_all_messages().await;
49//! let last_5 = context.get_last_messages(5).await;
50//!
51//! // Check history status
52//! let count = context.chat_history_len().await;
53//! let is_empty = context.is_chat_history_empty().await;
54//! # }
55//! ```
56//!
57//! ## Context with Message Limits
58//!
59//! ```rust
60//! use graph_flow::Context;
61//!
62//! # #[tokio::main]
63//! # async fn main() {
64//! // Create context with maximum 100 messages
65//! let context = Context::with_max_chat_messages(100);
66//!
67//! // Messages will be automatically pruned when limit is exceeded
68//! for i in 0..150 {
69//! context.add_user_message(format!("Message {}", i)).await;
70//! }
71//!
72//! // Only the last 100 messages are kept
73//! assert_eq!(context.chat_history_len().await, 100);
74//! # }
75//! ```
76//!
77//! ## LLM Integration (with `rig` feature)
78//!
79//! ```rust
80//! # #[cfg(feature = "rig")]
81//! # {
82//! use graph_flow::Context;
83//!
84//! # #[tokio::main]
85//! # async fn main() {
86//! let context = Context::new();
87//!
88//! context.add_user_message("What is the capital of France?".to_string()).await;
89//! context.add_assistant_message("The capital of France is Paris.".to_string()).await;
90//!
91//! // Get messages in rig format for LLM calls
92//! let rig_messages = context.get_rig_messages().await;
93//! let recent_messages = context.get_last_rig_messages(10).await;
94//!
95//! // Use with rig's completion API
96//! // let response = agent.completion(&rig_messages).await?;
97//! # }
98//! # }
99//! ```
100
101use chrono::{DateTime, Utc};
102use dashmap::DashMap;
103use serde::{Deserialize, Serialize};
104use serde_json::Value;
105use std::sync::{Arc, RwLock};
106
107#[cfg(feature = "rig")]
108use rig::completion::Message;
109
110/// Represents the role of a message in a conversation.
111///
112/// Used in chat history to distinguish between different types of messages.
113#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
114pub enum MessageRole {
115 /// Message from a user/human
116 User,
117 /// Message from an assistant/AI
118 Assistant,
119 /// System message (instructions, status updates, etc.)
120 System,
121}
122
123/// A serializable message that can be converted to/from rig::completion::Message.
124///
125/// This struct provides a unified message format that can be stored, serialized,
126/// and optionally converted to other formats like rig's Message type.
127///
128/// # Examples
129///
130/// ```rust
131/// use graph_flow::{SerializableMessage, MessageRole};
132///
133/// // Create different types of messages
134/// let user_msg = SerializableMessage::user("Hello!".to_string());
135/// let assistant_msg = SerializableMessage::assistant("Hi there!".to_string());
136/// let system_msg = SerializableMessage::system("Session started".to_string());
137///
138/// // Access message properties
139/// assert_eq!(user_msg.role, MessageRole::User);
140/// assert_eq!(user_msg.content, "Hello!");
141/// ```
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct SerializableMessage {
144 /// The role of the message sender
145 pub role: MessageRole,
146 /// The content of the message
147 pub content: String,
148 /// When the message was created
149 pub timestamp: DateTime<Utc>,
150}
151
152impl SerializableMessage {
153 /// Create a new message with the specified role and content.
154 ///
155 /// The timestamp is automatically set to the current UTC time.
156 pub fn new(role: MessageRole, content: String) -> Self {
157 Self {
158 role,
159 content,
160 timestamp: Utc::now(),
161 }
162 }
163
164 /// Create a new user message.
165 ///
166 /// # Examples
167 ///
168 /// ```rust
169 /// use graph_flow::SerializableMessage;
170 ///
171 /// let msg = SerializableMessage::user("Hello, world!".to_string());
172 /// ```
173 pub fn user(content: String) -> Self {
174 Self::new(MessageRole::User, content)
175 }
176
177 /// Create a new assistant message.
178 ///
179 /// # Examples
180 ///
181 /// ```rust
182 /// use graph_flow::SerializableMessage;
183 ///
184 /// let msg = SerializableMessage::assistant("Hello! How can I help?".to_string());
185 /// ```
186 pub fn assistant(content: String) -> Self {
187 Self::new(MessageRole::Assistant, content)
188 }
189
190 /// Create a new system message.
191 ///
192 /// # Examples
193 ///
194 /// ```rust
195 /// use graph_flow::SerializableMessage;
196 ///
197 /// let msg = SerializableMessage::system("User logged in".to_string());
198 /// ```
199 pub fn system(content: String) -> Self {
200 Self::new(MessageRole::System, content)
201 }
202}
203
204/// Container for managing chat history with serialization support.
205///
206/// Provides automatic message limit management and convenient methods
207/// for adding and retrieving messages.
208///
209/// # Examples
210///
211/// ```rust
212/// use graph_flow::ChatHistory;
213///
214/// let mut history = ChatHistory::new();
215/// history.add_user_message("Hello".to_string());
216/// history.add_assistant_message("Hi there!".to_string());
217///
218/// assert_eq!(history.len(), 2);
219/// assert!(!history.is_empty());
220/// ```
221#[derive(Debug, Clone, Serialize, Deserialize, Default)]
222pub struct ChatHistory {
223 messages: Vec<SerializableMessage>,
224 max_messages: Option<usize>,
225}
226
227impl ChatHistory {
228 /// Create a new empty chat history with a default limit of 1000 messages.
229 pub fn new() -> Self {
230 Self {
231 messages: Vec::new(),
232 max_messages: Some(1000), // Default limit to prevent unbounded growth
233 }
234 }
235
236 /// Create a new chat history with a maximum message limit.
237 ///
238 /// When the limit is exceeded, older messages are automatically removed.
239 ///
240 /// # Examples
241 ///
242 /// ```rust
243 /// use graph_flow::ChatHistory;
244 ///
245 /// let mut history = ChatHistory::with_max_messages(10);
246 ///
247 /// // Add 15 messages
248 /// for i in 0..15 {
249 /// history.add_user_message(format!("Message {}", i));
250 /// }
251 ///
252 /// // Only the last 10 are kept
253 /// assert_eq!(history.len(), 10);
254 /// ```
255 pub fn with_max_messages(max: usize) -> Self {
256 Self {
257 messages: Vec::new(),
258 max_messages: Some(max),
259 }
260 }
261
262 /// Add a user message to the chat history.
263 pub fn add_user_message(&mut self, content: String) {
264 self.add_message(SerializableMessage::user(content));
265 }
266
267 /// Add an assistant message to the chat history.
268 pub fn add_assistant_message(&mut self, content: String) {
269 self.add_message(SerializableMessage::assistant(content));
270 }
271
272 /// Add a system message to the chat history.
273 pub fn add_system_message(&mut self, content: String) {
274 self.add_message(SerializableMessage::system(content));
275 }
276
277 /// Add a message to the chat history, respecting max_messages limit.
278 fn add_message(&mut self, message: SerializableMessage) {
279 self.messages.push(message);
280
281 if let Some(max) = self.max_messages {
282 if self.messages.len() > max {
283 self.messages.drain(0..(self.messages.len() - max));
284 }
285 }
286 }
287
288 /// Clear all messages from the chat history.
289 pub fn clear(&mut self) {
290 self.messages.clear();
291 }
292
293 /// Get the number of messages in the chat history.
294 pub fn len(&self) -> usize {
295 self.messages.len()
296 }
297
298 /// Check if the chat history is empty.
299 pub fn is_empty(&self) -> bool {
300 self.messages.is_empty()
301 }
302
303 /// Get a reference to all messages.
304 pub fn messages(&self) -> &[SerializableMessage] {
305 &self.messages
306 }
307
308 /// Get the last N messages.
309 ///
310 /// If N is greater than the total number of messages, all messages are returned.
311 ///
312 /// # Examples
313 ///
314 /// ```rust
315 /// use graph_flow::ChatHistory;
316 ///
317 /// let mut history = ChatHistory::new();
318 /// history.add_user_message("Message 1".to_string());
319 /// history.add_user_message("Message 2".to_string());
320 /// history.add_user_message("Message 3".to_string());
321 ///
322 /// let last_two = history.last_messages(2);
323 /// assert_eq!(last_two.len(), 2);
324 /// assert_eq!(last_two[0].content, "Message 2");
325 /// assert_eq!(last_two[1].content, "Message 3");
326 /// ```
327 pub fn last_messages(&self, n: usize) -> &[SerializableMessage] {
328 let start = if self.messages.len() > n {
329 self.messages.len() - n
330 } else {
331 0
332 };
333 &self.messages[start..]
334 }
335}
336
337/// Helper struct for serializing/deserializing Context
338#[derive(Serialize, Deserialize)]
339struct ContextData {
340 data: std::collections::HashMap<String, Value>,
341 chat_history: ChatHistory,
342}
343
344/// Context for sharing data between tasks in a graph execution.
345///
346/// Provides thread-safe storage for workflow state and dedicated chat history
347/// management. The context is shared across all tasks in a workflow execution.
348///
349/// # Examples
350///
351/// ## Basic Usage
352///
353/// ```rust
354/// use graph_flow::Context;
355///
356/// # #[tokio::main]
357/// # async fn main() {
358/// let context = Context::new();
359///
360/// // Store different types of data
361/// context.set("user_id", 12345).await;
362/// context.set("name", "Alice".to_string()).await;
363/// context.set("settings", vec!["opt1", "opt2"]).await;
364///
365/// // Retrieve data
366/// let user_id: Option<i32> = context.get("user_id").await;
367/// let name: Option<String> = context.get("name").await;
368/// let settings: Option<Vec<String>> = context.get("settings").await;
369/// # }
370/// ```
371///
372/// ## Chat History
373///
374/// ```rust
375/// use graph_flow::Context;
376///
377/// # #[tokio::main]
378/// # async fn main() {
379/// let context = Context::new();
380///
381/// // Add messages
382/// context.add_user_message("Hello".to_string()).await;
383/// context.add_assistant_message("Hi there!".to_string()).await;
384///
385/// // Get message history
386/// let history = context.get_chat_history().await;
387/// let last_5 = context.get_last_messages(5).await;
388/// # }
389/// ```
390#[derive(Clone, Debug)]
391pub struct Context {
392 data: Arc<DashMap<String, Value>>,
393 chat_history: Arc<RwLock<ChatHistory>>,
394}
395
396impl Context {
397 /// Create a new empty context.
398 pub fn new() -> Self {
399 Self {
400 data: Arc::new(DashMap::new()),
401 chat_history: Arc::new(RwLock::new(ChatHistory::new())),
402 }
403 }
404
405 /// Create a new context with a maximum chat history size.
406 ///
407 /// When the chat history exceeds this size, older messages are automatically removed.
408 ///
409 /// # Examples
410 ///
411 /// ```rust
412 /// use graph_flow::Context;
413 ///
414 /// # #[tokio::main]
415 /// # async fn main() {
416 /// let context = Context::with_max_chat_messages(50);
417 ///
418 /// // Chat history will be limited to 50 messages
419 /// for i in 0..100 {
420 /// context.add_user_message(format!("Message {}", i)).await;
421 /// }
422 ///
423 /// assert_eq!(context.chat_history_len().await, 50);
424 /// # }
425 /// ```
426 pub fn with_max_chat_messages(max: usize) -> Self {
427 Self {
428 data: Arc::new(DashMap::new()),
429 chat_history: Arc::new(RwLock::new(ChatHistory::with_max_messages(max))),
430 }
431 }
432
433 // Regular context methods (unchanged API)
434
435 /// Set a value in the context.
436 ///
437 /// The value must be serializable. Most common Rust types are supported.
438 ///
439 /// # Examples
440 ///
441 /// ```rust
442 /// use graph_flow::Context;
443 /// use serde::{Serialize, Deserialize};
444 ///
445 /// #[derive(Serialize, Deserialize)]
446 /// struct UserData {
447 /// id: u32,
448 /// name: String,
449 /// }
450 ///
451 /// # #[tokio::main]
452 /// # async fn main() {
453 /// let context = Context::new();
454 ///
455 /// // Store primitive types
456 /// context.set("count", 42).await;
457 /// context.set("name", "Alice".to_string()).await;
458 /// context.set("active", true).await;
459 ///
460 /// // Store complex types
461 /// let user = UserData { id: 1, name: "Bob".to_string() };
462 /// context.set("user", user).await;
463 /// # }
464 /// ```
465 pub async fn set(&self, key: impl Into<String>, value: impl serde::Serialize) {
466 let value = serde_json::to_value(value).expect("Failed to serialize value");
467 self.data.insert(key.into(), value);
468 }
469
470 /// Get a value from the context.
471 ///
472 /// Returns `None` if the key doesn't exist or if deserialization fails.
473 ///
474 /// # Examples
475 ///
476 /// ```rust
477 /// use graph_flow::Context;
478 ///
479 /// # #[tokio::main]
480 /// # async fn main() {
481 /// let context = Context::new();
482 /// context.set("count", 42).await;
483 ///
484 /// let count: Option<i32> = context.get("count").await;
485 /// assert_eq!(count, Some(42));
486 ///
487 /// let missing: Option<String> = context.get("missing").await;
488 /// assert_eq!(missing, None);
489 /// # }
490 /// ```
491 pub async fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
492 self.data
493 .get(key)
494 .and_then(|v| serde_json::from_value(v.clone()).ok())
495 }
496
497 /// Remove a value from the context.
498 ///
499 /// Returns the removed value if it existed.
500 ///
501 /// # Examples
502 ///
503 /// ```rust
504 /// use graph_flow::Context;
505 ///
506 /// # #[tokio::main]
507 /// # async fn main() {
508 /// let context = Context::new();
509 /// context.set("temp", "value".to_string()).await;
510 ///
511 /// let removed = context.remove("temp").await;
512 /// assert!(removed.is_some());
513 ///
514 /// let value: Option<String> = context.get("temp").await;
515 /// assert_eq!(value, None);
516 /// # }
517 /// ```
518 pub async fn remove(&self, key: &str) -> Option<Value> {
519 self.data.remove(key).map(|(_, v)| v)
520 }
521
522 /// Clear all regular context data (does not affect chat history).
523 ///
524 /// # Examples
525 ///
526 /// ```rust
527 /// use graph_flow::Context;
528 ///
529 /// # #[tokio::main]
530 /// # async fn main() {
531 /// let context = Context::new();
532 /// context.set("key1", "value1".to_string()).await;
533 /// context.set("key2", "value2".to_string()).await;
534 /// context.add_user_message("Hello".to_string()).await;
535 ///
536 /// context.clear().await;
537 ///
538 /// // Regular data is cleared
539 /// let value: Option<String> = context.get("key1").await;
540 /// assert_eq!(value, None);
541 ///
542 /// // Chat history is preserved
543 /// assert_eq!(context.chat_history_len().await, 1);
544 /// # }
545 /// ```
546 pub async fn clear(&self) {
547 self.data.clear();
548 }
549
550 /// Synchronous version of get for use in edge conditions.
551 ///
552 /// This method should only be used when you're certain the data exists
553 /// and when async is not available (e.g., in edge condition closures).
554 ///
555 /// # Examples
556 ///
557 /// ```rust
558 /// use graph_flow::{Context, GraphBuilder};
559 ///
560 /// # #[tokio::main]
561 /// # async fn main() {
562 /// let context = Context::new();
563 /// context.set("condition", true).await;
564 ///
565 /// // Used in edge conditions
566 /// let graph = GraphBuilder::new("test")
567 /// .add_conditional_edge(
568 /// "task1",
569 /// |ctx| ctx.get_sync::<bool>("condition").unwrap_or(false),
570 /// "task2",
571 /// "task3"
572 /// );
573 /// # }
574 /// ```
575 pub fn get_sync<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
576 self.data
577 .get(key)
578 .and_then(|v| serde_json::from_value(v.clone()).ok())
579 }
580
581 /// Synchronous version of set for use when async is not available.
582 ///
583 /// # Examples
584 ///
585 /// ```rust
586 /// use graph_flow::Context;
587 ///
588 /// let context = Context::new();
589 /// context.set_sync("key", "value".to_string());
590 ///
591 /// let value: Option<String> = context.get_sync("key");
592 /// assert_eq!(value, Some("value".to_string()));
593 /// ```
594 pub fn set_sync(&self, key: impl Into<String>, value: impl serde::Serialize) {
595 let value = serde_json::to_value(value).expect("Failed to serialize value");
596 self.data.insert(key.into(), value);
597 }
598
599 // Chat history methods
600
601 /// Add a user message to the chat history.
602 ///
603 /// # Examples
604 ///
605 /// ```rust
606 /// use graph_flow::Context;
607 ///
608 /// # #[tokio::main]
609 /// # async fn main() {
610 /// let context = Context::new();
611 /// context.add_user_message("Hello, assistant!".to_string()).await;
612 /// # }
613 /// ```
614 pub async fn add_user_message(&self, content: String) {
615 if let Ok(mut history) = self.chat_history.write() {
616 history.add_user_message(content);
617 }
618 }
619
620 /// Add an assistant message to the chat history.
621 ///
622 /// # Examples
623 ///
624 /// ```rust
625 /// use graph_flow::Context;
626 ///
627 /// # #[tokio::main]
628 /// # async fn main() {
629 /// let context = Context::new();
630 /// context.add_assistant_message("Hello! How can I help you?".to_string()).await;
631 /// # }
632 /// ```
633 pub async fn add_assistant_message(&self, content: String) {
634 if let Ok(mut history) = self.chat_history.write() {
635 history.add_assistant_message(content);
636 }
637 }
638
639 /// Add a system message to the chat history.
640 ///
641 /// # Examples
642 ///
643 /// ```rust
644 /// use graph_flow::Context;
645 ///
646 /// # #[tokio::main]
647 /// # async fn main() {
648 /// let context = Context::new();
649 /// context.add_system_message("Session started".to_string()).await;
650 /// # }
651 /// ```
652 pub async fn add_system_message(&self, content: String) {
653 if let Ok(mut history) = self.chat_history.write() {
654 history.add_system_message(content);
655 }
656 }
657
658 /// Get a clone of the current chat history.
659 ///
660 /// # Examples
661 ///
662 /// ```rust
663 /// use graph_flow::Context;
664 ///
665 /// # #[tokio::main]
666 /// # async fn main() {
667 /// let context = Context::new();
668 /// context.add_user_message("Hello".to_string()).await;
669 ///
670 /// let history = context.get_chat_history().await;
671 /// assert_eq!(history.len(), 1);
672 /// # }
673 /// ```
674 pub async fn get_chat_history(&self) -> ChatHistory {
675 if let Ok(history) = self.chat_history.read() {
676 history.clone()
677 } else {
678 ChatHistory::new()
679 }
680 }
681
682 /// Clear the chat history.
683 ///
684 /// # Examples
685 ///
686 /// ```rust
687 /// use graph_flow::Context;
688 ///
689 /// # #[tokio::main]
690 /// # async fn main() {
691 /// let context = Context::new();
692 /// context.add_user_message("Hello".to_string()).await;
693 /// assert_eq!(context.chat_history_len().await, 1);
694 ///
695 /// context.clear_chat_history().await;
696 /// assert_eq!(context.chat_history_len().await, 0);
697 /// # }
698 /// ```
699 pub async fn clear_chat_history(&self) {
700 if let Ok(mut history) = self.chat_history.write() {
701 history.clear();
702 }
703 }
704
705 /// Get the number of messages in the chat history.
706 pub async fn chat_history_len(&self) -> usize {
707 if let Ok(history) = self.chat_history.read() {
708 history.len()
709 } else {
710 0
711 }
712 }
713
714 /// Check if the chat history is empty.
715 pub async fn is_chat_history_empty(&self) -> bool {
716 if let Ok(history) = self.chat_history.read() {
717 history.is_empty()
718 } else {
719 true
720 }
721 }
722
723 /// Get the last N messages from chat history.
724 ///
725 /// # Examples
726 ///
727 /// ```rust
728 /// use graph_flow::Context;
729 ///
730 /// # #[tokio::main]
731 /// # async fn main() {
732 /// let context = Context::new();
733 /// context.add_user_message("Message 1".to_string()).await;
734 /// context.add_user_message("Message 2".to_string()).await;
735 /// context.add_user_message("Message 3".to_string()).await;
736 ///
737 /// let last_two = context.get_last_messages(2).await;
738 /// assert_eq!(last_two.len(), 2);
739 /// assert_eq!(last_two[0].content, "Message 2");
740 /// assert_eq!(last_two[1].content, "Message 3");
741 /// # }
742 /// ```
743 pub async fn get_last_messages(&self, n: usize) -> Vec<SerializableMessage> {
744 if let Ok(history) = self.chat_history.read() {
745 history.last_messages(n).to_vec()
746 } else {
747 Vec::new()
748 }
749 }
750
751 /// Get all messages from chat history as SerializableMessage.
752 ///
753 /// # Examples
754 ///
755 /// ```rust
756 /// use graph_flow::Context;
757 ///
758 /// # #[tokio::main]
759 /// # async fn main() {
760 /// let context = Context::new();
761 /// context.add_user_message("Hello".to_string()).await;
762 /// context.add_assistant_message("Hi there!".to_string()).await;
763 ///
764 /// let all_messages = context.get_all_messages().await;
765 /// assert_eq!(all_messages.len(), 2);
766 /// # }
767 /// ```
768 pub async fn get_all_messages(&self) -> Vec<SerializableMessage> {
769 if let Ok(history) = self.chat_history.read() {
770 history.messages().to_vec()
771 } else {
772 Vec::new()
773 }
774 }
775
776 // Rig integration methods (only available when rig feature is enabled)
777
778 #[cfg(feature = "rig")]
779 /// Get all chat history messages converted to rig::completion::Message format.
780 ///
781 /// This method is only available when the "rig" feature is enabled.
782 ///
783 /// # Examples
784 ///
785 /// ```rust
786 /// # #[cfg(feature = "rig")]
787 /// # {
788 /// use graph_flow::Context;
789 ///
790 /// # #[tokio::main]
791 /// # async fn main() {
792 /// let context = Context::new();
793 /// context.add_user_message("Hello".to_string()).await;
794 /// context.add_assistant_message("Hi there!".to_string()).await;
795 ///
796 /// let rig_messages = context.get_rig_messages().await;
797 /// assert_eq!(rig_messages.len(), 2);
798 /// # }
799 /// # }
800 /// ```
801 pub async fn get_rig_messages(&self) -> Vec<Message> {
802 let messages = self.get_all_messages().await;
803 messages
804 .iter()
805 .map(|msg| self.to_rig_message(msg))
806 .collect()
807 }
808
809 #[cfg(feature = "rig")]
810 /// Get the last N messages converted to rig::completion::Message format.
811 ///
812 /// This method is only available when the "rig" feature is enabled.
813 ///
814 /// # Examples
815 ///
816 /// ```rust
817 /// # #[cfg(feature = "rig")]
818 /// # {
819 /// use graph_flow::Context;
820 ///
821 /// # #[tokio::main]
822 /// # async fn main() {
823 /// let context = Context::new();
824 /// for i in 0..10 {
825 /// context.add_user_message(format!("Message {}", i)).await;
826 /// }
827 ///
828 /// let last_5 = context.get_last_rig_messages(5).await;
829 /// assert_eq!(last_5.len(), 5);
830 /// # }
831 /// # }
832 /// ```
833 pub async fn get_last_rig_messages(&self, n: usize) -> Vec<Message> {
834 let messages = self.get_last_messages(n).await;
835 messages
836 .iter()
837 .map(|msg| self.to_rig_message(msg))
838 .collect()
839 }
840
841 #[cfg(feature = "rig")]
842 /// Convert a SerializableMessage to a rig::completion::Message.
843 ///
844 /// This method is only available when the "rig" feature is enabled.
845 fn to_rig_message(&self, msg: &SerializableMessage) -> Message {
846 match msg.role {
847 MessageRole::User => Message::user(msg.content.clone()),
848 MessageRole::Assistant => Message::assistant(msg.content.clone()),
849 // rig doesn't have a system message type, so we'll treat it as a user message
850 // with a system prefix
851 MessageRole::System => Message::user(format!("[SYSTEM] {}", msg.content)),
852 }
853 }
854}
855
856impl Default for Context {
857 fn default() -> Self {
858 Self::new()
859 }
860}
861
862// Serialization support for Context
863impl Serialize for Context {
864 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
865 where
866 S: serde::Serializer,
867 {
868 // Convert DashMap to HashMap for serialization
869 let data: std::collections::HashMap<String, Value> = self
870 .data
871 .iter()
872 .map(|entry| (entry.key().clone(), entry.value().clone()))
873 .collect();
874
875 let chat_history = if let Ok(history) = self.chat_history.read() {
876 history.clone()
877 } else {
878 ChatHistory::new()
879 };
880
881 let context_data = ContextData { data, chat_history };
882 context_data.serialize(serializer)
883 }
884}
885
886impl<'de> Deserialize<'de> for Context {
887 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
888 where
889 D: serde::Deserializer<'de>,
890 {
891 let context_data = ContextData::deserialize(deserializer)?;
892
893 let data = Arc::new(DashMap::new());
894 for (key, value) in context_data.data {
895 data.insert(key, value);
896 }
897
898 let chat_history = Arc::new(RwLock::new(context_data.chat_history));
899
900 Ok(Context { data, chat_history })
901 }
902}
903
904#[cfg(test)]
905mod tests {
906 use super::*;
907
908 #[tokio::test]
909 async fn test_basic_context_operations() {
910 let context = Context::new();
911
912 context.set("key", "value").await;
913 let value: Option<String> = context.get("key").await;
914 assert_eq!(value, Some("value".to_string()));
915 }
916
917 #[tokio::test]
918 async fn test_chat_history_operations() {
919 let context = Context::new();
920
921 assert!(context.is_chat_history_empty().await);
922 assert_eq!(context.chat_history_len().await, 0);
923
924 context.add_user_message("Hello".to_string()).await;
925 context.add_assistant_message("Hi there!".to_string()).await;
926
927 assert!(!context.is_chat_history_empty().await);
928 assert_eq!(context.chat_history_len().await, 2);
929
930 let history = context.get_chat_history().await;
931 assert_eq!(history.len(), 2);
932 assert_eq!(history.messages()[0].content, "Hello");
933 assert_eq!(history.messages()[0].role, MessageRole::User);
934 assert_eq!(history.messages()[1].content, "Hi there!");
935 assert_eq!(history.messages()[1].role, MessageRole::Assistant);
936 }
937
938 #[tokio::test]
939 async fn test_chat_history_max_messages() {
940 let context = Context::with_max_chat_messages(2);
941
942 context.add_user_message("Message 1".to_string()).await;
943 context
944 .add_assistant_message("Response 1".to_string())
945 .await;
946 context.add_user_message("Message 2".to_string()).await;
947
948 let history = context.get_chat_history().await;
949 assert_eq!(history.len(), 2);
950 assert_eq!(history.messages()[0].content, "Response 1");
951 assert_eq!(history.messages()[1].content, "Message 2");
952 }
953
954 #[tokio::test]
955 async fn test_last_messages() {
956 let context = Context::new();
957
958 context.add_user_message("Message 1".to_string()).await;
959 context
960 .add_assistant_message("Response 1".to_string())
961 .await;
962 context.add_user_message("Message 2".to_string()).await;
963 context
964 .add_assistant_message("Response 2".to_string())
965 .await;
966
967 let last_two = context.get_last_messages(2).await;
968 assert_eq!(last_two.len(), 2);
969 assert_eq!(last_two[0].content, "Message 2");
970 assert_eq!(last_two[1].content, "Response 2");
971 }
972
973 #[tokio::test]
974 async fn test_context_serialization() {
975 let context = Context::new();
976 context.set("key", "value").await;
977 context.add_user_message("test message".to_string()).await;
978
979 let serialized = serde_json::to_string(&context).unwrap();
980 let deserialized: Context = serde_json::from_str(&serialized).unwrap();
981
982 let value: Option<String> = deserialized.get("key").await;
983 assert_eq!(value, Some("value".to_string()));
984
985 assert_eq!(deserialized.chat_history_len().await, 1);
986 let history = deserialized.get_chat_history().await;
987 assert_eq!(history.messages()[0].content, "test message");
988 assert_eq!(history.messages()[0].role, MessageRole::User);
989 }
990
991 #[test]
992 fn test_serializable_message() {
993 let msg = SerializableMessage::user("test content".to_string());
994 assert_eq!(msg.role, MessageRole::User);
995 assert_eq!(msg.content, "test content");
996
997 let serialized = serde_json::to_string(&msg).unwrap();
998 let deserialized: SerializableMessage = serde_json::from_str(&serialized).unwrap();
999
1000 assert_eq!(msg.role, deserialized.role);
1001 assert_eq!(msg.content, deserialized.content);
1002 }
1003
1004 #[test]
1005 fn test_chat_history_serialization() {
1006 let mut history = ChatHistory::new();
1007 history.add_user_message("Hello".to_string());
1008 history.add_assistant_message("Hi!".to_string());
1009
1010 let serialized = serde_json::to_string(&history).unwrap();
1011 let deserialized: ChatHistory = serde_json::from_str(&serialized).unwrap();
1012
1013 assert_eq!(deserialized.len(), 2);
1014 assert_eq!(deserialized.messages()[0].content, "Hello");
1015 assert_eq!(deserialized.messages()[1].content, "Hi!");
1016 }
1017
1018 #[cfg(feature = "rig")]
1019 #[tokio::test]
1020 async fn test_rig_integration() {
1021 let context = Context::new();
1022
1023 context.add_user_message("Hello".to_string()).await;
1024 context.add_assistant_message("Hi there!".to_string()).await;
1025 context
1026 .add_system_message("System message".to_string())
1027 .await;
1028
1029 let rig_messages = context.get_rig_messages().await;
1030 assert_eq!(rig_messages.len(), 3);
1031
1032 let last_two = context.get_last_rig_messages(2).await;
1033 assert_eq!(last_two.len(), 2);
1034
1035 // Test that the conversion works without panicking
1036 // We can't easily verify the content since rig::Message doesn't expose it directly
1037 // but we can verify the conversion completes without error
1038 let _debug_output = format!("{:?}", rig_messages);
1039 // Test passes if we reach this point without panicking
1040 }
1041}