#[test]
#[cfg(feature = "format-encryption")]
fn test_encrypted_with_compression() {
use tempfile::tempdir;
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct LargeModel {
data: Vec<f32>,
}
let model = LargeModel {
data: (0..1000).map(|i| (i % 10) as f32).collect(),
};
let password = "compress_and_encrypt";
let dir = tempdir().expect("create temp dir");
let path = dir.path().join("encrypted_compressed.apr");
save_encrypted(
&model,
ModelType::Custom,
&path,
SaveOptions::default(), password,
)
.expect("save_encrypted should succeed");
let loaded: LargeModel =
load_encrypted(&path, ModelType::Custom, password).expect("load_encrypted should succeed");
assert_eq!(loaded, model);
}
#[test]
#[cfg(feature = "format-encryption")]
fn test_x25519_recipient_roundtrip() {
use tempfile::tempdir;
use x25519_dalek::{PublicKey, StaticSecret};
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestModel {
weights: Vec<f32>,
bias: f32,
}
let model = TestModel {
weights: vec![1.0, 2.0, 3.0, 4.0, 5.0],
bias: 0.5,
};
let recipient_secret = StaticSecret::random_from_rng(aes_gcm::aead::OsRng);
let recipient_public = PublicKey::from(&recipient_secret);
let dir = tempdir().expect("create temp dir");
let path = dir.path().join("recipient_encrypted.apr");
save_for_recipient(
&model,
ModelType::Custom,
&path,
SaveOptions::default(),
&recipient_public,
)
.expect("save_for_recipient should succeed");
let info = inspect(&path).expect("inspect should succeed");
assert!(info.encrypted, "Model should be marked as encrypted");
assert_eq!(info.model_type, ModelType::Custom);
let loaded: TestModel = load_as_recipient(&path, ModelType::Custom, &recipient_secret)
.expect("load_as_recipient should succeed");
assert_eq!(loaded, model);
}
#[test]
#[cfg(feature = "format-encryption")]
fn test_x25519_wrong_key_fails() {
use tempfile::tempdir;
use x25519_dalek::{PublicKey, StaticSecret};
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestModel {
value: i32,
}
let model = TestModel { value: 42 };
let recipient_secret = StaticSecret::random_from_rng(aes_gcm::aead::OsRng);
let recipient_public = PublicKey::from(&recipient_secret);
let wrong_secret = StaticSecret::random_from_rng(aes_gcm::aead::OsRng);
let dir = tempdir().expect("create temp dir");
let path = dir.path().join("x25519_wrong_key.apr");
save_for_recipient(
&model,
ModelType::Custom,
&path,
SaveOptions::default(),
&recipient_public,
)
.expect("save_for_recipient should succeed");
let result: Result<TestModel> = load_as_recipient(&path, ModelType::Custom, &wrong_secret);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Decryption failed"));
}
#[test]
#[cfg(feature = "format-encryption")]
fn test_x25519_rejects_password_encrypted_file() {
use tempfile::tempdir;
use x25519_dalek::StaticSecret;
#[derive(Debug, Serialize, Deserialize)]
struct TestModel {
value: i32,
}
let model = TestModel { value: 42 };
let dir = tempdir().expect("create temp dir");
let path = dir.path().join("password_not_x25519.apr");
save_encrypted(
&model,
ModelType::Custom,
&path,
SaveOptions::default(),
"some_password",
)
.expect("save_encrypted should succeed");
let wrong_secret = StaticSecret::random_from_rng(aes_gcm::aead::OsRng);
let result: Result<TestModel> = load_as_recipient(&path, ModelType::Custom, &wrong_secret);
assert!(result.is_err());
}
#[test]
#[cfg(feature = "format-encryption")]
fn test_x25519_load_rejects_unencrypted_file() {
use tempfile::tempdir;
use x25519_dalek::StaticSecret;
#[derive(Debug, Serialize, Deserialize)]
struct TestModel {
value: i32,
}
let model = TestModel { value: 42 };
let dir = tempdir().expect("create temp dir");
let path = dir.path().join("unencrypted_for_x25519.apr");
save(&model, ModelType::Custom, &path, SaveOptions::default()).expect("save should succeed");
let secret = StaticSecret::random_from_rng(aes_gcm::aead::OsRng);
let result: Result<TestModel> = load_as_recipient(&path, ModelType::Custom, &secret);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("ENCRYPTED flag not set") || err_msg.contains("File too small"),
"Expected ENCRYPTED flag error or size error, got: {err_msg}"
);
}
#[test]
fn test_distillation_teacher_hash() {
use tempfile::tempdir;
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestModel {
value: i32,
}
let model = TestModel { value: 42 };
let dir = tempdir().expect("create temp dir");
let path = dir.path().join("distilled.apr");
let options = SaveOptions::default()
.with_name("student_model")
.with_description("Distilled from teacher abc123");
save(&model, ModelType::Custom, &path, options).expect("save should succeed");
let info = inspect(&path).expect("inspect should succeed");
assert!(info.metadata.description.is_some());
assert!(info
.metadata
.description
.as_ref()
.expect("description should be set")
.contains("abc123"));
}
#[test]
fn test_distillation_dedicated_field() {
use tempfile::tempdir;
#[derive(Debug, Serialize, Deserialize)]
struct TestModel {
value: i32,
}
let model = TestModel { value: 123 };
let dir = tempdir().expect("create temp dir");
let path = dir.path().join("distilled2.apr");
let options = SaveOptions::default().with_description("test description");
save(&model, ModelType::Custom, &path, options).expect("save should succeed");
let info = inspect(&path).expect("inspect should succeed");
assert_eq!(
info.metadata.description,
Some("test description".to_string())
);
let mut options2 = SaveOptions::default();
options2.metadata.distillation = Some("teacher_abc123".to_string());
let path2 = dir.path().join("distilled2b.apr");
save(&model, ModelType::Custom, &path2, options2).expect("save should succeed");
let info2 = inspect(&path2).expect("inspect should succeed");
assert_eq!(
info2.metadata.distillation,
Some("teacher_abc123".to_string())
);
}
#[test]
fn test_metadata_msgpack_roundtrip() {
let metadata = Metadata {
description: Some("test description".to_string()),
distillation: Some("teacher_abc123".to_string()),
..Default::default()
};
let bytes = rmp_serde::to_vec_named(&metadata).expect("serialize");
let restored: Metadata = rmp_serde::from_slice(&bytes).expect("deserialize");
assert_eq!(restored.description, metadata.description);
assert_eq!(restored.distillation, metadata.distillation);
}
#[test]
fn test_distillation_info_struct() {
use tempfile::tempdir;
#[derive(Debug, Serialize, Deserialize)]
struct TestModel {
value: i32,
}
let model = TestModel { value: 42 };
let dir = tempdir().expect("create temp dir");
let path = dir.path().join("distilled3.apr");
let distill_info = DistillationInfo {
method: DistillMethod::Standard,
teacher: TeacherProvenance {
hash: "sha256:abc123def456".to_string(),
signature: None,
model_type: ModelType::NeuralSequential,
param_count: 7_000_000_000, ensemble_teachers: None,
},
params: DistillationParams {
temperature: 3.0,
alpha: 0.7,
beta: None,
epochs: 10,
final_loss: Some(0.42),
},
layer_mapping: None,
};
let options = SaveOptions::default().with_distillation_info(distill_info);
save(&model, ModelType::Custom, &path, options).expect("save should succeed");
let info = inspect(&path).expect("inspect should succeed");
let restored = info
.metadata
.distillation_info
.expect("should have distillation_info");
assert!(matches!(restored.method, DistillMethod::Standard));
assert_eq!(restored.teacher.hash, "sha256:abc123def456");
assert_eq!(restored.teacher.param_count, 7_000_000_000);
assert!((restored.params.temperature - 3.0).abs() < f32::EPSILON);
assert!((restored.params.alpha - 0.7).abs() < f32::EPSILON);
assert_eq!(restored.params.epochs, 10);
assert!(
(restored.params.final_loss.expect("should have final_loss") - 0.42).abs() < f32::EPSILON
);
}
#[test]
fn test_distillation_progressive_with_layer_mapping() {
use tempfile::tempdir;
#[derive(Debug, Serialize, Deserialize)]
struct TestModel {
value: i32,
}
let model = TestModel { value: 42 };
let dir = tempdir().expect("create temp dir");
let path = dir.path().join("progressive.apr");
let layer_mapping = vec![
LayerMapping {
student_layer: 0,
teacher_layer: 0,
weight: 0.5,
},
LayerMapping {
student_layer: 1,
teacher_layer: 2,
weight: 0.3,
},
LayerMapping {
student_layer: 2,
teacher_layer: 5,
weight: 0.15,
},
LayerMapping {
student_layer: 3,
teacher_layer: 7,
weight: 0.05,
},
];
let distill_info = DistillationInfo {
method: DistillMethod::Progressive,
teacher: TeacherProvenance {
hash: "sha256:teacher_8layer".to_string(),
signature: Some("sig_abc123".to_string()),
model_type: ModelType::NeuralSequential,
param_count: 1_000_000_000, ensemble_teachers: None,
},
params: DistillationParams {
temperature: 4.0,
alpha: 0.8,
beta: Some(0.5), epochs: 20,
final_loss: Some(0.31),
},
layer_mapping: Some(layer_mapping),
};
let options = SaveOptions::default().with_distillation_info(distill_info);
save(&model, ModelType::Custom, &path, options).expect("save should succeed");
let info = inspect(&path).expect("inspect should succeed");
let restored = info
.metadata
.distillation_info
.expect("should have distillation_info");
assert!(matches!(restored.method, DistillMethod::Progressive));
assert_eq!(restored.teacher.hash, "sha256:teacher_8layer");
assert_eq!(restored.teacher.signature, Some("sig_abc123".to_string()));
assert_eq!(restored.teacher.param_count, 1_000_000_000);
assert!((restored.params.temperature - 4.0).abs() < f32::EPSILON);
assert!((restored.params.alpha - 0.8).abs() < f32::EPSILON);
assert!((restored.params.beta.expect("should have beta") - 0.5).abs() < f32::EPSILON);
assert_eq!(restored.params.epochs, 20);
let mapping = restored.layer_mapping.expect("should have layer_mapping");
assert_eq!(mapping.len(), 4);
assert_eq!(mapping[0].student_layer, 0);
assert_eq!(mapping[0].teacher_layer, 0);
assert!((mapping[0].weight - 0.5).abs() < f32::EPSILON);
assert_eq!(mapping[2].student_layer, 2);
assert_eq!(mapping[2].teacher_layer, 5);
}