Skip to main content

nika_engine/runtime/
partial.rs

1//! Partial completion support for agent execution
2//!
3//! When an agent hits a limit (turns, tokens, cost, duration), this module
4//! allows capturing and saving the partial progress for potential resumption.
5//!
6//! ## Usage
7//!
8//! ```rust,ignore
9//! use nika::runtime::partial::{PartialResult, PartialCheckpoint};
10//!
11//! // Create checkpoint when limit is reached
12//! let checkpoint = PartialCheckpoint::new(
13//!     "task-1",
14//!     0.65,  // 65% complete
15//!     "Generated 3 of 5 sections",
16//! );
17//!
18//! // Save progress
19//! checkpoint.save_to_file("./checkpoints/task-1.json").await?;
20//!
21//! // Load and resume later
22//! let restored = PartialCheckpoint::load_from_file("./checkpoints/task-1.json").await?;
23//! ```
24
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::path::Path;
28use std::sync::Arc;
29use std::time::{SystemTime, UNIX_EPOCH};
30
31use crate::error::NikaError;
32
33// ═══════════════════════════════════════════════════════════════════════════
34// PARTIAL RESULT
35// ═══════════════════════════════════════════════════════════════════════════
36
37/// Result of a partially completed task
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct PartialResult {
40    /// Current output content (may be incomplete)
41    pub content: String,
42
43    /// Progress percentage (0.0-1.0)
44    pub progress: f64,
45
46    /// Preview of the result for display
47    pub preview: String,
48
49    /// Reason why execution stopped
50    pub stop_reason: StopReason,
51
52    /// Number of turns completed
53    pub turns_completed: u32,
54
55    /// Total tokens used
56    pub tokens_used: u64,
57
58    /// Cost incurred (USD)
59    pub cost_usd: f64,
60}
61
62impl PartialResult {
63    /// Create a new partial result
64    pub fn new(content: impl Into<String>, progress: f64, stop_reason: StopReason) -> Self {
65        let content = content.into();
66        let preview = Self::generate_preview(&content);
67
68        Self {
69            content,
70            progress: progress.clamp(0.0, 1.0),
71            preview,
72            stop_reason,
73            turns_completed: 0,
74            tokens_used: 0,
75            cost_usd: 0.0,
76        }
77    }
78
79    /// Set usage statistics
80    pub fn with_usage(mut self, turns: u32, tokens: u64, cost: f64) -> Self {
81        self.turns_completed = turns;
82        self.tokens_used = tokens;
83        self.cost_usd = cost;
84        self
85    }
86
87    /// Generate a preview from content (first 100 chars or first line)
88    fn generate_preview(content: &str) -> String {
89        let first_line = content.lines().next().unwrap_or("");
90        if first_line.len() > 100 {
91            format!("{}...", crate::util::truncate_str(first_line, 97))
92        } else if first_line.len() < content.len() {
93            format!("{}...", first_line)
94        } else {
95            first_line.to_string()
96        }
97    }
98
99    /// Check if this is a meaningful partial result
100    pub fn is_meaningful(&self) -> bool {
101        !self.content.is_empty() && self.progress > 0.0
102    }
103
104    /// Calculate progress percentage as integer (0-100)
105    pub fn progress_percent(&self) -> u32 {
106        (self.progress * 100.0).round() as u32
107    }
108}
109
110// ═══════════════════════════════════════════════════════════════════════════
111// STOP REASON
112// ═══════════════════════════════════════════════════════════════════════════
113
114/// Reason for partial completion
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
116pub enum StopReason {
117    /// Hit maximum turns limit
118    TurnsLimit,
119    /// Hit token budget limit
120    TokensLimit,
121    /// Hit cost budget limit
122    CostLimit,
123    /// Hit duration limit
124    DurationLimit,
125    /// User requested stop
126    UserRequested,
127    /// Error occurred
128    Error,
129}
130
131impl StopReason {
132    /// Get human-readable description
133    pub fn description(&self) -> &'static str {
134        match self {
135            Self::TurnsLimit => "Maximum turns reached",
136            Self::TokensLimit => "Token budget exhausted",
137            Self::CostLimit => "Cost budget exhausted",
138            Self::DurationLimit => "Duration limit reached",
139            Self::UserRequested => "User requested stop",
140            Self::Error => "Error occurred",
141        }
142    }
143
144    /// Get short label for display
145    pub fn label(&self) -> &'static str {
146        match self {
147            Self::TurnsLimit => "turns",
148            Self::TokensLimit => "tokens",
149            Self::CostLimit => "cost",
150            Self::DurationLimit => "duration",
151            Self::UserRequested => "user",
152            Self::Error => "error",
153        }
154    }
155
156    /// Check if this is a limit-based stop
157    pub fn is_limit(&self) -> bool {
158        matches!(
159            self,
160            Self::TurnsLimit | Self::TokensLimit | Self::CostLimit | Self::DurationLimit
161        )
162    }
163}
164
165// ═══════════════════════════════════════════════════════════════════════════
166// PARTIAL CHECKPOINT
167// ═══════════════════════════════════════════════════════════════════════════
168
169/// Checkpoint for resuming partial execution
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct PartialCheckpoint {
172    /// Version for compatibility checking
173    pub version: u32,
174
175    /// Task identifier
176    pub task_id: Arc<str>,
177
178    /// Timestamp when checkpoint was created (Unix epoch seconds)
179    pub created_at: u64,
180
181    /// Current progress (0.0-1.0)
182    pub progress: f64,
183
184    /// The partial result
185    pub result: PartialResult,
186
187    /// Conversation history for resumption
188    pub history: Vec<CheckpointMessage>,
189
190    /// Context data for resumption
191    pub context: HashMap<String, serde_json::Value>,
192
193    /// Provider that was being used
194    pub provider: Option<String>,
195
196    /// Model that was being used
197    pub model: Option<String>,
198}
199
200/// Message in checkpoint history
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct CheckpointMessage {
203    /// Role: "user", "assistant", or "system"
204    pub role: String,
205    /// Message content
206    pub content: String,
207}
208
209impl PartialCheckpoint {
210    /// Current checkpoint format version
211    pub const CURRENT_VERSION: u32 = 1;
212
213    /// Create a new checkpoint
214    pub fn new(
215        task_id: impl Into<Arc<str>>,
216        progress: f64,
217        content: impl Into<String>,
218        stop_reason: StopReason,
219    ) -> Self {
220        let now = SystemTime::now()
221            .duration_since(UNIX_EPOCH)
222            .map(|d| d.as_secs())
223            .unwrap_or(0);
224
225        Self {
226            version: Self::CURRENT_VERSION,
227            task_id: task_id.into(),
228            created_at: now,
229            progress: progress.clamp(0.0, 1.0),
230            result: PartialResult::new(content, progress, stop_reason),
231            history: Vec::new(),
232            context: HashMap::new(),
233            provider: None,
234            model: None,
235        }
236    }
237
238    /// Add a message to history
239    pub fn add_message(&mut self, role: impl Into<String>, content: impl Into<String>) {
240        self.history.push(CheckpointMessage {
241            role: role.into(),
242            content: content.into(),
243        });
244    }
245
246    /// Set context data
247    pub fn set_context(&mut self, key: impl Into<String>, value: serde_json::Value) {
248        self.context.insert(key.into(), value);
249    }
250
251    /// Set provider info
252    pub fn with_provider(mut self, provider: impl Into<String>, model: impl Into<String>) -> Self {
253        self.provider = Some(provider.into());
254        self.model = Some(model.into());
255        self
256    }
257
258    /// Set usage statistics
259    pub fn with_usage(mut self, turns: u32, tokens: u64, cost: f64) -> Self {
260        self.result = self.result.with_usage(turns, tokens, cost);
261        self
262    }
263
264    /// Save checkpoint to file
265    pub async fn save_to_file(&self, path: impl AsRef<Path>) -> Result<(), NikaError> {
266        let path = path.as_ref();
267
268        // Ensure parent directory exists
269        if let Some(parent) = path.parent() {
270            tokio::fs::create_dir_all(parent)
271                .await
272                .map_err(NikaError::IoError)?;
273        }
274
275        // Serialize to JSON
276        let json =
277            serde_json::to_string_pretty(self).map_err(|e| NikaError::SerializationError {
278                details: format!("Failed to serialize checkpoint: {}", e),
279            })?;
280
281        // Write atomically (write to temp, then rename)
282        let temp_path = path.with_extension("tmp");
283        tokio::fs::write(&temp_path, &json)
284            .await
285            .map_err(NikaError::IoError)?;
286        tokio::fs::rename(&temp_path, path)
287            .await
288            .map_err(NikaError::IoError)?;
289
290        Ok(())
291    }
292
293    /// Load checkpoint from file
294    pub async fn load_from_file(path: impl AsRef<Path>) -> Result<Self, NikaError> {
295        let path = path.as_ref();
296
297        let json = tokio::fs::read_to_string(path)
298            .await
299            .map_err(NikaError::IoError)?;
300
301        let checkpoint: Self =
302            serde_json::from_str(&json).map_err(|e| NikaError::SerializationError {
303                details: format!("Failed to parse checkpoint: {}", e),
304            })?;
305
306        // Check version compatibility
307        if checkpoint.version > Self::CURRENT_VERSION {
308            return Err(NikaError::ValidationError {
309                reason: format!(
310                    "Checkpoint version {} is newer than supported version {}",
311                    checkpoint.version,
312                    Self::CURRENT_VERSION
313                ),
314            });
315        }
316
317        Ok(checkpoint)
318    }
319
320    /// Check if checkpoint exists at path
321    pub async fn exists(path: impl AsRef<Path>) -> bool {
322        tokio::fs::metadata(path.as_ref()).await.is_ok()
323    }
324
325    /// Generate checkpoint filename for a task
326    pub fn filename(task_id: &str) -> String {
327        // Sanitize task_id for use as filename
328        let safe_id: String = task_id
329            .chars()
330            .map(|c| {
331                if c.is_alphanumeric() || c == '-' || c == '_' {
332                    c
333                } else {
334                    '_'
335                }
336            })
337            .collect();
338        format!("{}.checkpoint.json", safe_id)
339    }
340}
341
342// ═══════════════════════════════════════════════════════════════════════════
343// TESTS
344// ═══════════════════════════════════════════════════════════════════════════
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    // ───────────────────────────────────────────────────────────────────────
351    // PartialResult tests
352    // ───────────────────────────────────────────────────────────────────────
353
354    #[test]
355    fn partial_result_new() {
356        let result = PartialResult::new("Hello world", 0.5, StopReason::TurnsLimit);
357        assert_eq!(result.content, "Hello world");
358        assert!((result.progress - 0.5).abs() < 0.0001);
359        assert_eq!(result.stop_reason, StopReason::TurnsLimit);
360    }
361
362    #[test]
363    fn partial_result_clamps_progress() {
364        let result = PartialResult::new("test", 1.5, StopReason::CostLimit);
365        assert!((result.progress - 1.0).abs() < 0.0001);
366
367        let result = PartialResult::new("test", -0.5, StopReason::CostLimit);
368        assert!((result.progress - 0.0).abs() < 0.0001);
369    }
370
371    #[test]
372    fn partial_result_with_usage() {
373        let result =
374            PartialResult::new("test", 0.75, StopReason::TokensLimit).with_usage(10, 5000, 0.05);
375
376        assert_eq!(result.turns_completed, 10);
377        assert_eq!(result.tokens_used, 5000);
378        assert!((result.cost_usd - 0.05).abs() < 0.0001);
379    }
380
381    #[test]
382    fn partial_result_preview_short() {
383        let result = PartialResult::new("Short content", 0.5, StopReason::TurnsLimit);
384        assert_eq!(result.preview, "Short content");
385    }
386
387    #[test]
388    fn partial_result_preview_multiline() {
389        let result = PartialResult::new(
390            "First line\nSecond line\nThird",
391            0.5,
392            StopReason::TurnsLimit,
393        );
394        assert_eq!(result.preview, "First line...");
395    }
396
397    #[test]
398    fn partial_result_preview_long_line() {
399        let long_content = "x".repeat(150);
400        let result = PartialResult::new(&long_content, 0.5, StopReason::TurnsLimit);
401        assert!(result.preview.ends_with("..."));
402        assert!(result.preview.len() <= 103); // 97 + "..."
403    }
404
405    #[test]
406    fn partial_result_is_meaningful() {
407        let meaningful = PartialResult::new("content", 0.5, StopReason::TurnsLimit);
408        assert!(meaningful.is_meaningful());
409
410        let empty = PartialResult::new("", 0.5, StopReason::TurnsLimit);
411        assert!(!empty.is_meaningful());
412
413        let zero_progress = PartialResult::new("content", 0.0, StopReason::TurnsLimit);
414        assert!(!zero_progress.is_meaningful());
415    }
416
417    #[test]
418    fn partial_result_progress_percent() {
419        let result = PartialResult::new("test", 0.654, StopReason::TurnsLimit);
420        assert_eq!(result.progress_percent(), 65);
421    }
422
423    // ───────────────────────────────────────────────────────────────────────
424    // StopReason tests
425    // ───────────────────────────────────────────────────────────────────────
426
427    #[test]
428    fn stop_reason_descriptions() {
429        assert!(!StopReason::TurnsLimit.description().is_empty());
430        assert!(!StopReason::TokensLimit.description().is_empty());
431        assert!(!StopReason::CostLimit.description().is_empty());
432        assert!(!StopReason::DurationLimit.description().is_empty());
433        assert!(!StopReason::UserRequested.description().is_empty());
434        assert!(!StopReason::Error.description().is_empty());
435    }
436
437    #[test]
438    fn stop_reason_labels() {
439        assert_eq!(StopReason::TurnsLimit.label(), "turns");
440        assert_eq!(StopReason::TokensLimit.label(), "tokens");
441        assert_eq!(StopReason::CostLimit.label(), "cost");
442        assert_eq!(StopReason::DurationLimit.label(), "duration");
443    }
444
445    #[test]
446    fn stop_reason_is_limit() {
447        assert!(StopReason::TurnsLimit.is_limit());
448        assert!(StopReason::TokensLimit.is_limit());
449        assert!(StopReason::CostLimit.is_limit());
450        assert!(StopReason::DurationLimit.is_limit());
451        assert!(!StopReason::UserRequested.is_limit());
452        assert!(!StopReason::Error.is_limit());
453    }
454
455    // ───────────────────────────────────────────────────────────────────────
456    // PartialCheckpoint tests
457    // ───────────────────────────────────────────────────────────────────────
458
459    #[test]
460    fn checkpoint_new() {
461        let checkpoint = PartialCheckpoint::new(
462            "task-1",
463            0.75,
464            "Partial content here",
465            StopReason::TurnsLimit,
466        );
467
468        assert_eq!(checkpoint.version, PartialCheckpoint::CURRENT_VERSION);
469        assert_eq!(&*checkpoint.task_id, "task-1");
470        assert!((checkpoint.progress - 0.75).abs() < 0.0001);
471        assert!(checkpoint.created_at > 0);
472    }
473
474    #[test]
475    fn checkpoint_add_message() {
476        let mut checkpoint = PartialCheckpoint::new("task-1", 0.5, "test", StopReason::TurnsLimit);
477        checkpoint.add_message("user", "Hello");
478        checkpoint.add_message("assistant", "Hi there");
479
480        assert_eq!(checkpoint.history.len(), 2);
481        assert_eq!(checkpoint.history[0].role, "user");
482        assert_eq!(checkpoint.history[1].content, "Hi there");
483    }
484
485    #[test]
486    fn checkpoint_set_context() {
487        let mut checkpoint = PartialCheckpoint::new("task-1", 0.5, "test", StopReason::TurnsLimit);
488        checkpoint.set_context("key1", serde_json::json!({"nested": "value"}));
489
490        assert!(checkpoint.context.contains_key("key1"));
491    }
492
493    #[test]
494    fn checkpoint_with_provider() {
495        let checkpoint = PartialCheckpoint::new("task-1", 0.5, "test", StopReason::TurnsLimit)
496            .with_provider("claude", "claude-sonnet-4-6");
497
498        assert_eq!(checkpoint.provider, Some("claude".to_string()));
499        assert_eq!(checkpoint.model, Some("claude-sonnet-4-6".to_string()));
500    }
501
502    #[test]
503    fn checkpoint_with_usage() {
504        let checkpoint = PartialCheckpoint::new("task-1", 0.5, "test", StopReason::TurnsLimit)
505            .with_usage(5, 10000, 0.15);
506
507        assert_eq!(checkpoint.result.turns_completed, 5);
508        assert_eq!(checkpoint.result.tokens_used, 10000);
509        assert!((checkpoint.result.cost_usd - 0.15).abs() < 0.0001);
510    }
511
512    #[test]
513    fn checkpoint_filename() {
514        assert_eq!(
515            PartialCheckpoint::filename("task-1"),
516            "task-1.checkpoint.json"
517        );
518        assert_eq!(
519            PartialCheckpoint::filename("task/with/slashes"),
520            "task_with_slashes.checkpoint.json"
521        );
522        assert_eq!(
523            PartialCheckpoint::filename("task with spaces"),
524            "task_with_spaces.checkpoint.json"
525        );
526    }
527
528    #[test]
529    fn checkpoint_serialization_roundtrip() {
530        let mut checkpoint =
531            PartialCheckpoint::new("task-1", 0.75, "Content", StopReason::CostLimit)
532                .with_provider("openai", "gpt-4o")
533                .with_usage(3, 5000, 0.05);
534
535        checkpoint.add_message("user", "Question");
536        checkpoint.add_message("assistant", "Answer");
537        checkpoint.set_context("data", serde_json::json!({"key": "value"}));
538
539        let json = serde_json::to_string(&checkpoint).unwrap();
540        let restored: PartialCheckpoint = serde_json::from_str(&json).unwrap();
541
542        assert_eq!(restored.task_id, checkpoint.task_id);
543        assert_eq!(restored.progress, checkpoint.progress);
544        assert_eq!(restored.history.len(), 2);
545        assert!(restored.context.contains_key("data"));
546    }
547
548    // ───────────────────────────────────────────────────────────────────────
549    // Async file tests (require tokio runtime)
550    // ───────────────────────────────────────────────────────────────────────
551
552    #[tokio::test]
553    async fn checkpoint_save_and_load() {
554        let temp_dir = tempfile::tempdir().unwrap();
555        let file_path = temp_dir.path().join("test.checkpoint.json");
556
557        let checkpoint =
558            PartialCheckpoint::new("test-task", 0.8, "Test content", StopReason::TokensLimit)
559                .with_usage(5, 8000, 0.12);
560
561        // Save
562        checkpoint.save_to_file(&file_path).await.unwrap();
563        assert!(file_path.exists());
564
565        // Load
566        let loaded = PartialCheckpoint::load_from_file(&file_path).await.unwrap();
567        assert_eq!(&*loaded.task_id, "test-task");
568        assert!((loaded.progress - 0.8).abs() < 0.0001);
569        assert_eq!(loaded.result.tokens_used, 8000);
570    }
571
572    #[tokio::test]
573    async fn checkpoint_exists() {
574        let temp_dir = tempfile::tempdir().unwrap();
575        let file_path = temp_dir.path().join("exists.json");
576
577        assert!(!PartialCheckpoint::exists(&file_path).await);
578
579        tokio::fs::write(&file_path, "{}").await.unwrap();
580        assert!(PartialCheckpoint::exists(&file_path).await);
581    }
582
583    #[tokio::test]
584    async fn checkpoint_load_nonexistent() {
585        let result = PartialCheckpoint::load_from_file("/nonexistent/path.json").await;
586        assert!(result.is_err());
587    }
588}