1use std::collections::HashMap;
2use std::fmt;
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use arrow_array::{
7 Array, RecordBatch, RecordBatchIterator, StringArray, TimestampMicrosecondArray, UInt64Array,
8};
9use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
10use futures::TryStreamExt;
11use lance::Dataset;
12use lance::dataset::{WriteMode, WriteParams};
13use lance_file::version::LanceFileVersion;
14
15use crate::error::{OmniError, Result};
16
17const GRAPH_RUNS_DIR: &str = "_graph_runs.lance";
18const GRAPH_RUN_ACTORS_DIR: &str = "_graph_run_actors.lance";
19pub(crate) const INTERNAL_RUN_BRANCH_PREFIX: &str = "__run__";
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct RunId(String);
23
24impl RunId {
25 pub fn new(id: impl Into<String>) -> Self {
26 Self(id.into())
27 }
28
29 pub fn as_str(&self) -> &str {
30 &self.0
31 }
32}
33
34impl fmt::Display for RunId {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 self.0.fmt(f)
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum RunStatus {
42 Running,
43 Published,
44 Failed,
45 Aborted,
46}
47
48impl RunStatus {
49 pub fn as_str(self) -> &'static str {
50 match self {
51 RunStatus::Running => "running",
52 RunStatus::Published => "published",
53 RunStatus::Failed => "failed",
54 RunStatus::Aborted => "aborted",
55 }
56 }
57
58 fn parse(value: &str) -> Result<Self> {
59 match value {
60 "running" => Ok(Self::Running),
61 "published" => Ok(Self::Published),
62 "failed" => Ok(Self::Failed),
63 "aborted" => Ok(Self::Aborted),
64 other => Err(OmniError::manifest(format!(
65 "invalid run status '{}'",
66 other
67 ))),
68 }
69 }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
73pub struct RunRecord {
74 pub run_id: RunId,
75 pub target_branch: String,
76 pub run_branch: String,
77 pub base_snapshot_id: String,
78 pub base_manifest_version: u64,
79 pub operation_hash: Option<String>,
80 pub actor_id: Option<String>,
81 pub status: RunStatus,
82 pub published_snapshot_id: Option<String>,
83 pub created_at: i64,
84 pub updated_at: i64,
85}
86
87impl RunRecord {
88 pub fn new(
89 target_branch: impl Into<String>,
90 base_snapshot_id: impl Into<String>,
91 base_manifest_version: u64,
92 operation_hash: Option<String>,
93 actor_id: Option<String>,
94 ) -> Result<Self> {
95 let now = now_micros()?;
96 let run_id = RunId::new(ulid::Ulid::new().to_string());
97 Ok(Self {
98 run_branch: internal_run_branch_name(&run_id),
99 run_id,
100 target_branch: target_branch.into(),
101 base_snapshot_id: base_snapshot_id.into(),
102 base_manifest_version,
103 operation_hash,
104 actor_id,
105 status: RunStatus::Running,
106 published_snapshot_id: None,
107 created_at: now,
108 updated_at: now,
109 })
110 }
111
112 pub fn with_status(
113 &self,
114 status: RunStatus,
115 published_snapshot_id: Option<String>,
116 ) -> Result<Self> {
117 Ok(Self {
118 run_id: self.run_id.clone(),
119 target_branch: self.target_branch.clone(),
120 run_branch: self.run_branch.clone(),
121 base_snapshot_id: self.base_snapshot_id.clone(),
122 base_manifest_version: self.base_manifest_version,
123 operation_hash: self.operation_hash.clone(),
124 actor_id: self.actor_id.clone(),
125 status,
126 published_snapshot_id,
127 created_at: self.created_at,
128 updated_at: now_micros()?,
129 })
130 }
131}
132
133pub struct RunRegistry {
134 dataset: Dataset,
135 actor_dataset: Option<Dataset>,
136 latest_by_id: HashMap<String, RunRecord>,
137 actor_by_run_id: HashMap<String, String>,
138 root_uri: String,
139}
140
141impl RunRegistry {
142 pub async fn init(root_uri: &str) -> Result<Self> {
143 let uri = graph_runs_uri(root_uri);
144 let batch = RecordBatch::new_empty(run_registry_schema());
145 let reader = RecordBatchIterator::new(vec![Ok(batch)], run_registry_schema());
146 let params = WriteParams {
147 mode: WriteMode::Create,
148 enable_stable_row_ids: true,
149 data_storage_version: Some(LanceFileVersion::V2_2),
150 ..Default::default()
151 };
152 let dataset = Dataset::write(reader, &uri as &str, Some(params))
153 .await
154 .map_err(|e| OmniError::Lance(e.to_string()))?;
155 let actor_dataset = create_run_actor_dataset(root_uri).await?;
156 Ok(Self {
157 dataset,
158 actor_dataset: Some(actor_dataset),
159 latest_by_id: HashMap::new(),
160 actor_by_run_id: HashMap::new(),
161 root_uri: root_uri.to_string(),
162 })
163 }
164
165 pub async fn open(root_uri: &str) -> Result<Self> {
166 let dataset = Dataset::open(&graph_runs_uri(root_uri))
167 .await
168 .map_err(|e| OmniError::Lance(e.to_string()))?;
169 let actor_dataset = Dataset::open(&graph_run_actors_uri(root_uri)).await.ok();
170 let actor_by_run_id = match &actor_dataset {
171 Some(dataset) => load_run_actor_cache(dataset).await?,
172 None => HashMap::new(),
173 };
174 let latest_by_id = load_run_cache(&dataset, &actor_by_run_id).await?;
175 Ok(Self {
176 dataset,
177 actor_dataset,
178 latest_by_id,
179 actor_by_run_id,
180 root_uri: root_uri.to_string(),
181 })
182 }
183
184 pub async fn refresh(&mut self, root_uri: &str) -> Result<()> {
185 self.dataset = Dataset::open(&graph_runs_uri(root_uri))
186 .await
187 .map_err(|e| OmniError::Lance(e.to_string()))?;
188 self.actor_dataset = Dataset::open(&graph_run_actors_uri(root_uri)).await.ok();
189 self.actor_by_run_id = match &self.actor_dataset {
190 Some(dataset) => load_run_actor_cache(dataset).await?,
191 None => HashMap::new(),
192 };
193 self.latest_by_id = load_run_cache(&self.dataset, &self.actor_by_run_id).await?;
194 self.root_uri = root_uri.to_string();
195 Ok(())
196 }
197
198 pub async fn append_record(&mut self, record: &RunRecord) -> Result<()> {
199 let batch = runs_to_batch(&[record.clone()])?;
200 let reader = RecordBatchIterator::new(vec![Ok(batch)], run_registry_schema());
201 let mut ds = self.dataset.clone();
202 ds.append(reader, None)
203 .await
204 .map_err(|e| OmniError::Lance(e.to_string()))?;
205 self.dataset = ds;
206 if let Some(actor_id) = &record.actor_id {
207 self.append_actor(record.run_id.as_str(), actor_id).await?;
208 }
209 let mut record = record.clone();
210 if record.actor_id.is_none() {
211 record.actor_id = self.actor_by_run_id.get(record.run_id.as_str()).cloned();
212 }
213 merge_latest_run(&mut self.latest_by_id, record);
214 Ok(())
215 }
216
217 pub async fn get_run(&self, run_id: &RunId) -> Result<Option<RunRecord>> {
218 Ok(self.latest_by_id.get(run_id.as_str()).cloned())
219 }
220
221 pub async fn list_runs(&self) -> Result<Vec<RunRecord>> {
222 self.load_runs().await
223 }
224
225 pub async fn load_runs(&self) -> Result<Vec<RunRecord>> {
226 let mut runs = self.latest_by_id.values().cloned().collect::<Vec<_>>();
227 runs.sort_by(|a, b| {
228 a.created_at
229 .cmp(&b.created_at)
230 .then_with(|| a.run_id.as_str().cmp(b.run_id.as_str()))
231 });
232 Ok(runs)
233 }
234
235 async fn append_actor(&mut self, run_id: &str, actor_id: &str) -> Result<()> {
236 if self
237 .actor_by_run_id
238 .get(run_id)
239 .is_some_and(|existing| existing == actor_id)
240 {
241 return Ok(());
242 }
243
244 let record = RunActorRecord {
245 run_id: run_id.to_string(),
246 actor_id: actor_id.to_string(),
247 created_at: now_micros()?,
248 };
249 let batch = run_actors_to_batch(&[record])?;
250 let reader = RecordBatchIterator::new(vec![Ok(batch)], run_actor_schema());
251 let mut dataset = match self.actor_dataset.take() {
252 Some(dataset) => dataset,
253 None => create_run_actor_dataset(&self.root_uri).await?,
254 };
255 dataset
256 .append(reader, None)
257 .await
258 .map_err(|e| OmniError::Lance(e.to_string()))?;
259 self.actor_by_run_id
260 .insert(run_id.to_string(), actor_id.to_string());
261 self.actor_dataset = Some(dataset);
262 Ok(())
263 }
264}
265
266pub(crate) fn is_internal_run_branch(name: &str) -> bool {
267 name.trim_start_matches('/')
268 .starts_with(INTERNAL_RUN_BRANCH_PREFIX)
269}
270
271pub(crate) fn internal_run_branch_name(run_id: &RunId) -> String {
272 format!("{}{}", INTERNAL_RUN_BRANCH_PREFIX, run_id.as_str())
273}
274
275pub(crate) fn graph_runs_uri(root_uri: &str) -> String {
276 format!("{}/{}", root_uri.trim_end_matches('/'), GRAPH_RUNS_DIR)
277}
278
279fn graph_run_actors_uri(root_uri: &str) -> String {
280 format!(
281 "{}/{}",
282 root_uri.trim_end_matches('/'),
283 GRAPH_RUN_ACTORS_DIR
284 )
285}
286
287fn run_registry_schema() -> SchemaRef {
288 Arc::new(Schema::new(vec![
289 Field::new("run_id", DataType::Utf8, false),
290 Field::new("target_branch", DataType::Utf8, false),
291 Field::new("run_branch", DataType::Utf8, false),
292 Field::new("base_snapshot_id", DataType::Utf8, false),
293 Field::new("base_manifest_version", DataType::UInt64, false),
294 Field::new("operation_hash", DataType::Utf8, true),
295 Field::new("status", DataType::Utf8, false),
296 Field::new("published_snapshot_id", DataType::Utf8, true),
297 Field::new(
298 "created_at",
299 DataType::Timestamp(TimeUnit::Microsecond, None),
300 false,
301 ),
302 Field::new(
303 "updated_at",
304 DataType::Timestamp(TimeUnit::Microsecond, None),
305 false,
306 ),
307 ]))
308}
309
310fn run_actor_schema() -> SchemaRef {
311 Arc::new(Schema::new(vec![
312 Field::new("run_id", DataType::Utf8, false),
313 Field::new("actor_id", DataType::Utf8, false),
314 Field::new(
315 "created_at",
316 DataType::Timestamp(TimeUnit::Microsecond, None),
317 false,
318 ),
319 ]))
320}
321
322async fn create_run_actor_dataset(root_uri: &str) -> Result<Dataset> {
323 let batch = RecordBatch::new_empty(run_actor_schema());
324 let reader = RecordBatchIterator::new(vec![Ok(batch)], run_actor_schema());
325 let params = WriteParams {
326 mode: WriteMode::Create,
327 enable_stable_row_ids: true,
328 data_storage_version: Some(LanceFileVersion::V2_2),
329 ..Default::default()
330 };
331 Dataset::write(
332 reader,
333 &graph_run_actors_uri(root_uri) as &str,
334 Some(params),
335 )
336 .await
337 .map_err(|e| OmniError::Lance(e.to_string()))
338}
339
340async fn load_run_cache(
341 dataset: &Dataset,
342 actor_by_run_id: &HashMap<String, String>,
343) -> Result<HashMap<String, RunRecord>> {
344 let batches: Vec<RecordBatch> = dataset
345 .scan()
346 .try_into_stream()
347 .await
348 .map_err(|e| OmniError::Lance(e.to_string()))?
349 .try_collect()
350 .await
351 .map_err(|e| OmniError::Lance(e.to_string()))?;
352
353 let mut latest_by_id = HashMap::new();
354 for mut record in load_runs_from_batches(&batches)? {
355 record.actor_id = actor_by_run_id.get(record.run_id.as_str()).cloned();
356 merge_latest_run(&mut latest_by_id, record);
357 }
358 Ok(latest_by_id)
359}
360
361async fn load_run_actor_cache(dataset: &Dataset) -> Result<HashMap<String, String>> {
362 let batches: Vec<RecordBatch> = dataset
363 .scan()
364 .try_into_stream()
365 .await
366 .map_err(|e| OmniError::Lance(e.to_string()))?
367 .try_collect()
368 .await
369 .map_err(|e| OmniError::Lance(e.to_string()))?;
370
371 let mut actors = HashMap::new();
372 for batch in batches {
373 let run_ids = string_column(&batch, "run_id", "run actor registry")?;
374 let actor_ids = string_column(&batch, "actor_id", "run actor registry")?;
375 for row in 0..batch.num_rows() {
376 actors.insert(
377 run_ids.value(row).to_string(),
378 actor_ids.value(row).to_string(),
379 );
380 }
381 }
382 Ok(actors)
383}
384
385fn load_runs_from_batches(batches: &[RecordBatch]) -> Result<Vec<RunRecord>> {
386 let mut runs = Vec::new();
387 for batch in batches {
388 let run_ids = string_column(batch, "run_id", "run registry")?;
389 let target_branches = string_column(batch, "target_branch", "run registry")?;
390 let run_branches = string_column(batch, "run_branch", "run registry")?;
391 let base_snapshot_ids = string_column(batch, "base_snapshot_id", "run registry")?;
392 let base_manifest_versions = u64_column(batch, "base_manifest_version", "run registry")?;
393 let operation_hashes = string_column(batch, "operation_hash", "run registry")?;
394 let statuses = string_column(batch, "status", "run registry")?;
395 let published_snapshot_ids = string_column(batch, "published_snapshot_id", "run registry")?;
396 let created_ats = timestamp_micros_column(batch, "created_at", "run registry")?;
397 let updated_ats = timestamp_micros_column(batch, "updated_at", "run registry")?;
398
399 for row in 0..batch.num_rows() {
400 runs.push(RunRecord {
401 run_id: RunId::new(run_ids.value(row)),
402 target_branch: target_branches.value(row).to_string(),
403 run_branch: run_branches.value(row).to_string(),
404 base_snapshot_id: base_snapshot_ids.value(row).to_string(),
405 base_manifest_version: base_manifest_versions.value(row),
406 operation_hash: if operation_hashes.is_null(row) {
407 None
408 } else {
409 Some(operation_hashes.value(row).to_string())
410 },
411 actor_id: None,
412 status: RunStatus::parse(statuses.value(row))?,
413 published_snapshot_id: if published_snapshot_ids.is_null(row) {
414 None
415 } else {
416 Some(published_snapshot_ids.value(row).to_string())
417 },
418 created_at: created_ats.value(row),
419 updated_at: updated_ats.value(row),
420 });
421 }
422 }
423 Ok(runs)
424}
425
426fn merge_latest_run(latest_by_id: &mut HashMap<String, RunRecord>, record: RunRecord) {
427 match latest_by_id.get(record.run_id.as_str()) {
428 Some(existing)
429 if existing.updated_at > record.updated_at
430 || (existing.updated_at == record.updated_at
431 && existing.created_at >= record.created_at) => {}
432 _ => {
433 latest_by_id.insert(record.run_id.as_str().to_string(), record);
434 }
435 }
436}
437
438fn string_column<'a>(batch: &'a RecordBatch, name: &str, context: &str) -> Result<&'a StringArray> {
439 batch
440 .column_by_name(name)
441 .ok_or_else(|| {
442 OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
443 })?
444 .as_any()
445 .downcast_ref::<StringArray>()
446 .ok_or_else(|| {
447 OmniError::manifest_internal(format!("{context} column '{name}' is not Utf8"))
448 })
449}
450
451fn u64_column<'a>(batch: &'a RecordBatch, name: &str, context: &str) -> Result<&'a UInt64Array> {
452 batch
453 .column_by_name(name)
454 .ok_or_else(|| {
455 OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
456 })?
457 .as_any()
458 .downcast_ref::<UInt64Array>()
459 .ok_or_else(|| {
460 OmniError::manifest_internal(format!("{context} column '{name}' is not UInt64"))
461 })
462}
463
464fn timestamp_micros_column<'a>(
465 batch: &'a RecordBatch,
466 name: &str,
467 context: &str,
468) -> Result<&'a TimestampMicrosecondArray> {
469 batch
470 .column_by_name(name)
471 .ok_or_else(|| {
472 OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
473 })?
474 .as_any()
475 .downcast_ref::<TimestampMicrosecondArray>()
476 .ok_or_else(|| {
477 OmniError::manifest_internal(format!(
478 "{context} column '{name}' is not Timestamp(Microsecond)"
479 ))
480 })
481}
482
483fn runs_to_batch(records: &[RunRecord]) -> Result<RecordBatch> {
484 let run_ids: Vec<&str> = records
485 .iter()
486 .map(|record| record.run_id.as_str())
487 .collect();
488 let target_branches: Vec<&str> = records
489 .iter()
490 .map(|record| record.target_branch.as_str())
491 .collect();
492 let run_branches: Vec<&str> = records
493 .iter()
494 .map(|record| record.run_branch.as_str())
495 .collect();
496 let base_snapshot_ids: Vec<&str> = records
497 .iter()
498 .map(|record| record.base_snapshot_id.as_str())
499 .collect();
500 let base_manifest_versions: Vec<u64> = records
501 .iter()
502 .map(|record| record.base_manifest_version)
503 .collect();
504 let operation_hashes: Vec<Option<&str>> = records
505 .iter()
506 .map(|record| record.operation_hash.as_deref())
507 .collect();
508 let statuses: Vec<&str> = records
509 .iter()
510 .map(|record| record.status.as_str())
511 .collect();
512 let published_snapshot_ids: Vec<Option<&str>> = records
513 .iter()
514 .map(|record| record.published_snapshot_id.as_deref())
515 .collect();
516 let created_ats: Vec<i64> = records.iter().map(|record| record.created_at).collect();
517 let updated_ats: Vec<i64> = records.iter().map(|record| record.updated_at).collect();
518
519 RecordBatch::try_new(
520 run_registry_schema(),
521 vec![
522 Arc::new(StringArray::from(run_ids)),
523 Arc::new(StringArray::from(target_branches)),
524 Arc::new(StringArray::from(run_branches)),
525 Arc::new(StringArray::from(base_snapshot_ids)),
526 Arc::new(UInt64Array::from(base_manifest_versions)),
527 Arc::new(StringArray::from(operation_hashes)),
528 Arc::new(StringArray::from(statuses)),
529 Arc::new(StringArray::from(published_snapshot_ids)),
530 Arc::new(TimestampMicrosecondArray::from(created_ats)),
531 Arc::new(TimestampMicrosecondArray::from(updated_ats)),
532 ],
533 )
534 .map_err(|e| OmniError::Lance(e.to_string()))
535}
536
537#[derive(Debug, Clone, PartialEq, Eq)]
538struct RunActorRecord {
539 run_id: String,
540 actor_id: String,
541 created_at: i64,
542}
543
544fn run_actors_to_batch(records: &[RunActorRecord]) -> Result<RecordBatch> {
545 let run_ids: Vec<&str> = records
546 .iter()
547 .map(|record| record.run_id.as_str())
548 .collect();
549 let actor_ids: Vec<&str> = records
550 .iter()
551 .map(|record| record.actor_id.as_str())
552 .collect();
553 let created_ats: Vec<i64> = records.iter().map(|record| record.created_at).collect();
554
555 RecordBatch::try_new(
556 run_actor_schema(),
557 vec![
558 Arc::new(StringArray::from(run_ids)),
559 Arc::new(StringArray::from(actor_ids)),
560 Arc::new(TimestampMicrosecondArray::from(created_ats)),
561 ],
562 )
563 .map_err(|e| OmniError::Lance(e.to_string()))
564}
565
566fn now_micros() -> Result<i64> {
567 let duration = SystemTime::now()
568 .duration_since(UNIX_EPOCH)
569 .map_err(|e| OmniError::manifest(format!("system clock error: {}", e)))?;
570 Ok(duration.as_micros() as i64)
571}
572
573#[cfg(test)]
574mod tests {
575 use std::sync::Arc;
576
577 use arrow_schema::{DataType, Field, Schema};
578
579 use super::*;
580
581 #[test]
582 fn load_runs_from_batches_returns_error_for_bad_schema() {
583 let batch = RecordBatch::try_new(
584 Arc::new(Schema::new(vec![
585 Field::new("run_id", DataType::UInt64, false),
586 Field::new("target_branch", DataType::Utf8, false),
587 Field::new("run_branch", DataType::Utf8, false),
588 Field::new("base_snapshot_id", DataType::Utf8, false),
589 Field::new("base_manifest_version", DataType::UInt64, false),
590 Field::new("operation_hash", DataType::Utf8, true),
591 Field::new("status", DataType::Utf8, false),
592 Field::new("published_snapshot_id", DataType::Utf8, true),
593 Field::new(
594 "created_at",
595 DataType::Timestamp(TimeUnit::Microsecond, None),
596 false,
597 ),
598 Field::new(
599 "updated_at",
600 DataType::Timestamp(TimeUnit::Microsecond, None),
601 false,
602 ),
603 ])),
604 vec![
605 Arc::new(UInt64Array::from(vec![1_u64])),
606 Arc::new(StringArray::from(vec!["main"])),
607 Arc::new(StringArray::from(vec!["__run__1"])),
608 Arc::new(StringArray::from(vec!["snap-1"])),
609 Arc::new(UInt64Array::from(vec![1_u64])),
610 Arc::new(StringArray::from(vec![None::<&str>])),
611 Arc::new(StringArray::from(vec!["running"])),
612 Arc::new(StringArray::from(vec![None::<&str>])),
613 Arc::new(TimestampMicrosecondArray::from(vec![1_i64])),
614 Arc::new(TimestampMicrosecondArray::from(vec![1_i64])),
615 ],
616 )
617 .unwrap();
618
619 let err = load_runs_from_batches(&[batch]).unwrap_err();
620 assert!(err.to_string().contains("run_id"));
621 }
622}