Skip to main content

omnigraph/db/
run_registry.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use arrow_array::{
7    Array, RecordBatch, RecordBatchIterator, StringArray, TimestampMicrosecondArray, UInt64Array,
8};
9use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
10use futures::TryStreamExt;
11use lance::Dataset;
12use lance::dataset::{WriteMode, WriteParams};
13use lance_file::version::LanceFileVersion;
14
15use crate::error::{OmniError, Result};
16
17const GRAPH_RUNS_DIR: &str = "_graph_runs.lance";
18const GRAPH_RUN_ACTORS_DIR: &str = "_graph_run_actors.lance";
19pub(crate) const INTERNAL_RUN_BRANCH_PREFIX: &str = "__run__";
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct RunId(String);
23
24impl RunId {
25    pub fn new(id: impl Into<String>) -> Self {
26        Self(id.into())
27    }
28
29    pub fn as_str(&self) -> &str {
30        &self.0
31    }
32}
33
34impl fmt::Display for RunId {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        self.0.fmt(f)
37    }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum RunStatus {
42    Running,
43    Published,
44    Failed,
45    Aborted,
46}
47
48impl RunStatus {
49    pub fn as_str(self) -> &'static str {
50        match self {
51            RunStatus::Running => "running",
52            RunStatus::Published => "published",
53            RunStatus::Failed => "failed",
54            RunStatus::Aborted => "aborted",
55        }
56    }
57
58    fn parse(value: &str) -> Result<Self> {
59        match value {
60            "running" => Ok(Self::Running),
61            "published" => Ok(Self::Published),
62            "failed" => Ok(Self::Failed),
63            "aborted" => Ok(Self::Aborted),
64            other => Err(OmniError::manifest(format!(
65                "invalid run status '{}'",
66                other
67            ))),
68        }
69    }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
73pub struct RunRecord {
74    pub run_id: RunId,
75    pub target_branch: String,
76    pub run_branch: String,
77    pub base_snapshot_id: String,
78    pub base_manifest_version: u64,
79    pub operation_hash: Option<String>,
80    pub actor_id: Option<String>,
81    pub status: RunStatus,
82    pub published_snapshot_id: Option<String>,
83    pub created_at: i64,
84    pub updated_at: i64,
85}
86
87impl RunRecord {
88    pub fn new(
89        target_branch: impl Into<String>,
90        base_snapshot_id: impl Into<String>,
91        base_manifest_version: u64,
92        operation_hash: Option<String>,
93        actor_id: Option<String>,
94    ) -> Result<Self> {
95        let now = now_micros()?;
96        let run_id = RunId::new(ulid::Ulid::new().to_string());
97        Ok(Self {
98            run_branch: internal_run_branch_name(&run_id),
99            run_id,
100            target_branch: target_branch.into(),
101            base_snapshot_id: base_snapshot_id.into(),
102            base_manifest_version,
103            operation_hash,
104            actor_id,
105            status: RunStatus::Running,
106            published_snapshot_id: None,
107            created_at: now,
108            updated_at: now,
109        })
110    }
111
112    pub fn with_status(
113        &self,
114        status: RunStatus,
115        published_snapshot_id: Option<String>,
116    ) -> Result<Self> {
117        Ok(Self {
118            run_id: self.run_id.clone(),
119            target_branch: self.target_branch.clone(),
120            run_branch: self.run_branch.clone(),
121            base_snapshot_id: self.base_snapshot_id.clone(),
122            base_manifest_version: self.base_manifest_version,
123            operation_hash: self.operation_hash.clone(),
124            actor_id: self.actor_id.clone(),
125            status,
126            published_snapshot_id,
127            created_at: self.created_at,
128            updated_at: now_micros()?,
129        })
130    }
131}
132
133pub struct RunRegistry {
134    dataset: Dataset,
135    actor_dataset: Option<Dataset>,
136    latest_by_id: HashMap<String, RunRecord>,
137    actor_by_run_id: HashMap<String, String>,
138    root_uri: String,
139}
140
141impl RunRegistry {
142    pub async fn init(root_uri: &str) -> Result<Self> {
143        let uri = graph_runs_uri(root_uri);
144        let batch = RecordBatch::new_empty(run_registry_schema());
145        let reader = RecordBatchIterator::new(vec![Ok(batch)], run_registry_schema());
146        let params = WriteParams {
147            mode: WriteMode::Create,
148            enable_stable_row_ids: true,
149            data_storage_version: Some(LanceFileVersion::V2_2),
150            ..Default::default()
151        };
152        let dataset = Dataset::write(reader, &uri as &str, Some(params))
153            .await
154            .map_err(|e| OmniError::Lance(e.to_string()))?;
155        let actor_dataset = create_run_actor_dataset(root_uri).await?;
156        Ok(Self {
157            dataset,
158            actor_dataset: Some(actor_dataset),
159            latest_by_id: HashMap::new(),
160            actor_by_run_id: HashMap::new(),
161            root_uri: root_uri.to_string(),
162        })
163    }
164
165    pub async fn open(root_uri: &str) -> Result<Self> {
166        let dataset = Dataset::open(&graph_runs_uri(root_uri))
167            .await
168            .map_err(|e| OmniError::Lance(e.to_string()))?;
169        let actor_dataset = Dataset::open(&graph_run_actors_uri(root_uri)).await.ok();
170        let actor_by_run_id = match &actor_dataset {
171            Some(dataset) => load_run_actor_cache(dataset).await?,
172            None => HashMap::new(),
173        };
174        let latest_by_id = load_run_cache(&dataset, &actor_by_run_id).await?;
175        Ok(Self {
176            dataset,
177            actor_dataset,
178            latest_by_id,
179            actor_by_run_id,
180            root_uri: root_uri.to_string(),
181        })
182    }
183
184    pub async fn refresh(&mut self, root_uri: &str) -> Result<()> {
185        self.dataset = Dataset::open(&graph_runs_uri(root_uri))
186            .await
187            .map_err(|e| OmniError::Lance(e.to_string()))?;
188        self.actor_dataset = Dataset::open(&graph_run_actors_uri(root_uri)).await.ok();
189        self.actor_by_run_id = match &self.actor_dataset {
190            Some(dataset) => load_run_actor_cache(dataset).await?,
191            None => HashMap::new(),
192        };
193        self.latest_by_id = load_run_cache(&self.dataset, &self.actor_by_run_id).await?;
194        self.root_uri = root_uri.to_string();
195        Ok(())
196    }
197
198    pub async fn append_record(&mut self, record: &RunRecord) -> Result<()> {
199        let batch = runs_to_batch(&[record.clone()])?;
200        let reader = RecordBatchIterator::new(vec![Ok(batch)], run_registry_schema());
201        let mut ds = self.dataset.clone();
202        ds.append(reader, None)
203            .await
204            .map_err(|e| OmniError::Lance(e.to_string()))?;
205        self.dataset = ds;
206        if let Some(actor_id) = &record.actor_id {
207            self.append_actor(record.run_id.as_str(), actor_id).await?;
208        }
209        let mut record = record.clone();
210        if record.actor_id.is_none() {
211            record.actor_id = self.actor_by_run_id.get(record.run_id.as_str()).cloned();
212        }
213        merge_latest_run(&mut self.latest_by_id, record);
214        Ok(())
215    }
216
217    pub async fn get_run(&self, run_id: &RunId) -> Result<Option<RunRecord>> {
218        Ok(self.latest_by_id.get(run_id.as_str()).cloned())
219    }
220
221    pub async fn list_runs(&self) -> Result<Vec<RunRecord>> {
222        self.load_runs().await
223    }
224
225    pub async fn load_runs(&self) -> Result<Vec<RunRecord>> {
226        let mut runs = self.latest_by_id.values().cloned().collect::<Vec<_>>();
227        runs.sort_by(|a, b| {
228            a.created_at
229                .cmp(&b.created_at)
230                .then_with(|| a.run_id.as_str().cmp(b.run_id.as_str()))
231        });
232        Ok(runs)
233    }
234
235    async fn append_actor(&mut self, run_id: &str, actor_id: &str) -> Result<()> {
236        if self
237            .actor_by_run_id
238            .get(run_id)
239            .is_some_and(|existing| existing == actor_id)
240        {
241            return Ok(());
242        }
243
244        let record = RunActorRecord {
245            run_id: run_id.to_string(),
246            actor_id: actor_id.to_string(),
247            created_at: now_micros()?,
248        };
249        let batch = run_actors_to_batch(&[record])?;
250        let reader = RecordBatchIterator::new(vec![Ok(batch)], run_actor_schema());
251        let mut dataset = match self.actor_dataset.take() {
252            Some(dataset) => dataset,
253            None => create_run_actor_dataset(&self.root_uri).await?,
254        };
255        dataset
256            .append(reader, None)
257            .await
258            .map_err(|e| OmniError::Lance(e.to_string()))?;
259        self.actor_by_run_id
260            .insert(run_id.to_string(), actor_id.to_string());
261        self.actor_dataset = Some(dataset);
262        Ok(())
263    }
264}
265
266pub(crate) fn is_internal_run_branch(name: &str) -> bool {
267    name.trim_start_matches('/')
268        .starts_with(INTERNAL_RUN_BRANCH_PREFIX)
269}
270
271pub(crate) fn internal_run_branch_name(run_id: &RunId) -> String {
272    format!("{}{}", INTERNAL_RUN_BRANCH_PREFIX, run_id.as_str())
273}
274
275pub(crate) fn graph_runs_uri(root_uri: &str) -> String {
276    format!("{}/{}", root_uri.trim_end_matches('/'), GRAPH_RUNS_DIR)
277}
278
279fn graph_run_actors_uri(root_uri: &str) -> String {
280    format!(
281        "{}/{}",
282        root_uri.trim_end_matches('/'),
283        GRAPH_RUN_ACTORS_DIR
284    )
285}
286
287fn run_registry_schema() -> SchemaRef {
288    Arc::new(Schema::new(vec![
289        Field::new("run_id", DataType::Utf8, false),
290        Field::new("target_branch", DataType::Utf8, false),
291        Field::new("run_branch", DataType::Utf8, false),
292        Field::new("base_snapshot_id", DataType::Utf8, false),
293        Field::new("base_manifest_version", DataType::UInt64, false),
294        Field::new("operation_hash", DataType::Utf8, true),
295        Field::new("status", DataType::Utf8, false),
296        Field::new("published_snapshot_id", DataType::Utf8, true),
297        Field::new(
298            "created_at",
299            DataType::Timestamp(TimeUnit::Microsecond, None),
300            false,
301        ),
302        Field::new(
303            "updated_at",
304            DataType::Timestamp(TimeUnit::Microsecond, None),
305            false,
306        ),
307    ]))
308}
309
310fn run_actor_schema() -> SchemaRef {
311    Arc::new(Schema::new(vec![
312        Field::new("run_id", DataType::Utf8, false),
313        Field::new("actor_id", DataType::Utf8, false),
314        Field::new(
315            "created_at",
316            DataType::Timestamp(TimeUnit::Microsecond, None),
317            false,
318        ),
319    ]))
320}
321
322async fn create_run_actor_dataset(root_uri: &str) -> Result<Dataset> {
323    let batch = RecordBatch::new_empty(run_actor_schema());
324    let reader = RecordBatchIterator::new(vec![Ok(batch)], run_actor_schema());
325    let params = WriteParams {
326        mode: WriteMode::Create,
327        enable_stable_row_ids: true,
328        data_storage_version: Some(LanceFileVersion::V2_2),
329        ..Default::default()
330    };
331    Dataset::write(
332        reader,
333        &graph_run_actors_uri(root_uri) as &str,
334        Some(params),
335    )
336    .await
337    .map_err(|e| OmniError::Lance(e.to_string()))
338}
339
340async fn load_run_cache(
341    dataset: &Dataset,
342    actor_by_run_id: &HashMap<String, String>,
343) -> Result<HashMap<String, RunRecord>> {
344    let batches: Vec<RecordBatch> = dataset
345        .scan()
346        .try_into_stream()
347        .await
348        .map_err(|e| OmniError::Lance(e.to_string()))?
349        .try_collect()
350        .await
351        .map_err(|e| OmniError::Lance(e.to_string()))?;
352
353    let mut latest_by_id = HashMap::new();
354    for mut record in load_runs_from_batches(&batches)? {
355        record.actor_id = actor_by_run_id.get(record.run_id.as_str()).cloned();
356        merge_latest_run(&mut latest_by_id, record);
357    }
358    Ok(latest_by_id)
359}
360
361async fn load_run_actor_cache(dataset: &Dataset) -> Result<HashMap<String, String>> {
362    let batches: Vec<RecordBatch> = dataset
363        .scan()
364        .try_into_stream()
365        .await
366        .map_err(|e| OmniError::Lance(e.to_string()))?
367        .try_collect()
368        .await
369        .map_err(|e| OmniError::Lance(e.to_string()))?;
370
371    let mut actors = HashMap::new();
372    for batch in batches {
373        let run_ids = string_column(&batch, "run_id", "run actor registry")?;
374        let actor_ids = string_column(&batch, "actor_id", "run actor registry")?;
375        for row in 0..batch.num_rows() {
376            actors.insert(
377                run_ids.value(row).to_string(),
378                actor_ids.value(row).to_string(),
379            );
380        }
381    }
382    Ok(actors)
383}
384
385fn load_runs_from_batches(batches: &[RecordBatch]) -> Result<Vec<RunRecord>> {
386    let mut runs = Vec::new();
387    for batch in batches {
388        let run_ids = string_column(batch, "run_id", "run registry")?;
389        let target_branches = string_column(batch, "target_branch", "run registry")?;
390        let run_branches = string_column(batch, "run_branch", "run registry")?;
391        let base_snapshot_ids = string_column(batch, "base_snapshot_id", "run registry")?;
392        let base_manifest_versions = u64_column(batch, "base_manifest_version", "run registry")?;
393        let operation_hashes = string_column(batch, "operation_hash", "run registry")?;
394        let statuses = string_column(batch, "status", "run registry")?;
395        let published_snapshot_ids = string_column(batch, "published_snapshot_id", "run registry")?;
396        let created_ats = timestamp_micros_column(batch, "created_at", "run registry")?;
397        let updated_ats = timestamp_micros_column(batch, "updated_at", "run registry")?;
398
399        for row in 0..batch.num_rows() {
400            runs.push(RunRecord {
401                run_id: RunId::new(run_ids.value(row)),
402                target_branch: target_branches.value(row).to_string(),
403                run_branch: run_branches.value(row).to_string(),
404                base_snapshot_id: base_snapshot_ids.value(row).to_string(),
405                base_manifest_version: base_manifest_versions.value(row),
406                operation_hash: if operation_hashes.is_null(row) {
407                    None
408                } else {
409                    Some(operation_hashes.value(row).to_string())
410                },
411                actor_id: None,
412                status: RunStatus::parse(statuses.value(row))?,
413                published_snapshot_id: if published_snapshot_ids.is_null(row) {
414                    None
415                } else {
416                    Some(published_snapshot_ids.value(row).to_string())
417                },
418                created_at: created_ats.value(row),
419                updated_at: updated_ats.value(row),
420            });
421        }
422    }
423    Ok(runs)
424}
425
426fn merge_latest_run(latest_by_id: &mut HashMap<String, RunRecord>, record: RunRecord) {
427    match latest_by_id.get(record.run_id.as_str()) {
428        Some(existing)
429            if existing.updated_at > record.updated_at
430                || (existing.updated_at == record.updated_at
431                    && existing.created_at >= record.created_at) => {}
432        _ => {
433            latest_by_id.insert(record.run_id.as_str().to_string(), record);
434        }
435    }
436}
437
438fn string_column<'a>(batch: &'a RecordBatch, name: &str, context: &str) -> Result<&'a StringArray> {
439    batch
440        .column_by_name(name)
441        .ok_or_else(|| {
442            OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
443        })?
444        .as_any()
445        .downcast_ref::<StringArray>()
446        .ok_or_else(|| {
447            OmniError::manifest_internal(format!("{context} column '{name}' is not Utf8"))
448        })
449}
450
451fn u64_column<'a>(batch: &'a RecordBatch, name: &str, context: &str) -> Result<&'a UInt64Array> {
452    batch
453        .column_by_name(name)
454        .ok_or_else(|| {
455            OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
456        })?
457        .as_any()
458        .downcast_ref::<UInt64Array>()
459        .ok_or_else(|| {
460            OmniError::manifest_internal(format!("{context} column '{name}' is not UInt64"))
461        })
462}
463
464fn timestamp_micros_column<'a>(
465    batch: &'a RecordBatch,
466    name: &str,
467    context: &str,
468) -> Result<&'a TimestampMicrosecondArray> {
469    batch
470        .column_by_name(name)
471        .ok_or_else(|| {
472            OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
473        })?
474        .as_any()
475        .downcast_ref::<TimestampMicrosecondArray>()
476        .ok_or_else(|| {
477            OmniError::manifest_internal(format!(
478                "{context} column '{name}' is not Timestamp(Microsecond)"
479            ))
480        })
481}
482
483fn runs_to_batch(records: &[RunRecord]) -> Result<RecordBatch> {
484    let run_ids: Vec<&str> = records
485        .iter()
486        .map(|record| record.run_id.as_str())
487        .collect();
488    let target_branches: Vec<&str> = records
489        .iter()
490        .map(|record| record.target_branch.as_str())
491        .collect();
492    let run_branches: Vec<&str> = records
493        .iter()
494        .map(|record| record.run_branch.as_str())
495        .collect();
496    let base_snapshot_ids: Vec<&str> = records
497        .iter()
498        .map(|record| record.base_snapshot_id.as_str())
499        .collect();
500    let base_manifest_versions: Vec<u64> = records
501        .iter()
502        .map(|record| record.base_manifest_version)
503        .collect();
504    let operation_hashes: Vec<Option<&str>> = records
505        .iter()
506        .map(|record| record.operation_hash.as_deref())
507        .collect();
508    let statuses: Vec<&str> = records
509        .iter()
510        .map(|record| record.status.as_str())
511        .collect();
512    let published_snapshot_ids: Vec<Option<&str>> = records
513        .iter()
514        .map(|record| record.published_snapshot_id.as_deref())
515        .collect();
516    let created_ats: Vec<i64> = records.iter().map(|record| record.created_at).collect();
517    let updated_ats: Vec<i64> = records.iter().map(|record| record.updated_at).collect();
518
519    RecordBatch::try_new(
520        run_registry_schema(),
521        vec![
522            Arc::new(StringArray::from(run_ids)),
523            Arc::new(StringArray::from(target_branches)),
524            Arc::new(StringArray::from(run_branches)),
525            Arc::new(StringArray::from(base_snapshot_ids)),
526            Arc::new(UInt64Array::from(base_manifest_versions)),
527            Arc::new(StringArray::from(operation_hashes)),
528            Arc::new(StringArray::from(statuses)),
529            Arc::new(StringArray::from(published_snapshot_ids)),
530            Arc::new(TimestampMicrosecondArray::from(created_ats)),
531            Arc::new(TimestampMicrosecondArray::from(updated_ats)),
532        ],
533    )
534    .map_err(|e| OmniError::Lance(e.to_string()))
535}
536
537#[derive(Debug, Clone, PartialEq, Eq)]
538struct RunActorRecord {
539    run_id: String,
540    actor_id: String,
541    created_at: i64,
542}
543
544fn run_actors_to_batch(records: &[RunActorRecord]) -> Result<RecordBatch> {
545    let run_ids: Vec<&str> = records
546        .iter()
547        .map(|record| record.run_id.as_str())
548        .collect();
549    let actor_ids: Vec<&str> = records
550        .iter()
551        .map(|record| record.actor_id.as_str())
552        .collect();
553    let created_ats: Vec<i64> = records.iter().map(|record| record.created_at).collect();
554
555    RecordBatch::try_new(
556        run_actor_schema(),
557        vec![
558            Arc::new(StringArray::from(run_ids)),
559            Arc::new(StringArray::from(actor_ids)),
560            Arc::new(TimestampMicrosecondArray::from(created_ats)),
561        ],
562    )
563    .map_err(|e| OmniError::Lance(e.to_string()))
564}
565
566fn now_micros() -> Result<i64> {
567    let duration = SystemTime::now()
568        .duration_since(UNIX_EPOCH)
569        .map_err(|e| OmniError::manifest(format!("system clock error: {}", e)))?;
570    Ok(duration.as_micros() as i64)
571}
572
573#[cfg(test)]
574mod tests {
575    use std::sync::Arc;
576
577    use arrow_schema::{DataType, Field, Schema};
578
579    use super::*;
580
581    #[test]
582    fn load_runs_from_batches_returns_error_for_bad_schema() {
583        let batch = RecordBatch::try_new(
584            Arc::new(Schema::new(vec![
585                Field::new("run_id", DataType::UInt64, false),
586                Field::new("target_branch", DataType::Utf8, false),
587                Field::new("run_branch", DataType::Utf8, false),
588                Field::new("base_snapshot_id", DataType::Utf8, false),
589                Field::new("base_manifest_version", DataType::UInt64, false),
590                Field::new("operation_hash", DataType::Utf8, true),
591                Field::new("status", DataType::Utf8, false),
592                Field::new("published_snapshot_id", DataType::Utf8, true),
593                Field::new(
594                    "created_at",
595                    DataType::Timestamp(TimeUnit::Microsecond, None),
596                    false,
597                ),
598                Field::new(
599                    "updated_at",
600                    DataType::Timestamp(TimeUnit::Microsecond, None),
601                    false,
602                ),
603            ])),
604            vec![
605                Arc::new(UInt64Array::from(vec![1_u64])),
606                Arc::new(StringArray::from(vec!["main"])),
607                Arc::new(StringArray::from(vec!["__run__1"])),
608                Arc::new(StringArray::from(vec!["snap-1"])),
609                Arc::new(UInt64Array::from(vec![1_u64])),
610                Arc::new(StringArray::from(vec![None::<&str>])),
611                Arc::new(StringArray::from(vec!["running"])),
612                Arc::new(StringArray::from(vec![None::<&str>])),
613                Arc::new(TimestampMicrosecondArray::from(vec![1_i64])),
614                Arc::new(TimestampMicrosecondArray::from(vec![1_i64])),
615            ],
616        )
617        .unwrap();
618
619        let err = load_runs_from_batches(&[batch]).unwrap_err();
620        assert!(err.to_string().contains("run_id"));
621    }
622}