1use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::path::Path;
28use std::sync::Arc;
29use std::time::{SystemTime, UNIX_EPOCH};
30
31use crate::error::NikaError;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct PartialResult {
40 pub content: String,
42
43 pub progress: f64,
45
46 pub preview: String,
48
49 pub stop_reason: StopReason,
51
52 pub turns_completed: u32,
54
55 pub tokens_used: u64,
57
58 pub cost_usd: f64,
60}
61
62impl PartialResult {
63 pub fn new(content: impl Into<String>, progress: f64, stop_reason: StopReason) -> Self {
65 let content = content.into();
66 let preview = Self::generate_preview(&content);
67
68 Self {
69 content,
70 progress: progress.clamp(0.0, 1.0),
71 preview,
72 stop_reason,
73 turns_completed: 0,
74 tokens_used: 0,
75 cost_usd: 0.0,
76 }
77 }
78
79 pub fn with_usage(mut self, turns: u32, tokens: u64, cost: f64) -> Self {
81 self.turns_completed = turns;
82 self.tokens_used = tokens;
83 self.cost_usd = cost;
84 self
85 }
86
87 fn generate_preview(content: &str) -> String {
89 let first_line = content.lines().next().unwrap_or("");
90 if first_line.len() > 100 {
91 format!("{}...", crate::util::truncate_str(first_line, 97))
92 } else if first_line.len() < content.len() {
93 format!("{}...", first_line)
94 } else {
95 first_line.to_string()
96 }
97 }
98
99 pub fn is_meaningful(&self) -> bool {
101 !self.content.is_empty() && self.progress > 0.0
102 }
103
104 pub fn progress_percent(&self) -> u32 {
106 (self.progress * 100.0).round() as u32
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
116pub enum StopReason {
117 TurnsLimit,
119 TokensLimit,
121 CostLimit,
123 DurationLimit,
125 UserRequested,
127 Error,
129}
130
131impl StopReason {
132 pub fn description(&self) -> &'static str {
134 match self {
135 Self::TurnsLimit => "Maximum turns reached",
136 Self::TokensLimit => "Token budget exhausted",
137 Self::CostLimit => "Cost budget exhausted",
138 Self::DurationLimit => "Duration limit reached",
139 Self::UserRequested => "User requested stop",
140 Self::Error => "Error occurred",
141 }
142 }
143
144 pub fn label(&self) -> &'static str {
146 match self {
147 Self::TurnsLimit => "turns",
148 Self::TokensLimit => "tokens",
149 Self::CostLimit => "cost",
150 Self::DurationLimit => "duration",
151 Self::UserRequested => "user",
152 Self::Error => "error",
153 }
154 }
155
156 pub fn is_limit(&self) -> bool {
158 matches!(
159 self,
160 Self::TurnsLimit | Self::TokensLimit | Self::CostLimit | Self::DurationLimit
161 )
162 }
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct PartialCheckpoint {
172 pub version: u32,
174
175 pub task_id: Arc<str>,
177
178 pub created_at: u64,
180
181 pub progress: f64,
183
184 pub result: PartialResult,
186
187 pub history: Vec<CheckpointMessage>,
189
190 pub context: HashMap<String, serde_json::Value>,
192
193 pub provider: Option<String>,
195
196 pub model: Option<String>,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct CheckpointMessage {
203 pub role: String,
205 pub content: String,
207}
208
209impl PartialCheckpoint {
210 pub const CURRENT_VERSION: u32 = 1;
212
213 pub fn new(
215 task_id: impl Into<Arc<str>>,
216 progress: f64,
217 content: impl Into<String>,
218 stop_reason: StopReason,
219 ) -> Self {
220 let now = SystemTime::now()
221 .duration_since(UNIX_EPOCH)
222 .map(|d| d.as_secs())
223 .unwrap_or(0);
224
225 Self {
226 version: Self::CURRENT_VERSION,
227 task_id: task_id.into(),
228 created_at: now,
229 progress: progress.clamp(0.0, 1.0),
230 result: PartialResult::new(content, progress, stop_reason),
231 history: Vec::new(),
232 context: HashMap::new(),
233 provider: None,
234 model: None,
235 }
236 }
237
238 pub fn add_message(&mut self, role: impl Into<String>, content: impl Into<String>) {
240 self.history.push(CheckpointMessage {
241 role: role.into(),
242 content: content.into(),
243 });
244 }
245
246 pub fn set_context(&mut self, key: impl Into<String>, value: serde_json::Value) {
248 self.context.insert(key.into(), value);
249 }
250
251 pub fn with_provider(mut self, provider: impl Into<String>, model: impl Into<String>) -> Self {
253 self.provider = Some(provider.into());
254 self.model = Some(model.into());
255 self
256 }
257
258 pub fn with_usage(mut self, turns: u32, tokens: u64, cost: f64) -> Self {
260 self.result = self.result.with_usage(turns, tokens, cost);
261 self
262 }
263
264 pub async fn save_to_file(&self, path: impl AsRef<Path>) -> Result<(), NikaError> {
266 let path = path.as_ref();
267
268 if let Some(parent) = path.parent() {
270 tokio::fs::create_dir_all(parent)
271 .await
272 .map_err(NikaError::IoError)?;
273 }
274
275 let json =
277 serde_json::to_string_pretty(self).map_err(|e| NikaError::SerializationError {
278 details: format!("Failed to serialize checkpoint: {}", e),
279 })?;
280
281 let temp_path = path.with_extension("tmp");
283 tokio::fs::write(&temp_path, &json)
284 .await
285 .map_err(NikaError::IoError)?;
286 tokio::fs::rename(&temp_path, path)
287 .await
288 .map_err(NikaError::IoError)?;
289
290 Ok(())
291 }
292
293 pub async fn load_from_file(path: impl AsRef<Path>) -> Result<Self, NikaError> {
295 let path = path.as_ref();
296
297 let json = tokio::fs::read_to_string(path)
298 .await
299 .map_err(NikaError::IoError)?;
300
301 let checkpoint: Self =
302 serde_json::from_str(&json).map_err(|e| NikaError::SerializationError {
303 details: format!("Failed to parse checkpoint: {}", e),
304 })?;
305
306 if checkpoint.version > Self::CURRENT_VERSION {
308 return Err(NikaError::ValidationError {
309 reason: format!(
310 "Checkpoint version {} is newer than supported version {}",
311 checkpoint.version,
312 Self::CURRENT_VERSION
313 ),
314 });
315 }
316
317 Ok(checkpoint)
318 }
319
320 pub async fn exists(path: impl AsRef<Path>) -> bool {
322 tokio::fs::metadata(path.as_ref()).await.is_ok()
323 }
324
325 pub fn filename(task_id: &str) -> String {
327 let safe_id: String = task_id
329 .chars()
330 .map(|c| {
331 if c.is_alphanumeric() || c == '-' || c == '_' {
332 c
333 } else {
334 '_'
335 }
336 })
337 .collect();
338 format!("{}.checkpoint.json", safe_id)
339 }
340}
341
342#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
355 fn partial_result_new() {
356 let result = PartialResult::new("Hello world", 0.5, StopReason::TurnsLimit);
357 assert_eq!(result.content, "Hello world");
358 assert!((result.progress - 0.5).abs() < 0.0001);
359 assert_eq!(result.stop_reason, StopReason::TurnsLimit);
360 }
361
362 #[test]
363 fn partial_result_clamps_progress() {
364 let result = PartialResult::new("test", 1.5, StopReason::CostLimit);
365 assert!((result.progress - 1.0).abs() < 0.0001);
366
367 let result = PartialResult::new("test", -0.5, StopReason::CostLimit);
368 assert!((result.progress - 0.0).abs() < 0.0001);
369 }
370
371 #[test]
372 fn partial_result_with_usage() {
373 let result =
374 PartialResult::new("test", 0.75, StopReason::TokensLimit).with_usage(10, 5000, 0.05);
375
376 assert_eq!(result.turns_completed, 10);
377 assert_eq!(result.tokens_used, 5000);
378 assert!((result.cost_usd - 0.05).abs() < 0.0001);
379 }
380
381 #[test]
382 fn partial_result_preview_short() {
383 let result = PartialResult::new("Short content", 0.5, StopReason::TurnsLimit);
384 assert_eq!(result.preview, "Short content");
385 }
386
387 #[test]
388 fn partial_result_preview_multiline() {
389 let result = PartialResult::new(
390 "First line\nSecond line\nThird",
391 0.5,
392 StopReason::TurnsLimit,
393 );
394 assert_eq!(result.preview, "First line...");
395 }
396
397 #[test]
398 fn partial_result_preview_long_line() {
399 let long_content = "x".repeat(150);
400 let result = PartialResult::new(&long_content, 0.5, StopReason::TurnsLimit);
401 assert!(result.preview.ends_with("..."));
402 assert!(result.preview.len() <= 103); }
404
405 #[test]
406 fn partial_result_is_meaningful() {
407 let meaningful = PartialResult::new("content", 0.5, StopReason::TurnsLimit);
408 assert!(meaningful.is_meaningful());
409
410 let empty = PartialResult::new("", 0.5, StopReason::TurnsLimit);
411 assert!(!empty.is_meaningful());
412
413 let zero_progress = PartialResult::new("content", 0.0, StopReason::TurnsLimit);
414 assert!(!zero_progress.is_meaningful());
415 }
416
417 #[test]
418 fn partial_result_progress_percent() {
419 let result = PartialResult::new("test", 0.654, StopReason::TurnsLimit);
420 assert_eq!(result.progress_percent(), 65);
421 }
422
423 #[test]
428 fn stop_reason_descriptions() {
429 assert!(!StopReason::TurnsLimit.description().is_empty());
430 assert!(!StopReason::TokensLimit.description().is_empty());
431 assert!(!StopReason::CostLimit.description().is_empty());
432 assert!(!StopReason::DurationLimit.description().is_empty());
433 assert!(!StopReason::UserRequested.description().is_empty());
434 assert!(!StopReason::Error.description().is_empty());
435 }
436
437 #[test]
438 fn stop_reason_labels() {
439 assert_eq!(StopReason::TurnsLimit.label(), "turns");
440 assert_eq!(StopReason::TokensLimit.label(), "tokens");
441 assert_eq!(StopReason::CostLimit.label(), "cost");
442 assert_eq!(StopReason::DurationLimit.label(), "duration");
443 }
444
445 #[test]
446 fn stop_reason_is_limit() {
447 assert!(StopReason::TurnsLimit.is_limit());
448 assert!(StopReason::TokensLimit.is_limit());
449 assert!(StopReason::CostLimit.is_limit());
450 assert!(StopReason::DurationLimit.is_limit());
451 assert!(!StopReason::UserRequested.is_limit());
452 assert!(!StopReason::Error.is_limit());
453 }
454
455 #[test]
460 fn checkpoint_new() {
461 let checkpoint = PartialCheckpoint::new(
462 "task-1",
463 0.75,
464 "Partial content here",
465 StopReason::TurnsLimit,
466 );
467
468 assert_eq!(checkpoint.version, PartialCheckpoint::CURRENT_VERSION);
469 assert_eq!(&*checkpoint.task_id, "task-1");
470 assert!((checkpoint.progress - 0.75).abs() < 0.0001);
471 assert!(checkpoint.created_at > 0);
472 }
473
474 #[test]
475 fn checkpoint_add_message() {
476 let mut checkpoint = PartialCheckpoint::new("task-1", 0.5, "test", StopReason::TurnsLimit);
477 checkpoint.add_message("user", "Hello");
478 checkpoint.add_message("assistant", "Hi there");
479
480 assert_eq!(checkpoint.history.len(), 2);
481 assert_eq!(checkpoint.history[0].role, "user");
482 assert_eq!(checkpoint.history[1].content, "Hi there");
483 }
484
485 #[test]
486 fn checkpoint_set_context() {
487 let mut checkpoint = PartialCheckpoint::new("task-1", 0.5, "test", StopReason::TurnsLimit);
488 checkpoint.set_context("key1", serde_json::json!({"nested": "value"}));
489
490 assert!(checkpoint.context.contains_key("key1"));
491 }
492
493 #[test]
494 fn checkpoint_with_provider() {
495 let checkpoint = PartialCheckpoint::new("task-1", 0.5, "test", StopReason::TurnsLimit)
496 .with_provider("claude", "claude-sonnet-4-6");
497
498 assert_eq!(checkpoint.provider, Some("claude".to_string()));
499 assert_eq!(checkpoint.model, Some("claude-sonnet-4-6".to_string()));
500 }
501
502 #[test]
503 fn checkpoint_with_usage() {
504 let checkpoint = PartialCheckpoint::new("task-1", 0.5, "test", StopReason::TurnsLimit)
505 .with_usage(5, 10000, 0.15);
506
507 assert_eq!(checkpoint.result.turns_completed, 5);
508 assert_eq!(checkpoint.result.tokens_used, 10000);
509 assert!((checkpoint.result.cost_usd - 0.15).abs() < 0.0001);
510 }
511
512 #[test]
513 fn checkpoint_filename() {
514 assert_eq!(
515 PartialCheckpoint::filename("task-1"),
516 "task-1.checkpoint.json"
517 );
518 assert_eq!(
519 PartialCheckpoint::filename("task/with/slashes"),
520 "task_with_slashes.checkpoint.json"
521 );
522 assert_eq!(
523 PartialCheckpoint::filename("task with spaces"),
524 "task_with_spaces.checkpoint.json"
525 );
526 }
527
528 #[test]
529 fn checkpoint_serialization_roundtrip() {
530 let mut checkpoint =
531 PartialCheckpoint::new("task-1", 0.75, "Content", StopReason::CostLimit)
532 .with_provider("openai", "gpt-4o")
533 .with_usage(3, 5000, 0.05);
534
535 checkpoint.add_message("user", "Question");
536 checkpoint.add_message("assistant", "Answer");
537 checkpoint.set_context("data", serde_json::json!({"key": "value"}));
538
539 let json = serde_json::to_string(&checkpoint).unwrap();
540 let restored: PartialCheckpoint = serde_json::from_str(&json).unwrap();
541
542 assert_eq!(restored.task_id, checkpoint.task_id);
543 assert_eq!(restored.progress, checkpoint.progress);
544 assert_eq!(restored.history.len(), 2);
545 assert!(restored.context.contains_key("data"));
546 }
547
548 #[tokio::test]
553 async fn checkpoint_save_and_load() {
554 let temp_dir = tempfile::tempdir().unwrap();
555 let file_path = temp_dir.path().join("test.checkpoint.json");
556
557 let checkpoint =
558 PartialCheckpoint::new("test-task", 0.8, "Test content", StopReason::TokensLimit)
559 .with_usage(5, 8000, 0.12);
560
561 checkpoint.save_to_file(&file_path).await.unwrap();
563 assert!(file_path.exists());
564
565 let loaded = PartialCheckpoint::load_from_file(&file_path).await.unwrap();
567 assert_eq!(&*loaded.task_id, "test-task");
568 assert!((loaded.progress - 0.8).abs() < 0.0001);
569 assert_eq!(loaded.result.tokens_used, 8000);
570 }
571
572 #[tokio::test]
573 async fn checkpoint_exists() {
574 let temp_dir = tempfile::tempdir().unwrap();
575 let file_path = temp_dir.path().join("exists.json");
576
577 assert!(!PartialCheckpoint::exists(&file_path).await);
578
579 tokio::fs::write(&file_path, "{}").await.unwrap();
580 assert!(PartialCheckpoint::exists(&file_path).await);
581 }
582
583 #[tokio::test]
584 async fn checkpoint_load_nonexistent() {
585 let result = PartialCheckpoint::load_from_file("/nonexistent/path.json").await;
586 assert!(result.is_err());
587 }
588}