agtrace_runtime/domain/
model.rs

1use agtrace_engine::ContextWindowUsage;
2use anyhow::Result;
3use chrono::{DateTime, Utc};
4use std::path::PathBuf;
5
6#[derive(Debug, Clone)]
7pub struct SessionState {
8    pub session_id: String,
9    pub project_root: Option<PathBuf>,
10    pub log_path: Option<PathBuf>,
11    pub start_time: DateTime<Utc>,
12    pub last_activity: DateTime<Utc>,
13    pub model: Option<String>,
14    pub context_window_limit: Option<u64>,
15    pub current_usage: ContextWindowUsage,
16    pub current_reasoning_tokens: i32,
17    pub error_count: u32,
18    pub event_count: usize,
19    pub turn_count: usize,
20}
21
22impl SessionState {
23    pub fn new(
24        session_id: String,
25        project_root: Option<PathBuf>,
26        log_path: Option<PathBuf>,
27        start_time: DateTime<Utc>,
28    ) -> Self {
29        Self {
30            session_id,
31            project_root,
32            log_path,
33            start_time,
34            last_activity: start_time,
35            model: None,
36            context_window_limit: None,
37            current_usage: ContextWindowUsage::default(),
38            current_reasoning_tokens: 0,
39            error_count: 0,
40            event_count: 0,
41            turn_count: 0,
42        }
43    }
44
45    pub fn total_input_side_tokens(&self) -> i32 {
46        self.current_usage.input_tokens()
47    }
48
49    pub fn total_output_side_tokens(&self) -> i32 {
50        self.current_usage.output_tokens()
51    }
52
53    /// Get total tokens as type-safe TokenCount
54    pub fn total_tokens(&self) -> agtrace_engine::TokenCount {
55        self.current_usage.total_tokens()
56    }
57
58    /// Get context limit as type-safe ContextLimit
59    pub fn context_limit(&self) -> Option<agtrace_engine::ContextLimit> {
60        self.context_window_limit
61            .map(agtrace_engine::ContextLimit::new)
62    }
63
64    pub fn validate_tokens(&self, model_limit: Option<u64>) -> Result<(), String> {
65        if self.current_usage.fresh_input.0 < 0
66            || self.current_usage.output.0 < 0
67            || self.current_usage.cache_creation.0 < 0
68            || self.current_usage.cache_read.0 < 0
69        {
70            return Err("Negative token count detected".to_string());
71        }
72
73        let total = self.total_tokens();
74        if let Some(limit) = model_limit
75            && total.as_u64() > limit
76        {
77            return Err(format!(
78                "Token count {} exceeds model limit {}",
79                total.as_u64(),
80                limit
81            ));
82        }
83
84        Ok(())
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[test]
93    fn test_session_state_initialization() {
94        let state = SessionState::new("test-id".to_string(), None, None, Utc::now());
95
96        assert_eq!(state.session_id, "test-id");
97        assert!(state.current_usage.is_empty());
98        assert_eq!(state.current_reasoning_tokens, 0);
99        assert_eq!(state.error_count, 0);
100        assert_eq!(state.event_count, 0);
101        assert_eq!(state.turn_count, 0);
102    }
103
104    #[test]
105    fn test_session_state_token_snapshot() {
106        let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
107
108        state.current_usage = ContextWindowUsage::from_raw(100, 0, 0, 50);
109        assert_eq!(state.total_input_side_tokens(), 100);
110        assert_eq!(state.total_output_side_tokens(), 50);
111        assert_eq!(state.total_tokens(), agtrace_engine::TokenCount::new(150));
112
113        state.current_usage = ContextWindowUsage::from_raw(10, 0, 1000, 60);
114        assert_eq!(state.total_input_side_tokens(), 1010);
115        assert_eq!(state.total_output_side_tokens(), 60);
116        assert_eq!(state.total_tokens(), agtrace_engine::TokenCount::new(1070));
117    }
118
119    #[test]
120    fn test_validate_tokens_success() {
121        let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
122        state.current_usage = ContextWindowUsage::from_raw(1000, 2000, 10000, 500);
123        assert!(state.validate_tokens(Some(200_000)).is_ok());
124    }
125
126    #[test]
127    fn test_validate_tokens_exceeds_limit() {
128        let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
129        state.current_usage = ContextWindowUsage::from_raw(100_000, 0, 0, 150_000);
130        let result = state.validate_tokens(Some(200_000));
131        assert!(result.is_err());
132        assert!(result.unwrap_err().contains("exceeds model limit"));
133    }
134
135    #[test]
136    fn test_validate_tokens_negative() {
137        let mut state = SessionState::new("test-id".to_string(), None, None, Utc::now());
138        state.current_usage = ContextWindowUsage::from_raw(-100, 0, 0, 0);
139
140        let result = state.validate_tokens(None);
141        assert!(result.is_err());
142        assert_eq!(result.unwrap_err(), "Negative token count detected");
143    }
144}