use serde::Deserialize;
use crate::error::NikaError;
#[derive(Debug, Clone, Default, Deserialize)]
pub struct LimitsConfig {
#[serde(default)]
pub max_turns: u32,
#[serde(default)]
pub max_tokens: u64,
#[serde(default)]
pub max_cost_usd: f64,
#[serde(default)]
pub max_duration_secs: u64,
#[serde(default)]
pub on_limit_reached: OnLimitReachedConfig,
}
impl LimitsConfig {
pub fn has_limits(&self) -> bool {
self.max_turns > 0
|| self.max_tokens > 0
|| self.max_cost_usd > 0.0
|| self.max_duration_secs > 0
}
pub fn has_turns_limit(&self) -> bool {
self.max_turns > 0
}
pub fn has_tokens_limit(&self) -> bool {
self.max_tokens > 0
}
pub fn has_cost_limit(&self) -> bool {
self.max_cost_usd > 0.0
}
pub fn has_duration_limit(&self) -> bool {
self.max_duration_secs > 0
}
pub fn validate(&self) -> Result<(), NikaError> {
if self.max_cost_usd < 0.0 {
return Err(NikaError::ValidationError {
reason: format!(
"limits.max_cost_usd must be non-negative, got {}",
self.max_cost_usd
),
});
}
self.on_limit_reached.validate()?;
Ok(())
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct OnLimitReachedConfig {
#[serde(default)]
pub action: LimitAction,
#[serde(default = "default_save_progress")]
pub save_progress: bool,
#[serde(default)]
pub message: Option<String>,
}
impl Default for OnLimitReachedConfig {
fn default() -> Self {
Self {
action: LimitAction::default(),
save_progress: default_save_progress(),
message: None,
}
}
}
impl OnLimitReachedConfig {
pub fn validate(&self) -> Result<(), NikaError> {
if self.action == LimitAction::Escalate && !self.save_progress {
}
Ok(())
}
}
fn default_save_progress() -> bool {
true
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LimitAction {
#[default]
CompletePartial,
Fail,
Escalate,
}
impl LimitAction {
pub fn description(&self) -> &'static str {
match self {
LimitAction::CompletePartial => "complete with partial results",
LimitAction::Fail => "fail the task",
LimitAction::Escalate => "escalate to human/supervisor",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LimitType {
Turns,
Tokens,
Cost,
Duration,
}
impl LimitType {
pub fn name(&self) -> &'static str {
match self {
LimitType::Turns => "turns",
LimitType::Tokens => "tokens",
LimitType::Cost => "cost",
LimitType::Duration => "duration",
}
}
}
impl std::fmt::Display for LimitType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone)]
pub struct LimitStatus {
pub limit_type: LimitType,
pub current: f64,
pub maximum: f64,
pub usage_pct: f64,
pub exceeded: bool,
}
impl LimitStatus {
pub fn new(limit_type: LimitType, current: f64, maximum: f64) -> Self {
let usage_pct = if maximum > 0.0 {
(current / maximum).min(1.0)
} else {
0.0
};
Self {
limit_type,
current,
maximum,
usage_pct,
exceeded: maximum > 0.0 && current >= maximum,
}
}
pub fn remaining(&self) -> f64 {
if self.maximum > 0.0 {
(self.maximum - self.current).max(0.0)
} else {
f64::INFINITY
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::serde_yaml;
#[test]
fn parse_limits_config_full() {
let yaml = r#"
max_turns: 20
max_tokens: 50000
max_cost_usd: 2.00
max_duration_secs: 300
on_limit_reached:
action: complete_partial
save_progress: true
message: "Limit reached, returning partial results"
"#;
let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.max_turns, 20);
assert_eq!(config.max_tokens, 50000);
assert_eq!(config.max_cost_usd, 2.00);
assert_eq!(config.max_duration_secs, 300);
assert_eq!(config.on_limit_reached.action, LimitAction::CompletePartial);
assert!(config.on_limit_reached.save_progress);
assert!(config.on_limit_reached.message.is_some());
}
#[test]
fn parse_limits_config_defaults() {
let yaml = "";
let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.max_turns, 0); assert_eq!(config.max_tokens, 0); assert!((config.max_cost_usd - 0.0).abs() < f64::EPSILON); assert_eq!(config.max_duration_secs, 0); assert_eq!(config.on_limit_reached.action, LimitAction::CompletePartial);
assert!(config.on_limit_reached.save_progress);
}
#[test]
fn parse_limits_config_partial() {
let yaml = r#"
max_turns: 10
max_cost_usd: 1.50
"#;
let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.max_turns, 10);
assert_eq!(config.max_cost_usd, 1.50);
assert_eq!(config.max_tokens, 0); assert_eq!(config.max_duration_secs, 0); }
#[test]
fn parse_limit_action_complete_partial() {
let yaml = r#"
on_limit_reached:
action: complete_partial
"#;
let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.on_limit_reached.action, LimitAction::CompletePartial);
}
#[test]
fn parse_limit_action_fail() {
let yaml = r#"
on_limit_reached:
action: fail
"#;
let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.on_limit_reached.action, LimitAction::Fail);
}
#[test]
fn parse_limit_action_escalate() {
let yaml = r#"
on_limit_reached:
action: escalate
"#;
let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.on_limit_reached.action, LimitAction::Escalate);
}
#[test]
fn has_limits_false_when_all_zero() {
let config = LimitsConfig::default();
assert!(!config.has_limits());
}
#[test]
fn has_limits_true_when_turns_set() {
let config = LimitsConfig {
max_turns: 10,
..Default::default()
};
assert!(config.has_limits());
assert!(config.has_turns_limit());
assert!(!config.has_tokens_limit());
}
#[test]
fn has_limits_true_when_tokens_set() {
let config = LimitsConfig {
max_tokens: 50000,
..Default::default()
};
assert!(config.has_limits());
assert!(config.has_tokens_limit());
}
#[test]
fn has_limits_true_when_cost_set() {
let config = LimitsConfig {
max_cost_usd: 2.00,
..Default::default()
};
assert!(config.has_limits());
assert!(config.has_cost_limit());
}
#[test]
fn has_limits_true_when_duration_set() {
let config = LimitsConfig {
max_duration_secs: 300,
..Default::default()
};
assert!(config.has_limits());
assert!(config.has_duration_limit());
}
#[test]
fn validate_config_valid() {
let config = LimitsConfig {
max_turns: 20,
max_tokens: 50000,
max_cost_usd: 2.00,
max_duration_secs: 300,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn validate_negative_cost_invalid() {
let config = LimitsConfig {
max_cost_usd: -1.00,
..Default::default()
};
let err = config.validate().unwrap_err();
assert!(err.to_string().contains("max_cost_usd"));
assert!(err.to_string().contains("non-negative"));
}
#[test]
fn validate_zero_values_valid() {
let config = LimitsConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn limit_status_not_exceeded() {
let status = LimitStatus::new(LimitType::Turns, 5.0, 20.0);
assert!(!status.exceeded);
assert_eq!(status.usage_pct, 0.25);
assert_eq!(status.remaining(), 15.0);
}
#[test]
fn limit_status_exceeded() {
let status = LimitStatus::new(LimitType::Tokens, 50000.0, 50000.0);
assert!(status.exceeded);
assert_eq!(status.usage_pct, 1.0);
assert_eq!(status.remaining(), 0.0);
}
#[test]
fn limit_status_over_exceeded() {
let status = LimitStatus::new(LimitType::Cost, 3.50, 2.00);
assert!(status.exceeded);
assert_eq!(status.usage_pct, 1.0); assert_eq!(status.remaining(), 0.0);
}
#[test]
fn limit_status_unlimited() {
let status = LimitStatus::new(LimitType::Duration, 100.0, 0.0);
assert!(!status.exceeded);
assert_eq!(status.usage_pct, 0.0);
assert!(status.remaining().is_infinite());
}
#[test]
fn limit_type_names() {
assert_eq!(LimitType::Turns.name(), "turns");
assert_eq!(LimitType::Tokens.name(), "tokens");
assert_eq!(LimitType::Cost.name(), "cost");
assert_eq!(LimitType::Duration.name(), "duration");
}
#[test]
fn limit_type_display() {
assert_eq!(format!("{}", LimitType::Turns), "turns");
assert_eq!(format!("{}", LimitType::Cost), "cost");
}
#[test]
fn limit_action_descriptions() {
assert_eq!(
LimitAction::CompletePartial.description(),
"complete with partial results"
);
assert_eq!(LimitAction::Fail.description(), "fail the task");
assert_eq!(
LimitAction::Escalate.description(),
"escalate to human/supervisor"
);
}
}