agtrace_runtime/domain/
model.rs1use agtrace_engine::ContextWindowUsage;
2use anyhow::Result;
3use chrono::{DateTime, Utc};
4use std::path::PathBuf;
5
6#[derive(Debug, Clone)]
7pub struct SessionState {
8 pub session_id: String,
9 pub project_root: Option<PathBuf>,
10 pub start_time: DateTime<Utc>,
11 pub last_activity: DateTime<Utc>,
12 pub model: Option<String>,
13 pub context_window_limit: Option<u64>,
14 pub current_usage: ContextWindowUsage,
15 pub current_reasoning_tokens: i32,
16 pub error_count: u32,
17 pub event_count: usize,
18 pub turn_count: usize,
19}
20
21impl SessionState {
22 pub fn new(
23 session_id: String,
24 project_root: Option<PathBuf>,
25 start_time: DateTime<Utc>,
26 ) -> Self {
27 Self {
28 session_id,
29 project_root,
30 start_time,
31 last_activity: start_time,
32 model: None,
33 context_window_limit: None,
34 current_usage: ContextWindowUsage::default(),
35 current_reasoning_tokens: 0,
36 error_count: 0,
37 event_count: 0,
38 turn_count: 0,
39 }
40 }
41
42 pub fn total_input_side_tokens(&self) -> i32 {
43 self.current_usage.input_tokens()
44 }
45
46 pub fn total_output_side_tokens(&self) -> i32 {
47 self.current_usage.output_tokens()
48 }
49
50 pub fn total_context_window_tokens(&self) -> i32 {
51 self.current_usage.context_window_tokens()
52 }
53
54 pub fn total_tokens(&self) -> agtrace_engine::TokenCount {
56 self.current_usage.total_tokens()
57 }
58
59 pub fn context_limit(&self) -> Option<agtrace_engine::ContextLimit> {
61 self.context_window_limit
62 .map(agtrace_engine::ContextLimit::new)
63 }
64
65 pub fn validate_tokens(&self, model_limit: Option<u64>) -> Result<(), String> {
66 let total = self.total_context_window_tokens();
67
68 if total < 0
69 || self.current_usage.fresh_input.0 < 0
70 || self.current_usage.output.0 < 0
71 || self.current_usage.cache_creation.0 < 0
72 || self.current_usage.cache_read.0 < 0
73 {
74 return Err("Negative token count detected".to_string());
75 }
76
77 if let Some(limit) = model_limit
78 && total as u64 > limit
79 {
80 return Err(format!(
81 "Token count {} exceeds model limit {}",
82 total, limit
83 ));
84 }
85
86 Ok(())
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn test_session_state_initialization() {
96 let state = SessionState::new("test-id".to_string(), None, Utc::now());
97
98 assert_eq!(state.session_id, "test-id");
99 assert!(state.current_usage.is_empty());
100 assert_eq!(state.current_reasoning_tokens, 0);
101 assert_eq!(state.error_count, 0);
102 assert_eq!(state.event_count, 0);
103 assert_eq!(state.turn_count, 0);
104 }
105
106 #[test]
107 fn test_session_state_token_snapshot() {
108 let mut state = SessionState::new("test-id".to_string(), None, Utc::now());
109
110 state.current_usage = ContextWindowUsage::from_raw(100, 0, 0, 50);
111 assert_eq!(state.total_input_side_tokens(), 100);
112 assert_eq!(state.total_output_side_tokens(), 50);
113 assert_eq!(state.total_context_window_tokens(), 150);
114
115 state.current_usage = ContextWindowUsage::from_raw(10, 0, 1000, 60);
116 assert_eq!(state.total_input_side_tokens(), 1010);
117 assert_eq!(state.total_output_side_tokens(), 60);
118 assert_eq!(state.total_context_window_tokens(), 1070);
119 }
120
121 #[test]
122 fn test_validate_tokens_success() {
123 let mut state = SessionState::new("test-id".to_string(), None, Utc::now());
124 state.current_usage = ContextWindowUsage::from_raw(1000, 2000, 10000, 500);
125 assert!(state.validate_tokens(Some(200_000)).is_ok());
126 }
127
128 #[test]
129 fn test_validate_tokens_exceeds_limit() {
130 let mut state = SessionState::new("test-id".to_string(), None, Utc::now());
131 state.current_usage = ContextWindowUsage::from_raw(100_000, 0, 0, 150_000);
132 let result = state.validate_tokens(Some(200_000));
133 assert!(result.is_err());
134 assert!(result.unwrap_err().contains("exceeds model limit"));
135 }
136
137 #[test]
138 fn test_validate_tokens_negative() {
139 let mut state = SessionState::new("test-id".to_string(), None, Utc::now());
140 state.current_usage = ContextWindowUsage::from_raw(-100, 0, 0, 0);
141
142 let result = state.validate_tokens(None);
143 assert!(result.is_err());
144 assert_eq!(result.unwrap_err(), "Negative token count detected");
145 }
146}