#[cfg(feature = "encryption")]
const NONCE_SIZE: usize = 12;
#[cfg(feature = "encryption")]
const TAG_SIZE: usize = 16;
#[cfg(feature = "encryption")]
const CIPHERTEXT_OVERHEAD: usize = NONCE_SIZE + TAG_SIZE;
#[cfg(feature = "encryption")]
use crate::encryption::{derive_domain_key, ZeroKnowledgeEncryptor};
#[cfg(feature = "encryption")]
use crate::ffi::error::CachekitError;
#[cfg(feature = "encryption")]
use crate::ffi::handles::CachekitEncryptor;
#[cfg(feature = "encryption")]
use std::panic::catch_unwind;
#[cfg(feature = "encryption")]
use std::slice;
#[cfg(feature = "encryption")]
#[no_mangle]
pub unsafe extern "C" fn cachekit_encryptor_new(
error_out: *mut CachekitError,
) -> *mut CachekitEncryptor {
let result = catch_unwind(|| match ZeroKnowledgeEncryptor::new() {
Ok(encryptor) => {
if !error_out.is_null() {
unsafe { *error_out = CachekitError::Ok };
}
CachekitEncryptor::into_opaque_ptr(encryptor)
}
Err(e) => {
if !error_out.is_null() {
unsafe { *error_out = CachekitError::from(e) };
}
std::ptr::null_mut()
}
});
result.unwrap_or_else(|_| {
if !error_out.is_null() {
unsafe { *error_out = CachekitError::InvalidInput };
}
std::ptr::null_mut()
})
}
#[cfg(feature = "encryption")]
#[no_mangle]
pub unsafe extern "C" fn cachekit_encryptor_free(handle: *mut CachekitEncryptor) {
let _ = catch_unwind(|| {
unsafe {
let _encryptor = CachekitEncryptor::from_opaque_ptr(handle);
}
});
}
#[cfg(feature = "encryption")]
#[no_mangle]
pub unsafe extern "C" fn cachekit_encryptor_get_counter(handle: *const CachekitEncryptor) -> u64 {
let result = catch_unwind(|| {
let encryptor = match unsafe { CachekitEncryptor::as_ref(handle) } {
Some(enc) => enc,
None => return 0, };
encryptor.get_nonce_counter()
});
result.unwrap_or(0)
}
#[cfg(feature = "encryption")]
#[no_mangle]
pub unsafe extern "C" fn cachekit_encrypt(
handle: *mut CachekitEncryptor,
key: *const u8,
key_len: usize,
aad: *const u8,
aad_len: usize,
plaintext: *const u8,
plaintext_len: usize,
output: *mut u8,
output_len: *mut usize,
) -> CachekitError {
let result = catch_unwind(|| {
if key.is_null() {
return CachekitError::NullPointer;
}
if aad.is_null() {
return CachekitError::NullPointer;
}
if plaintext.is_null() {
return CachekitError::NullPointer;
}
if output.is_null() {
return CachekitError::NullPointer;
}
if output_len.is_null() {
return CachekitError::NullPointer;
}
let encryptor = match unsafe { CachekitEncryptor::as_ref(handle) } {
Some(enc) => enc,
None => return CachekitError::InvalidHandle,
};
let available_size = unsafe { *output_len };
let counter = encryptor.get_nonce_counter();
if counter >= (1u64 << 31) {
return CachekitError::RotationNeeded;
}
if key_len != 32 {
return CachekitError::InvalidKeyLength;
}
let key_slice = unsafe { slice::from_raw_parts(key, key_len) };
let aad_slice = unsafe { slice::from_raw_parts(aad, aad_len) };
let plaintext_slice = unsafe { slice::from_raw_parts(plaintext, plaintext_len) };
let ciphertext = match encryptor.encrypt_aes_gcm(plaintext_slice, key_slice, aad_slice) {
Ok(data) => data,
Err(e) => {
return CachekitError::from(e);
}
};
if ciphertext.len() > available_size {
unsafe {
*output_len = ciphertext.len();
}
return CachekitError::BufferTooSmall;
}
unsafe {
std::ptr::copy_nonoverlapping(ciphertext.as_ptr(), output, ciphertext.len());
*output_len = ciphertext.len();
}
CachekitError::Ok
});
result.unwrap_or(CachekitError::InvalidInput)
}
#[cfg(feature = "encryption")]
#[no_mangle]
pub unsafe extern "C" fn cachekit_decrypt(
handle: *const CachekitEncryptor,
key: *const u8,
key_len: usize,
aad: *const u8,
aad_len: usize,
ciphertext: *const u8,
ciphertext_len: usize,
output: *mut u8,
output_len: *mut usize,
) -> CachekitError {
let result = catch_unwind(|| {
if key.is_null() {
return CachekitError::NullPointer;
}
if aad.is_null() {
return CachekitError::NullPointer;
}
if ciphertext.is_null() {
return CachekitError::NullPointer;
}
if output.is_null() {
return CachekitError::NullPointer;
}
if output_len.is_null() {
return CachekitError::NullPointer;
}
let encryptor = match unsafe { CachekitEncryptor::as_ref(handle) } {
Some(enc) => enc,
None => return CachekitError::InvalidHandle,
};
let available_size = unsafe { *output_len };
if key_len != 32 {
return CachekitError::InvalidKeyLength;
}
if ciphertext_len < CIPHERTEXT_OVERHEAD {
return CachekitError::InvalidInput;
}
let key_slice = unsafe { slice::from_raw_parts(key, key_len) };
let aad_slice = unsafe { slice::from_raw_parts(aad, aad_len) };
let ciphertext_slice = unsafe { slice::from_raw_parts(ciphertext, ciphertext_len) };
let plaintext = match encryptor.decrypt_aes_gcm(ciphertext_slice, key_slice, aad_slice) {
Ok(data) => data,
Err(e) => {
return CachekitError::from(e);
}
};
if plaintext.len() > available_size {
unsafe {
*output_len = plaintext.len();
}
return CachekitError::BufferTooSmall;
}
unsafe {
std::ptr::copy_nonoverlapping(plaintext.as_ptr(), output, plaintext.len());
*output_len = plaintext.len();
}
CachekitError::Ok
});
result.unwrap_or(CachekitError::DecryptionFailed)
}
#[cfg(feature = "encryption")]
#[no_mangle]
pub unsafe extern "C" fn cachekit_derive_key(
master: *const u8,
master_len: usize,
salt: *const u8,
salt_len: usize,
domain: *const u8,
domain_len: usize,
out_key: *mut u8,
) -> CachekitError {
let result = catch_unwind(|| {
if master.is_null() {
return CachekitError::NullPointer;
}
if salt.is_null() {
return CachekitError::NullPointer;
}
if domain.is_null() {
return CachekitError::NullPointer;
}
if out_key.is_null() {
return CachekitError::NullPointer;
}
let master_slice = unsafe { slice::from_raw_parts(master, master_len) };
let salt_slice = unsafe { slice::from_raw_parts(salt, salt_len) };
let domain_slice = unsafe { slice::from_raw_parts(domain, domain_len) };
let domain_str = match std::str::from_utf8(domain_slice) {
Ok(s) => s,
Err(_) => return CachekitError::InvalidInput,
};
let derived_key = match derive_domain_key(master_slice, domain_str, salt_slice) {
Ok(key) => key,
Err(e) => {
return CachekitError::from(e);
}
};
unsafe {
std::ptr::copy_nonoverlapping(derived_key.as_ptr(), out_key, 32);
}
CachekitError::Ok
});
result.unwrap_or(CachekitError::InvalidInput)
}
#[cfg(all(test, feature = "encryption"))]
mod tests {
use super::*;
#[test]
fn test_encryptor_lifecycle() {
unsafe {
let handle = cachekit_encryptor_new(std::ptr::null_mut());
assert!(!handle.is_null());
let counter = cachekit_encryptor_get_counter(handle);
assert_eq!(counter, 0);
cachekit_encryptor_free(handle);
}
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
unsafe {
let handle = cachekit_encryptor_new(std::ptr::null_mut());
assert!(!handle.is_null());
let key = [0u8; 32];
let aad = b"test_context";
let plaintext = b"Hello, FFI encryption!";
let mut ciphertext = vec![0u8; plaintext.len() + 100];
let mut ciphertext_len = ciphertext.len();
let result = cachekit_encrypt(
handle,
key.as_ptr(),
32,
aad.as_ptr(),
aad.len(),
plaintext.as_ptr(),
plaintext.len(),
ciphertext.as_mut_ptr(),
&mut ciphertext_len,
);
assert_eq!(result, CachekitError::Ok);
assert_eq!(ciphertext_len, plaintext.len() + CIPHERTEXT_OVERHEAD);
let mut decrypted = vec![0u8; plaintext.len() + 100];
let mut decrypted_len = decrypted.len();
let result = cachekit_decrypt(
handle,
key.as_ptr(),
32,
aad.as_ptr(),
aad.len(),
ciphertext.as_ptr(),
ciphertext_len,
decrypted.as_mut_ptr(),
&mut decrypted_len,
);
assert_eq!(result, CachekitError::Ok);
assert_eq!(decrypted_len, plaintext.len());
assert_eq!(&decrypted[..decrypted_len], plaintext);
cachekit_encryptor_free(handle);
}
}
#[test]
fn test_encrypt_null_checks() {
unsafe {
let handle = cachekit_encryptor_new(std::ptr::null_mut());
let key = [0u8; 32];
let aad = b"test";
let plaintext = b"test";
let mut output = vec![0u8; 100];
let mut output_len = output.len();
assert_eq!(
cachekit_encrypt(
std::ptr::null_mut(),
key.as_ptr(),
32,
aad.as_ptr(),
aad.len(),
plaintext.as_ptr(),
plaintext.len(),
output.as_mut_ptr(),
&mut output_len,
),
CachekitError::InvalidHandle
);
assert_eq!(
cachekit_encrypt(
handle,
std::ptr::null(),
32,
aad.as_ptr(),
aad.len(),
plaintext.as_ptr(),
plaintext.len(),
output.as_mut_ptr(),
&mut output_len,
),
CachekitError::NullPointer
);
cachekit_encryptor_free(handle);
}
}
#[test]
fn test_decrypt_wrong_key() {
unsafe {
let handle = cachekit_encryptor_new(std::ptr::null_mut());
let key1 = [0u8; 32];
let key2 = [1u8; 32];
let aad = b"test";
let plaintext = b"secret";
let mut ciphertext = vec![0u8; plaintext.len() + 100];
let mut ciphertext_len = ciphertext.len();
let result = cachekit_encrypt(
handle,
key1.as_ptr(),
32,
aad.as_ptr(),
aad.len(),
plaintext.as_ptr(),
plaintext.len(),
ciphertext.as_mut_ptr(),
&mut ciphertext_len,
);
assert_eq!(result, CachekitError::Ok);
let mut decrypted = vec![0u8; plaintext.len() + 100];
let mut decrypted_len = decrypted.len();
let result = cachekit_decrypt(
handle,
key2.as_ptr(),
32,
aad.as_ptr(),
aad.len(),
ciphertext.as_ptr(),
ciphertext_len,
decrypted.as_mut_ptr(),
&mut decrypted_len,
);
assert_eq!(result, CachekitError::DecryptionFailed);
cachekit_encryptor_free(handle);
}
}
#[test]
fn test_derive_key_basic() {
unsafe {
let master = b"test_master_key_32_bytes_long!!!";
let salt = b"tenant123";
let domain = b"encryption";
let mut out_key = [0u8; 32];
let result = cachekit_derive_key(
master.as_ptr(),
master.len(),
salt.as_ptr(),
salt.len(),
domain.as_ptr(),
domain.len(),
out_key.as_mut_ptr(),
);
assert_eq!(result, CachekitError::Ok);
assert_ne!(out_key, [0u8; 32]);
}
}
#[test]
fn test_derive_key_deterministic() {
unsafe {
let master = b"test_master_key_32_bytes_long!!!";
let salt = b"tenant123";
let domain = b"encryption";
let mut key1 = [0u8; 32];
let mut key2 = [0u8; 32];
cachekit_derive_key(
master.as_ptr(),
master.len(),
salt.as_ptr(),
salt.len(),
domain.as_ptr(),
domain.len(),
key1.as_mut_ptr(),
);
cachekit_derive_key(
master.as_ptr(),
master.len(),
salt.as_ptr(),
salt.len(),
domain.as_ptr(),
domain.len(),
key2.as_mut_ptr(),
);
assert_eq!(key1, key2);
}
}
#[test]
fn test_invalid_key_length() {
unsafe {
let handle = cachekit_encryptor_new(std::ptr::null_mut());
let short_key = [0u8; 16]; let aad = b"test";
let plaintext = b"test";
let mut output = vec![0u8; 100];
let mut output_len = output.len();
let result = cachekit_encrypt(
handle,
short_key.as_ptr(),
16,
aad.as_ptr(),
aad.len(),
plaintext.as_ptr(),
plaintext.len(),
output.as_mut_ptr(),
&mut output_len,
);
assert_eq!(result, CachekitError::InvalidKeyLength);
cachekit_encryptor_free(handle);
}
}
#[test]
fn test_buffer_too_small() {
unsafe {
let handle = cachekit_encryptor_new(std::ptr::null_mut());
let key = [0u8; 32];
let aad = b"test";
let plaintext = b"Hello, World!";
let mut output = vec![0u8; 10]; let mut output_len = output.len();
let result = cachekit_encrypt(
handle,
key.as_ptr(),
32,
aad.as_ptr(),
aad.len(),
plaintext.as_ptr(),
plaintext.len(),
output.as_mut_ptr(),
&mut output_len,
);
assert_eq!(result, CachekitError::BufferTooSmall);
assert_eq!(output_len, plaintext.len() + CIPHERTEXT_OVERHEAD);
cachekit_encryptor_free(handle);
}
}
}