agtrace-engine 0.7.1

Internal core engine for the agtrace CLI. Not intended for direct use.
Documentation
use crate::domain::model::SessionState;
use agtrace_types::ModelLimitResolver;

/// Context window limit configuration for a model.
///
/// Defines the total token limit and reserved buffer percentage
/// for context management operations.
#[derive(Debug, Clone)]
pub struct TokenLimit {
    /// Total context window size in tokens.
    pub total_limit: u64,
    /// Percentage of context reserved for compaction/management (0-100).
    pub compaction_buffer_pct: f64,
}

impl TokenLimit {
    /// Create a new token limit with validation.
    ///
    /// # Panics
    /// Panics if `compaction_buffer_pct` is not in range 0-100.
    pub fn new(total_limit: u64, compaction_buffer_pct: f64) -> Self {
        assert!(
            (0.0..=100.0).contains(&compaction_buffer_pct),
            "compaction_buffer_pct must be in range 0-100, got: {}",
            compaction_buffer_pct
        );

        Self {
            total_limit,
            compaction_buffer_pct,
        }
    }

    /// Calculate effective usable token limit after buffer reservation.
    ///
    /// Returns the actual number of tokens available for use after
    /// subtracting the compaction buffer.
    pub fn effective_limit(&self) -> u64 {
        if self.compaction_buffer_pct == 0.0 {
            return self.total_limit;
        }

        let usable_pct = 100.0 - self.compaction_buffer_pct;
        let effective = (self.total_limit as f64 * usable_pct / 100.0).floor() as u64;

        effective.max(1)
    }
}

/// Token limit lookup service for model specifications.
///
/// Wraps a `ModelLimitResolver` to provide token limit queries
/// and usage percentage calculations for sessions.
pub struct TokenLimits<R: ModelLimitResolver> {
    resolver: R,
}

impl<R: ModelLimitResolver> TokenLimits<R> {
    /// Create a new token limits service with the given resolver.
    pub fn new(resolver: R) -> Self {
        Self { resolver }
    }

    /// Get token limit for a model by name.
    ///
    /// Returns `None` if the model is not recognized by the resolver.
    pub fn get_limit(&self, model: &str) -> Option<TokenLimit> {
        self.resolver
            .resolve_model_limit(model)
            .map(|spec| TokenLimit::new(spec.max_tokens, spec.compaction_buffer_pct))
    }

    /// Calculate usage percentages from session state.
    ///
    /// Returns `(input_pct, output_pct, total_pct)` where each is
    /// a percentage (0-100+) relative to the context window limit.
    /// Returns `None` if model or limit information is unavailable.
    pub fn get_usage_percentage_from_state(&self, state: &SessionState) -> Option<(f64, f64, f64)> {
        let limit_total = if let Some(l) = state.context_window_limit {
            l
        } else {
            let model = state.model.as_ref()?;
            self.get_limit(model)?.total_limit
        };

        let input_side = state.total_input_side_tokens() as u64;
        let output_side = state.total_output_side_tokens() as u64;
        let total = state.total_tokens().as_u64();

        let input_pct = (input_side as f64 / limit_total as f64) * 100.0;
        let output_pct = (output_side as f64 / limit_total as f64) * 100.0;
        let total_pct = (total as f64 / limit_total as f64) * 100.0;

        Some((input_pct, output_pct, total_pct))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::ContextWindowUsage;
    use agtrace_types::ModelSpec;
    use chrono::Utc;

    struct MockResolver;

    impl ModelLimitResolver for MockResolver {
        fn resolve_model_limit(&self, model: &str) -> Option<ModelSpec> {
            if model.starts_with("claude-3-5-sonnet") {
                Some(ModelSpec {
                    max_tokens: 200_000,
                    compaction_buffer_pct: 22.5,
                })
            } else {
                None
            }
        }
    }

    #[test]
    fn test_get_limit_exact_match() {
        let limits = TokenLimits::new(MockResolver);
        let limit = limits.get_limit("claude-3-5-sonnet-20241022");
        assert!(limit.is_some());
        let limit = limit.unwrap();
        assert_eq!(limit.total_limit, 200_000);
        assert_eq!(limit.compaction_buffer_pct, 22.5);
        assert_eq!(limit.effective_limit(), 155_000);
    }

    #[test]
    fn test_get_limit_prefix_match() {
        let limits = TokenLimits::new(MockResolver);
        let limit = limits.get_limit("claude-3-5-sonnet");
        assert!(limit.is_some());
    }

    #[test]
    fn test_unknown_model() {
        let limits = TokenLimits::new(MockResolver);
        let result = limits.get_limit("unknown-model");
        assert!(result.is_none());
    }

    #[test]
    fn test_get_usage_percentage_from_state() {
        let limits = TokenLimits::new(MockResolver);
        let mut state = SessionState::new("test".to_string(), None, None, Utc::now());
        state.model = Some("claude-3-5-sonnet-20241022".to_string());
        state.current_usage = ContextWindowUsage::from_raw(1000, 0, 12000, 500);

        let (input_pct, output_pct, total_pct) =
            limits.get_usage_percentage_from_state(&state).unwrap();

        let eps = 1e-6;
        assert!((input_pct - 6.5).abs() < eps);
        assert!((output_pct - 0.25).abs() < eps);
        assert!((total_pct - 6.75).abs() < eps);
    }

    #[test]
    fn test_get_usage_percentage_from_state_no_cache() {
        let limits = TokenLimits::new(MockResolver);
        let mut state = SessionState::new("test".to_string(), None, None, Utc::now());
        state.model = Some("claude-3-5-sonnet-20241022".to_string());
        state.context_window_limit = Some(200_000);
        state.current_usage = ContextWindowUsage::from_raw(100_000, 0, 0, 4_000);

        let (input_pct, output_pct, total_pct) =
            limits.get_usage_percentage_from_state(&state).unwrap();

        assert_eq!(input_pct, 50.0);
        assert_eq!(output_pct, 2.0);
        assert_eq!(total_pct, 52.0);
    }

    #[test]
    fn test_effective_limit() {
        let limit = TokenLimit::new(200_000, 22.5);
        assert_eq!(limit.effective_limit(), 155_000);

        let limit_no_buffer = TokenLimit::new(400_000, 0.0);
        assert_eq!(limit_no_buffer.effective_limit(), 400_000);

        let limit_high_buffer = TokenLimit::new(1000, 99.0);
        assert_eq!(limit_high_buffer.effective_limit(), 10);
    }

    #[test]
    #[should_panic(expected = "compaction_buffer_pct must be in range 0-100")]
    fn test_invalid_buffer_pct_negative() {
        TokenLimit::new(200_000, -10.0);
    }

    #[test]
    #[should_panic(expected = "compaction_buffer_pct must be in range 0-100")]
    fn test_invalid_buffer_pct_over_100() {
        TokenLimit::new(200_000, 150.0);
    }
}