agtrace_runtime/domain/
model.rs

1use agtrace_engine::ContextWindowUsage;
2use anyhow::Result;
3use chrono::{DateTime, Utc};
4use std::path::PathBuf;
5
6#[derive(Debug, Clone)]
7pub struct SessionState {
8    pub session_id: String,
9    pub project_root: Option<PathBuf>,
10    pub log_path: Option<PathBuf>,
11    pub start_time: DateTime<Utc>,
12    pub last_activity: DateTime<Utc>,
13    pub model: Option<String>,
14    pub context_window_limit: Option<u64>,
15    pub current_usage: ContextWindowUsage,
16    pub current_reasoning_tokens: i32,
17    pub error_count: u32,
18    pub event_count: usize,
19    pub turn_count: usize,
20}
21
22impl SessionState {
23    pub fn new(
24        session_id: String,
25        project_root: Option<PathBuf>,
26        log_path: Option<PathBuf>,
27        start_time: DateTime<Utc>,
28    ) -> Self {
29        Self {
30            session_id,
31            project_root,
32            log_path,
33            start_time,
34            last_activity: start_time,
35            model: None,
36            context_window_limit: None,
37            current_usage: ContextWindowUsage::default(),
38            current_reasoning_tokens: 0,
39            error_count: 0,
40            event_count: 0,
41            turn_count: 0,
42        }
43    }
44
45    pub fn total_input_side_tokens(&self) -> i32 {
46        self.current_usage.input_tokens()
47    }
48
49    pub fn total_output_side_tokens(&self) -> i32 {
50        self.current_usage.output_tokens()
51    }
52
53    pub fn total_context_window_tokens(&self) -> i32 {
54        self.current_usage.context_window_tokens()
55    }
56
57    /// Get total tokens as type-safe TokenCount
58    pub fn total_tokens(&self) -> agtrace_engine::TokenCount {
59        self.current_usage.total_tokens()
60    }
61
62    /// Get context limit as type-safe ContextLimit
63    pub fn context_limit(&self) -> Option<agtrace_engine::ContextLimit> {
64        self.context_window_limit
65            .map(agtrace_engine::ContextLimit::new)
66    }
67
68    pub fn validate_tokens(&self, model_limit: Option<u64>) -> Result<(), String> {
69        let total = self.total_context_window_tokens();
70
71        if total < 0
72            || self.current_usage.fresh_input.0 < 0
73            || self.current_usage.output.0 < 0
74            || self.current_usage.cache_creation.0 < 0
75            || self.current_usage.cache_read.0 < 0
76        {
77            return Err("Negative token count detected".to_string());
78        }
79
80        if let Some(limit) = model_limit
81            && total as u64 > limit
82        {
83            return Err(format!(
84                "Token count {} exceeds model limit {}",
85                total, limit
86            ));
87        }
88
89        Ok(())
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    #[test]
98    fn test_session_state_initialization() {
99        let state = SessionState::new("test-id".to_string(), None, None, Utc::now());
100
101        assert_eq!(state.session_id, "test-id");
102        assert!(state.current_usage.is_empty());
103        assert_eq!(state.current_reasoning_tokens, 0);
104        assert_eq!(state.error_count, 0);
105        assert_eq!(state.event_count, 0);
106        assert_eq!(state.turn_count, 0);
107    }
108
109    #[test]
110    fn test_session_state_token_snapshot() {
111        let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
112
113        state.current_usage = ContextWindowUsage::from_raw(100, 0, 0, 50);
114        assert_eq!(state.total_input_side_tokens(), 100);
115        assert_eq!(state.total_output_side_tokens(), 50);
116        assert_eq!(state.total_context_window_tokens(), 150);
117
118        state.current_usage = ContextWindowUsage::from_raw(10, 0, 1000, 60);
119        assert_eq!(state.total_input_side_tokens(), 1010);
120        assert_eq!(state.total_output_side_tokens(), 60);
121        assert_eq!(state.total_context_window_tokens(), 1070);
122    }
123
124    #[test]
125    fn test_validate_tokens_success() {
126        let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
127        state.current_usage = ContextWindowUsage::from_raw(1000, 2000, 10000, 500);
128        assert!(state.validate_tokens(Some(200_000)).is_ok());
129    }
130
131    #[test]
132    fn test_validate_tokens_exceeds_limit() {
133        let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
134        state.current_usage = ContextWindowUsage::from_raw(100_000, 0, 0, 150_000);
135        let result = state.validate_tokens(Some(200_000));
136        assert!(result.is_err());
137        assert!(result.unwrap_err().contains("exceeds model limit"));
138    }
139
140    #[test]
141    fn test_validate_tokens_negative() {
142        let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
143        state.current_usage = ContextWindowUsage::from_raw(-100, 0, 0, 0);
144
145        let result = state.validate_tokens(None);
146        assert!(result.is_err());
147        assert_eq!(result.unwrap_err(), "Negative token count detected");
148    }
149}