Skip to main content

reddb_server/storage/ml/
jobs.rs

1//! ML job definitions — what a job is, what state it can be in, how
2//! to serialize it for persistence.
3//!
4//! Jobs are the unit of async work. Training, backfill, and bulk
5//! inference-audit all flow through the same [`MlJob`] record so the
6//! operator can inspect `SELECT * FROM ML_JOBS` and see every
7//! long-running piece of ML work in one place.
8
9use std::time::{SystemTime, UNIX_EPOCH};
10
11/// Opaque job identifier. 128-bit so it's collision-free across
12/// restarts and replicas without coordination.
13pub type MlJobId = u128;
14
15/// Kind of work a job performs. Determines which worker handler is
16/// dispatched and how `progress` is interpreted.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum MlJobKind {
19    /// `CREATE MODEL ... WITH (async = true)` — train a classifier,
20    /// symbolic regression, etc.
21    Train,
22    /// `ALTER EMBEDDING COLUMN ... WITH BACKFILL = BACKGROUND` —
23    /// re-embed existing rows under a new model.
24    Backfill,
25    /// `CREATE FEATURE ...` — materialise the bitemporal feature log
26    /// from the source query.
27    FeatureRefresh,
28    /// Post-hoc drift computation over a window of recent writes.
29    DriftCompute,
30}
31
32impl MlJobKind {
33    pub fn token(self) -> &'static str {
34        match self {
35            MlJobKind::Train => "train",
36            MlJobKind::Backfill => "backfill",
37            MlJobKind::FeatureRefresh => "feature_refresh",
38            MlJobKind::DriftCompute => "drift_compute",
39        }
40    }
41
42    pub fn from_token(token: &str) -> Option<MlJobKind> {
43        match token {
44            "train" => Some(MlJobKind::Train),
45            "backfill" => Some(MlJobKind::Backfill),
46            "feature_refresh" => Some(MlJobKind::FeatureRefresh),
47            "drift_compute" => Some(MlJobKind::DriftCompute),
48            _ => None,
49        }
50    }
51}
52
53/// State machine for a job. Terminal states are `Completed`,
54/// `Failed`, and `Cancelled` — workers must not mutate a record in a
55/// terminal state.
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum MlJobStatus {
58    Queued,
59    Running,
60    Completed,
61    Failed,
62    Cancelled,
63}
64
65impl MlJobStatus {
66    pub fn token(self) -> &'static str {
67        match self {
68            MlJobStatus::Queued => "queued",
69            MlJobStatus::Running => "running",
70            MlJobStatus::Completed => "completed",
71            MlJobStatus::Failed => "failed",
72            MlJobStatus::Cancelled => "cancelled",
73        }
74    }
75
76    pub fn from_token(token: &str) -> Option<MlJobStatus> {
77        match token {
78            "queued" => Some(MlJobStatus::Queued),
79            "running" => Some(MlJobStatus::Running),
80            "completed" => Some(MlJobStatus::Completed),
81            "failed" => Some(MlJobStatus::Failed),
82            "cancelled" => Some(MlJobStatus::Cancelled),
83            _ => None,
84        }
85    }
86
87    pub fn is_terminal(self) -> bool {
88        matches!(
89            self,
90            MlJobStatus::Completed | MlJobStatus::Failed | MlJobStatus::Cancelled
91        )
92    }
93}
94
95/// Persistent state of a single ML job.
96///
97/// Everything the operator needs to inspect `SELECT * FROM ML_JOBS`
98/// lives here. `spec_json` carries kind-specific parameters (which
99/// algorithm, which features, which hyperparameters); workers parse
100/// it themselves so the registry stays schema-free.
101#[derive(Debug, Clone)]
102pub struct MlJob {
103    pub id: MlJobId,
104    pub kind: MlJobKind,
105    /// Name of the model / feature / embedding column the job mutates.
106    pub target_name: String,
107    pub status: MlJobStatus,
108    /// 0..=100.
109    pub progress: u8,
110    /// Epoch millis. `0` when not yet scheduled / finished.
111    pub created_at_ms: u64,
112    pub started_at_ms: u64,
113    pub finished_at_ms: u64,
114    /// Populated on `Failed`.
115    pub error_message: Option<String>,
116    /// Free-form payload describing the job — parsed by the worker.
117    pub spec_json: String,
118    /// Free-form metrics (accuracy, f1, etc.) — written by the worker
119    /// before it transitions to `Completed`.
120    pub metrics_json: Option<String>,
121}
122
123impl MlJob {
124    pub fn new(id: MlJobId, kind: MlJobKind, target_name: String, spec_json: String) -> Self {
125        Self {
126            id,
127            kind,
128            target_name,
129            status: MlJobStatus::Queued,
130            progress: 0,
131            created_at_ms: now_ms(),
132            started_at_ms: 0,
133            finished_at_ms: 0,
134            error_message: None,
135            spec_json,
136            metrics_json: None,
137        }
138    }
139
140    /// True once the job has reached a terminal status.
141    pub fn is_terminal(&self) -> bool {
142        self.status.is_terminal()
143    }
144
145    /// Duration between `started_at` and `finished_at`, if both are
146    /// set. `None` while the job is still running or never started.
147    pub fn duration_ms(&self) -> Option<u64> {
148        if self.started_at_ms == 0 || self.finished_at_ms == 0 {
149            return None;
150        }
151        self.finished_at_ms.checked_sub(self.started_at_ms)
152    }
153}
154
155pub(crate) fn now_ms() -> u64 {
156    SystemTime::now()
157        .duration_since(UNIX_EPOCH)
158        .map(|d| d.as_millis() as u64)
159        .unwrap_or(0)
160}
161
162// ---- JSON serialisation --------------------------------------------------
163//
164// Jobs are persisted as a small JSON object per row. The schema is:
165//
166// {
167//   "id":           "0xdeadbeef..."  (hex, 32 chars),
168//   "kind":         "train" | "backfill" | ...,
169//   "target":       "<name>",
170//   "status":       "queued" | "running" | ...,
171//   "progress":     0..=100,
172//   "created_at":   <u64 ms>,
173//   "started_at":   <u64 ms>,
174//   "finished_at":  <u64 ms>,
175//   "error":        "<msg>" | null,
176//   "spec":         "<json string, opaque>",
177//   "metrics":      "<json string, opaque>" | null
178// }
179//
180// `spec` and `metrics` are quoted JSON strings (not nested objects)
181// so the registry layer stays schema-free — the worker owns the
182// payload shape.
183
184use crate::json::{Map, Value as JsonValue};
185
186impl MlJob {
187    /// Encode as a compact JSON object suitable for KV storage.
188    pub fn to_json(&self) -> String {
189        let mut obj = Map::new();
190        obj.insert(
191            "id".to_string(),
192            JsonValue::String(format!("{:032x}", self.id)),
193        );
194        obj.insert(
195            "kind".to_string(),
196            JsonValue::String(self.kind.token().to_string()),
197        );
198        obj.insert(
199            "target".to_string(),
200            JsonValue::String(self.target_name.clone()),
201        );
202        obj.insert(
203            "status".to_string(),
204            JsonValue::String(self.status.token().to_string()),
205        );
206        obj.insert(
207            "progress".to_string(),
208            JsonValue::Number(self.progress as f64),
209        );
210        obj.insert(
211            "created_at".to_string(),
212            JsonValue::Number(self.created_at_ms as f64),
213        );
214        obj.insert(
215            "started_at".to_string(),
216            JsonValue::Number(self.started_at_ms as f64),
217        );
218        obj.insert(
219            "finished_at".to_string(),
220            JsonValue::Number(self.finished_at_ms as f64),
221        );
222        obj.insert(
223            "error".to_string(),
224            match &self.error_message {
225                Some(s) => JsonValue::String(s.clone()),
226                None => JsonValue::Null,
227            },
228        );
229        obj.insert(
230            "spec".to_string(),
231            JsonValue::String(self.spec_json.clone()),
232        );
233        obj.insert(
234            "metrics".to_string(),
235            match &self.metrics_json {
236                Some(s) => JsonValue::String(s.clone()),
237                None => JsonValue::Null,
238            },
239        );
240        JsonValue::Object(obj).to_string_compact()
241    }
242
243    /// Inverse of [`Self::to_json`]. Returns `None` on any field
244    /// mismatch — callers either skip the record or surface a
245    /// persistence-corruption error.
246    pub fn from_json(raw: &str) -> Option<Self> {
247        let parsed = crate::json::parse_json(raw).ok()?;
248        let value = JsonValue::from(parsed);
249        let obj = value.as_object()?;
250        let id_hex = obj.get("id")?.as_str()?;
251        if id_hex.len() != 32 {
252            return None;
253        }
254        let id = u128::from_str_radix(id_hex, 16).ok()?;
255        let kind = MlJobKind::from_token(obj.get("kind")?.as_str()?)?;
256        let target = obj.get("target")?.as_str()?.to_string();
257        let status = MlJobStatus::from_token(obj.get("status")?.as_str()?)?;
258        let progress = obj.get("progress")?.as_i64()? as u8;
259        let created_at = obj.get("created_at")?.as_i64()? as u64;
260        let started_at = obj.get("started_at")?.as_i64()? as u64;
261        let finished_at = obj.get("finished_at")?.as_i64()? as u64;
262        let error_message = match obj.get("error") {
263            Some(JsonValue::String(s)) => Some(s.clone()),
264            _ => None,
265        };
266        let spec_json = obj.get("spec")?.as_str()?.to_string();
267        let metrics_json = match obj.get("metrics") {
268            Some(JsonValue::String(s)) => Some(s.clone()),
269            _ => None,
270        };
271        Some(MlJob {
272            id,
273            kind,
274            target_name: target,
275            status,
276            progress: progress.min(100),
277            created_at_ms: created_at,
278            started_at_ms: started_at,
279            finished_at_ms: finished_at,
280            error_message,
281            spec_json,
282            metrics_json,
283        })
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn status_token_round_trips() {
293        for s in [
294            MlJobStatus::Queued,
295            MlJobStatus::Running,
296            MlJobStatus::Completed,
297            MlJobStatus::Failed,
298            MlJobStatus::Cancelled,
299        ] {
300            assert_eq!(MlJobStatus::from_token(s.token()), Some(s));
301        }
302    }
303
304    #[test]
305    fn kind_token_round_trips() {
306        for k in [
307            MlJobKind::Train,
308            MlJobKind::Backfill,
309            MlJobKind::FeatureRefresh,
310            MlJobKind::DriftCompute,
311        ] {
312            assert_eq!(MlJobKind::from_token(k.token()), Some(k));
313        }
314    }
315
316    #[test]
317    fn only_completed_failed_cancelled_are_terminal() {
318        assert!(!MlJobStatus::Queued.is_terminal());
319        assert!(!MlJobStatus::Running.is_terminal());
320        assert!(MlJobStatus::Completed.is_terminal());
321        assert!(MlJobStatus::Failed.is_terminal());
322        assert!(MlJobStatus::Cancelled.is_terminal());
323    }
324
325    #[test]
326    fn new_job_is_queued_with_zero_timestamps() {
327        let job = MlJob::new(1, MlJobKind::Train, "spam".into(), "{}".into());
328        assert_eq!(job.status, MlJobStatus::Queued);
329        assert_eq!(job.progress, 0);
330        assert_eq!(job.started_at_ms, 0);
331        assert_eq!(job.finished_at_ms, 0);
332        assert!(job.duration_ms().is_none());
333    }
334
335    #[test]
336    fn duration_requires_both_timestamps() {
337        let mut job = MlJob::new(1, MlJobKind::Train, "spam".into(), "{}".into());
338        job.started_at_ms = 1000;
339        assert!(job.duration_ms().is_none());
340        job.finished_at_ms = 1250;
341        assert_eq!(job.duration_ms(), Some(250));
342    }
343}