#[cfg(feature = "encryption")]
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use oxionnx_core::OnnxError;
use std::path::Path;
#[cfg(feature = "encryption")]
pub fn encrypt_model(
input_path: &Path,
output_path: &Path,
key: &[u8; 32],
) -> Result<(), OnnxError> {
let plaintext = std::fs::read(input_path)
.map_err(|e| OnnxError::Parse(format!("Cannot read model file: {}", e)))?;
let ciphertext_with_nonce = encrypt_bytes(&plaintext, key)?;
std::fs::write(output_path, &ciphertext_with_nonce)
.map_err(|e| OnnxError::Internal(format!("Cannot write encrypted file: {}", e)))?;
Ok(())
}
#[cfg(feature = "encryption")]
pub fn encrypt_bytes(plaintext: &[u8], key: &[u8; 32]) -> Result<Vec<u8>, OnnxError> {
let cipher = Aes256Gcm::new_from_slice(key)
.map_err(|e| OnnxError::Internal(format!("Failed to create cipher: {}", e)))?;
let mut nonce_bytes = [0u8; 12];
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
nonce_bytes[..8].copy_from_slice(×tamp.to_le_bytes()[..8]);
let size_bytes = (plaintext.len() as u32).to_le_bytes();
nonce_bytes[8..12].copy_from_slice(&size_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|e| OnnxError::Internal(format!("Encryption failed: {}", e)))?;
let mut output = Vec::with_capacity(12 + ciphertext.len());
output.extend_from_slice(&nonce_bytes);
output.extend_from_slice(&ciphertext);
Ok(output)
}
#[cfg(feature = "encryption")]
pub fn decrypt_model(encrypted_path: &Path, key: &[u8; 32]) -> Result<Vec<u8>, OnnxError> {
let data = std::fs::read(encrypted_path)
.map_err(|e| OnnxError::Parse(format!("Cannot read encrypted file: {}", e)))?;
decrypt_bytes(&data, key)
}
#[cfg(feature = "encryption")]
pub fn decrypt_bytes(data: &[u8], key: &[u8; 32]) -> Result<Vec<u8>, OnnxError> {
if data.len() < 12 {
return Err(OnnxError::Parse("Encrypted data too short".into()));
}
let (nonce_bytes, ciphertext) = data.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
let cipher = Aes256Gcm::new_from_slice(key)
.map_err(|e| OnnxError::Internal(format!("Failed to create cipher: {}", e)))?;
cipher
.decrypt(nonce, ciphertext)
.map_err(|e| OnnxError::Internal(format!("Decryption failed (wrong key?): {}", e)))
}
#[cfg(feature = "encryption")]
pub fn load_encrypted(path: &Path, key: &[u8; 32]) -> Result<crate::session::Session, OnnxError> {
let plaintext = decrypt_model(path, key)?;
crate::session::Session::from_bytes(&plaintext)
}
#[cfg(test)]
#[cfg(feature = "encryption")]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt_roundtrip() {
let key = [0x42u8; 32];
let plaintext = b"ONNX model content for testing roundtrip encryption";
let encrypted = encrypt_bytes(plaintext, &key).expect("encryption should succeed");
let decrypted = decrypt_bytes(&encrypted, &key).expect("decryption should succeed");
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_decrypt_wrong_key() {
let key = [0x42u8; 32];
let wrong_key = [0x99u8; 32];
let plaintext = b"secret model data";
let encrypted = encrypt_bytes(plaintext, &key).expect("encryption should succeed");
let result = decrypt_bytes(&encrypted, &wrong_key);
assert!(result.is_err(), "decryption with wrong key should fail");
}
#[test]
fn test_encrypt_model_file() {
let key = [0xABu8; 32];
let plaintext = b"fake ONNX model bytes for file-based test";
let tmp = std::env::temp_dir();
let input_path = tmp.join("oxionnx_test_encrypt_input.onnx");
let output_path = tmp.join("oxionnx_test_encrypt_output.enc");
std::fs::write(&input_path, plaintext).expect("should write test input");
encrypt_model(&input_path, &output_path, &key).expect("encrypt_model should succeed");
let decrypted = decrypt_model(&output_path, &key).expect("decrypt_model should succeed");
assert_eq!(decrypted, plaintext);
let _ = std::fs::remove_file(&input_path);
let _ = std::fs::remove_file(&output_path);
}
}