use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use arrow_array::{
Array, RecordBatch, RecordBatchIterator, StringArray, TimestampMicrosecondArray,
};
use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
use futures::TryStreamExt;
use lance::Dataset;
use lance::dataset::{WriteMode, WriteParams};
use lance_file::version::LanceFileVersion;
use serde::{Deserialize, Serialize};
use crate::error::{OmniError, Result};
const RECOVERIES_DIR: &str = "_graph_commit_recoveries.lance";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub(crate) enum RecoveryKind {
RolledForward,
RolledBack,
}
impl RecoveryKind {
fn as_str(self) -> &'static str {
match self {
RecoveryKind::RolledForward => "RolledForward",
RecoveryKind::RolledBack => "RolledBack",
}
}
fn parse(s: &str) -> Result<Self> {
match s {
"RolledForward" => Ok(RecoveryKind::RolledForward),
"RolledBack" => Ok(RecoveryKind::RolledBack),
other => Err(OmniError::manifest_internal(format!(
"unknown recovery_kind '{}' in _graph_commit_recoveries.lance",
other
))),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub(crate) struct TableOutcome {
pub table_key: String,
pub from_version: u64,
pub to_version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct RecoveryAuditRecord {
pub graph_commit_id: String,
pub recovery_kind: RecoveryKind,
pub recovery_for_actor: Option<String>,
pub operation_id: String,
pub sidecar_writer_kind: String,
pub per_table_outcomes: Vec<TableOutcome>,
pub created_at: i64,
}
pub(crate) struct RecoveryAudit {
root_uri: String,
dataset: Option<Dataset>,
}
impl RecoveryAudit {
pub(crate) async fn open(root_uri: &str) -> Result<Self> {
let root = root_uri.trim_end_matches('/').to_string();
let dataset = Dataset::open(&recoveries_uri(&root)).await.ok();
Ok(Self {
root_uri: root,
dataset,
})
}
pub(crate) async fn append(&mut self, record: RecoveryAuditRecord) -> Result<()> {
let batch = recovery_record_to_batch(&record)?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], recoveries_schema());
let mut dataset = match self.dataset.take() {
Some(dataset) => dataset,
None => create_recoveries_dataset(&self.root_uri).await?,
};
dataset
.append(reader, None)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.dataset = Some(dataset);
Ok(())
}
pub(crate) async fn list(&self) -> Result<Vec<RecoveryAuditRecord>> {
let dataset = match &self.dataset {
Some(dataset) => dataset,
None => return Ok(Vec::new()),
};
let batches: Vec<RecordBatch> = dataset
.scan()
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut out = Vec::new();
for batch in batches {
for row in 0..batch.num_rows() {
out.push(decode_row(&batch, row)?);
}
}
out.sort_by_key(|r| r.created_at);
Ok(out)
}
}
fn recoveries_uri(root_uri: &str) -> String {
format!("{}/{}", root_uri.trim_end_matches('/'), RECOVERIES_DIR)
}
fn recoveries_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("graph_commit_id", DataType::Utf8, false),
Field::new("recovery_kind", DataType::Utf8, false),
Field::new("recovery_for_actor", DataType::Utf8, true),
Field::new("operation_id", DataType::Utf8, false),
Field::new("sidecar_writer_kind", DataType::Utf8, false),
Field::new("per_table_outcomes_json", DataType::Utf8, false),
Field::new(
"created_at",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
]))
}
async fn create_recoveries_dataset(root_uri: &str) -> Result<Dataset> {
let uri = recoveries_uri(root_uri);
let batch = RecordBatch::new_empty(recoveries_schema());
let reader = RecordBatchIterator::new(vec![Ok(batch)], recoveries_schema());
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
..Default::default()
};
match Dataset::write(reader, &uri as &str, Some(params)).await {
Ok(dataset) => Ok(dataset),
Err(err) if err.to_string().contains("Dataset already exists") => Dataset::open(&uri)
.await
.map_err(|open_err| OmniError::Lance(open_err.to_string())),
Err(err) => Err(OmniError::Lance(err.to_string())),
}
}
fn recovery_record_to_batch(record: &RecoveryAuditRecord) -> Result<RecordBatch> {
let outcomes_json = serde_json::to_string(&record.per_table_outcomes).map_err(|e| {
OmniError::manifest_internal(format!(
"failed to serialize per_table_outcomes for recovery audit: {}",
e
))
})?;
RecordBatch::try_new(
recoveries_schema(),
vec![
Arc::new(StringArray::from(vec![record.graph_commit_id.clone()])),
Arc::new(StringArray::from(vec![record.recovery_kind.as_str()])),
Arc::new(StringArray::from(vec![record
.recovery_for_actor
.clone()])),
Arc::new(StringArray::from(vec![record.operation_id.clone()])),
Arc::new(StringArray::from(vec![record.sidecar_writer_kind.clone()])),
Arc::new(StringArray::from(vec![outcomes_json])),
Arc::new(TimestampMicrosecondArray::from(vec![record.created_at])),
],
)
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn decode_row(batch: &RecordBatch, row: usize) -> Result<RecoveryAuditRecord> {
let str_col = |name: &str| -> Result<&StringArray> {
batch
.column_by_name(name)
.ok_or_else(|| OmniError::manifest_internal(format!("missing column '{}' in recovery audit", name)))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest_internal(format!("column '{}' has wrong type", name)))
};
let ts_col = batch
.column_by_name("created_at")
.ok_or_else(|| OmniError::manifest_internal("missing 'created_at' column".to_string()))?
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.ok_or_else(|| {
OmniError::manifest_internal("'created_at' column has wrong type".to_string())
})?;
let graph_commit_ids = str_col("graph_commit_id")?;
let kinds = str_col("recovery_kind")?;
let for_actors = str_col("recovery_for_actor")?;
let op_ids = str_col("operation_id")?;
let writers = str_col("sidecar_writer_kind")?;
let outcomes_json = str_col("per_table_outcomes_json")?;
let outcomes: Vec<TableOutcome> =
serde_json::from_str(outcomes_json.value(row)).map_err(|e| {
OmniError::manifest_internal(format!(
"failed to deserialize per_table_outcomes_json from recovery audit: {}",
e
))
})?;
Ok(RecoveryAuditRecord {
graph_commit_id: graph_commit_ids.value(row).to_string(),
recovery_kind: RecoveryKind::parse(kinds.value(row))?,
recovery_for_actor: if for_actors.is_null(row) {
None
} else {
Some(for_actors.value(row).to_string())
},
operation_id: op_ids.value(row).to_string(),
sidecar_writer_kind: writers.value(row).to_string(),
per_table_outcomes: outcomes,
created_at: ts_col.value(row),
})
}
pub(crate) fn now_micros() -> Result<i64> {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_micros() as i64)
.map_err(|e| {
OmniError::manifest_internal(format!("system clock before unix epoch: {}", e))
})
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_record() -> RecoveryAuditRecord {
RecoveryAuditRecord {
graph_commit_id: "01H000000000000000000000XX".to_string(),
recovery_kind: RecoveryKind::RolledForward,
recovery_for_actor: Some("act-alice".to_string()),
operation_id: "01H000000000000000000000OP".to_string(),
sidecar_writer_kind: "Mutation".to_string(),
per_table_outcomes: vec![
TableOutcome {
table_key: "node:Person".to_string(),
from_version: 5,
to_version: 6,
},
TableOutcome {
table_key: "edge:Knows".to_string(),
from_version: 12,
to_version: 13,
},
],
created_at: 1_700_000_000_000_000,
}
}
#[tokio::test]
async fn recovery_audit_round_trips_through_lance() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path().to_str().unwrap();
let mut audit = RecoveryAudit::open(root).await.unwrap();
assert!(audit.list().await.unwrap().is_empty());
let record = sample_record();
audit.append(record.clone()).await.unwrap();
let listed = audit.list().await.unwrap();
assert_eq!(listed.len(), 1);
assert_eq!(listed[0], record);
let mut second = sample_record();
second.graph_commit_id = "01H000000000000000000000YY".to_string();
second.recovery_kind = RecoveryKind::RolledBack;
second.recovery_for_actor = None;
second.created_at = record.created_at + 1;
audit.append(second.clone()).await.unwrap();
let listed = audit.list().await.unwrap();
assert_eq!(listed.len(), 2);
assert_eq!(listed[0], record);
assert_eq!(listed[1], second);
}
#[tokio::test]
async fn recovery_audit_persists_across_open_cycles() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path().to_str().unwrap();
{
let mut audit = RecoveryAudit::open(root).await.unwrap();
audit.append(sample_record()).await.unwrap();
}
let audit = RecoveryAudit::open(root).await.unwrap();
let listed = audit.list().await.unwrap();
assert_eq!(listed.len(), 1);
assert_eq!(listed[0], sample_record());
}
#[test]
fn recovery_kind_round_trips_through_string() {
assert_eq!(
RecoveryKind::parse("RolledForward").unwrap(),
RecoveryKind::RolledForward,
);
assert_eq!(
RecoveryKind::parse("RolledBack").unwrap(),
RecoveryKind::RolledBack,
);
assert!(RecoveryKind::parse("Garbage").is_err());
}
}