agtrace_runtime/domain/
model.rs1use agtrace_engine::ContextWindowUsage;
2use anyhow::Result;
3use chrono::{DateTime, Utc};
4use std::path::PathBuf;
5
6#[derive(Debug, Clone)]
7pub struct SessionState {
8 pub session_id: String,
9 pub project_root: Option<PathBuf>,
10 pub log_path: Option<PathBuf>,
11 pub start_time: DateTime<Utc>,
12 pub last_activity: DateTime<Utc>,
13 pub model: Option<String>,
14 pub context_window_limit: Option<u64>,
15 pub current_usage: ContextWindowUsage,
16 pub current_reasoning_tokens: i32,
17 pub error_count: u32,
18 pub event_count: usize,
19 pub turn_count: usize,
20}
21
22impl SessionState {
23 pub fn new(
24 session_id: String,
25 project_root: Option<PathBuf>,
26 log_path: Option<PathBuf>,
27 start_time: DateTime<Utc>,
28 ) -> Self {
29 Self {
30 session_id,
31 project_root,
32 log_path,
33 start_time,
34 last_activity: start_time,
35 model: None,
36 context_window_limit: None,
37 current_usage: ContextWindowUsage::default(),
38 current_reasoning_tokens: 0,
39 error_count: 0,
40 event_count: 0,
41 turn_count: 0,
42 }
43 }
44
45 pub fn total_input_side_tokens(&self) -> i32 {
46 self.current_usage.input_tokens()
47 }
48
49 pub fn total_output_side_tokens(&self) -> i32 {
50 self.current_usage.output_tokens()
51 }
52
53 pub fn total_tokens(&self) -> agtrace_engine::TokenCount {
55 self.current_usage.total_tokens()
56 }
57
58 pub fn context_limit(&self) -> Option<agtrace_engine::ContextLimit> {
60 self.context_window_limit
61 .map(agtrace_engine::ContextLimit::new)
62 }
63
64 pub fn validate_tokens(&self, model_limit: Option<u64>) -> Result<(), String> {
65 if self.current_usage.fresh_input.0 < 0
66 || self.current_usage.output.0 < 0
67 || self.current_usage.cache_creation.0 < 0
68 || self.current_usage.cache_read.0 < 0
69 {
70 return Err("Negative token count detected".to_string());
71 }
72
73 let total = self.total_tokens();
74 if let Some(limit) = model_limit
75 && total.as_u64() > limit
76 {
77 return Err(format!(
78 "Token count {} exceeds model limit {}",
79 total.as_u64(),
80 limit
81 ));
82 }
83
84 Ok(())
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91
92 #[test]
93 fn test_session_state_initialization() {
94 let state = SessionState::new("test-id".to_string(), None, None, Utc::now());
95
96 assert_eq!(state.session_id, "test-id");
97 assert!(state.current_usage.is_empty());
98 assert_eq!(state.current_reasoning_tokens, 0);
99 assert_eq!(state.error_count, 0);
100 assert_eq!(state.event_count, 0);
101 assert_eq!(state.turn_count, 0);
102 }
103
104 #[test]
105 fn test_session_state_token_snapshot() {
106 let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
107
108 state.current_usage = ContextWindowUsage::from_raw(100, 0, 0, 50);
109 assert_eq!(state.total_input_side_tokens(), 100);
110 assert_eq!(state.total_output_side_tokens(), 50);
111 assert_eq!(state.total_tokens(), agtrace_engine::TokenCount::new(150));
112
113 state.current_usage = ContextWindowUsage::from_raw(10, 0, 1000, 60);
114 assert_eq!(state.total_input_side_tokens(), 1010);
115 assert_eq!(state.total_output_side_tokens(), 60);
116 assert_eq!(state.total_tokens(), agtrace_engine::TokenCount::new(1070));
117 }
118
119 #[test]
120 fn test_validate_tokens_success() {
121 let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
122 state.current_usage = ContextWindowUsage::from_raw(1000, 2000, 10000, 500);
123 assert!(state.validate_tokens(Some(200_000)).is_ok());
124 }
125
126 #[test]
127 fn test_validate_tokens_exceeds_limit() {
128 let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
129 state.current_usage = ContextWindowUsage::from_raw(100_000, 0, 0, 150_000);
130 let result = state.validate_tokens(Some(200_000));
131 assert!(result.is_err());
132 assert!(result.unwrap_err().contains("exceeds model limit"));
133 }
134
135 #[test]
136 fn test_validate_tokens_negative() {
137 let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
138 state.current_usage = ContextWindowUsage::from_raw(-100, 0, 0, 0);
139
140 let result = state.validate_tokens(None);
141 assert!(result.is_err());
142 assert_eq!(result.unwrap_err(), "Negative token count detected");
143 }
144}