use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
use ed25519_dalek::{SigningKey, VerifyingKey};
use serde::{de::DeserializeOwned, Serialize};
use super::core_io::{
compress_payload, crc32, decompress_and_deserialize, parse_and_validate_header,
read_file_content, verify_file_checksum, verify_payload_boundary, verify_signed_flag,
};
use super::{Header, ModelType, SaveOptions, HEADER_SIZE, PUBLIC_KEY_SIZE, SIGNATURE_SIZE};
use crate::error::{AprenderError, Result};
#[allow(clippy::needless_pass_by_value)] pub fn save_signed<M: Serialize>(
model: &M,
model_type: ModelType,
path: impl AsRef<Path>,
options: SaveOptions,
signing_key: &SigningKey,
) -> Result<()> {
use ed25519_dalek::Signer;
let path = path.as_ref();
let payload_uncompressed = bincode::serialize(model)
.map_err(|e| AprenderError::Serialization(format!("Failed to serialize model: {e}")))?;
let (payload_compressed, compression) =
compress_payload(&payload_uncompressed, options.compression)?;
let metadata_bytes = rmp_serde::to_vec_named(&options.metadata)
.map_err(|e| AprenderError::Serialization(format!("Failed to serialize metadata: {e}")))?;
let mut header = Header::new(model_type);
header.compression = compression;
header.metadata_size = metadata_bytes.len() as u32;
header.payload_size = payload_compressed.len() as u32;
header.uncompressed_size = payload_uncompressed.len() as u32;
header.flags = header.flags.with_signed();
let mut signable_content = Vec::new();
signable_content.extend_from_slice(&header.to_bytes());
signable_content.extend_from_slice(&metadata_bytes);
signable_content.extend_from_slice(&payload_compressed);
let signature = signing_key.sign(&signable_content);
let verifying_key = signing_key.verifying_key();
let mut content = signable_content;
content.extend_from_slice(&signature.to_bytes()); content.extend_from_slice(verifying_key.as_bytes());
let checksum = crc32(&content);
content.extend_from_slice(&checksum.to_le_bytes());
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
writer.write_all(&content)?;
writer.flush()?;
Ok(())
}
pub fn load_verified<M: DeserializeOwned>(
path: impl AsRef<Path>,
expected_type: ModelType,
trusted_key: Option<&VerifyingKey>,
) -> Result<M> {
let path = path.as_ref();
let content = read_file_content(path)?;
verify_signed_file_size(&content)?;
verify_file_checksum(&content)?;
let header = parse_and_validate_header(&content, expected_type)?;
verify_signed_flag(&header)?;
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let payload_end = metadata_end + header.payload_size as usize;
let signature_start = payload_end;
let pubkey_start = signature_start + SIGNATURE_SIZE;
let pubkey_end = pubkey_start + PUBLIC_KEY_SIZE;
verify_payload_boundary(pubkey_end, content.len())?;
let (signature, embedded_key) =
extract_signature_and_key(&content, signature_start, pubkey_start, pubkey_end)?;
let verifying_key = trusted_key.unwrap_or(&embedded_key);
let signable_content = &content[..payload_end];
verify_signature(verifying_key, signable_content, &signature)?;
decompress_and_deserialize(&content[metadata_end..payload_end], header.compression)
}
fn verify_signed_file_size(content: &[u8]) -> Result<()> {
const SIGNATURE_BLOCK_SIZE: usize = SIGNATURE_SIZE + PUBLIC_KEY_SIZE;
if content.len() < HEADER_SIZE + SIGNATURE_BLOCK_SIZE + 4 {
return Err(AprenderError::FormatError {
message: format!("File too small for signed model: {} bytes", content.len()),
});
}
Ok(())
}
fn extract_signature_and_key(
content: &[u8],
signature_start: usize,
pubkey_start: usize,
pubkey_end: usize,
) -> Result<(ed25519_dalek::Signature, VerifyingKey)> {
use ed25519_dalek::Signature;
let signature_bytes: [u8; 64] =
content[signature_start..pubkey_start]
.try_into()
.map_err(|_| AprenderError::FormatError {
message: "Invalid signature size".to_string(),
})?;
let signature = Signature::from_bytes(&signature_bytes);
let pubkey_bytes: [u8; 32] =
content[pubkey_start..pubkey_end]
.try_into()
.map_err(|_| AprenderError::FormatError {
message: "Invalid public key size".to_string(),
})?;
let embedded_key =
VerifyingKey::from_bytes(&pubkey_bytes).map_err(|e| AprenderError::FormatError {
message: format!("Invalid public key: {e}"),
})?;
Ok((signature, embedded_key))
}
fn verify_signature(
verifying_key: &VerifyingKey,
signable_content: &[u8],
signature: &ed25519_dalek::Signature,
) -> Result<()> {
use ed25519_dalek::Verifier;
verifying_key
.verify(signable_content, signature)
.map_err(|e| AprenderError::SignatureInvalid {
reason: format!("Signature verification failed: {e}"),
})
}