use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
use serde::{de::DeserializeOwned, Serialize};
use super::core_io::{
compress_payload, crc32, decompress_payload, parse_and_validate_header, read_file_content,
verify_encrypted_flag, verify_file_checksum, verify_payload_boundary,
};
use super::{
Header, ModelType, SaveOptions, HEADER_SIZE, HKDF_INFO, KEY_SIZE, NONCE_SIZE,
RECIPIENT_HASH_SIZE, SALT_SIZE, X25519_PUBLIC_KEY_SIZE,
};
use super::{X25519PublicKey, X25519SecretKey};
use crate::error::{AprenderError, Result};
use aes_gcm::aead::rand_core::RngCore;
#[cfg(feature = "format-encryption")]
#[allow(clippy::needless_pass_by_value)] pub fn save_encrypted<M: Serialize>(
model: &M,
model_type: ModelType,
path: impl AsRef<Path>,
options: SaveOptions,
password: &str,
) -> Result<()> {
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use argon2::Argon2;
let path = path.as_ref();
let payload_uncompressed = bincode::serialize(model)
.map_err(|e| AprenderError::Serialization(format!("Failed to serialize model: {e}")))?;
let (payload_compressed, compression) =
compress_payload(&payload_uncompressed, options.compression)?;
let mut salt = [0u8; SALT_SIZE];
let mut nonce_bytes = [0u8; NONCE_SIZE];
aes_gcm::aead::OsRng.fill_bytes(&mut salt);
aes_gcm::aead::OsRng.fill_bytes(&mut nonce_bytes);
let mut key = [0u8; KEY_SIZE];
Argon2::default()
.hash_password_into(password.as_bytes(), &salt, &mut key)
.map_err(|e| AprenderError::Other(format!("Key derivation failed: {e}")))?;
let cipher = Aes256Gcm::new_from_slice(&key)
.map_err(|e| AprenderError::Other(format!("Failed to create cipher: {e}")))?;
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, payload_compressed.as_ref())
.map_err(|e| AprenderError::Other(format!("Encryption failed: {e}")))?;
let metadata_bytes = rmp_serde::to_vec_named(&options.metadata)
.map_err(|e| AprenderError::Serialization(format!("Failed to serialize metadata: {e}")))?;
let mut header = Header::new(model_type);
header.compression = compression;
header.metadata_size = metadata_bytes.len() as u32;
header.payload_size = (SALT_SIZE + NONCE_SIZE + ciphertext.len()) as u32;
header.uncompressed_size = payload_uncompressed.len() as u32;
header.flags = header.flags.with_encrypted();
let mut content = Vec::new();
content.extend_from_slice(&header.to_bytes());
content.extend_from_slice(&metadata_bytes);
content.extend_from_slice(&salt);
content.extend_from_slice(&nonce_bytes);
content.extend_from_slice(&ciphertext);
let checksum = crc32(&content);
content.extend_from_slice(&checksum.to_le_bytes());
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
writer.write_all(&content)?;
writer.flush()?;
Ok(())
}
#[cfg(feature = "format-encryption")]
pub fn load_encrypted<M: DeserializeOwned>(
path: impl AsRef<Path>,
expected_type: ModelType,
password: &str,
) -> Result<M> {
let path = path.as_ref();
let content = read_file_content(path)?;
verify_password_encrypted_file_size(&content)?;
verify_file_checksum(&content)?;
let header = parse_and_validate_header(&content, expected_type)?;
verify_encrypted_flag(&header)?;
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let salt_end = metadata_end + SALT_SIZE;
let nonce_end = salt_end + NONCE_SIZE;
let payload_end = metadata_end + header.payload_size as usize;
verify_payload_boundary(payload_end, content.len())?;
let (salt, nonce_bytes, ciphertext) = extract_password_encryption_components(
&content,
metadata_end,
salt_end,
nonce_end,
payload_end,
)?;
let payload_compressed = decrypt_password_payload(password, &salt, &nonce_bytes, ciphertext)?;
let payload_uncompressed = decompress_payload(&payload_compressed, header.compression)?;
bincode::deserialize(&payload_uncompressed)
.map_err(|e| AprenderError::Serialization(format!("Failed to deserialize model: {e}")))
}
#[cfg(feature = "format-encryption")]
fn verify_password_encrypted_file_size(content: &[u8]) -> Result<()> {
if content.len() < HEADER_SIZE + SALT_SIZE + NONCE_SIZE + 4 {
return Err(AprenderError::FormatError {
message: format!(
"File too small for encrypted model: {} bytes",
content.len()
),
});
}
Ok(())
}
#[cfg(feature = "format-encryption")]
fn extract_password_encryption_components(
content: &[u8],
metadata_end: usize,
salt_end: usize,
nonce_end: usize,
payload_end: usize,
) -> Result<([u8; SALT_SIZE], [u8; NONCE_SIZE], &[u8])> {
let salt: [u8; SALT_SIZE] =
content[metadata_end..salt_end]
.try_into()
.map_err(|_| AprenderError::FormatError {
message: "Invalid salt size".to_string(),
})?;
let nonce_bytes: [u8; NONCE_SIZE] =
content[salt_end..nonce_end]
.try_into()
.map_err(|_| AprenderError::FormatError {
message: "Invalid nonce size".to_string(),
})?;
let ciphertext = &content[nonce_end..payload_end];
Ok((salt, nonce_bytes, ciphertext))
}
#[cfg(feature = "format-encryption")]
fn decrypt_password_payload(
password: &str,
salt: &[u8; SALT_SIZE],
nonce_bytes: &[u8; NONCE_SIZE],
ciphertext: &[u8],
) -> Result<Vec<u8>> {
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use argon2::Argon2;
let mut key = [0u8; KEY_SIZE];
Argon2::default()
.hash_password_into(password.as_bytes(), salt, &mut key)
.map_err(|e| AprenderError::Other(format!("Key derivation failed: {e}")))?;
let cipher = Aes256Gcm::new_from_slice(&key)
.map_err(|e| AprenderError::Other(format!("Failed to create cipher: {e}")))?;
let nonce = Nonce::from_slice(nonce_bytes);
cipher
.decrypt(nonce, ciphertext)
.map_err(|_| AprenderError::DecryptionFailed {
message: "Decryption failed (wrong password or corrupted data)".to_string(),
})
}
#[cfg(feature = "format-encryption")]
#[allow(clippy::needless_pass_by_value)] pub fn save_for_recipient<M: Serialize>(
model: &M,
model_type: ModelType,
path: impl AsRef<Path>,
options: SaveOptions,
recipient_public_key: &X25519PublicKey,
) -> Result<()> {
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use hkdf::Hkdf;
use sha2::Sha256;
let path = path.as_ref();
let payload_uncompressed = bincode::serialize(model)
.map_err(|e| AprenderError::Serialization(format!("Failed to serialize model: {e}")))?;
let (payload_compressed, compression) =
compress_payload(&payload_uncompressed, options.compression)?;
let ephemeral_secret = X25519SecretKey::random_from_rng(&mut aes_gcm::aead::OsRng);
let ephemeral_public = X25519PublicKey::from(&ephemeral_secret);
let shared_secret = ephemeral_secret.diffie_hellman(recipient_public_key);
let hkdf = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
let mut key = [0u8; KEY_SIZE];
hkdf.expand(HKDF_INFO, &mut key)
.map_err(|_| AprenderError::Other("HKDF expansion failed".to_string()))?;
let mut nonce_bytes = [0u8; NONCE_SIZE];
aes_gcm::aead::OsRng.fill_bytes(&mut nonce_bytes);
let cipher = Aes256Gcm::new_from_slice(&key)
.map_err(|e| AprenderError::Other(format!("Failed to create cipher: {e}")))?;
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, payload_compressed.as_ref())
.map_err(|e| AprenderError::Other(format!("Encryption failed: {e}")))?;
let recipient_hash: [u8; RECIPIENT_HASH_SIZE] = recipient_public_key.as_bytes()
[..RECIPIENT_HASH_SIZE]
.try_into()
.expect("recipient hash size is correct");
let metadata_bytes = rmp_serde::to_vec_named(&options.metadata)
.map_err(|e| AprenderError::Serialization(format!("Failed to serialize metadata: {e}")))?;
let mut header = Header::new(model_type);
header.compression = compression;
header.metadata_size = metadata_bytes.len() as u32;
header.payload_size =
(X25519_PUBLIC_KEY_SIZE + RECIPIENT_HASH_SIZE + NONCE_SIZE + ciphertext.len()) as u32;
header.uncompressed_size = payload_uncompressed.len() as u32;
header.flags = header.flags.with_encrypted();
let mut content = Vec::new();
content.extend_from_slice(&header.to_bytes());
content.extend_from_slice(&metadata_bytes);
content.extend_from_slice(ephemeral_public.as_bytes()); content.extend_from_slice(&recipient_hash); content.extend_from_slice(&nonce_bytes); content.extend_from_slice(&ciphertext);
let checksum = crc32(&content);
content.extend_from_slice(&checksum.to_le_bytes());
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
writer.write_all(&content)?;
writer.flush()?;
Ok(())
}
#[cfg(feature = "format-encryption")]
pub fn load_as_recipient<M: DeserializeOwned>(
path: impl AsRef<Path>,
expected_type: ModelType,
recipient_secret_key: &X25519SecretKey,
) -> Result<M> {
let path = path.as_ref();
let content = read_file_content(path)?;
verify_x25519_encrypted_file_size(&content)?;
verify_file_checksum(&content)?;
let header = parse_and_validate_header(&content, expected_type)?;
verify_encrypted_flag(&header)?;
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let ephemeral_pub_end = metadata_end + X25519_PUBLIC_KEY_SIZE;
let recipient_hash_end = ephemeral_pub_end + RECIPIENT_HASH_SIZE;
let nonce_end = recipient_hash_end + NONCE_SIZE;
let payload_end = metadata_end + header.payload_size as usize;
verify_payload_boundary(payload_end, content.len())?;
let (ephemeral_public, stored_recipient_hash) = extract_x25519_recipient_info(
&content,
metadata_end,
ephemeral_pub_end,
recipient_hash_end,
)?;
verify_recipient(recipient_secret_key, stored_recipient_hash)?;
let (nonce_bytes, ciphertext) =
extract_nonce_and_ciphertext(&content, recipient_hash_end, nonce_end, payload_end)?;
let payload_compressed = decrypt_x25519_payload(
recipient_secret_key,
&ephemeral_public,
&nonce_bytes,
ciphertext,
)?;
let payload_uncompressed = decompress_payload(&payload_compressed, header.compression)?;
bincode::deserialize(&payload_uncompressed)
.map_err(|e| AprenderError::Serialization(format!("Failed to deserialize model: {e}")))
}
#[cfg(feature = "format-encryption")]
fn verify_x25519_encrypted_file_size(content: &[u8]) -> Result<()> {
const MIN_PAYLOAD_SIZE: usize = X25519_PUBLIC_KEY_SIZE + RECIPIENT_HASH_SIZE + NONCE_SIZE;
if content.len() < HEADER_SIZE + MIN_PAYLOAD_SIZE + 4 {
return Err(AprenderError::FormatError {
message: format!(
"File too small for X25519 encrypted model: {} bytes",
content.len()
),
});
}
Ok(())
}
#[cfg(feature = "format-encryption")]
fn extract_x25519_recipient_info(
content: &[u8],
metadata_end: usize,
ephemeral_pub_end: usize,
recipient_hash_end: usize,
) -> Result<(X25519PublicKey, [u8; RECIPIENT_HASH_SIZE])> {
let ephemeral_pub_bytes: [u8; X25519_PUBLIC_KEY_SIZE] = content
[metadata_end..ephemeral_pub_end]
.try_into()
.map_err(|_| AprenderError::FormatError {
message: "Invalid ephemeral public key size".to_string(),
})?;
let ephemeral_public = X25519PublicKey::from(ephemeral_pub_bytes);
let stored_recipient_hash: [u8; RECIPIENT_HASH_SIZE] = content
[ephemeral_pub_end..recipient_hash_end]
.try_into()
.map_err(|_| AprenderError::FormatError {
message: "Invalid recipient hash size".to_string(),
})?;
Ok((ephemeral_public, stored_recipient_hash))
}
include!("encrypted.rs");