use std::time::{SystemTime, UNIX_EPOCH};
pub type MlJobId = u128;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MlJobKind {
Train,
Backfill,
FeatureRefresh,
DriftCompute,
}
impl MlJobKind {
pub fn token(self) -> &'static str {
match self {
MlJobKind::Train => "train",
MlJobKind::Backfill => "backfill",
MlJobKind::FeatureRefresh => "feature_refresh",
MlJobKind::DriftCompute => "drift_compute",
}
}
pub fn from_token(token: &str) -> Option<MlJobKind> {
match token {
"train" => Some(MlJobKind::Train),
"backfill" => Some(MlJobKind::Backfill),
"feature_refresh" => Some(MlJobKind::FeatureRefresh),
"drift_compute" => Some(MlJobKind::DriftCompute),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MlJobStatus {
Queued,
Running,
Completed,
Failed,
Cancelled,
}
impl MlJobStatus {
pub fn token(self) -> &'static str {
match self {
MlJobStatus::Queued => "queued",
MlJobStatus::Running => "running",
MlJobStatus::Completed => "completed",
MlJobStatus::Failed => "failed",
MlJobStatus::Cancelled => "cancelled",
}
}
pub fn from_token(token: &str) -> Option<MlJobStatus> {
match token {
"queued" => Some(MlJobStatus::Queued),
"running" => Some(MlJobStatus::Running),
"completed" => Some(MlJobStatus::Completed),
"failed" => Some(MlJobStatus::Failed),
"cancelled" => Some(MlJobStatus::Cancelled),
_ => None,
}
}
pub fn is_terminal(self) -> bool {
matches!(
self,
MlJobStatus::Completed | MlJobStatus::Failed | MlJobStatus::Cancelled
)
}
}
#[derive(Debug, Clone)]
pub struct MlJob {
pub id: MlJobId,
pub kind: MlJobKind,
pub target_name: String,
pub status: MlJobStatus,
pub progress: u8,
pub created_at_ms: u64,
pub started_at_ms: u64,
pub finished_at_ms: u64,
pub error_message: Option<String>,
pub spec_json: String,
pub metrics_json: Option<String>,
}
impl MlJob {
pub fn new(id: MlJobId, kind: MlJobKind, target_name: String, spec_json: String) -> Self {
Self {
id,
kind,
target_name,
status: MlJobStatus::Queued,
progress: 0,
created_at_ms: now_ms(),
started_at_ms: 0,
finished_at_ms: 0,
error_message: None,
spec_json,
metrics_json: None,
}
}
pub fn is_terminal(&self) -> bool {
self.status.is_terminal()
}
pub fn duration_ms(&self) -> Option<u64> {
if self.started_at_ms == 0 || self.finished_at_ms == 0 {
return None;
}
self.finished_at_ms.checked_sub(self.started_at_ms)
}
}
pub(crate) fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
use crate::json::{Map, Value as JsonValue};
impl MlJob {
pub fn to_json(&self) -> String {
let mut obj = Map::new();
obj.insert(
"id".to_string(),
JsonValue::String(format!("{:032x}", self.id)),
);
obj.insert(
"kind".to_string(),
JsonValue::String(self.kind.token().to_string()),
);
obj.insert(
"target".to_string(),
JsonValue::String(self.target_name.clone()),
);
obj.insert(
"status".to_string(),
JsonValue::String(self.status.token().to_string()),
);
obj.insert(
"progress".to_string(),
JsonValue::Number(self.progress as f64),
);
obj.insert(
"created_at".to_string(),
JsonValue::Number(self.created_at_ms as f64),
);
obj.insert(
"started_at".to_string(),
JsonValue::Number(self.started_at_ms as f64),
);
obj.insert(
"finished_at".to_string(),
JsonValue::Number(self.finished_at_ms as f64),
);
obj.insert(
"error".to_string(),
match &self.error_message {
Some(s) => JsonValue::String(s.clone()),
None => JsonValue::Null,
},
);
obj.insert(
"spec".to_string(),
JsonValue::String(self.spec_json.clone()),
);
obj.insert(
"metrics".to_string(),
match &self.metrics_json {
Some(s) => JsonValue::String(s.clone()),
None => JsonValue::Null,
},
);
JsonValue::Object(obj).to_string_compact()
}
pub fn from_json(raw: &str) -> Option<Self> {
let parsed = crate::json::parse_json(raw).ok()?;
let value = JsonValue::from(parsed);
let obj = value.as_object()?;
let id_hex = obj.get("id")?.as_str()?;
if id_hex.len() != 32 {
return None;
}
let id = u128::from_str_radix(id_hex, 16).ok()?;
let kind = MlJobKind::from_token(obj.get("kind")?.as_str()?)?;
let target = obj.get("target")?.as_str()?.to_string();
let status = MlJobStatus::from_token(obj.get("status")?.as_str()?)?;
let progress = obj.get("progress")?.as_i64()? as u8;
let created_at = obj.get("created_at")?.as_i64()? as u64;
let started_at = obj.get("started_at")?.as_i64()? as u64;
let finished_at = obj.get("finished_at")?.as_i64()? as u64;
let error_message = match obj.get("error") {
Some(JsonValue::String(s)) => Some(s.clone()),
_ => None,
};
let spec_json = obj.get("spec")?.as_str()?.to_string();
let metrics_json = match obj.get("metrics") {
Some(JsonValue::String(s)) => Some(s.clone()),
_ => None,
};
Some(MlJob {
id,
kind,
target_name: target,
status,
progress: progress.min(100),
created_at_ms: created_at,
started_at_ms: started_at,
finished_at_ms: finished_at,
error_message,
spec_json,
metrics_json,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn status_token_round_trips() {
for s in [
MlJobStatus::Queued,
MlJobStatus::Running,
MlJobStatus::Completed,
MlJobStatus::Failed,
MlJobStatus::Cancelled,
] {
assert_eq!(MlJobStatus::from_token(s.token()), Some(s));
}
}
#[test]
fn kind_token_round_trips() {
for k in [
MlJobKind::Train,
MlJobKind::Backfill,
MlJobKind::FeatureRefresh,
MlJobKind::DriftCompute,
] {
assert_eq!(MlJobKind::from_token(k.token()), Some(k));
}
}
#[test]
fn only_completed_failed_cancelled_are_terminal() {
assert!(!MlJobStatus::Queued.is_terminal());
assert!(!MlJobStatus::Running.is_terminal());
assert!(MlJobStatus::Completed.is_terminal());
assert!(MlJobStatus::Failed.is_terminal());
assert!(MlJobStatus::Cancelled.is_terminal());
}
#[test]
fn new_job_is_queued_with_zero_timestamps() {
let job = MlJob::new(1, MlJobKind::Train, "spam".into(), "{}".into());
assert_eq!(job.status, MlJobStatus::Queued);
assert_eq!(job.progress, 0);
assert_eq!(job.started_at_ms, 0);
assert_eq!(job.finished_at_ms, 0);
assert!(job.duration_ms().is_none());
}
#[test]
fn duration_requires_both_timestamps() {
let mut job = MlJob::new(1, MlJobKind::Train, "spam".into(), "{}".into());
job.started_at_ms = 1000;
assert!(job.duration_ms().is_none());
job.finished_at_ms = 1250;
assert_eq!(job.duration_ms(), Some(250));
}
}