1use serde::Deserialize;
29
30use crate::error::CoreError;
31
32#[derive(Debug, Clone, Default, Deserialize)]
41pub struct LimitsConfig {
42 #[serde(default)]
44 pub max_turns: u32,
45
46 #[serde(default)]
48 pub max_tokens: u64,
49
50 #[serde(default)]
52 pub max_cost_usd: f64,
53
54 #[serde(default)]
56 pub max_duration_secs: u64,
57
58 #[serde(default)]
60 pub on_limit_reached: OnLimitReachedConfig,
61}
62
63impl LimitsConfig {
64 pub fn has_limits(&self) -> bool {
66 self.max_turns > 0
67 || self.max_tokens > 0
68 || self.max_cost_usd > 0.0
69 || self.max_duration_secs > 0
70 }
71
72 pub fn has_turns_limit(&self) -> bool {
74 self.max_turns > 0
75 }
76
77 pub fn has_tokens_limit(&self) -> bool {
79 self.max_tokens > 0
80 }
81
82 pub fn has_cost_limit(&self) -> bool {
84 self.max_cost_usd > 0.0
85 }
86
87 pub fn has_duration_limit(&self) -> bool {
89 self.max_duration_secs > 0
90 }
91
92 pub fn validate(&self) -> Result<(), CoreError> {
94 if self.max_cost_usd < 0.0 {
96 return Err(CoreError::ValidationError {
97 reason: format!(
98 "limits.max_cost_usd must be non-negative, got {}",
99 self.max_cost_usd
100 ),
101 });
102 }
103
104 self.on_limit_reached.validate()?;
106
107 Ok(())
108 }
109}
110
111#[derive(Debug, Clone, Deserialize)]
117pub struct OnLimitReachedConfig {
118 #[serde(default)]
120 pub action: LimitAction,
121
122 #[serde(default = "default_save_progress")]
124 pub save_progress: bool,
125
126 #[serde(default)]
128 pub message: Option<String>,
129}
130
131impl Default for OnLimitReachedConfig {
132 fn default() -> Self {
133 Self {
134 action: LimitAction::default(),
135 save_progress: default_save_progress(),
136 message: None,
137 }
138 }
139}
140
141impl OnLimitReachedConfig {
142 pub fn validate(&self) -> Result<(), CoreError> {
144 if self.action == LimitAction::Escalate && !self.save_progress {
146 }
148 Ok(())
149 }
150}
151
152fn default_save_progress() -> bool {
153 true
154}
155
156#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize)]
162#[serde(rename_all = "snake_case")]
163pub enum LimitAction {
164 #[default]
166 CompletePartial,
167
168 Fail,
170
171 Escalate,
173}
174
175impl LimitAction {
176 pub fn description(&self) -> &'static str {
178 match self {
179 LimitAction::CompletePartial => "complete with partial results",
180 LimitAction::Fail => "fail the task",
181 LimitAction::Escalate => "escalate to human/supervisor",
182 }
183 }
184}
185
186#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum LimitType {
193 Turns,
195 Tokens,
197 Cost,
199 Duration,
201}
202
203impl LimitType {
204 pub fn name(&self) -> &'static str {
206 match self {
207 LimitType::Turns => "turns",
208 LimitType::Tokens => "tokens",
209 LimitType::Cost => "cost",
210 LimitType::Duration => "duration",
211 }
212 }
213}
214
215impl std::fmt::Display for LimitType {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 write!(f, "{}", self.name())
218 }
219}
220
221#[derive(Debug, Clone)]
227pub struct LimitStatus {
228 pub limit_type: LimitType,
230 pub current: f64,
232 pub maximum: f64,
234 pub usage_pct: f64,
236 pub exceeded: bool,
238}
239
240impl LimitStatus {
241 pub fn new(limit_type: LimitType, current: f64, maximum: f64) -> Self {
243 let usage_pct = if maximum > 0.0 {
244 (current / maximum).min(1.0)
245 } else {
246 0.0
247 };
248 Self {
249 limit_type,
250 current,
251 maximum,
252 usage_pct,
253 exceeded: maximum > 0.0 && current >= maximum,
254 }
255 }
256
257 pub fn remaining(&self) -> f64 {
259 if self.maximum > 0.0 {
260 (self.maximum - self.current).max(0.0)
261 } else {
262 f64::INFINITY
263 }
264 }
265}
266
267#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::serde_yaml;
275
276 #[test]
281 fn parse_limits_config_full() {
282 let yaml = r#"
283max_turns: 20
284max_tokens: 50000
285max_cost_usd: 2.00
286max_duration_secs: 300
287on_limit_reached:
288 action: complete_partial
289 save_progress: true
290 message: "Limit reached, returning partial results"
291"#;
292 let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
293 assert_eq!(config.max_turns, 20);
294 assert_eq!(config.max_tokens, 50000);
295 assert_eq!(config.max_cost_usd, 2.00);
296 assert_eq!(config.max_duration_secs, 300);
297 assert_eq!(config.on_limit_reached.action, LimitAction::CompletePartial);
298 assert!(config.on_limit_reached.save_progress);
299 assert!(config.on_limit_reached.message.is_some());
300 }
301
302 #[test]
303 fn parse_limits_config_defaults() {
304 let yaml = "";
305 let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
306 assert_eq!(config.max_turns, 0); assert_eq!(config.max_tokens, 0); assert!((config.max_cost_usd - 0.0).abs() < f64::EPSILON); assert_eq!(config.max_duration_secs, 0); assert_eq!(config.on_limit_reached.action, LimitAction::CompletePartial);
311 assert!(config.on_limit_reached.save_progress);
312 }
313
314 #[test]
315 fn parse_limits_config_partial() {
316 let yaml = r#"
317max_turns: 10
318max_cost_usd: 1.50
319"#;
320 let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
321 assert_eq!(config.max_turns, 10);
322 assert_eq!(config.max_cost_usd, 1.50);
323 assert_eq!(config.max_tokens, 0); assert_eq!(config.max_duration_secs, 0); }
326
327 #[test]
332 fn parse_limit_action_complete_partial() {
333 let yaml = r#"
334on_limit_reached:
335 action: complete_partial
336"#;
337 let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
338 assert_eq!(config.on_limit_reached.action, LimitAction::CompletePartial);
339 }
340
341 #[test]
342 fn parse_limit_action_fail() {
343 let yaml = r#"
344on_limit_reached:
345 action: fail
346"#;
347 let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
348 assert_eq!(config.on_limit_reached.action, LimitAction::Fail);
349 }
350
351 #[test]
352 fn parse_limit_action_escalate() {
353 let yaml = r#"
354on_limit_reached:
355 action: escalate
356"#;
357 let config: LimitsConfig = serde_yaml::from_str(yaml).unwrap();
358 assert_eq!(config.on_limit_reached.action, LimitAction::Escalate);
359 }
360
361 #[test]
366 fn has_limits_false_when_all_zero() {
367 let config = LimitsConfig::default();
368 assert!(!config.has_limits());
369 }
370
371 #[test]
372 fn has_limits_true_when_turns_set() {
373 let config = LimitsConfig {
374 max_turns: 10,
375 ..Default::default()
376 };
377 assert!(config.has_limits());
378 assert!(config.has_turns_limit());
379 assert!(!config.has_tokens_limit());
380 }
381
382 #[test]
383 fn has_limits_true_when_tokens_set() {
384 let config = LimitsConfig {
385 max_tokens: 50000,
386 ..Default::default()
387 };
388 assert!(config.has_limits());
389 assert!(config.has_tokens_limit());
390 }
391
392 #[test]
393 fn has_limits_true_when_cost_set() {
394 let config = LimitsConfig {
395 max_cost_usd: 2.00,
396 ..Default::default()
397 };
398 assert!(config.has_limits());
399 assert!(config.has_cost_limit());
400 }
401
402 #[test]
403 fn has_limits_true_when_duration_set() {
404 let config = LimitsConfig {
405 max_duration_secs: 300,
406 ..Default::default()
407 };
408 assert!(config.has_limits());
409 assert!(config.has_duration_limit());
410 }
411
412 #[test]
417 fn validate_config_valid() {
418 let config = LimitsConfig {
419 max_turns: 20,
420 max_tokens: 50000,
421 max_cost_usd: 2.00,
422 max_duration_secs: 300,
423 ..Default::default()
424 };
425 assert!(config.validate().is_ok());
426 }
427
428 #[test]
429 fn validate_negative_cost_invalid() {
430 let config = LimitsConfig {
431 max_cost_usd: -1.00,
432 ..Default::default()
433 };
434 let err = config.validate().unwrap_err();
435 assert!(err.to_string().contains("max_cost_usd"));
436 assert!(err.to_string().contains("non-negative"));
437 }
438
439 #[test]
440 fn validate_zero_values_valid() {
441 let config = LimitsConfig::default();
442 assert!(config.validate().is_ok());
443 }
444
445 #[test]
450 fn limit_status_not_exceeded() {
451 let status = LimitStatus::new(LimitType::Turns, 5.0, 20.0);
452 assert!(!status.exceeded);
453 assert_eq!(status.usage_pct, 0.25);
454 assert_eq!(status.remaining(), 15.0);
455 }
456
457 #[test]
458 fn limit_status_exceeded() {
459 let status = LimitStatus::new(LimitType::Tokens, 50000.0, 50000.0);
460 assert!(status.exceeded);
461 assert_eq!(status.usage_pct, 1.0);
462 assert_eq!(status.remaining(), 0.0);
463 }
464
465 #[test]
466 fn limit_status_over_exceeded() {
467 let status = LimitStatus::new(LimitType::Cost, 3.50, 2.00);
468 assert!(status.exceeded);
469 assert_eq!(status.usage_pct, 1.0); assert_eq!(status.remaining(), 0.0);
471 }
472
473 #[test]
474 fn limit_status_unlimited() {
475 let status = LimitStatus::new(LimitType::Duration, 100.0, 0.0);
476 assert!(!status.exceeded);
477 assert_eq!(status.usage_pct, 0.0);
478 assert!(status.remaining().is_infinite());
479 }
480
481 #[test]
486 fn limit_type_names() {
487 assert_eq!(LimitType::Turns.name(), "turns");
488 assert_eq!(LimitType::Tokens.name(), "tokens");
489 assert_eq!(LimitType::Cost.name(), "cost");
490 assert_eq!(LimitType::Duration.name(), "duration");
491 }
492
493 #[test]
494 fn limit_type_display() {
495 assert_eq!(format!("{}", LimitType::Turns), "turns");
496 assert_eq!(format!("{}", LimitType::Cost), "cost");
497 }
498
499 #[test]
504 fn limit_action_descriptions() {
505 assert_eq!(
506 LimitAction::CompletePartial.description(),
507 "complete with partial results"
508 );
509 assert_eq!(LimitAction::Fail.description(), "fail the task");
510 assert_eq!(
511 LimitAction::Escalate.description(),
512 "escalate to human/supervisor"
513 );
514 }
515}