Skip to main content

nika_core/ast/
limits.rs

1//! Agent Limits Configuration
2//!
3//! Defines resource limits for agent execution including turns, tokens,
4//! cost, and duration constraints with configurable actions on limit reach.
5//!
6//! ## Features
7//!
8//! - **max_turns**: Maximum agentic loop iterations
9//! - **max_tokens**: Total token budget (input + output)
10//! - **max_cost_usd**: Cost ceiling per execution
11//! - **max_duration_secs**: Wall-clock timeout
12//!
13//! ## Example
14//!
15//! ```yaml
16//! agent:
17//!   prompt: "Research {{topic}}"
18//!   limits:
19//!     max_turns: 20
20//!     max_tokens: 50000
21//!     max_cost_usd: 2.00
22//!     max_duration_secs: 300
23//!     on_limit_reached:
24//!       action: complete_partial
25//!       save_progress: true
26//! ```
27
28use serde::Deserialize;
29
30use crate::error::CoreError;
31
32// ═══════════════════════════════════════════════════════════════════════════
33// LimitsConfig
34// ═══════════════════════════════════════════════════════════════════════════
35
36/// Configuration for agent execution limits.
37///
38/// Defines resource constraints that prevent runaway execution
39/// and enable cost control for LLM-based agents.
40#[derive(Debug, Clone, Default, Deserialize)]
41pub struct LimitsConfig {
42    /// Maximum number of agentic loop turns (0 = unlimited)
43    #[serde(default)]
44    pub max_turns: u32,
45
46    /// Maximum total tokens (input + output) (0 = unlimited)
47    #[serde(default)]
48    pub max_tokens: u64,
49
50    /// Maximum cost in USD (0.0 = unlimited)
51    #[serde(default)]
52    pub max_cost_usd: f64,
53
54    /// Maximum execution duration in seconds (0 = unlimited)
55    #[serde(default)]
56    pub max_duration_secs: u64,
57
58    /// Action to take when a limit is reached
59    #[serde(default)]
60    pub on_limit_reached: OnLimitReachedConfig,
61}
62
63impl LimitsConfig {
64    /// Check if any limits are configured.
65    pub fn has_limits(&self) -> bool {
66        self.max_turns > 0
67            || self.max_tokens > 0
68            || self.max_cost_usd > 0.0
69            || self.max_duration_secs > 0
70    }
71
72    /// Check if turns limit is configured.
73    pub fn has_turns_limit(&self) -> bool {
74        self.max_turns > 0
75    }
76
77    /// Check if tokens limit is configured.
78    pub fn has_tokens_limit(&self) -> bool {
79        self.max_tokens > 0
80    }
81
82    /// Check if cost limit is configured.
83    pub fn has_cost_limit(&self) -> bool {
84        self.max_cost_usd > 0.0
85    }
86
87    /// Check if duration limit is configured.
88    pub fn has_duration_limit(&self) -> bool {
89        self.max_duration_secs > 0
90    }
91
92    /// Validate the limits configuration.
93    pub fn validate(&self) -> Result<(), CoreError> {
94        // Cost must be non-negative
95        if self.max_cost_usd < 0.0 {
96            return Err(CoreError::ValidationError {
97                reason: format!(
98                    "limits.max_cost_usd must be non-negative, got {}",
99                    self.max_cost_usd
100                ),
101            });
102        }
103
104        // Validate on_limit_reached
105        self.on_limit_reached.validate()?;
106
107        Ok(())
108    }
109}
110
111// ═══════════════════════════════════════════════════════════════════════════
112// OnLimitReachedConfig
113// ═══════════════════════════════════════════════════════════════════════════
114
115/// Configuration for behavior when a limit is reached.
116#[derive(Debug, Clone, Deserialize)]
117pub struct OnLimitReachedConfig {
118    /// Action to take when limit is reached
119    #[serde(default)]
120    pub action: LimitAction,
121
122    /// Whether to save partial progress
123    #[serde(default = "default_save_progress")]
124    pub save_progress: bool,
125
126    /// Custom message for partial completion
127    #[serde(default)]
128    pub message: Option<String>,
129}
130
131impl Default for OnLimitReachedConfig {
132    fn default() -> Self {
133        Self {
134            action: LimitAction::default(),
135            save_progress: default_save_progress(),
136            message: None,
137        }
138    }
139}
140
141impl OnLimitReachedConfig {
142    /// Validate the on_limit_reached configuration.
143    pub fn validate(&self) -> Result<(), CoreError> {
144        // Escalate action requires save_progress to be useful
145        if self.action == LimitAction::Escalate && !self.save_progress {
146            // Just a warning, not an error
147        }
148        Ok(())
149    }
150}
151
152fn default_save_progress() -> bool {
153    true
154}
155
156// ═══════════════════════════════════════════════════════════════════════════
157// LimitAction
158// ═══════════════════════════════════════════════════════════════════════════
159
160/// Action to take when a limit is reached.
161#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize)]
162#[serde(rename_all = "snake_case")]
163pub enum LimitAction {
164    /// Complete with partial results (recommended)
165    #[default]
166    CompletePartial,
167
168    /// Fail the task with an error
169    Fail,
170
171    /// Escalate to human or supervisor agent
172    Escalate,
173}
174
175impl LimitAction {
176    /// Get a human-readable description of the action.
177    pub fn description(&self) -> &'static str {
178        match self {
179            LimitAction::CompletePartial => "complete with partial results",
180            LimitAction::Fail => "fail the task",
181            LimitAction::Escalate => "escalate to human/supervisor",
182        }
183    }
184}
185
186// ═══════════════════════════════════════════════════════════════════════════
187// LimitType
188// ═══════════════════════════════════════════════════════════════════════════
189
190/// Type of limit that was reached.
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum LimitType {
193    /// Maximum turns exceeded
194    Turns,
195    /// Maximum tokens exceeded
196    Tokens,
197    /// Maximum cost exceeded
198    Cost,
199    /// Maximum duration exceeded
200    Duration,
201}
202
203impl LimitType {
204    /// Get a human-readable name for the limit type.
205    pub fn name(&self) -> &'static str {
206        match self {
207            LimitType::Turns => "turns",
208            LimitType::Tokens => "tokens",
209            LimitType::Cost => "cost",
210            LimitType::Duration => "duration",
211        }
212    }
213}
214
215impl std::fmt::Display for LimitType {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        write!(f, "{}", self.name())
218    }
219}
220
221// ═══════════════════════════════════════════════════════════════════════════
222// LimitStatus
223// ═══════════════════════════════════════════════════════════════════════════
224
225/// Current status of a limit check.
226#[derive(Debug, Clone)]
227pub struct LimitStatus {
228    /// Type of limit
229    pub limit_type: LimitType,
230    /// Current value
231    pub current: f64,
232    /// Maximum allowed value
233    pub maximum: f64,
234    /// Percentage used (0.0-1.0)
235    pub usage_pct: f64,
236    /// Whether the limit has been exceeded
237    pub exceeded: bool,
238}
239
240impl LimitStatus {
241    /// Create a new limit status.
242    pub fn new(limit_type: LimitType, current: f64, maximum: f64) -> Self {
243        let usage_pct = if maximum > 0.0 {
244            (current / maximum).min(1.0)
245        } else {
246            0.0
247        };
248        Self {
249            limit_type,
250            current,
251            maximum,
252            usage_pct,
253            exceeded: maximum > 0.0 && current >= maximum,
254        }
255    }
256
257    /// Get remaining capacity.
258    pub fn remaining(&self) -> f64 {
259        if self.maximum > 0.0 {
260            (self.maximum - self.current).max(0.0)
261        } else {
262            f64::INFINITY
263        }
264    }
265}
266
267// ═══════════════════════════════════════════════════════════════════════════
268// Tests
269// ═══════════════════════════════════════════════════════════════════════════
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use crate::serde_yaml;
275
276    // ========================================================================
277    // LimitsConfig parsing tests
278    // ========================================================================
279
280    #[test]
281    fn parse_limits_config_full() {
282        let yaml = r#"
283max_turns: 20
284max_tokens: 50000
285max_cost_usd: 2.00
286max_duration_secs: 300
287on_limit_reached:
288  action: complete_partial
289  save_progress: true
290  message: "Limit reached, returning partial results"
291"#;
292        let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
293        assert_eq!(config.max_turns, 20);
294        assert_eq!(config.max_tokens, 50000);
295        assert_eq!(config.max_cost_usd, 2.00);
296        assert_eq!(config.max_duration_secs, 300);
297        assert_eq!(config.on_limit_reached.action, LimitAction::CompletePartial);
298        assert!(config.on_limit_reached.save_progress);
299        assert!(config.on_limit_reached.message.is_some());
300    }
301
302    #[test]
303    fn parse_limits_config_defaults() {
304        let yaml = "";
305        let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
306        assert_eq!(config.max_turns, 0); // 0 = unlimited
307        assert_eq!(config.max_tokens, 0); // 0 = unlimited
308        assert!((config.max_cost_usd - 0.0).abs() < f64::EPSILON); // 0.0 = unlimited
309        assert_eq!(config.max_duration_secs, 0); // 0 = unlimited
310        assert_eq!(config.on_limit_reached.action, LimitAction::CompletePartial);
311        assert!(config.on_limit_reached.save_progress);
312    }
313
314    #[test]
315    fn parse_limits_config_partial() {
316        let yaml = r#"
317max_turns: 10
318max_cost_usd: 1.50
319"#;
320        let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
321        assert_eq!(config.max_turns, 10);
322        assert_eq!(config.max_cost_usd, 1.50);
323        assert_eq!(config.max_tokens, 0); // default
324        assert_eq!(config.max_duration_secs, 0); // default
325    }
326
327    // ========================================================================
328    // LimitAction parsing tests
329    // ========================================================================
330
331    #[test]
332    fn parse_limit_action_complete_partial() {
333        let yaml = r#"
334on_limit_reached:
335  action: complete_partial
336"#;
337        let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
338        assert_eq!(config.on_limit_reached.action, LimitAction::CompletePartial);
339    }
340
341    #[test]
342    fn parse_limit_action_fail() {
343        let yaml = r#"
344on_limit_reached:
345  action: fail
346"#;
347        let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
348        assert_eq!(config.on_limit_reached.action, LimitAction::Fail);
349    }
350
351    #[test]
352    fn parse_limit_action_escalate() {
353        let yaml = r#"
354on_limit_reached:
355  action: escalate
356"#;
357        let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
358        assert_eq!(config.on_limit_reached.action, LimitAction::Escalate);
359    }
360
361    // ========================================================================
362    // has_limits tests
363    // ========================================================================
364
365    #[test]
366    fn has_limits_false_when_all_zero() {
367        let config = LimitsConfig::default();
368        assert!(!config.has_limits());
369    }
370
371    #[test]
372    fn has_limits_true_when_turns_set() {
373        let config = LimitsConfig {
374            max_turns: 10,
375            ..Default::default()
376        };
377        assert!(config.has_limits());
378        assert!(config.has_turns_limit());
379        assert!(!config.has_tokens_limit());
380    }
381
382    #[test]
383    fn has_limits_true_when_tokens_set() {
384        let config = LimitsConfig {
385            max_tokens: 50000,
386            ..Default::default()
387        };
388        assert!(config.has_limits());
389        assert!(config.has_tokens_limit());
390    }
391
392    #[test]
393    fn has_limits_true_when_cost_set() {
394        let config = LimitsConfig {
395            max_cost_usd: 2.00,
396            ..Default::default()
397        };
398        assert!(config.has_limits());
399        assert!(config.has_cost_limit());
400    }
401
402    #[test]
403    fn has_limits_true_when_duration_set() {
404        let config = LimitsConfig {
405            max_duration_secs: 300,
406            ..Default::default()
407        };
408        assert!(config.has_limits());
409        assert!(config.has_duration_limit());
410    }
411
412    // ========================================================================
413    // Validation tests
414    // ========================================================================
415
416    #[test]
417    fn validate_config_valid() {
418        let config = LimitsConfig {
419            max_turns: 20,
420            max_tokens: 50000,
421            max_cost_usd: 2.00,
422            max_duration_secs: 300,
423            ..Default::default()
424        };
425        assert!(config.validate().is_ok());
426    }
427
428    #[test]
429    fn validate_negative_cost_invalid() {
430        let config = LimitsConfig {
431            max_cost_usd: -1.00,
432            ..Default::default()
433        };
434        let err = config.validate().unwrap_err();
435        assert!(err.to_string().contains("max_cost_usd"));
436        assert!(err.to_string().contains("non-negative"));
437    }
438
439    #[test]
440    fn validate_zero_values_valid() {
441        let config = LimitsConfig::default();
442        assert!(config.validate().is_ok());
443    }
444
445    // ========================================================================
446    // LimitStatus tests
447    // ========================================================================
448
449    #[test]
450    fn limit_status_not_exceeded() {
451        let status = LimitStatus::new(LimitType::Turns, 5.0, 20.0);
452        assert!(!status.exceeded);
453        assert_eq!(status.usage_pct, 0.25);
454        assert_eq!(status.remaining(), 15.0);
455    }
456
457    #[test]
458    fn limit_status_exceeded() {
459        let status = LimitStatus::new(LimitType::Tokens, 50000.0, 50000.0);
460        assert!(status.exceeded);
461        assert_eq!(status.usage_pct, 1.0);
462        assert_eq!(status.remaining(), 0.0);
463    }
464
465    #[test]
466    fn limit_status_over_exceeded() {
467        let status = LimitStatus::new(LimitType::Cost, 3.50, 2.00);
468        assert!(status.exceeded);
469        assert_eq!(status.usage_pct, 1.0); // Capped at 1.0
470        assert_eq!(status.remaining(), 0.0);
471    }
472
473    #[test]
474    fn limit_status_unlimited() {
475        let status = LimitStatus::new(LimitType::Duration, 100.0, 0.0);
476        assert!(!status.exceeded);
477        assert_eq!(status.usage_pct, 0.0);
478        assert!(status.remaining().is_infinite());
479    }
480
481    // ========================================================================
482    // LimitType tests
483    // ========================================================================
484
485    #[test]
486    fn limit_type_names() {
487        assert_eq!(LimitType::Turns.name(), "turns");
488        assert_eq!(LimitType::Tokens.name(), "tokens");
489        assert_eq!(LimitType::Cost.name(), "cost");
490        assert_eq!(LimitType::Duration.name(), "duration");
491    }
492
493    #[test]
494    fn limit_type_display() {
495        assert_eq!(format!("{}", LimitType::Turns), "turns");
496        assert_eq!(format!("{}", LimitType::Cost), "cost");
497    }
498
499    // ========================================================================
500    // LimitAction description tests
501    // ========================================================================
502
503    #[test]
504    fn limit_action_descriptions() {
505        assert_eq!(
506            LimitAction::CompletePartial.description(),
507            "complete with partial results"
508        );
509        assert_eq!(LimitAction::Fail.description(), "fail the task");
510        assert_eq!(
511            LimitAction::Escalate.description(),
512            "escalate to human/supervisor"
513        );
514    }
515}