use super::{AprV2Reader, AprV2Writer, V2FormatError};
#[derive(Debug, Clone, Default)]
pub struct ProvenancePatch {
pub license: Option<String>,
pub data_source: Option<String>,
pub data_license: Option<String>,
}
impl ProvenancePatch {
#[must_use]
pub fn has_any(&self) -> bool {
self.license.is_some() || self.data_source.is_some() || self.data_license.is_some()
}
}
pub fn stamp_provenance_bytes(
input: &[u8],
patch: &ProvenancePatch,
) -> Result<Vec<u8>, V2FormatError> {
if !patch.has_any() {
return Err(V2FormatError::InvalidHeader(
"stamp_provenance_bytes: patch has no fields set — \
refusing to rewrite without changes"
.to_string(),
));
}
let reader = AprV2Reader::from_bytes(input)?;
let original_flags = reader.header().flags;
let mut new_metadata = reader.metadata().clone();
if let Some(ref lic) = patch.license {
new_metadata.license = Some(lic.clone());
}
if let Some(ref ds) = patch.data_source {
new_metadata.data_source = Some(ds.clone());
}
if let Some(ref dl) = patch.data_license {
new_metadata.data_license = Some(dl.clone());
}
let mut writer = AprV2Writer::new(new_metadata);
writer.set_header_flags(original_flags);
for name in reader.tensor_names() {
let entry = reader
.get_tensor(name)
.ok_or_else(|| V2FormatError::InvalidHeader(format!("tensor {name} vanished")))?;
let data = reader
.get_tensor_data(name)
.ok_or_else(|| V2FormatError::InvalidHeader(format!("tensor {name} has no data")))?;
writer.add_tensor(
name.to_string(),
entry.dtype,
entry.shape.clone(),
data.to_vec(),
);
}
writer.write()
}
#[cfg(test)]
mod tests {
use super::super::{AprV2Flags, AprV2Metadata, TensorDType};
use super::*;
fn minimal_apr_with_flags(flags: u16) -> Vec<u8> {
let metadata = AprV2Metadata::new("stamp-test");
let mut writer = AprV2Writer::new(metadata);
writer.set_header_flags(AprV2Flags::from_bits(flags));
writer.add_tensor(
"weight",
TensorDType::F32,
vec![2, 3],
vec![0u8; 24], );
writer.write().expect("write test apr")
}
#[test]
fn stamp_populates_all_three_fields_when_source_is_unpopulated() {
let input = minimal_apr_with_flags(0);
let patch = ProvenancePatch {
license: Some("Apache-2.0".into()),
data_source: Some("huggingface.co/Qwen/Qwen2.5-Coder-7B-Instruct".into()),
data_license: Some("Qwen-License-Agreement-v1".into()),
};
let output = stamp_provenance_bytes(&input, &patch).expect("stamp must succeed");
let reader = AprV2Reader::from_bytes(&output).expect("stamped buffer must parse");
let md = reader.metadata();
assert_eq!(md.license.as_deref(), Some("Apache-2.0"));
assert_eq!(
md.data_source.as_deref(),
Some("huggingface.co/Qwen/Qwen2.5-Coder-7B-Instruct")
);
assert_eq!(
md.data_license.as_deref(),
Some("Qwen-License-Agreement-v1")
);
}
#[test]
fn stamp_preserves_tensor_data_byte_for_byte() {
let input = minimal_apr_with_flags(0);
let input_reader = AprV2Reader::from_bytes(&input).unwrap();
let original_bytes: Vec<u8> = input_reader
.get_tensor_data("weight")
.expect("input has weight")
.to_vec();
let patch = ProvenancePatch {
license: Some("MIT".into()),
..Default::default()
};
let output = stamp_provenance_bytes(&input, &patch).unwrap();
let out_reader = AprV2Reader::from_bytes(&output).unwrap();
let round_tripped = out_reader
.get_tensor_data("weight")
.expect("output has weight");
assert_eq!(
original_bytes.as_slice(),
round_tripped,
"tensor bytes must survive stamp verbatim"
);
}
#[test]
fn stamp_preserves_header_flags() {
let flags = AprV2Flags::QUANTIZED | AprV2Flags::HAS_VOCAB;
let input = minimal_apr_with_flags(flags);
let in_reader = AprV2Reader::from_bytes(&input).unwrap();
assert!(in_reader.header().flags.contains(AprV2Flags::QUANTIZED));
assert!(in_reader.header().flags.contains(AprV2Flags::HAS_VOCAB));
let patch = ProvenancePatch {
license: Some("Apache-2.0".into()),
..Default::default()
};
let output = stamp_provenance_bytes(&input, &patch).unwrap();
let out_reader = AprV2Reader::from_bytes(&output).unwrap();
assert!(
out_reader.header().flags.contains(AprV2Flags::QUANTIZED),
"QUANTIZED flag must survive stamp"
);
assert!(
out_reader.header().flags.contains(AprV2Flags::HAS_VOCAB),
"HAS_VOCAB flag must survive stamp"
);
assert!(
out_reader
.header()
.flags
.contains(AprV2Flags::LAYOUT_ROW_MAJOR),
"LAYOUT_ROW_MAJOR must always be set"
);
}
#[test]
fn stamp_rejects_empty_patch() {
let input = minimal_apr_with_flags(0);
let empty = ProvenancePatch::default();
let err = stamp_provenance_bytes(&input, &empty).unwrap_err();
let msg = format!("{err:?}");
assert!(
msg.contains("patch has no fields"),
"empty-patch error must be explicit: {msg}"
);
}
#[test]
fn stamp_allows_partial_patch_leaving_other_fields_unchanged() {
let mut md = AprV2Metadata::new("partial-test");
md.license = Some("Apache-2.0".into());
let mut writer = AprV2Writer::new(md);
writer.add_tensor("w", TensorDType::F32, vec![4], vec![0u8; 16]);
let input = writer.write().unwrap();
let patch = ProvenancePatch {
data_source: Some("teacher-only".into()),
..Default::default()
};
let output = stamp_provenance_bytes(&input, &patch).unwrap();
let out_reader = AprV2Reader::from_bytes(&output).unwrap();
assert_eq!(
out_reader.metadata().license.as_deref(),
Some("Apache-2.0"),
"unchanged license must survive"
);
assert_eq!(
out_reader.metadata().data_source.as_deref(),
Some("teacher-only"),
"patched data_source must land"
);
assert!(
out_reader.metadata().data_license.is_none(),
"untouched data_license must remain None"
);
}
#[test]
fn stamp_is_idempotent_under_identical_patch() {
let input = minimal_apr_with_flags(0);
let patch = ProvenancePatch {
license: Some("Apache-2.0".into()),
data_source: Some("teacher-only".into()),
data_license: Some("Apache-2.0".into()),
};
let first = stamp_provenance_bytes(&input, &patch).unwrap();
let second = stamp_provenance_bytes(&first, &patch).unwrap();
assert_eq!(
first, second,
"applying the same patch twice must be byte-identical (idempotent)"
);
}
}