Skip to main content

harn_vm/
checkpoint.rs

1//! Checkpoint system for resilient pipeline execution.
2//!
3//! Provides `checkpoint`, `checkpoint_get`, and `checkpoint_clear` builtins.
4//! Checkpoints are persisted to `<base_dir>/.harn/checkpoints/<pipeline>.json`
5//! and survive pipeline crashes/timeouts. On resume, a pipeline can skip
6//! already-processed items by checking `checkpoint_get`.
7
8use std::cell::RefCell;
9use std::collections::BTreeMap;
10use std::path::{Path, PathBuf};
11use std::rc::Rc;
12
13use crate::value::{VmError, VmValue};
14use crate::vm::Vm;
15
16struct CheckpointState {
17    data: BTreeMap<String, serde_json::Value>,
18    path: PathBuf,
19    loaded: bool,
20}
21
22impl CheckpointState {
23    fn new(base_dir: &Path, pipeline_name: &str) -> Self {
24        Self {
25            data: BTreeMap::new(),
26            path: base_dir
27                .join(".harn")
28                .join("checkpoints")
29                .join(format!("{pipeline_name}.json")),
30            loaded: false,
31        }
32    }
33
34    fn ensure_loaded(&mut self) {
35        if self.loaded {
36            return;
37        }
38        self.loaded = true;
39        if let Ok(contents) = std::fs::read_to_string(&self.path) {
40            if let Ok(serde_json::Value::Object(map)) =
41                serde_json::from_str::<serde_json::Value>(&contents)
42            {
43                for (k, v) in map {
44                    self.data.insert(k, v);
45                }
46            }
47        }
48    }
49
50    fn save(&self) -> Result<(), String> {
51        let obj: serde_json::Map<String, serde_json::Value> = self
52            .data
53            .iter()
54            .map(|(k, v)| (k.clone(), v.clone()))
55            .collect();
56        let json = serde_json::to_string_pretty(&serde_json::Value::Object(obj))
57            .map_err(|e| format!("checkpoint save error: {e}"))?;
58        if let Some(parent) = self.path.parent() {
59            std::fs::create_dir_all(parent).map_err(|e| format!("checkpoint mkdir error: {e}"))?;
60        }
61        std::fs::write(&self.path, json).map_err(|e| format!("checkpoint write error: {e}"))?;
62        Ok(())
63    }
64
65    fn get(&mut self, key: &str) -> VmValue {
66        self.ensure_loaded();
67        match self.data.get(key) {
68            Some(v) => json_to_vm(v),
69            None => VmValue::Nil,
70        }
71    }
72
73    fn set(&mut self, key: String, value: serde_json::Value) -> Result<(), String> {
74        self.ensure_loaded();
75        self.data.insert(key, value);
76        self.save()
77    }
78
79    fn clear(&mut self) -> Result<(), String> {
80        self.data.clear();
81        if self.path.exists() {
82            std::fs::remove_file(&self.path).map_err(|e| format!("checkpoint clear error: {e}"))?;
83        }
84        Ok(())
85    }
86
87    fn list(&mut self) -> Vec<String> {
88        self.ensure_loaded();
89        self.data.keys().cloned().collect()
90    }
91}
92
93fn vm_to_json(val: &VmValue) -> serde_json::Value {
94    match val {
95        VmValue::String(s) => serde_json::Value::String(s.to_string()),
96        VmValue::Int(n) => serde_json::json!(*n),
97        VmValue::Float(n) => serde_json::json!(*n),
98        VmValue::Bool(b) => serde_json::Value::Bool(*b),
99        VmValue::Nil => serde_json::Value::Null,
100        VmValue::List(items) => serde_json::Value::Array(items.iter().map(vm_to_json).collect()),
101        VmValue::Dict(map) => {
102            let obj: serde_json::Map<String, serde_json::Value> = map
103                .iter()
104                .map(|(k, v)| (k.clone(), vm_to_json(v)))
105                .collect();
106            serde_json::Value::Object(obj)
107        }
108        _ => serde_json::Value::Null,
109    }
110}
111
112fn json_to_vm(jv: &serde_json::Value) -> VmValue {
113    match jv {
114        serde_json::Value::Null => VmValue::Nil,
115        serde_json::Value::Bool(b) => VmValue::Bool(*b),
116        serde_json::Value::Number(n) => {
117            if let Some(i) = n.as_i64() {
118                VmValue::Int(i)
119            } else {
120                VmValue::Float(n.as_f64().unwrap_or(0.0))
121            }
122        }
123        serde_json::Value::String(s) => VmValue::String(Rc::from(s.as_str())),
124        serde_json::Value::Array(arr) => {
125            VmValue::List(Rc::new(arr.iter().map(json_to_vm).collect()))
126        }
127        serde_json::Value::Object(map) => {
128            let mut m = BTreeMap::new();
129            for (k, v) in map {
130                m.insert(k.clone(), json_to_vm(v));
131            }
132            VmValue::Dict(Rc::new(m))
133        }
134    }
135}
136
137/// Sanitize a pipeline name for use as a filename.
138/// Rejects path traversal attempts and invalid characters.
139fn sanitize_pipeline_name(name: &str) -> String {
140    // Use only the filename component, stripping any directory parts
141    let base = std::path::Path::new(name)
142        .file_name()
143        .and_then(|f| f.to_str())
144        .unwrap_or("default");
145    // Reject empty or dot-only names
146    if base.is_empty() || base == "." || base == ".." {
147        return "default".to_string();
148    }
149    base.to_string()
150}
151
152/// Register checkpoint builtins on a VM.
153///
154/// The pipeline name is used to namespace checkpoint files. If not provided,
155/// defaults to "default".
156pub fn register_checkpoint_builtins(vm: &mut Vm, base_dir: &Path, pipeline_name: &str) {
157    let safe_name = sanitize_pipeline_name(pipeline_name);
158    let state = Rc::new(RefCell::new(CheckpointState::new(base_dir, &safe_name)));
159
160    // checkpoint(key, value) — persist a checkpoint immediately
161    let s = Rc::clone(&state);
162    vm.register_builtin("checkpoint", move |args, _out| {
163        let key = args.first().map(|a| a.display()).unwrap_or_default();
164        let value = args.get(1).unwrap_or(&VmValue::Nil);
165        let json_val = vm_to_json(value);
166        s.borrow_mut()
167            .set(key, json_val)
168            .map_err(VmError::Runtime)?;
169        Ok(VmValue::Nil)
170    });
171
172    // checkpoint_get(key) -> value | nil
173    let s = Rc::clone(&state);
174    vm.register_builtin("checkpoint_get", move |args, _out| {
175        let key = args.first().map(|a| a.display()).unwrap_or_default();
176        Ok(s.borrow_mut().get(&key))
177    });
178
179    // checkpoint_clear() — clear all checkpoints for this pipeline
180    let s = Rc::clone(&state);
181    vm.register_builtin("checkpoint_clear", move |_args, _out| {
182        s.borrow_mut().clear().map_err(VmError::Runtime)?;
183        Ok(VmValue::Nil)
184    });
185
186    // checkpoint_list() -> [key1, key2, ...]
187    let s = Rc::clone(&state);
188    vm.register_builtin("checkpoint_list", move |_args, _out| {
189        let keys = s.borrow_mut().list();
190        Ok(VmValue::List(Rc::new(
191            keys.into_iter()
192                .map(|k| VmValue::String(Rc::from(k)))
193                .collect(),
194        )))
195    });
196}