1use chrono::{DateTime, Utc};
25
26use katu_core::agent::AgentDefinition;
27use katu_core::compaction::{CompactionConfig, TokenBudgetState};
28use katu_core::message::{AssistantMessage, ContentBlock, Message, ToolResultMessage, UserContent};
29use katu_core::types::{ModelId, Role, SessionId};
30use katu_core::usage::Usage;
31use katu_core::CancellationToken;
32
33use crate::compaction::CompactionState;
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum SessionStatus {
51 Idle,
53 Running,
55 Cancelled,
57}
58
59impl SessionStatus {
60 pub fn is_running(&self) -> bool {
62 matches!(self, Self::Running)
63 }
64
65 pub fn is_cancelled(&self) -> bool {
67 matches!(self, Self::Cancelled)
68 }
69
70 pub fn is_idle(&self) -> bool {
72 matches!(self, Self::Idle)
73 }
74}
75
76pub struct Session {
100 id: SessionId,
102
103 agent: AgentDefinition,
105
106 model_id: ModelId,
108
109 messages: Vec<Message>,
111
112 status: SessionStatus,
114
115 cancel_token: CancellationToken,
117
118 step_count: u32,
120
121 total_usage: Usage,
123
124 context_tokens: u64,
126
127 context_window: u64,
129
130 compaction_config: CompactionConfig,
132
133 compaction_state: CompactionState,
135
136 created_at: DateTime<Utc>,
138
139 updated_at: DateTime<Utc>,
141}
142
143impl Session {
144 pub fn new(agent: AgentDefinition, model_id: ModelId) -> Self {
150 let now = Utc::now();
151 Self {
152 id: SessionId::new(),
153 agent,
154 model_id,
155 messages: Vec::new(),
156 status: SessionStatus::Idle,
157 cancel_token: CancellationToken::new(),
158 step_count: 0,
159 total_usage: Usage::default(),
160 context_tokens: 0,
161 context_window: 0,
162 compaction_config: CompactionConfig::default(),
163 compaction_state: CompactionState::new(),
164 created_at: now,
165 updated_at: now,
166 }
167 }
168
169 pub fn with_context_window(mut self, context_window: u64) -> Self {
171 self.context_window = context_window;
172 self
173 }
174
175 pub fn with_compaction_config(mut self, config: CompactionConfig) -> Self {
177 self.compaction_config = config;
178 self
179 }
180
181 pub fn with_cancel_token(mut self, token: CancellationToken) -> Self {
183 self.cancel_token = token;
184 self
185 }
186}
187
188impl Session {
193 pub fn id(&self) -> &SessionId {
195 &self.id
196 }
197
198 pub fn agent(&self) -> &AgentDefinition {
200 &self.agent
201 }
202
203 pub fn model_id(&self) -> &ModelId {
205 &self.model_id
206 }
207
208 pub fn set_model_id(&mut self, model_id: ModelId) {
210 self.model_id = model_id;
211 self.touch();
212 }
213}
214
215impl Session {
220 pub fn messages(&self) -> impl Iterator<Item = &Message> {
222 self.messages.iter()
223 }
224
225 pub fn message_count(&self) -> usize {
227 self.messages.len()
228 }
229
230 pub fn last_role(&self) -> Option<Role> {
232 self.messages.last().map(|m| m.role())
233 }
234
235 pub fn last_assistant(&self) -> Option<&AssistantMessage> {
237 self.messages.iter().rev().find_map(|m| match m {
238 Message::Assistant(a) => Some(a),
239 _ => None,
240 })
241 }
242
243 pub fn has_pending_tool_calls(&self) -> bool {
245 self.last_assistant()
246 .map(|a| a.has_tool_calls())
247 .unwrap_or(false)
248 }
249
250 pub fn push_user(&mut self, content: impl Into<UserContent>) {
252 self.messages.push(Message::user(content));
253 self.touch();
254 }
255
256 pub fn push_assistant(&mut self, message: AssistantMessage) {
258 if let Some(usage) = &message.usage {
259 self.accumulate_usage(usage);
260 }
261 self.messages.push(Message::Assistant(message));
262 self.touch();
263 }
264
265 pub fn push_tool_results(&mut self, results: Vec<ToolResultMessage>) {
267 for result in results {
268 self.messages.push(Message::ToolResult(result));
269 }
270 self.touch();
271 }
272
273 pub fn push_message(&mut self, message: Message) {
275 if let Message::Assistant(ref a) = message {
276 if let Some(usage) = &a.usage {
277 self.accumulate_usage(usage);
278 }
279 }
280 self.messages.push(message);
281 self.touch();
282 }
283
284 pub fn replace_messages(&mut self, messages: Vec<Message>) {
286 self.messages = messages;
287 self.touch();
288 }
289
290 pub fn message_slice(&self) -> &[Message] {
292 &self.messages
293 }
294
295 pub fn truncate_tool_result(&mut self, index: usize, truncation_msg: &str, max_chars: usize) {
300 if let Some(Message::ToolResult(tr)) = self.messages.get_mut(index) {
301 let total_chars: usize = tr.content.iter().map(|b| match b {
302 ContentBlock::Text { text } => text.len(),
303 ContentBlock::Image { .. } => 0,
304 }).sum();
305
306 if total_chars > max_chars {
307 tr.content = vec![ContentBlock::Text {
309 text: truncation_msg.to_string(),
310 }];
311 self.touch();
312 }
313 }
314 }
315}
316
317impl Session {
322 pub fn status(&self) -> SessionStatus {
324 self.status
325 }
326
327 pub fn cancel_token(&self) -> &CancellationToken {
329 &self.cancel_token
330 }
331
332 pub fn begin_run(&mut self) {
337 assert!(
338 self.status.is_idle(),
339 "cannot begin run: session is {:?}",
340 self.status
341 );
342 self.status = SessionStatus::Running;
343 self.cancel_token = CancellationToken::new();
345 self.step_count = 0;
347 self.touch();
348 }
349
350 pub fn end_run(&mut self) {
352 self.status = SessionStatus::Idle;
353 self.touch();
354 }
355
356 pub fn cancel(&mut self) {
358 if self.status.is_running() {
359 self.cancel_token.cancel();
360 self.status = SessionStatus::Cancelled;
361 self.touch();
362 }
363 }
364
365 pub fn reset_after_cancel(&mut self) {
367 if self.status.is_cancelled() {
368 self.status = SessionStatus::Idle;
369 self.cancel_token = CancellationToken::new();
370 self.touch();
371 }
372 }
373}
374
375impl Session {
380 pub fn step_count(&self) -> u32 {
382 self.step_count
383 }
384
385 pub fn increment_step(&mut self) -> u32 {
387 self.step_count += 1;
388 self.step_count
389 }
390
391 pub fn max_steps(&self) -> u32 {
393 self.agent.max_steps.unwrap_or(DEFAULT_MAX_STEPS)
394 }
395
396 pub fn is_over_step_limit(&self) -> bool {
398 self.step_count >= self.max_steps()
399 }
400}
401
402const DEFAULT_MAX_STEPS: u32 = 100;
404
405impl Session {
410 pub fn total_usage(&self) -> &Usage {
412 &self.total_usage
413 }
414
415 fn accumulate_usage(&mut self, usage: &Usage) {
417 self.total_usage.input_tokens += usage.input_tokens;
418 self.total_usage.output_tokens += usage.output_tokens;
419 self.total_usage.cache_read_tokens += usage.cache_read_tokens;
420 self.total_usage.cache_write_tokens += usage.cache_write_tokens;
421 self.total_usage.total_tokens += usage.total_tokens;
422 if let Some(r) = usage.reasoning_tokens {
423 *self.total_usage.reasoning_tokens.get_or_insert(0) += r;
424 }
425 if let Some(cost) = &usage.cost {
426 let total_cost = self.total_usage.cost.get_or_insert(katu_core::usage::Cost {
427 input: 0.0,
428 output: 0.0,
429 cache_read: 0.0,
430 cache_write: 0.0,
431 total: 0.0,
432 });
433 total_cost.input += cost.input;
434 total_cost.output += cost.output;
435 total_cost.cache_read += cost.cache_read;
436 total_cost.cache_write += cost.cache_write;
437 total_cost.total += cost.total;
438 }
439 }
440}
441
442impl Session {
447 pub fn context_tokens(&self) -> u64 {
449 self.context_tokens
450 }
451
452 pub fn context_window(&self) -> u64 {
454 self.context_window
455 }
456
457 pub fn compaction_config(&self) -> &CompactionConfig {
459 &self.compaction_config
460 }
461
462 pub fn set_context_tokens(&mut self, tokens: u64) {
464 self.context_tokens = tokens;
465 }
466
467 pub fn budget_state(&self) -> TokenBudgetState {
469 TokenBudgetState::from_usage(
470 self.context_tokens,
471 self.context_window,
472 self.compaction_config.reserve_tokens as u64,
473 )
474 }
475
476 pub fn should_compact(&self) -> bool {
478 self.compaction_config.auto_enabled && self.budget_state().should_auto_compact()
479 }
480
481 pub fn compaction_state(&self) -> &CompactionState {
483 &self.compaction_state
484 }
485
486 pub fn compaction_state_mut(&mut self) -> &mut CompactionState {
488 &mut self.compaction_state
489 }
490}
491
492impl Session {
497 pub fn created_at(&self) -> DateTime<Utc> {
499 self.created_at
500 }
501
502 pub fn updated_at(&self) -> DateTime<Utc> {
504 self.updated_at
505 }
506
507 fn touch(&mut self) {
509 self.updated_at = Utc::now();
510 }
511}
512
513#[cfg(test)]
518mod tests {
519 use super::*;
520 use katu_core::agent::AgentRole;
521 use katu_core::message::AssistantBlock;
522 use katu_core::types::{FinishReason, MessageId, ToolCallId};
523 use katu_core::usage::Cost;
524
525 fn test_agent() -> AgentDefinition {
526 AgentDefinition::new("test", AgentRole::Primary)
527 .with_max_steps(10)
528 }
529
530 fn test_session() -> Session {
531 Session::new(test_agent(), ModelId::new("gpt-4o"))
532 .with_context_window(200_000)
533 }
534
535 fn make_assistant(text: &str, usage: Option<Usage>) -> AssistantMessage {
536 AssistantMessage {
537 id: MessageId::new(),
538 content: vec![AssistantBlock::Text { text: text.into() }],
539 model: "gpt-4o".into(),
540 provider: "openai".into(),
541 finish_reason: FinishReason::Stop,
542 usage,
543 timestamp: Utc::now(),
544 }
545 }
546
547 fn make_assistant_with_tool_call() -> AssistantMessage {
548 AssistantMessage {
549 id: MessageId::new(),
550 content: vec![
551 AssistantBlock::Text { text: "reading file".into() },
552 AssistantBlock::ToolCall {
553 id: ToolCallId::new("call_1"),
554 name: "read_file".into(),
555 arguments: serde_json::json!({"path": "src/main.rs"}),
556 },
557 ],
558 model: "gpt-4o".into(),
559 provider: "openai".into(),
560 finish_reason: FinishReason::ToolCalls,
561 usage: Some(Usage {
562 input_tokens: 100,
563 output_tokens: 50,
564 total_tokens: 150,
565 ..Default::default()
566 }),
567 timestamp: Utc::now(),
568 }
569 }
570
571 #[test]
574 fn test_new_session() {
575 let session = test_session();
576 assert!(session.status().is_idle());
577 assert_eq!(session.step_count(), 0);
578 assert_eq!(session.message_count(), 0);
579 assert!(session.messages().next().is_none());
580 assert_eq!(session.model_id().as_str(), "gpt-4o");
581 assert_eq!(session.agent().name.as_str(), "test");
582 assert_eq!(session.context_window(), 200_000);
583 }
584
585 #[test]
586 fn test_session_id_unique() {
587 let s1 = test_session();
588 let s2 = test_session();
589 assert_ne!(s1.id(), s2.id());
590 }
591
592 #[test]
595 fn test_push_user() {
596 let mut session = test_session();
597 session.push_user("hello");
598 assert_eq!(session.message_count(), 1);
599 assert_eq!(session.last_role(), Some(Role::User));
600 }
601
602 #[test]
603 fn test_push_assistant() {
604 let mut session = test_session();
605 session.push_user("hi");
606 session.push_assistant(make_assistant("hello!", None));
607 assert_eq!(session.message_count(), 2);
608 assert_eq!(session.last_role(), Some(Role::Assistant));
609 }
610
611 #[test]
612 fn test_push_tool_results() {
613 let mut session = test_session();
614 session.push_user("read file");
615 session.push_assistant(make_assistant_with_tool_call());
616
617 let results = vec![ToolResultMessage {
618 id: MessageId::new(),
619 tool_call_id: ToolCallId::new("call_1"),
620 tool_name: "read_file".into(),
621 content: vec![katu_core::message::ContentBlock::Text {
622 text: "file contents".into(),
623 }],
624 is_error: false,
625 timestamp: Utc::now(),
626 }];
627 session.push_tool_results(results);
628 assert_eq!(session.message_count(), 3);
629 assert_eq!(session.last_role(), Some(Role::Tool));
630 }
631
632 #[test]
633 fn test_has_pending_tool_calls() {
634 let mut session = test_session();
635 assert!(!session.has_pending_tool_calls());
636
637 session.push_assistant(make_assistant_with_tool_call());
638 assert!(session.has_pending_tool_calls());
639
640 session.push_assistant(make_assistant("done", None));
641 assert!(!session.has_pending_tool_calls());
642 }
643
644 #[test]
645 fn test_last_assistant() {
646 let mut session = test_session();
647 assert!(session.last_assistant().is_none());
648
649 session.push_assistant(make_assistant("first", None));
650 session.push_user("next");
651 session.push_assistant(make_assistant("second", None));
652
653 let last = session.last_assistant().unwrap();
654 assert_eq!(last.text(), "second");
655 }
656
657 #[test]
658 fn test_replace_messages() {
659 let mut session = test_session();
660 session.push_user("one");
661 session.push_user("two");
662 assert_eq!(session.message_count(), 2);
663
664 session.replace_messages(vec![Message::user("compacted")]);
665 assert_eq!(session.message_count(), 1);
666 }
667
668 #[test]
669 fn test_message_slice() {
670 let mut session = test_session();
671 session.push_user("a");
672 session.push_user("b");
673 let slice = session.message_slice();
674 assert_eq!(slice.len(), 2);
675 }
676
677 #[test]
680 fn test_status_lifecycle() {
681 let mut session = test_session();
682
683 assert!(session.status().is_idle());
684 session.begin_run();
685 assert!(session.status().is_running());
686 session.end_run();
687 assert!(session.status().is_idle());
688 }
689
690 #[test]
691 fn test_cancel() {
692 let mut session = test_session();
693 session.begin_run();
694
695 let token = session.cancel_token().clone();
696 assert!(!token.is_cancelled());
697
698 session.cancel();
699 assert!(session.status().is_cancelled());
700 assert!(token.is_cancelled());
701 }
702
703 #[test]
704 fn test_cancel_idle_is_noop() {
705 let mut session = test_session();
706 session.cancel();
707 assert!(session.status().is_idle());
708 }
709
710 #[test]
711 fn test_reset_after_cancel() {
712 let mut session = test_session();
713 session.begin_run();
714 session.cancel();
715 assert!(session.status().is_cancelled());
716
717 session.reset_after_cancel();
718 assert!(session.status().is_idle());
719 assert!(!session.cancel_token().is_cancelled());
720 }
721
722 #[test]
723 fn test_begin_run_resets_cancel_token() {
724 let mut session = test_session();
725 session.begin_run();
726 session.cancel();
727 session.reset_after_cancel();
728
729 session.begin_run();
730 assert!(!session.cancel_token().is_cancelled());
731 }
732
733 #[test]
734 #[should_panic(expected = "cannot begin run")]
735 fn test_begin_run_while_running_panics() {
736 let mut session = test_session();
737 session.begin_run();
738 session.begin_run();
739 }
740
741 #[test]
744 fn test_step_tracking() {
745 let mut session = test_session();
746 assert_eq!(session.step_count(), 0);
747 assert_eq!(session.max_steps(), 10);
748 assert!(!session.is_over_step_limit());
749
750 for i in 1..=10 {
751 let step = session.increment_step();
752 assert_eq!(step, i);
753 }
754 assert!(session.is_over_step_limit());
755 }
756
757 #[test]
758 fn test_step_count_resets_on_begin_run() {
759 let mut session = test_session();
760 session.begin_run();
761
762 session.increment_step();
764 session.increment_step();
765 session.increment_step();
766 assert_eq!(session.step_count(), 3);
767 session.end_run();
768
769 session.begin_run();
771 assert_eq!(session.step_count(), 0);
772 assert!(!session.is_over_step_limit());
773 }
774
775 #[test]
776 fn test_default_max_steps() {
777 let agent = AgentDefinition::new("no_limit", AgentRole::Primary);
778 let session = Session::new(agent, ModelId::new("gpt-4o"));
779 assert_eq!(session.max_steps(), DEFAULT_MAX_STEPS);
780 }
781
782 #[test]
785 fn test_usage_accumulation() {
786 let mut session = test_session();
787
788 let usage1 = Usage {
789 input_tokens: 100,
790 output_tokens: 50,
791 cache_read_tokens: 20,
792 cache_write_tokens: 10,
793 reasoning_tokens: Some(5),
794 total_tokens: 150,
795 cost: None,
796 };
797 session.push_assistant(make_assistant("a", Some(usage1)));
798
799 let usage2 = Usage {
800 input_tokens: 200,
801 output_tokens: 80,
802 cache_read_tokens: 30,
803 cache_write_tokens: 0,
804 reasoning_tokens: Some(10),
805 total_tokens: 280,
806 cost: None,
807 };
808 session.push_assistant(make_assistant("b", Some(usage2)));
809
810 let total = session.total_usage();
811 assert_eq!(total.input_tokens, 300);
812 assert_eq!(total.output_tokens, 130);
813 assert_eq!(total.cache_read_tokens, 50);
814 assert_eq!(total.cache_write_tokens, 10);
815 assert_eq!(total.reasoning_tokens, Some(15));
816 assert_eq!(total.total_tokens, 430);
817 }
818
819 #[test]
820 fn test_usage_accumulation_with_cost() {
821 let mut session = test_session();
822
823 let usage = Usage {
824 input_tokens: 100,
825 output_tokens: 50,
826 total_tokens: 150,
827 cost: Some(Cost {
828 input: 0.01,
829 output: 0.03,
830 cache_read: 0.001,
831 cache_write: 0.002,
832 total: 0.043,
833 }),
834 ..Default::default()
835 };
836 session.push_assistant(make_assistant("a", Some(usage)));
837
838 let total_cost = session.total_usage().cost.as_ref().unwrap();
839 assert!((total_cost.total - 0.043).abs() < f64::EPSILON);
840 }
841
842 #[test]
843 fn test_push_message_accumulates_assistant_usage() {
844 let mut session = test_session();
845 let msg = Message::Assistant(make_assistant("x", Some(Usage {
846 input_tokens: 10,
847 output_tokens: 5,
848 total_tokens: 15,
849 ..Default::default()
850 })));
851 session.push_message(msg);
852 assert_eq!(session.total_usage().input_tokens, 10);
853 }
854
855 #[test]
856 fn test_push_message_no_usage_for_user() {
857 let mut session = test_session();
858 session.push_message(Message::user("hello"));
859 assert_eq!(session.total_usage().input_tokens, 0);
860 }
861
862 #[test]
865 fn test_context_window() {
866 let mut session = test_session();
867 session.set_context_tokens(150_000);
868 assert_eq!(session.context_tokens(), 150_000);
869 }
870
871 #[test]
872 fn test_budget_state_normal() {
873 let mut session = test_session();
874 session.set_context_tokens(50_000);
875 let state = session.budget_state();
876 assert!(matches!(state, TokenBudgetState::Normal { .. }));
877 }
878
879 #[test]
880 fn test_should_compact() {
881 let mut session = Session::new(test_agent(), ModelId::new("gpt-4o"))
882 .with_context_window(200_000)
883 .with_compaction_config(CompactionConfig::default());
884
885 session.set_context_tokens(50_000);
887 assert!(!session.should_compact());
888
889 session.set_context_tokens(195_000);
891 assert!(session.should_compact());
892 }
893
894 #[test]
895 fn test_should_compact_disabled() {
896 let config = CompactionConfig::default().with_auto_enabled(false);
897 let mut session = Session::new(test_agent(), ModelId::new("gpt-4o"))
898 .with_context_window(200_000)
899 .with_compaction_config(config);
900
901 session.set_context_tokens(195_000);
902 assert!(!session.should_compact());
903 }
904
905 #[test]
908 fn test_set_model_id() {
909 let mut session = test_session();
910 assert_eq!(session.model_id().as_str(), "gpt-4o");
911 session.set_model_id(ModelId::new("claude-sonnet-4-20250514"));
912 assert_eq!(session.model_id().as_str(), "claude-sonnet-4-20250514");
913 }
914
915 #[test]
918 fn test_timestamps() {
919 let before = Utc::now();
920 let session = test_session();
921 let after = Utc::now();
922
923 assert!(session.created_at() >= before);
924 assert!(session.created_at() <= after);
925 assert!(session.updated_at() >= before);
926 }
927
928 #[test]
929 fn test_touch_updates_timestamp() {
930 let mut session = test_session();
931 let initial = session.updated_at();
932
933 std::thread::sleep(std::time::Duration::from_millis(2));
935 session.push_user("trigger touch");
936
937 assert!(session.updated_at() > initial);
938 }
939}