use std::sync::LazyLock;
use serde::{Deserialize, Serialize};
use crate::actions::DomainMetadata;
use crate::engine_data::{GetData, RowVisitor, TypedGetData as _};
use crate::schema::{ColumnName, ColumnNamesAndTypes, DataType};
use crate::utils::require;
use crate::{DeltaResult, Engine, Error, Snapshot};
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct RowTrackingDomainMetadata {
row_id_high_water_mark: i64,
}
pub(crate) const ROW_TRACKING_DOMAIN_NAME: &str = "delta.rowTracking";
impl RowTrackingDomainMetadata {
pub(crate) fn new(row_id_high_water_mark: i64) -> Self {
RowTrackingDomainMetadata {
row_id_high_water_mark,
}
}
pub(crate) fn get_high_water_mark(
snapshot: &Snapshot,
engine: &dyn Engine,
) -> DeltaResult<Option<i64>> {
Ok(snapshot
.get_domain_metadata_internal(ROW_TRACKING_DOMAIN_NAME, engine)?
.map(|config| serde_json::from_str::<Self>(&config))
.transpose()?
.map(|metadata| metadata.row_id_high_water_mark))
}
}
impl TryFrom<RowTrackingDomainMetadata> for DomainMetadata {
type Error = crate::Error;
fn try_from(metadata: RowTrackingDomainMetadata) -> DeltaResult<Self> {
Ok(DomainMetadata::new(
ROW_TRACKING_DOMAIN_NAME.to_string(),
serde_json::to_string(&metadata)?,
))
}
}
pub(crate) struct RowTrackingVisitor {
pub(crate) row_id_high_water_mark: i64,
pub(crate) base_row_id_batches: Vec<Vec<i64>>,
}
impl RowTrackingVisitor {
const DEFAULT_HIGH_WATER_MARK: i64 = -1;
pub(crate) fn new(row_id_high_water_mark: Option<i64>, num_batches: Option<usize>) -> Self {
Self {
row_id_high_water_mark: row_id_high_water_mark.unwrap_or(Self::DEFAULT_HIGH_WATER_MARK),
base_row_id_batches: Vec::with_capacity(num_batches.unwrap_or(0)),
}
}
}
impl RowVisitor for RowTrackingVisitor {
fn selected_column_names_and_types(&self) -> (&'static [ColumnName], &'static [DataType]) {
static NAMES_AND_TYPES: LazyLock<ColumnNamesAndTypes> = LazyLock::new(|| {
(
vec![ColumnName::new(["stats", "numRecords"])],
vec![DataType::LONG],
)
.into()
});
NAMES_AND_TYPES.as_ref()
}
fn visit<'a>(&mut self, row_count: usize, getters: &[&'a dyn GetData<'a>]) -> DeltaResult<()> {
require!(
getters.len() == 1,
Error::generic(format!(
"Wrong number of RowTrackingVisitor getters: {}",
getters.len()
))
);
let mut batch_base_row_ids = Vec::with_capacity(row_count);
let mut current_hwm = self.row_id_high_water_mark;
for i in 0..row_count {
let num_records: i64 = getters[0].get_opt(i, "numRecords")?.ok_or_else(|| {
Error::InternalError(
"numRecords must be present in Add actions when row tracking is enabled."
.to_string(),
)
})?;
batch_base_row_ids.push(current_hwm + 1);
current_hwm += num_records;
}
self.base_row_id_batches.push(batch_base_row_ids);
self.row_id_high_water_mark = current_hwm;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine_data::GetData;
use crate::utils::test_utils::assert_result_error_with_message;
struct MockGetData {
num_records_values: Vec<Option<i64>>,
}
impl MockGetData {
fn new(num_records_values: Vec<Option<i64>>) -> Self {
Self { num_records_values }
}
}
impl<'a> GetData<'a> for MockGetData {
fn get_long(&'a self, row_index: usize, field_name: &str) -> DeltaResult<Option<i64>> {
if field_name == "numRecords" {
Ok(self.num_records_values.get(row_index).copied().flatten())
} else {
Ok(None)
}
}
}
fn create_getters(num_records_mock: &MockGetData) -> Vec<&dyn GetData<'_>> {
vec![num_records_mock]
}
#[test]
fn test_visit_basic_functionality() -> DeltaResult<()> {
let mut visitor = RowTrackingVisitor::new(None, Some(1));
let num_records_mock = MockGetData::new(vec![Some(10), Some(5), Some(20)]);
let getters = create_getters(&num_records_mock);
visitor.visit(3, &getters)?;
assert_eq!(visitor.base_row_id_batches.len(), 1);
assert_eq!(visitor.base_row_id_batches[0], vec![0, 10, 15]);
assert_eq!(visitor.row_id_high_water_mark, 34);
Ok(())
}
#[test]
fn test_visit_with_negative_high_water_mark() -> DeltaResult<()> {
let mut visitor = RowTrackingVisitor::new(Some(-5), Some(1));
let num_records_mock = MockGetData::new(vec![Some(3), Some(2)]);
let getters = create_getters(&num_records_mock);
visitor.visit(2, &getters)?;
assert_eq!(visitor.base_row_id_batches.len(), 1);
assert_eq!(visitor.base_row_id_batches[0], vec![-4, -1]);
assert_eq!(visitor.row_id_high_water_mark, 0);
Ok(())
}
#[test]
fn test_visit_with_zero_records() -> DeltaResult<()> {
let mut visitor = RowTrackingVisitor::new(Some(10), Some(1));
let num_records_mock = MockGetData::new(vec![Some(0), Some(0), Some(5)]);
let getters = create_getters(&num_records_mock);
visitor.visit(3, &getters)?;
assert_eq!(visitor.base_row_id_batches.len(), 1);
assert_eq!(visitor.base_row_id_batches[0], vec![11, 11, 11]);
assert_eq!(visitor.row_id_high_water_mark, 15);
Ok(())
}
#[test]
fn test_visit_empty_batch() -> DeltaResult<()> {
let mut visitor = RowTrackingVisitor::new(Some(42), None);
let num_records_mock = MockGetData::new(vec![]);
let getters = create_getters(&num_records_mock);
visitor.visit(0, &getters)?;
assert_eq!(visitor.base_row_id_batches.len(), 1);
assert!(visitor.base_row_id_batches[0].is_empty());
assert_eq!(visitor.row_id_high_water_mark, 42);
Ok(())
}
#[test]
fn test_visit_multiple_batches() -> DeltaResult<()> {
let mut visitor = RowTrackingVisitor::new(Some(0), Some(2));
let num_records_mock1 = MockGetData::new(vec![Some(10), Some(5)]);
let getters1 = create_getters(&num_records_mock1);
visitor.visit(2, &getters1)?;
let num_records_mock2 = MockGetData::new(vec![Some(3), Some(7), Some(2)]);
let getters2 = create_getters(&num_records_mock2);
visitor.visit(3, &getters2)?;
assert_eq!(visitor.base_row_id_batches.len(), 2);
assert_eq!(visitor.base_row_id_batches[0], vec![1, 11]);
assert_eq!(visitor.base_row_id_batches[1], vec![16, 19, 26]);
assert_eq!(visitor.row_id_high_water_mark, 27);
Ok(())
}
#[test]
fn test_visit_wrong_getter_count() -> DeltaResult<()> {
let mut visitor = RowTrackingVisitor::new(Some(0), None);
let wrong_getters: Vec<&dyn GetData<'_>> = vec![];
let result = visitor.visit(1, &wrong_getters);
assert_result_error_with_message(result, "Wrong number of RowTrackingVisitor getters");
Ok(())
}
#[test]
fn test_visit_missing_num_records() -> DeltaResult<()> {
let mut visitor = RowTrackingVisitor::new(Some(0), None);
let num_records_mock = MockGetData::new(vec![None]); let getters = create_getters(&num_records_mock);
let result = visitor.visit(1, &getters);
assert_result_error_with_message(
result,
"numRecords must be present in Add actions when row tracking is enabled",
);
Ok(())
}
#[test]
fn test_selected_column_names_and_types() {
let visitor = RowTrackingVisitor::new(Some(0), None);
let (names, types) = visitor.selected_column_names_and_types();
assert_eq!(names, (vec![ColumnName::new(["stats", "numRecords"])]));
assert_eq!(types, vec![DataType::LONG]);
}
#[test]
fn test_serialization_roundtrip() -> DeltaResult<()> {
let original = RowTrackingDomainMetadata::new(-42);
let json = serde_json::to_string(&original)?;
let deserialized: RowTrackingDomainMetadata = serde_json::from_str(&json)?;
assert_eq!(
original.row_id_high_water_mark,
deserialized.row_id_high_water_mark
);
Ok(())
}
}