agtrace_runtime/domain/
model.rs

1use agtrace_engine::ContextWindowUsage;
2use anyhow::Result;
3use chrono::{DateTime, Utc};
4use std::path::PathBuf;
5
6#[derive(Debug, Clone)]
7pub struct SessionState {
8    pub session_id: String,
9    pub project_root: Option<PathBuf>,
10    pub start_time: DateTime<Utc>,
11    pub last_activity: DateTime<Utc>,
12    pub model: Option<String>,
13    pub context_window_limit: Option<u64>,
14    pub current_usage: ContextWindowUsage,
15    pub current_reasoning_tokens: i32,
16    pub error_count: u32,
17    pub event_count: usize,
18    pub turn_count: usize,
19}
20
21impl SessionState {
22    pub fn new(
23        session_id: String,
24        project_root: Option<PathBuf>,
25        start_time: DateTime<Utc>,
26    ) -> Self {
27        Self {
28            session_id,
29            project_root,
30            start_time,
31            last_activity: start_time,
32            model: None,
33            context_window_limit: None,
34            current_usage: ContextWindowUsage::default(),
35            current_reasoning_tokens: 0,
36            error_count: 0,
37            event_count: 0,
38            turn_count: 0,
39        }
40    }
41
42    pub fn total_input_side_tokens(&self) -> i32 {
43        self.current_usage.input_tokens()
44    }
45
46    pub fn total_output_side_tokens(&self) -> i32 {
47        self.current_usage.output_tokens()
48    }
49
50    pub fn total_context_window_tokens(&self) -> i32 {
51        self.current_usage.context_window_tokens()
52    }
53
54    /// Get total tokens as type-safe TokenCount
55    pub fn total_tokens(&self) -> agtrace_engine::TokenCount {
56        self.current_usage.total_tokens()
57    }
58
59    /// Get context limit as type-safe ContextLimit
60    pub fn context_limit(&self) -> Option<agtrace_engine::ContextLimit> {
61        self.context_window_limit
62            .map(agtrace_engine::ContextLimit::new)
63    }
64
65    pub fn validate_tokens(&self, model_limit: Option<u64>) -> Result<(), String> {
66        let total = self.total_context_window_tokens();
67
68        if total < 0
69            || self.current_usage.fresh_input.0 < 0
70            || self.current_usage.output.0 < 0
71            || self.current_usage.cache_creation.0 < 0
72            || self.current_usage.cache_read.0 < 0
73        {
74            return Err("Negative token count detected".to_string());
75        }
76
77        if let Some(limit) = model_limit
78            && total as u64 > limit
79        {
80            return Err(format!(
81                "Token count {} exceeds model limit {}",
82                total, limit
83            ));
84        }
85
86        Ok(())
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn test_session_state_initialization() {
96        let state = SessionState::new("test-id".to_string(), None, Utc::now());
97
98        assert_eq!(state.session_id, "test-id");
99        assert!(state.current_usage.is_empty());
100        assert_eq!(state.current_reasoning_tokens, 0);
101        assert_eq!(state.error_count, 0);
102        assert_eq!(state.event_count, 0);
103        assert_eq!(state.turn_count, 0);
104    }
105
106    #[test]
107    fn test_session_state_token_snapshot() {
108        let mut state = SessionState::new("test-id".to_string(), None, Utc::now());
109
110        state.current_usage = ContextWindowUsage::from_raw(100, 0, 0, 50);
111        assert_eq!(state.total_input_side_tokens(), 100);
112        assert_eq!(state.total_output_side_tokens(), 50);
113        assert_eq!(state.total_context_window_tokens(), 150);
114
115        state.current_usage = ContextWindowUsage::from_raw(10, 0, 1000, 60);
116        assert_eq!(state.total_input_side_tokens(), 1010);
117        assert_eq!(state.total_output_side_tokens(), 60);
118        assert_eq!(state.total_context_window_tokens(), 1070);
119    }
120
121    #[test]
122    fn test_validate_tokens_success() {
123        let mut state = SessionState::new("test-id".to_string(), None, Utc::now());
124        state.current_usage = ContextWindowUsage::from_raw(1000, 2000, 10000, 500);
125        assert!(state.validate_tokens(Some(200_000)).is_ok());
126    }
127
128    #[test]
129    fn test_validate_tokens_exceeds_limit() {
130        let mut state = SessionState::new("test-id".to_string(), None, Utc::now());
131        state.current_usage = ContextWindowUsage::from_raw(100_000, 0, 0, 150_000);
132        let result = state.validate_tokens(Some(200_000));
133        assert!(result.is_err());
134        assert!(result.unwrap_err().contains("exceeds model limit"));
135    }
136
137    #[test]
138    fn test_validate_tokens_negative() {
139        let mut state = SessionState::new("test-id".to_string(), None, Utc::now());
140        state.current_usage = ContextWindowUsage::from_raw(-100, 0, 0, 0);
141
142        let result = state.validate_tokens(None);
143        assert!(result.is_err());
144        assert_eq!(result.unwrap_err(), "Negative token count detected");
145    }
146}