agtrace_runtime/domain/
model.rs1use agtrace_engine::ContextWindowUsage;
2use anyhow::Result;
3use chrono::{DateTime, Utc};
4use std::path::PathBuf;
5
6#[derive(Debug, Clone)]
7pub struct SessionState {
8 pub session_id: String,
9 pub project_root: Option<PathBuf>,
10 pub log_path: Option<PathBuf>,
11 pub start_time: DateTime<Utc>,
12 pub last_activity: DateTime<Utc>,
13 pub model: Option<String>,
14 pub context_window_limit: Option<u64>,
15 pub current_usage: ContextWindowUsage,
16 pub current_reasoning_tokens: i32,
17 pub error_count: u32,
18 pub event_count: usize,
19 pub turn_count: usize,
20}
21
22impl SessionState {
23 pub fn new(
24 session_id: String,
25 project_root: Option<PathBuf>,
26 log_path: Option<PathBuf>,
27 start_time: DateTime<Utc>,
28 ) -> Self {
29 Self {
30 session_id,
31 project_root,
32 log_path,
33 start_time,
34 last_activity: start_time,
35 model: None,
36 context_window_limit: None,
37 current_usage: ContextWindowUsage::default(),
38 current_reasoning_tokens: 0,
39 error_count: 0,
40 event_count: 0,
41 turn_count: 0,
42 }
43 }
44
45 pub fn total_input_side_tokens(&self) -> i32 {
46 self.current_usage.input_tokens()
47 }
48
49 pub fn total_output_side_tokens(&self) -> i32 {
50 self.current_usage.output_tokens()
51 }
52
53 pub fn total_context_window_tokens(&self) -> i32 {
54 self.current_usage.context_window_tokens()
55 }
56
57 pub fn total_tokens(&self) -> agtrace_engine::TokenCount {
59 self.current_usage.total_tokens()
60 }
61
62 pub fn context_limit(&self) -> Option<agtrace_engine::ContextLimit> {
64 self.context_window_limit
65 .map(agtrace_engine::ContextLimit::new)
66 }
67
68 pub fn validate_tokens(&self, model_limit: Option<u64>) -> Result<(), String> {
69 let total = self.total_context_window_tokens();
70
71 if total < 0
72 || self.current_usage.fresh_input.0 < 0
73 || self.current_usage.output.0 < 0
74 || self.current_usage.cache_creation.0 < 0
75 || self.current_usage.cache_read.0 < 0
76 {
77 return Err("Negative token count detected".to_string());
78 }
79
80 if let Some(limit) = model_limit
81 && total as u64 > limit
82 {
83 return Err(format!(
84 "Token count {} exceeds model limit {}",
85 total, limit
86 ));
87 }
88
89 Ok(())
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[test]
98 fn test_session_state_initialization() {
99 let state = SessionState::new("test-id".to_string(), None, None, Utc::now());
100
101 assert_eq!(state.session_id, "test-id");
102 assert!(state.current_usage.is_empty());
103 assert_eq!(state.current_reasoning_tokens, 0);
104 assert_eq!(state.error_count, 0);
105 assert_eq!(state.event_count, 0);
106 assert_eq!(state.turn_count, 0);
107 }
108
109 #[test]
110 fn test_session_state_token_snapshot() {
111 let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
112
113 state.current_usage = ContextWindowUsage::from_raw(100, 0, 0, 50);
114 assert_eq!(state.total_input_side_tokens(), 100);
115 assert_eq!(state.total_output_side_tokens(), 50);
116 assert_eq!(state.total_context_window_tokens(), 150);
117
118 state.current_usage = ContextWindowUsage::from_raw(10, 0, 1000, 60);
119 assert_eq!(state.total_input_side_tokens(), 1010);
120 assert_eq!(state.total_output_side_tokens(), 60);
121 assert_eq!(state.total_context_window_tokens(), 1070);
122 }
123
124 #[test]
125 fn test_validate_tokens_success() {
126 let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
127 state.current_usage = ContextWindowUsage::from_raw(1000, 2000, 10000, 500);
128 assert!(state.validate_tokens(Some(200_000)).is_ok());
129 }
130
131 #[test]
132 fn test_validate_tokens_exceeds_limit() {
133 let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
134 state.current_usage = ContextWindowUsage::from_raw(100_000, 0, 0, 150_000);
135 let result = state.validate_tokens(Some(200_000));
136 assert!(result.is_err());
137 assert!(result.unwrap_err().contains("exceeds model limit"));
138 }
139
140 #[test]
141 fn test_validate_tokens_negative() {
142 let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
143 state.current_usage = ContextWindowUsage::from_raw(-100, 0, 0, 0);
144
145 let result = state.validate_tokens(None);
146 assert!(result.is_err());
147 assert_eq!(result.unwrap_err(), "Negative token count detected");
148 }
149}