1use crate::core::SessionId;
83use chrono::{DateTime, Utc};
84use serde::{Deserialize, Serialize};
85use std::collections::HashMap;
86use std::io::Read as IoRead;
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90#[serde(default)]
91pub struct SessionConfig {
92 pub max_tokens: usize,
94 pub keep_recent_messages: usize,
96 pub compression_level: i32,
98 pub compression_threshold: f32,
100}
101
102impl Default for SessionConfig {
103 fn default() -> Self {
104 Self {
105 max_tokens: 100_000, keep_recent_messages: 20, compression_level: 3, compression_threshold: 0.8, }
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct SessionContext {
116 pub session_id: SessionId,
118 pub conversation_history: TokenEfficientHistory,
120 pub task_context: TaskContext,
122 pub agent_state: AgentState,
124 pub workspace_state: WorkspaceState,
126 pub metadata: HashMap<String, serde_json::Value>,
128 pub config: SessionConfig,
130}
131
132impl SessionContext {
133 pub fn new(session_id: SessionId) -> Self {
135 let config = SessionConfig::default();
136 let mut conversation_history = TokenEfficientHistory::new();
137 conversation_history.max_tokens = config.max_tokens;
138 conversation_history.keep_recent = config.keep_recent_messages;
139 conversation_history.compression_level = config.compression_level;
140
141 Self {
142 session_id,
143 conversation_history,
144 task_context: TaskContext::default(),
145 agent_state: AgentState::default(),
146 workspace_state: WorkspaceState::default(),
147 metadata: HashMap::new(),
148 config,
149 }
150 }
151
152 pub fn add_message(&mut self, message: Message) {
154 self.conversation_history.add_message_struct(message);
155 }
156
157 pub fn add_message_raw(&mut self, role: MessageRole, content: String) {
159 self.conversation_history.add_message(role, content);
160 }
161
162 pub fn get_message_count(&self) -> usize {
164 self.conversation_history.messages.len()
165 }
166
167 pub fn get_total_tokens(&self) -> usize {
169 self.conversation_history.current_tokens
170 }
171
172 pub fn get_recent_messages(&self, n: usize) -> Vec<&Message> {
174 let message_count = self.conversation_history.messages.len();
175 if n >= message_count {
176 self.conversation_history.messages.iter().collect()
177 } else {
178 self.conversation_history
179 .messages
180 .iter()
181 .skip(message_count - n)
182 .collect()
183 }
184 }
185
186 pub async fn compress_context(&mut self) -> bool {
188 let threshold = (self.conversation_history.max_tokens as f32
190 * self.config.compression_threshold) as usize;
191
192 if self.conversation_history.current_tokens > threshold {
193 self.conversation_history.compress_old_messages();
194 true
195 } else {
196 false
197 }
198 }
199
200 pub fn update_task(&mut self, task: TaskContext) {
202 self.task_context = task;
203 }
204
205 pub fn summarize(&self) -> ContextSummary {
207 ContextSummary {
208 session_id: self.session_id.clone(),
209 message_count: self.conversation_history.messages.len(),
210 current_task: self.task_context.name.clone(),
211 agent_state: self.agent_state.state.clone(),
212 workspace_files: self.workspace_state.tracked_files.len(),
213 }
214 }
215
216 pub fn get_compression_stats(&self) -> CompressionStats {
218 self.conversation_history.get_compression_stats()
219 }
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct TokenEfficientHistory {
225 #[serde(default)]
227 pub messages: Vec<Message>,
228 #[serde(default)]
230 pub compressed_history: Option<CompressedHistory>,
231 #[serde(default = "default_max_tokens")]
233 pub max_tokens: usize,
234 #[serde(default)]
236 pub current_tokens: usize,
237 #[serde(default = "default_keep_recent")]
239 pub keep_recent: usize,
240 #[serde(default = "default_compression_level")]
242 pub compression_level: i32,
243 #[serde(default)]
245 pub total_messages_added: usize,
246 #[serde(default)]
248 pub tokens_saved_by_compression: usize,
249}
250
251fn default_max_tokens() -> usize {
252 100_000
253}
254
255fn default_keep_recent() -> usize {
256 20
257}
258
259fn default_compression_level() -> i32 {
260 3
261}
262
263impl Default for TokenEfficientHistory {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269impl TokenEfficientHistory {
270 pub fn new() -> Self {
272 Self {
273 messages: Vec::new(),
274 compressed_history: None,
275 max_tokens: 100_000,
276 current_tokens: 0,
277 keep_recent: 20,
278 compression_level: 3,
279 total_messages_added: 0,
280 tokens_saved_by_compression: 0,
281 }
282 }
283
284 pub fn add_message(&mut self, role: MessageRole, content: String) {
286 let token_estimate = estimate_tokens(&content);
287 let message = Message {
288 role,
289 content,
290 timestamp: Utc::now(),
291 token_count: token_estimate,
292 };
293
294 self.messages.push(message);
295 self.current_tokens += token_estimate;
296 self.total_messages_added += 1;
297
298 if self.current_tokens > self.max_tokens {
300 self.compress_old_messages();
301 }
302 }
303
304 pub fn add_message_struct(&mut self, message: Message) {
306 self.current_tokens += message.token_count;
307 self.messages.push(message);
308 self.total_messages_added += 1;
309
310 if self.current_tokens > self.max_tokens {
312 self.compress_old_messages();
313 }
314 }
315
316 pub fn compress_old_messages(&mut self) {
318 if self.messages.len() <= self.keep_recent {
320 return;
321 }
322
323 let split_point = self.messages.len() - self.keep_recent;
325 let messages_to_compress: Vec<Message> = self.messages.drain(..split_point).collect();
326
327 if messages_to_compress.is_empty() {
328 return;
329 }
330
331 let tokens_to_compress: usize = messages_to_compress.iter().map(|m| m.token_count).sum();
333
334 let json_data = match serde_json::to_vec(&messages_to_compress) {
336 Ok(data) => data,
337 Err(e) => {
338 tracing::warn!("Failed to serialize messages for compression: {}", e);
339 let mut restored = messages_to_compress;
341 restored.append(&mut self.messages);
342 self.messages = restored;
343 return;
344 }
345 };
346
347 let compressed_data = match zstd::encode_all(json_data.as_slice(), self.compression_level) {
349 Ok(data) => data,
350 Err(e) => {
351 tracing::warn!("Failed to compress messages: {}", e);
352 let mut restored = messages_to_compress;
354 restored.append(&mut self.messages);
355 self.messages = restored;
356 return;
357 }
358 };
359
360 let original_size = json_data.len();
362 let compressed_size = compressed_data.len();
363 let compression_ratio = if original_size > 0 {
364 1.0 - (compressed_size as f64 / original_size as f64)
365 } else {
366 0.0
367 };
368
369 let summary = create_compression_summary(&messages_to_compress);
371
372 let new_compressed = if let Some(existing) = self.compressed_history.take() {
374 CompressedHistory {
375 compressed_data: merge_compressed_data(
376 &existing.compressed_data,
377 &compressed_data,
378 self.compression_level,
379 ),
380 summary: format!("{}\n---\n{}", existing.summary, summary),
381 message_count: existing.message_count + messages_to_compress.len(),
382 original_tokens: existing.original_tokens + tokens_to_compress,
383 compressed_bytes: existing.compressed_bytes + compressed_size,
384 compression_ratio: (existing.compression_ratio + compression_ratio) / 2.0,
385 }
386 } else {
387 CompressedHistory {
388 compressed_data,
389 summary,
390 message_count: messages_to_compress.len(),
391 original_tokens: tokens_to_compress,
392 compressed_bytes: compressed_size,
393 compression_ratio,
394 }
395 };
396
397 self.compressed_history = Some(new_compressed);
399 self.current_tokens -= tokens_to_compress;
400 self.tokens_saved_by_compression += tokens_to_compress;
401
402 let summary_tokens = estimate_tokens(
404 self.compressed_history
405 .as_ref()
406 .map(|h| h.summary.as_str())
407 .unwrap_or(""),
408 );
409 self.current_tokens += summary_tokens.min(100); tracing::info!(
412 "Compressed {} messages ({} tokens) with {:.1}% ratio",
413 messages_to_compress.len(),
414 tokens_to_compress,
415 compression_ratio * 100.0
416 );
417 }
418
419 pub fn decompress_history(&self) -> Option<Vec<Message>> {
421 let compressed = self.compressed_history.as_ref()?;
422
423 let mut decompressed = Vec::new();
425 let mut decoder = match zstd::Decoder::new(compressed.compressed_data.as_slice()) {
426 Ok(d) => d,
427 Err(e) => {
428 tracing::error!("Failed to create zstd decoder: {}", e);
429 return None;
430 }
431 };
432
433 if let Err(e) = decoder.read_to_end(&mut decompressed) {
434 tracing::error!("Failed to decompress history: {}", e);
435 return None;
436 }
437
438 match serde_json::from_slice(&decompressed) {
440 Ok(messages) => Some(messages),
441 Err(e) => {
442 tracing::error!("Failed to deserialize decompressed messages: {}", e);
443 None
444 }
445 }
446 }
447
448 pub fn get_all_messages(&self) -> Vec<Message> {
450 let mut all_messages = self.decompress_history().unwrap_or_default();
451 all_messages.extend(self.messages.clone());
452 all_messages
453 }
454
455 pub fn get_messages_within_limit(&self, token_limit: usize) -> Vec<&Message> {
457 let mut messages = Vec::new();
458 let mut tokens = 0;
459
460 for message in self.messages.iter().rev() {
462 if tokens + message.token_count <= token_limit {
463 messages.push(message);
464 tokens += message.token_count;
465 } else {
466 break;
467 }
468 }
469
470 messages.reverse();
471 messages
472 }
473
474 pub fn get_compression_stats(&self) -> CompressionStats {
476 let compressed_stats = self.compressed_history.as_ref().map(|h| {
477 (
478 h.message_count,
479 h.original_tokens,
480 h.compressed_bytes,
481 h.compression_ratio,
482 )
483 });
484
485 CompressionStats {
486 total_messages_added: self.total_messages_added,
487 active_messages: self.messages.len(),
488 compressed_messages: compressed_stats.map(|(c, _, _, _)| c).unwrap_or(0),
489 active_tokens: self.current_tokens,
490 tokens_saved: self.tokens_saved_by_compression,
491 compressed_bytes: compressed_stats.map(|(_, _, b, _)| b).unwrap_or(0),
492 compression_ratio: compressed_stats.map(|(_, _, _, r)| r).unwrap_or(0.0),
493 }
494 }
495}
496
497fn estimate_tokens(content: &str) -> usize {
499 if content.is_empty() {
501 return 1;
502 }
503
504 let word_count = content.split_whitespace().count();
510 let char_count = content.chars().count();
511
512 let special_chars = content
514 .chars()
515 .filter(|c| !c.is_alphanumeric() && !c.is_whitespace())
516 .count();
517
518 let estimate = (word_count as f64 * 1.3) as usize + special_chars + 2;
520
521 let char_estimate = char_count / 4;
523
524 estimate.max(char_estimate).max(1)
526}
527
528fn create_compression_summary(messages: &[Message]) -> String {
530 if messages.is_empty() {
531 return String::new();
532 }
533
534 let first = messages.first().unwrap();
535 let last = messages.last().unwrap();
536
537 let user_count = messages
538 .iter()
539 .filter(|m| m.role == MessageRole::User)
540 .count();
541 let assistant_count = messages
542 .iter()
543 .filter(|m| m.role == MessageRole::Assistant)
544 .count();
545
546 format!(
547 "[Compressed: {} messages ({} user, {} assistant) from {} to {}]",
548 messages.len(),
549 user_count,
550 assistant_count,
551 first.timestamp.format("%H:%M:%S"),
552 last.timestamp.format("%H:%M:%S")
553 )
554}
555
556fn merge_compressed_data(existing: &[u8], new: &[u8], level: i32) -> Vec<u8> {
558 let mut existing_decompressed = Vec::new();
562 if let Ok(mut decoder) = zstd::Decoder::new(existing) {
563 let _ = decoder.read_to_end(&mut existing_decompressed);
564 }
565
566 let mut new_decompressed = Vec::new();
567 if let Ok(mut decoder) = zstd::Decoder::new(new) {
568 let _ = decoder.read_to_end(&mut new_decompressed);
569 }
570
571 let existing_messages: Vec<Message> =
573 serde_json::from_slice(&existing_decompressed).unwrap_or_default();
574 let new_messages: Vec<Message> = serde_json::from_slice(&new_decompressed).unwrap_or_default();
575
576 let mut merged = existing_messages;
577 merged.extend(new_messages);
578
579 let json_data = serde_json::to_vec(&merged).unwrap_or_default();
581 zstd::encode_all(json_data.as_slice(), level).unwrap_or_else(|_| new.to_vec())
582}
583
584#[derive(Debug, Clone, Serialize, Deserialize)]
586pub struct Message {
587 pub role: MessageRole,
589 pub content: String,
591 pub timestamp: DateTime<Utc>,
593 pub token_count: usize,
595}
596
597#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
599pub enum MessageRole {
600 System,
601 User,
602 Assistant,
603 Tool,
604}
605
606#[derive(Debug, Clone, Serialize, Deserialize)]
608pub struct CompressedHistory {
609 #[serde(with = "base64_serde")]
611 pub compressed_data: Vec<u8>,
612 pub summary: String,
614 pub message_count: usize,
616 pub original_tokens: usize,
618 pub compressed_bytes: usize,
620 pub compression_ratio: f64,
622}
623
624mod base64_serde {
626 use base64::{Engine, engine::general_purpose::STANDARD};
627 use serde::{Deserialize, Deserializer, Serialize, Serializer};
628
629 pub fn serialize<S>(data: &Vec<u8>, serializer: S) -> Result<S::Ok, S::Error>
630 where
631 S: Serializer,
632 {
633 let encoded = STANDARD.encode(data);
634 encoded.serialize(serializer)
635 }
636
637 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
638 where
639 D: Deserializer<'de>,
640 {
641 let encoded = String::deserialize(deserializer)?;
642 STANDARD.decode(&encoded).map_err(serde::de::Error::custom)
643 }
644}
645
646#[derive(Debug, Clone, Serialize, Deserialize)]
648pub struct CompressionStats {
649 pub total_messages_added: usize,
651 pub active_messages: usize,
653 pub compressed_messages: usize,
655 pub active_tokens: usize,
657 pub tokens_saved: usize,
659 pub compressed_bytes: usize,
661 pub compression_ratio: f64,
663}
664
665#[derive(Debug, Clone, Default, Serialize, Deserialize)]
667pub struct TaskContext {
668 pub id: Option<String>,
670 pub name: Option<String>,
672 pub description: Option<String>,
674 pub task_type: Option<String>,
676 pub priority: Option<TaskPriority>,
678 pub started_at: Option<DateTime<Utc>>,
680 pub metadata: HashMap<String, serde_json::Value>,
682}
683
684#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
686pub enum TaskPriority {
687 Low,
688 Medium,
689 High,
690 Critical,
691}
692
693#[derive(Debug, Clone, Default, Serialize, Deserialize)]
695pub struct AgentState {
696 pub state: String,
698 pub capabilities: Vec<String>,
700 pub metrics: HashMap<String, f64>,
702 pub last_error: Option<String>,
704}
705
706#[derive(Debug, Clone, Default, Serialize, Deserialize)]
708pub struct WorkspaceState {
709 pub working_directory: String,
711 pub tracked_files: HashMap<String, FileState>,
713 pub recent_changes: Vec<FileChange>,
715}
716
717#[derive(Debug, Clone, Serialize, Deserialize)]
719pub struct FileState {
720 pub path: String,
722 pub last_modified: DateTime<Utc>,
724 pub hash: String,
726 pub is_modified: bool,
728}
729
730#[derive(Debug, Clone, Serialize, Deserialize)]
732pub struct FileChange {
733 pub path: String,
735 pub change_type: FileChangeType,
737 pub timestamp: DateTime<Utc>,
739}
740
741#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
743pub enum FileChangeType {
744 Created,
745 Modified,
746 Deleted,
747 Renamed,
748}
749
750#[derive(Debug, Clone, Serialize, Deserialize)]
752pub struct ContextSummary {
753 pub session_id: SessionId,
755 pub message_count: usize,
757 pub current_task: Option<String>,
759 pub agent_state: String,
761 pub workspace_files: usize,
763}
764
765#[cfg(test)]
766mod tests {
767 use super::*;
768
769 #[test]
770 fn test_token_efficient_history() {
771 let mut history = TokenEfficientHistory::new();
772 history.max_tokens = 200; history.keep_recent = 3; for i in 0..10 {
777 history.add_message(
778 MessageRole::User,
779 format!("Test message number {}", i), );
781 }
782
783 assert!(history.messages.len() <= 10);
785
786 assert!(
788 history.compressed_history.is_some() || history.current_tokens <= history.max_tokens
789 );
790
791 assert!(
794 history.current_tokens <= history.max_tokens + 150,
795 "current_tokens {} exceeded max_tokens {} + 150",
796 history.current_tokens,
797 history.max_tokens
798 );
799 }
800
801 #[test]
802 fn test_zstd_compression() {
803 let mut history = TokenEfficientHistory::new();
804 history.max_tokens = 50;
805 history.keep_recent = 2;
806
807 for i in 0..20 {
809 history.add_message(MessageRole::User, format!("Test message number {}", i));
810 }
811
812 assert!(history.compressed_history.is_some());
814
815 let decompressed = history.decompress_history();
817 assert!(decompressed.is_some());
818
819 let messages = decompressed.unwrap();
820 assert!(!messages.is_empty());
821 }
822
823 #[test]
824 fn test_compression_stats() {
825 let mut history = TokenEfficientHistory::new();
826 history.max_tokens = 30;
827 history.keep_recent = 2;
828
829 for i in 0..10 {
831 history.add_message(MessageRole::User, format!("Message {}", i));
832 }
833
834 let stats = history.get_compression_stats();
835 assert_eq!(stats.total_messages_added, 10);
836 assert!(stats.compressed_messages > 0 || stats.active_messages == 10);
837 }
838
839 #[test]
840 fn test_context_summary() {
841 let session_id = SessionId::new();
842 let mut context = SessionContext::new(session_id.clone());
843
844 context.add_message_raw(MessageRole::User, "Hello".to_string());
845 context.add_message_raw(MessageRole::Assistant, "Hi there!".to_string());
846
847 let summary = context.summarize();
848 assert_eq!(summary.session_id, session_id);
849 assert_eq!(summary.message_count, 2);
850 }
851
852 #[test]
853 fn test_new_api_methods() {
854 let session_id = SessionId::new();
855 let mut context = SessionContext::new(session_id.clone());
856
857 assert_eq!(context.get_message_count(), 0);
859
860 let message = Message {
862 role: MessageRole::User,
863 content: "Test message".to_string(),
864 timestamp: Utc::now(),
865 token_count: 3,
866 };
867 context.add_message(message);
868
869 assert_eq!(context.get_message_count(), 1);
871 assert_eq!(context.get_total_tokens(), 3);
872
873 let recent = context.get_recent_messages(1);
875 assert_eq!(recent.len(), 1);
876 assert_eq!(recent[0].content, "Test message");
877
878 assert_eq!(context.config.max_tokens, 100_000);
880 }
881
882 #[tokio::test]
883 async fn test_compress_context() {
884 let session_id = SessionId::new();
885 let mut context = SessionContext::new(session_id);
886
887 context.config.max_tokens = 50;
889 context.config.compression_threshold = 0.5;
890 context.conversation_history.max_tokens = 50;
891 context.conversation_history.keep_recent = 3;
892
893 for i in 0..10 {
895 let message = Message {
896 role: MessageRole::User,
897 content: format!("Message {}", i),
898 timestamp: Utc::now(),
899 token_count: 10,
900 };
901 context.add_message(message);
902 }
903
904 let stats = context.get_compression_stats();
907 assert!(stats.compressed_messages > 0 || stats.total_messages_added == 10);
908 }
909
910 #[test]
911 fn test_estimate_tokens() {
912 let tokens = estimate_tokens("Hello world");
914 assert!(tokens >= 2);
915
916 let tokens = estimate_tokens("");
918 assert_eq!(tokens, 1); let tokens = estimate_tokens("Hello, world! How are you?");
922 assert!(tokens >= 5);
923
924 let tokens = estimate_tokens(&"word ".repeat(100));
926 assert!(tokens >= 100);
927 }
928
929 #[test]
930 fn test_get_all_messages() {
931 let mut history = TokenEfficientHistory::new();
932 history.max_tokens = 20;
933 history.keep_recent = 2;
934
935 for i in 0..5 {
937 history.add_message(MessageRole::User, format!("Message {}", i));
938 }
939
940 let all = history.get_all_messages();
942 assert_eq!(all.len(), 5);
943 }
944}