use crate::error::{TrazaeoError, TrazaeoResult};
use crate::utils::Hash;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ProvenanceStartMode {
SourceCapture,
TransportCapture,
DatasetBootstrap,
DatasetIncremental,
}
impl ProvenanceStartMode {
pub fn as_str(self) -> &'static str {
match self {
Self::SourceCapture => "source_capture",
Self::TransportCapture => "transport_capture",
Self::DatasetBootstrap => "dataset_bootstrap",
Self::DatasetIncremental => "dataset_incremental",
}
}
pub fn parse(value: &str) -> Option<Self> {
match value {
"source_capture" => Some(Self::SourceCapture),
"transport_capture" => Some(Self::TransportCapture),
"dataset_bootstrap" => Some(Self::DatasetBootstrap),
"dataset_incremental" => Some(Self::DatasetIncremental),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SourceFileEntry {
pub source_uri: String,
pub content_hash: String,
pub byte_length: u64,
pub observed_mtime: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SourceManifest {
pub manifest_id: String,
pub manifest_created_at: String,
pub source_dataset_id: String,
pub source_files: Vec<SourceFileEntry>,
pub source_file_count: usize,
pub source_root_hash: String,
}
const SOURCE_MANIFEST_UNSAFE_CHARS: [char; 3] = ['|', '\n', '\r'];
fn is_source_manifest_field_safe(value: &str) -> bool {
!value
.chars()
.any(|ch| SOURCE_MANIFEST_UNSAFE_CHARS.contains(&ch) || ch.is_control())
}
fn validate_source_manifest_field(
entry_index: usize,
field_name: &'static str,
value: &str,
) -> TrazaeoResult<()> {
if is_source_manifest_field_safe(value) {
return Ok(());
}
Err(TrazaeoError::invalid_input(
"validate source manifest",
format!(
"source_files[{entry_index}].{field_name} contains unsupported delimiter or control character"
),
))
}
pub fn validate_source_files_for_canonicalization(
source_files: &[SourceFileEntry],
) -> TrazaeoResult<()> {
for (entry_index, entry) in source_files.iter().enumerate() {
validate_source_manifest_field(entry_index, "source_uri", &entry.source_uri)?;
validate_source_manifest_field(entry_index, "content_hash", &entry.content_hash)?;
if let Some(observed_mtime) = &entry.observed_mtime {
if observed_mtime.is_empty() {
return Err(TrazaeoError::invalid_input(
"validate source manifest",
format!("source_files[{entry_index}].observed_mtime must not be empty"),
));
}
validate_source_manifest_field(entry_index, "observed_mtime", observed_mtime)?;
}
}
Ok(())
}
pub fn canonical_source_manifest_bytes(source_files: &[SourceFileEntry]) -> Vec<u8> {
let mut normalized = source_files.to_vec();
normalized.sort_by(|a, b| {
a.source_uri
.cmp(&b.source_uri)
.then(a.content_hash.cmp(&b.content_hash))
.then(a.byte_length.cmp(&b.byte_length))
.then(a.observed_mtime.cmp(&b.observed_mtime))
});
let rows: Vec<String> = normalized
.iter()
.map(|entry| {
format!(
"{}|{}|{}|{}",
entry.source_uri,
entry.content_hash,
entry.byte_length,
entry.observed_mtime.clone().unwrap_or_default()
)
})
.collect();
rows.join("\n").into_bytes()
}
pub fn compute_source_root_hash(source_files: &[SourceFileEntry]) -> Hash {
let canonical = canonical_source_manifest_bytes(source_files);
let h = blake3::hash(&canonical);
Hash(*h.as_bytes())
}
pub fn compute_source_root_hash_checked(source_files: &[SourceFileEntry]) -> TrazaeoResult<Hash> {
validate_source_files_for_canonicalization(source_files)?;
Ok(compute_source_root_hash(source_files))
}
pub fn validate_source_manifest(manifest: &SourceManifest) -> TrazaeoResult<()> {
if manifest.source_file_count != manifest.source_files.len() {
return Err(TrazaeoError::invalid_input(
"validate source manifest",
"source_file_count does not match source_files length",
));
}
validate_source_files_for_canonicalization(&manifest.source_files)?;
let expected_root = hex::encode(compute_source_root_hash(&manifest.source_files).0);
if expected_root != manifest.source_root_hash {
return Err(TrazaeoError::invalid_input(
"validate source manifest",
"source_root_hash does not match canonical source_files",
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn files_variant_a() -> Vec<SourceFileEntry> {
vec![
SourceFileEntry {
source_uri: "s3://bucket/b.nc".to_string(),
content_hash: "h2".to_string(),
byte_length: 2,
observed_mtime: Some("2026-01-02T00:00:00Z".to_string()),
},
SourceFileEntry {
source_uri: "s3://bucket/a.nc".to_string(),
content_hash: "h1".to_string(),
byte_length: 1,
observed_mtime: Some("2026-01-01T00:00:00Z".to_string()),
},
]
}
#[test]
fn canonical_source_manifest_is_order_stable() {
let mut b = files_variant_a();
b.reverse();
assert_eq!(
canonical_source_manifest_bytes(&files_variant_a()),
canonical_source_manifest_bytes(&b)
);
}
#[test]
fn source_root_hash_is_order_stable() {
let mut b = files_variant_a();
b.reverse();
assert_eq!(
compute_source_root_hash(&files_variant_a()),
compute_source_root_hash(&b)
);
}
#[test]
fn canonical_source_manifest_preserves_legacy_safe_encoding() {
assert_eq!(
canonical_source_manifest_bytes(&files_variant_a()),
b"s3://bucket/a.nc|h1|1|2026-01-01T00:00:00Z\ns3://bucket/b.nc|h2|2|2026-01-02T00:00:00Z"
.to_vec()
);
}
#[test]
fn source_manifest_validation_rejects_ambiguous_canonical_fields() {
let source_files = vec![SourceFileEntry {
source_uri: "s3://bucket/a.nc".to_string(),
content_hash: "h1".to_string(),
byte_length: 1,
observed_mtime: Some(
"2026-01-01T00:00:00Z\ns3://bucket/b.nc|h2|2|2026-01-02T00:00:00Z".to_string(),
),
}];
let manifest = SourceManifest {
manifest_id: "manifest-1".to_string(),
manifest_created_at: "2026-01-01T00:00:00Z".to_string(),
source_dataset_id: "dataset-1".to_string(),
source_file_count: source_files.len(),
source_root_hash: hex::encode(compute_source_root_hash(&source_files).0),
source_files,
};
let err = validate_source_manifest(&manifest).expect_err("ambiguous manifest must fail");
assert!(err.to_string().contains("unsupported delimiter"));
}
#[test]
fn checked_source_root_hash_rejects_ambiguous_canonical_fields() {
let source_files = vec![SourceFileEntry {
source_uri: "s3://bucket/a.nc|extra".to_string(),
content_hash: "h1".to_string(),
byte_length: 1,
observed_mtime: None,
}];
let err = compute_source_root_hash_checked(&source_files)
.expect_err("ambiguous source files must fail");
assert!(err.to_string().contains("unsupported delimiter"));
}
#[test]
fn checked_source_root_hash_rejects_empty_observed_mtime() {
let source_files = vec![SourceFileEntry {
source_uri: "s3://bucket/a.nc".to_string(),
content_hash: "h1".to_string(),
byte_length: 1,
observed_mtime: Some(String::new()),
}];
let err =
compute_source_root_hash_checked(&source_files).expect_err("empty mtime must fail");
assert!(err.to_string().contains("observed_mtime must not be empty"));
}
#[test]
fn source_manifest_validation_rejects_empty_observed_mtime() {
let source_files = vec![SourceFileEntry {
source_uri: "s3://bucket/a.nc".to_string(),
content_hash: "h1".to_string(),
byte_length: 1,
observed_mtime: Some(String::new()),
}];
let manifest = SourceManifest {
manifest_id: "manifest-1".to_string(),
manifest_created_at: "2026-01-01T00:00:00Z".to_string(),
source_dataset_id: "dataset-1".to_string(),
source_file_count: source_files.len(),
source_root_hash: hex::encode(compute_source_root_hash(&source_files).0),
source_files,
};
let err = validate_source_manifest(&manifest).expect_err("empty mtime must fail");
assert!(err.to_string().contains("observed_mtime must not be empty"));
}
}