use crate::encryption::{
decrypt_data, encrypt_data, prepare_key_context, EncryptionAlgorithm, PreparedKeyContext,
};
use crate::key::{resolve_key, KeyLookupParams, KeyMaterial, KeyRole, ValidateMode};
use crate::policy::AccessPolicy;
use crate::signing::{sign_data, verify_signature, SignatureAlgorithm};
use crate::tensor::{Metadata, TensorInfo};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use once_cell::sync::OnceCell;
use rayon::prelude::*;
use ring::rand::{self, SecureRandom};
use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use thiserror::Error;
use zeroize::Zeroizing;
pub const CRYPTOTENSORS_VERSION_V1: &str = "1";
pub const CRYPTOTENSORS_VERSION_V2: &str = "2";
const PARALLEL_CONTEXT_THRESHOLD: usize = 16;
pub const DEFAULT_CHUNK_SIZE: usize = 2 * 1024 * 1024;
#[derive(Debug, Error)]
pub enum CryptoTensorsError {
#[error("failed to unwrap key: {0}")]
KeyUnwrap(String),
#[error("failed to create key: {0}")]
KeyCreation(String),
#[error("failed to load key: {0}")]
KeyLoad(String),
#[error("invalid key: {0}")]
InvalidKey(String),
#[error("invalid JWK URL: {0}")]
InvalidJwkUrl(String),
#[error("no suitable key found")]
NoSuitableKey,
#[error("multiple keys found without key ID (kid)")]
AmbiguousKeySet,
#[error("master key is required but not available")]
MissingMasterKey,
#[error("signing key is required but not available")]
MissingSigningKey,
#[error("verification key is required but not available")]
MissingVerificationKey,
#[error("encryption failed: {0}")]
Encryption(String),
#[error("decryption failed: {0}")]
Decryption(String),
#[error("invalid algorithm: {0}")]
InvalidAlgorithm(String),
#[error("invalid key length: expected {expected} bytes, got {actual} bytes")]
InvalidKeyLength {
expected: usize,
actual: usize,
},
#[error("invalid IV length: expected {expected} bytes, got {actual} bytes")]
InvalidIvLength {
expected: usize,
actual: usize,
},
#[error("invalid tag length: expected {expected} bytes, got {actual} bytes")]
InvalidTagLength {
expected: usize,
actual: usize,
},
#[error("random generation failed: {0}")]
RandomGeneration(String),
#[error("signing failed: {0}")]
Signing(String),
#[error("verification failed: {0}")]
Verification(String),
#[error("missing signature: {0}")]
MissingSignature(String),
#[error("invalid signature format")]
InvalidSignatureFormat,
#[error("policy error: {0}")]
Policy(String),
#[error("unsupported version: {0}")]
UnsupportedVersion(String),
#[error("version field is missing")]
MissingVersion,
#[error("registry error: {0}")]
Registry(String),
#[error("JSON error: {0}")]
Json(String),
}
impl From<serde_json::Error> for CryptoTensorsError {
fn from(error: serde_json::Error) -> Self {
CryptoTensorsError::Json(error.to_string())
}
}
mod cryptor_serde {
use super::*;
pub fn serialize<S>(cell: &OnceCell<String>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match cell.get() {
Some(value) => value.serialize(serializer),
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<OnceCell<String>, D::Error>
where
D: Deserializer<'de>,
{
let value: Option<String> = Option::deserialize(deserializer)?;
let cell = OnceCell::new();
if let Some(v) = value {
cell.set(v)
.map_err(|_| D::Error::custom("Failed to set OnceCell value"))?;
}
Ok(cell)
}
}
fn is_empty_cell(cell: &OnceCell<String>) -> bool {
cell.get().is_none()
}
#[derive(Default)]
pub struct SerializeCryptoConfig {
pub enc_key: Option<KeyMaterial>,
pub sign_key: Option<KeyMaterial>,
pub enc_kid: Option<String>,
pub enc_jku: Option<String>,
pub sign_kid: Option<String>,
pub sign_jku: Option<String>,
pub policy: Option<AccessPolicy>,
pub tensor_filter: Option<Vec<String>>,
pub chunk_size: Option<usize>,
pub version: Option<String>,
}
impl SerializeCryptoConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_keys(enc_key: KeyMaterial, sign_key: KeyMaterial) -> Self {
let mut config = Self::new();
config.enc_key = Some(enc_key);
config.sign_key = Some(sign_key);
config
}
pub fn with_kids(enc_kid: &str, sign_kid: &str) -> Self {
let mut config = Self::new();
config.enc_kid = Some(enc_kid.to_string());
config.sign_kid = Some(sign_kid.to_string());
config
}
pub fn enc_kid(mut self, kid: &str) -> Self {
self.enc_kid = Some(kid.to_string());
self
}
pub fn sign_kid(mut self, kid: &str) -> Self {
self.sign_kid = Some(kid.to_string());
self
}
pub fn enc_jku(mut self, jku: &str) -> Self {
self.enc_jku = Some(jku.to_string());
self
}
pub fn sign_jku(mut self, jku: &str) -> Self {
self.sign_jku = Some(jku.to_string());
self
}
pub fn policy(mut self, policy: AccessPolicy) -> Self {
self.policy = Some(policy);
self
}
pub fn tensor_filter(mut self, filter: Vec<String>) -> Self {
self.tensor_filter = Some(filter);
self
}
pub fn chunk_size(mut self, size: usize) -> Self {
self.chunk_size = Some(size);
self
}
}
#[derive(Default)]
pub struct DeserializeCryptoConfig {
pub enc_key: Option<KeyMaterial>,
pub sign_key: Option<KeyMaterial>,
}
impl DeserializeCryptoConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_keys(enc_key: KeyMaterial, sign_key: KeyMaterial) -> Self {
let mut config = Self::new();
config.enc_key = Some(enc_key);
config.sign_key = Some(sign_key);
config
}
}
#[derive(Clone, Copy)]
enum SerializeKeyKind {
Enc,
Sign,
}
#[derive(Clone, Copy)]
enum DeserializeKeyKind {
Enc,
Sign,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SingleCryptor {
#[serde(skip)]
enc_algo: String,
#[serde(with = "cryptor_serde")]
wrapped_key: OnceCell<String>,
#[serde(with = "cryptor_serde")]
key_iv: OnceCell<String>,
#[serde(with = "cryptor_serde")]
key_tag: OnceCell<String>,
#[serde(default, skip_serializing_if = "is_empty_cell", with = "cryptor_serde")]
iv: OnceCell<String>,
#[serde(default, skip_serializing_if = "is_empty_cell", with = "cryptor_serde")]
tag: OnceCell<String>,
#[serde(default, skip_serializing_if = "is_empty_cell", with = "cryptor_serde")]
base_iv: OnceCell<String>,
#[serde(default, skip_serializing_if = "is_empty_cell", with = "cryptor_serde")]
tags: OnceCell<String>,
#[serde(skip)]
buffer: OnceCell<Arc<Vec<u8>>>,
#[serde(skip)]
ctx: OnceCell<Arc<CryptoContext>>,
}
pub(crate) struct CryptoContext {
pub(crate) data_key_ctx: PreparedKeyContext,
pub(crate) iv: Option<Vec<u8>>,
pub(crate) tag: Option<Vec<u8>>,
pub(crate) base_iv: Option<Vec<u8>>,
pub(crate) tags: Option<Vec<u8>>,
pub(crate) data_key: Zeroizing<Vec<u8>>,
}
impl fmt::Debug for CryptoContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CryptoContext")
.field("data_key_ctx", &self.data_key_ctx)
.field("iv", &self.iv.is_some())
.field("tag", &self.tag.is_some())
.field("base_iv", &self.base_iv.is_some())
.field("tags_len", &self.tags.as_ref().map(|t| t.len()))
.field("data_key", &"[REDACTED]")
.finish()
}
}
impl SingleCryptor {
fn new(key_material: &KeyMaterial) -> Result<Self, CryptoTensorsError> {
let alg = key_material
.alg
.parse::<EncryptionAlgorithm>()
.map_err(|_| CryptoTensorsError::InvalidAlgorithm(key_material.alg.clone()))?;
Ok(Self {
enc_algo: alg.to_string(),
wrapped_key: OnceCell::new(),
key_iv: OnceCell::new(),
key_tag: OnceCell::new(),
iv: OnceCell::new(),
tag: OnceCell::new(),
base_iv: OnceCell::new(),
tags: OnceCell::new(),
buffer: OnceCell::new(),
ctx: OnceCell::new(),
})
}
fn random_key(&self) -> Result<Vec<u8>, CryptoTensorsError> {
let algo = self
.enc_algo
.parse::<EncryptionAlgorithm>()
.map_err(|_| CryptoTensorsError::InvalidAlgorithm(self.enc_algo.clone()))?;
let mut key = vec![0u8; algo.key_len()];
let rng = rand::SystemRandom::new();
rng.fill(&mut key)
.map_err(|e| CryptoTensorsError::RandomGeneration(e.to_string()))?;
Ok(key)
}
fn wrap_key(
&self,
key: &[u8],
master_key_ctx: &PreparedKeyContext,
) -> Result<(), CryptoTensorsError> {
let mut key_buf = key.to_vec();
let (key_iv, key_tag) = encrypt_data(&mut key_buf, master_key_ctx)?;
self.wrapped_key
.set(BASE64.encode(&key_buf))
.map_err(|_| CryptoTensorsError::Encryption("Failed to set wrapped key".to_string()))?;
self.key_iv
.set(BASE64.encode(&key_iv))
.map_err(|_| CryptoTensorsError::Encryption("Failed to set key iv".to_string()))?;
self.key_tag
.set(BASE64.encode(&key_tag))
.map_err(|_| CryptoTensorsError::Encryption("Failed to set key tag".to_string()))?;
Ok(())
}
fn prepare_context(
&mut self,
master_key_ctx: &PreparedKeyContext,
) -> Result<(), CryptoTensorsError> {
if self.ctx.get().is_some() {
return Ok(());
}
let iv = self
.iv
.get()
.map(|s| BASE64.decode(s))
.transpose()
.map_err(|e| CryptoTensorsError::KeyUnwrap(e.to_string()))?;
let tag = self
.tag
.get()
.map(|s| BASE64.decode(s))
.transpose()
.map_err(|e| CryptoTensorsError::KeyUnwrap(e.to_string()))?;
let base_iv = self
.base_iv
.get()
.map(|s| BASE64.decode(s))
.transpose()
.map_err(|e| CryptoTensorsError::KeyUnwrap(e.to_string()))?;
let tags = self
.tags
.get()
.map(|s| BASE64.decode(s))
.transpose()
.map_err(|e| CryptoTensorsError::KeyUnwrap(e.to_string()))?;
let mut data_key = Zeroizing::new(
BASE64
.decode(self.wrapped_key.get().ok_or_else(|| {
CryptoTensorsError::KeyUnwrap("wrapped_key is empty".to_string())
})?)
.map_err(|e| CryptoTensorsError::KeyUnwrap(e.to_string()))?,
);
let key_iv = BASE64
.decode(
self.key_iv
.get()
.ok_or_else(|| CryptoTensorsError::KeyUnwrap("key_iv is empty".to_string()))?,
)
.map_err(|e| CryptoTensorsError::KeyUnwrap(e.to_string()))?;
let key_tag = BASE64
.decode(
self.key_tag
.get()
.ok_or_else(|| CryptoTensorsError::KeyUnwrap("key_tag is empty".to_string()))?,
)
.map_err(|e| CryptoTensorsError::KeyUnwrap(e.to_string()))?;
decrypt_data(&mut data_key, master_key_ctx, &key_iv, &key_tag)?;
let data_key_ctx = prepare_key_context(&data_key, &self.enc_algo)?;
self.ctx
.set(Arc::new(CryptoContext {
data_key_ctx,
iv,
tag,
base_iv,
tags,
data_key,
}))
.map_err(|_| {
CryptoTensorsError::Decryption("Failed to set CryptoContext".to_string())
})?;
Ok(())
}
fn decrypt_from_file(
&self,
file: &std::fs::File,
file_offset: u64,
len: usize,
chunk_size: Option<usize>,
) -> Result<&[u8], CryptoTensorsError> {
self.buffer
.get_or_try_init(|| {
let ctx_arc = self.ctx.get().ok_or_else(|| {
CryptoTensorsError::Decryption("Context not prepared".to_string())
})?;
#[allow(unused_variables)]
let get_strategy = || -> String {
std::env::var("CRYPTOTENSORS_PREAD_STRATEGY")
.unwrap_or_else(|_| "B".to_string())
};
#[cfg(unix)]
{
if let (Some(c_size), "B") = (chunk_size, get_strategy().as_str()) {
let mut buffer = vec![0u8; len];
let tags = ctx_arc.tags.as_ref().ok_or_else(|| {
CryptoTensorsError::Decryption(
"Missing tags for chunked decryption".to_string(),
)
})?;
let base_iv = ctx_arc.base_iv.as_ref().ok_or_else(|| {
CryptoTensorsError::Decryption(
"Missing base_iv for chunked decryption".to_string(),
)
})?;
let tag_len = ctx_arc.data_key_ctx.algo.tag_len();
if buffer.is_empty() {
if tags.len() < tag_len {
return Err(CryptoTensorsError::Decryption(
"Missing tag for empty tensor".to_string(),
));
}
crate::encryption::decrypt_data(
&mut [],
&ctx_arc.data_key_ctx,
base_iv,
&tags[0..tag_len],
)?;
} else {
use std::os::unix::fs::FileExt;
buffer.par_chunks_mut(c_size).enumerate().try_for_each(
|(i, chunk)| {
let expected_offset = i * tag_len;
if expected_offset + tag_len > tags.len() {
return Err(CryptoTensorsError::Decryption(format!(
"Missing tag for chunk {}",
i
)));
}
let tag = &tags[expected_offset..expected_offset + tag_len];
let iv = crate::encryption::derive_chunk_iv(base_iv, i)?;
let chunk_file_offset = file_offset + (i * c_size) as u64;
file.read_exact_at(chunk, chunk_file_offset).map_err(|e| {
CryptoTensorsError::Decryption(format!(
"pread failed: {}",
e
))
})?;
crate::encryption::decrypt_data(
chunk,
&ctx_arc.data_key_ctx,
&iv,
tag,
)
},
)?;
}
return Ok(Arc::new(buffer));
}
}
let mut buffer = vec![0u8; len];
#[cfg(unix)]
{
use std::os::unix::fs::FileExt;
file.read_exact_at(&mut buffer, file_offset).map_err(|e| {
CryptoTensorsError::Decryption(format!("pread failed: {}", e))
})?;
}
#[cfg(windows)]
{
use std::os::windows::fs::FileExt;
let mut bytes_read = 0;
while bytes_read < buffer.len() {
let n = file
.seek_read(&mut buffer[bytes_read..], file_offset + bytes_read as u64)
.map_err(|e| {
CryptoTensorsError::Decryption(format!("seek_read failed: {}", e))
})?;
if n == 0 {
return Err(CryptoTensorsError::Decryption(
"seek_read: unexpected EOF".to_string(),
));
}
bytes_read += n;
}
}
#[cfg(not(any(unix, windows)))]
{
use std::io::{Read, Seek, SeekFrom};
let mut file = file.try_clone().map_err(|e| {
CryptoTensorsError::Decryption(format!("file clone failed: {}", e))
})?;
file.seek(SeekFrom::Start(file_offset)).map_err(|e| {
CryptoTensorsError::Decryption(format!("seek failed: {}", e))
})?;
file.read_exact(&mut buffer).map_err(|e| {
CryptoTensorsError::Decryption(format!("read failed: {}", e))
})?;
}
self.perform_decryption(&mut buffer, ctx_arc, chunk_size)?;
Ok(Arc::new(buffer))
})
.map(|arc_ref| arc_ref.as_slice())
}
fn decrypt(&self, data: &[u8], chunk_size: Option<usize>) -> Result<&[u8], CryptoTensorsError> {
self.buffer
.get_or_try_init(|| {
let ctx_arc = self.ctx.get().ok_or_else(|| {
CryptoTensorsError::Decryption("Context not prepared".to_string())
})?;
let mut buffer = data.to_vec();
self.perform_decryption(&mut buffer, ctx_arc, chunk_size)?;
Ok(Arc::new(buffer))
})
.map(|arc_ref| arc_ref.as_slice())
}
fn perform_decryption(
&self,
buffer: &mut [u8],
ctx_arc: &CryptoContext,
chunk_size: Option<usize>,
) -> Result<(), CryptoTensorsError> {
if let Some(c_size) = chunk_size {
let base_iv = ctx_arc.base_iv.as_ref().ok_or_else(|| {
CryptoTensorsError::Decryption("Missing base_iv for chunked decryption".to_string())
})?;
let tags = ctx_arc.tags.as_ref().ok_or_else(|| {
CryptoTensorsError::Decryption("Missing tags for chunked decryption".to_string())
})?;
let tag_len = ctx_arc.data_key_ctx.algo.tag_len();
if buffer.is_empty() {
if tags.len() < tag_len {
return Err(CryptoTensorsError::Decryption(
"Missing tag for empty tensor".to_string(),
));
}
crate::encryption::decrypt_data(
&mut [],
&ctx_arc.data_key_ctx,
base_iv,
&tags[0..tag_len],
)?;
} else {
buffer
.par_chunks_mut(c_size)
.enumerate()
.try_for_each(|(i, chunk)| {
let expected_offset = i * tag_len;
if expected_offset + tag_len > tags.len() {
return Err(CryptoTensorsError::Decryption(format!(
"Missing tag for chunk {}",
i
)));
}
let tag = &tags[expected_offset..expected_offset + tag_len];
let iv = crate::encryption::derive_chunk_iv(base_iv, i)?;
crate::encryption::decrypt_data(chunk, &ctx_arc.data_key_ctx, &iv, tag)
})?;
}
} else {
let iv = ctx_arc.iv.as_ref().ok_or_else(|| {
CryptoTensorsError::Decryption("Missing iv for v1 decryption".to_string())
})?;
let tag = ctx_arc.tag.as_ref().ok_or_else(|| {
CryptoTensorsError::Decryption("Missing tag for v1 decryption".to_string())
})?;
crate::encryption::decrypt_data(buffer, &ctx_arc.data_key_ctx, iv, tag)?;
}
Ok(())
}
fn encrypt(
&self,
data: &[u8],
master_key_ctx: &PreparedKeyContext,
chunk_size: Option<usize>,
) -> Result<(), CryptoTensorsError> {
let data_key = Zeroizing::new(self.random_key()?);
let data_key_ctx = prepare_key_context(&data_key, &self.enc_algo)?;
let mut buffer = data.to_vec();
let mut out_iv = None;
let mut out_tag = None;
let mut out_base_iv = None;
let mut out_tags = None;
if let Some(c_size) = chunk_size {
let aead_algo = data_key_ctx.algo.get_aead_algo();
let mut base_iv = vec![0u8; aead_algo.nonce_len()];
let rng = rand::SystemRandom::new();
rng.fill(&mut base_iv)
.map_err(|e| CryptoTensorsError::RandomGeneration(e.to_string()))?;
let chunk_count = buffer.len().div_ceil(c_size);
let mut all_tags =
vec![0u8; std::cmp::max(1, chunk_count) * data_key_ctx.algo.tag_len()];
if buffer.is_empty() {
let tag =
crate::encryption::encrypt_data_with_iv(&mut [], &data_key_ctx, &base_iv)?;
all_tags[0..tag.len()].copy_from_slice(&tag);
} else {
let tags_res: Result<Vec<Vec<u8>>, CryptoTensorsError> = buffer
.par_chunks_mut(c_size)
.enumerate()
.map(|(i, chunk)| {
let chunk_iv = crate::encryption::derive_chunk_iv(&base_iv, i)?;
crate::encryption::encrypt_data_with_iv(chunk, &data_key_ctx, &chunk_iv)
})
.collect();
let tags_vec = tags_res?;
all_tags.clear();
for tag in tags_vec {
all_tags.extend_from_slice(&tag);
}
}
self.base_iv
.set(BASE64.encode(&base_iv))
.map_err(|_| CryptoTensorsError::Encryption("Failed to set base_iv".to_string()))?;
self.tags
.set(BASE64.encode(&all_tags))
.map_err(|_| CryptoTensorsError::Encryption("Failed to set tags".to_string()))?;
out_base_iv = Some(base_iv);
out_tags = Some(all_tags);
} else {
let (iv, tag) = encrypt_data(&mut buffer, &data_key_ctx)?;
self.iv
.set(BASE64.encode(&iv))
.map_err(|_| CryptoTensorsError::Encryption("Failed to set iv".to_string()))?;
self.tag
.set(BASE64.encode(&tag))
.map_err(|_| CryptoTensorsError::Encryption("Failed to set tag".to_string()))?;
out_iv = Some(iv);
out_tag = Some(tag);
}
self.wrap_key(&data_key, master_key_ctx)?;
self.buffer
.set(Arc::new(buffer))
.map_err(|_| CryptoTensorsError::Encryption("Failed to set buffer".to_string()))?;
self.ctx
.set(Arc::new(CryptoContext {
data_key_ctx,
iv: out_iv,
tag: out_tag,
base_iv: out_base_iv,
tags: out_tags,
data_key,
}))
.map_err(|_| {
CryptoTensorsError::Encryption("Failed to set CryptoContext".to_string())
})?;
Ok(())
}
pub fn with_new_key(&self, new_key: &KeyMaterial) -> Result<Self, CryptoTensorsError> {
if !new_key.alg.is_empty() && new_key.alg != self.enc_algo {
return Err(CryptoTensorsError::InvalidAlgorithm(format!(
"Algorithm mismatch: current={}, new={}",
self.enc_algo, new_key.alg
)));
}
let new_master_key = Zeroizing::new(new_key.get_master_key_bytes()?);
let new_master_key_ctx = prepare_key_context(&new_master_key, &self.enc_algo)?;
let (new_wrapped_key, new_key_iv, new_key_tag) = if self.wrapped_key.get().is_some() {
let ctx_arc = self.ctx.get().ok_or_else(|| {
CryptoTensorsError::Decryption(
"Context not prepared, unable to rewrap key".to_string(),
)
})?;
let mut dek = ctx_arc.data_key.clone();
let (key_iv, key_tag) = encrypt_data(&mut dek, &new_master_key_ctx)?;
(
Some(BASE64.encode(&dek)),
Some(BASE64.encode(&key_iv)),
Some(BASE64.encode(&key_tag)),
)
} else {
(None, None, None)
};
let new_cryptor = Self {
enc_algo: self.enc_algo.clone(),
wrapped_key: OnceCell::new(),
key_iv: OnceCell::new(),
key_tag: OnceCell::new(),
iv: self.iv.clone(), tag: self.tag.clone(), base_iv: self.base_iv.clone(), tags: self.tags.clone(), buffer: OnceCell::new(), ctx: self.ctx.clone(), };
if let (Some(wk), Some(kiv), Some(ktag)) = (new_wrapped_key, new_key_iv, new_key_tag) {
new_cryptor.wrapped_key.set(wk).map_err(|_| {
CryptoTensorsError::Encryption("Failed to set wrapped key".to_string())
})?;
new_cryptor
.key_iv
.set(kiv)
.map_err(|_| CryptoTensorsError::Encryption("Failed to set key iv".to_string()))?;
new_cryptor
.key_tag
.set(ktag)
.map_err(|_| CryptoTensorsError::Encryption("Failed to set key tag".to_string()))?;
}
Ok(new_cryptor)
}
}
impl Clone for SingleCryptor {
fn clone(&self) -> Self {
Self {
enc_algo: self.enc_algo.clone(),
wrapped_key: self.wrapped_key.clone(),
key_iv: self.key_iv.clone(),
key_tag: self.key_tag.clone(),
iv: self.iv.clone(),
tag: self.tag.clone(),
base_iv: self.base_iv.clone(),
tags: self.tags.clone(),
buffer: self.buffer.clone(),
ctx: self.ctx.clone(),
}
}
}
#[derive(Debug, Clone)]
struct HeaderSigner {
alg: String,
priv_key: Option<Vec<u8>>,
pub_key: Option<Vec<u8>>,
signature: OnceCell<Vec<u8>>,
}
impl HeaderSigner {
fn new(key_material: &KeyMaterial) -> Result<Self, CryptoTensorsError> {
let alg = key_material
.alg
.parse::<SignatureAlgorithm>()
.map_err(|_| CryptoTensorsError::InvalidAlgorithm(key_material.alg.clone()))?;
let priv_key = key_material
.d_priv
.get()
.and_then(|k| k.as_ref())
.map(|k| {
BASE64.decode(k).map_err(|e| {
CryptoTensorsError::InvalidKey(format!("Invalid base64 private key: {}", e))
})
})
.transpose()?;
let pub_key = key_material
.x_pub
.get()
.and_then(|k| k.as_ref())
.map(|k| {
BASE64.decode(k).map_err(|e| {
CryptoTensorsError::InvalidKey(format!("Invalid base64 public key: {}", e))
})
})
.transpose()?;
Ok(Self {
alg: alg.to_string(),
priv_key,
pub_key,
signature: OnceCell::new(),
})
}
fn sign(&self, data: &[u8]) -> Result<(), CryptoTensorsError> {
match &self.priv_key {
Some(key) => {
let signature = sign_data(data, key, &self.alg)?;
self.signature.set(signature).map_err(|_| {
CryptoTensorsError::Signing("Signature already set".to_string())
})?;
Ok(())
}
None => Err(CryptoTensorsError::MissingSigningKey),
}
}
fn verify(&self, data: &[u8]) -> Result<bool, CryptoTensorsError> {
match &self.pub_key {
Some(key) => match self.signature.get() {
Some(signature) => verify_signature(data, signature, key, &self.alg),
None => Err(CryptoTensorsError::MissingSignature(
"No signature to verify".to_string(),
)),
},
None => Err(CryptoTensorsError::MissingVerificationKey),
}
}
}
#[derive(Debug)]
pub struct CryptoTensors {
cryptors: HashMap<String, SingleCryptor>,
signer: HeaderSigner,
enc_key: KeyMaterial,
sign_key: KeyMaterial,
policy: AccessPolicy,
version: String,
chunk_size: Option<usize>,
}
impl CryptoTensors {
pub fn get(&self, tensor_name: &str) -> Option<&SingleCryptor> {
self.cryptors.get(tensor_name)
}
pub fn enc_key(&self) -> &KeyMaterial {
&self.enc_key
}
pub fn sign_key(&self) -> &KeyMaterial {
&self.sign_key
}
pub fn policy(&self) -> &AccessPolicy {
&self.policy
}
pub fn from_serialize_config(
tensors: Vec<String>,
config: &SerializeCryptoConfig,
) -> Result<Option<Self>, CryptoTensorsError> {
let enc_key = Self::resolve_key_from_serialize_config(config, SerializeKeyKind::Enc)?;
let sign_key = Self::resolve_key_from_serialize_config(config, SerializeKeyKind::Sign)?;
enc_key.validate(ValidateMode::ForCreation)?;
sign_key.validate(ValidateMode::ForCreation)?;
let matched_tensors = match &config.tensor_filter {
None => tensors,
Some(names) => names
.iter()
.filter(|name| tensors.contains(name))
.cloned()
.collect(),
};
if matched_tensors.is_empty() {
return Ok(None);
}
let cryptors = matched_tensors
.iter()
.map(|name| {
let cryptor = SingleCryptor::new(&enc_key)?;
Ok((name.clone(), cryptor))
})
.collect::<Result<HashMap<String, SingleCryptor>, CryptoTensorsError>>()?;
let signer = HeaderSigner::new(&sign_key)?;
let version = config
.version
.clone()
.unwrap_or_else(|| CRYPTOTENSORS_VERSION_V2.to_string());
if version != CRYPTOTENSORS_VERSION_V1 && version != CRYPTOTENSORS_VERSION_V2 {
return Err(CryptoTensorsError::UnsupportedVersion(version.clone()));
}
let chunk_size = if version == CRYPTOTENSORS_VERSION_V1 {
None
} else {
let size = config.chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE);
if size == 0 {
return Err(CryptoTensorsError::InvalidKey(
"chunk_size must be greater than 0".to_string(),
));
}
Some(size)
};
Ok(Some(Self {
cryptors,
enc_key,
sign_key,
signer,
policy: config.policy.clone().unwrap_or_default(),
version,
chunk_size,
}))
}
pub fn generate_metadata(
&self,
tensors: Vec<(String, TensorInfo)>,
metadata: Option<HashMap<String, String>>,
) -> Result<Option<HashMap<String, String>>, CryptoTensorsError> {
let mut new_metadata = metadata.unwrap_or_default();
new_metadata.remove("__signature__");
new_metadata.remove("__encryption__");
new_metadata.remove("__crypto_keys__");
new_metadata.remove("__policy__");
let mut key_material = serde_json::json!({
"version": self.version,
"enc": self.enc_key,
"sign": self.sign_key
});
if let Some(cs) = self.chunk_size {
key_material
.as_object_mut()
.unwrap()
.insert("chunk_size".to_string(), serde_json::json!(cs));
}
let key_material_json = serde_json::to_string(&key_material)
.map_err(|e| CryptoTensorsError::Encryption(e.to_string()))?;
new_metadata.insert("__crypto_keys__".to_string(), key_material_json);
let sorted_cryptors: std::collections::BTreeMap<_, _> = self.cryptors.iter().collect();
let crypto_json = serde_json::to_string(&sorted_cryptors)
.map_err(|e| CryptoTensorsError::Encryption(e.to_string()))?;
new_metadata.insert("__encryption__".to_string(), crypto_json);
let policy_json = serde_json::to_string(&self.policy)
.map_err(|e| CryptoTensorsError::Encryption(e.to_string()))?;
new_metadata.insert("__policy__".to_string(), policy_json);
let header = Metadata::new(Some(new_metadata.clone()), tensors).map_err(|e| {
CryptoTensorsError::InvalidKey(format!("Failed to create metadata: {}", e))
})?;
let header_json = serde_json::to_string(&header)?;
self.signer.sign(header_json.as_bytes())?;
let signature =
self.signer.signature.get().ok_or_else(|| {
CryptoTensorsError::Signing("Failed to get signature".to_string())
})?;
new_metadata.insert("__signature__".to_string(), BASE64.encode(signature));
Ok(Some(new_metadata))
}
fn params_for_serialize<'a>(
config: &'a SerializeCryptoConfig,
kind: SerializeKeyKind,
) -> KeyLookupParams<'a> {
let (direct, jku, kid, registry_allowed) = match kind {
SerializeKeyKind::Enc => (
config.enc_key.as_ref(),
config.enc_jku.as_deref(),
config.enc_kid.as_deref(),
config.enc_key.is_none(),
),
SerializeKeyKind::Sign => (
config.sign_key.as_ref(),
config.sign_jku.as_deref(),
config.sign_kid.as_deref(),
config.sign_key.is_none(),
),
};
KeyLookupParams {
direct,
jku,
kid,
registry_allowed,
}
}
fn params_for_deserialize<'a>(
key: &'a KeyMaterial,
config: Option<&'a DeserializeCryptoConfig>,
kind: DeserializeKeyKind,
) -> KeyLookupParams<'a> {
let (direct, registry_allowed) = match config {
None => (None, true),
Some(c) => {
let direct = match kind {
DeserializeKeyKind::Enc => c.enc_key.as_ref(),
DeserializeKeyKind::Sign => c.sign_key.as_ref(),
};
let registry_allowed = match kind {
DeserializeKeyKind::Enc => c.enc_key.is_none(),
DeserializeKeyKind::Sign => c.sign_key.is_none(),
};
(direct, registry_allowed)
}
};
KeyLookupParams {
direct,
jku: key.jku.as_deref(),
kid: key.kid.as_deref(),
registry_allowed,
}
}
fn resolve_key_from_deserialize_config(
key: &mut KeyMaterial,
config: Option<&DeserializeCryptoConfig>,
kind: DeserializeKeyKind,
) -> Result<(), CryptoTensorsError> {
let params = Self::params_for_deserialize(key, config, kind);
let role = match kind {
DeserializeKeyKind::Enc => KeyRole::Master,
DeserializeKeyKind::Sign => KeyRole::Verify,
};
let resolved = resolve_key(role, ¶ms, false)?;
if params.direct.is_some() {
*key = resolved;
} else {
key.update_from_key(&resolved)?;
if !resolved.alg.is_empty() {
key.alg = resolved.alg.clone();
}
}
Ok(())
}
fn resolve_key_from_serialize_config(
config: &SerializeCryptoConfig,
kind: SerializeKeyKind,
) -> Result<KeyMaterial, CryptoTensorsError> {
let params = Self::params_for_serialize(config, kind);
let role = match kind {
SerializeKeyKind::Enc => KeyRole::Master,
SerializeKeyKind::Sign => KeyRole::Signing,
};
let mut out = resolve_key(role, ¶ms, true)?;
if params.direct.is_none() {
let (kid, jku) = match kind {
SerializeKeyKind::Enc => (config.enc_kid.clone(), config.enc_jku.clone()),
SerializeKeyKind::Sign => (config.sign_kid.clone(), config.sign_jku.clone()),
};
if kid.is_some() {
out.kid = kid;
}
if jku.is_some() {
out.jku = jku;
}
}
Ok(out)
}
pub fn from_header(header: &Metadata) -> Result<Option<Self>, CryptoTensorsError> {
Self::from_header_with_config(header, None)
}
pub fn from_header_with_config(
header: &Metadata,
config: Option<&DeserializeCryptoConfig>,
) -> Result<Option<Self>, CryptoTensorsError> {
let metadata = match header.metadata().as_ref() {
Some(m) => m,
None => return Ok(None),
};
let encryption_info = match metadata.get("__encryption__") {
Some(info) => info,
None => return Ok(None),
};
let key_materials = metadata.get("__crypto_keys__").ok_or_else(|| {
CryptoTensorsError::InvalidKey("Missing __crypto_keys__ in metadata".to_string())
})?;
let signature_hex = metadata.get("__signature__").ok_or_else(|| {
CryptoTensorsError::MissingSignature("Missing __signature__ in metadata".to_string())
})?;
let policy_str = metadata.get("__policy__").ok_or_else(|| {
CryptoTensorsError::Policy("Missing __policy__ in metadata".to_string())
})?;
let key_materials: serde_json::Value =
serde_json::from_str(key_materials).map_err(|e| {
CryptoTensorsError::InvalidKey(format!("Failed to parse key materials: {}", e))
})?;
let version = key_materials
.get("version")
.and_then(|v| v.as_str())
.ok_or(CryptoTensorsError::MissingVersion)?;
if version != CRYPTOTENSORS_VERSION_V1 && version != CRYPTOTENSORS_VERSION_V2 {
return Err(CryptoTensorsError::UnsupportedVersion(version.to_string()));
}
let chunk_size = match key_materials.get("chunk_size") {
Some(v) => {
let raw = v.as_u64().ok_or_else(|| {
CryptoTensorsError::InvalidKey(
"Invalid chunk_size in __crypto_keys__ header".to_string(),
)
})?;
if raw == 0 {
return Err(CryptoTensorsError::InvalidKey(
"chunk_size must be greater than 0 in __crypto_keys__ header".to_string(),
));
}
Some(raw as usize)
}
None => {
if version == CRYPTOTENSORS_VERSION_V2 {
Some(DEFAULT_CHUNK_SIZE)
} else {
None
}
}
};
let mut enc_key = KeyMaterial::from_header(&key_materials["enc"])?;
let mut sign_key = KeyMaterial::from_header(&key_materials["sign"])?;
Self::resolve_key_from_deserialize_config(&mut enc_key, config, DeserializeKeyKind::Enc)?;
Self::resolve_key_from_deserialize_config(&mut sign_key, config, DeserializeKeyKind::Sign)?;
let signer = HeaderSigner::new(&sign_key)?;
let signature = BASE64
.decode(signature_hex)
.map_err(|_| CryptoTensorsError::InvalidSignatureFormat)?;
signer
.signature
.set(signature)
.expect("Failed to set signature");
let mut metadata_map = header.metadata().clone().unwrap_or_default();
metadata_map.remove("__signature__");
let mut tensors_vec = Vec::new();
for key in header.offset_keys() {
let info = header
.info(&key)
.ok_or_else(|| {
CryptoTensorsError::Verification(format!("Tensor {} not found in header", key))
})?
.clone();
tensors_vec.push((key, info));
}
let header_for_verify = Metadata::new(Some(metadata_map), tensors_vec)
.map_err(|e| CryptoTensorsError::Verification(e.to_string()))?;
let header_for_verify_json = serde_json::to_string(&header_for_verify)
.map_err(|e| CryptoTensorsError::Verification(e.to_string()))?;
if !signer.verify(header_for_verify_json.as_bytes())? {
return Err(CryptoTensorsError::Verification(
"Signature verification failed".to_string(),
));
}
let policy: AccessPolicy = serde_json::from_str(policy_str)
.map_err(|e| CryptoTensorsError::Policy(format!("Failed to parse policy: {}", e)))?;
if !policy.evaluate(String::new())? {
return Err(CryptoTensorsError::Policy(
"Policy evaluation denied".to_string(),
));
}
if enc_key.k.get().and_then(|v| v.as_ref()).is_none() {
let params = Self::params_for_deserialize(&enc_key, config, DeserializeKeyKind::Enc);
if params.registry_allowed {
let resolved = resolve_key(KeyRole::Master, ¶ms, false)?;
enc_key.update_from_key(&resolved)?;
if !resolved.alg.is_empty() {
enc_key.alg = resolved.alg.clone();
}
} else {
return Err(CryptoTensorsError::KeyLoad(
"encryption key material missing: provide enc_key or provider when not using registry".to_string(),
));
}
}
let master_key = Zeroizing::new(enc_key.get_master_key_bytes()?);
let master_key_ctx = prepare_key_context(&master_key, &enc_key.alg)?;
let mut cryptors: HashMap<String, SingleCryptor> = serde_json::from_str(encryption_info)
.map_err(|e| CryptoTensorsError::Encryption(e.to_string()))?;
if cryptors.len() >= PARALLEL_CONTEXT_THRESHOLD {
let prep_result: Result<Vec<()>, CryptoTensorsError> = cryptors
.par_iter_mut()
.map(|(_, cryptor)| {
cryptor.enc_algo = enc_key.alg.clone();
cryptor.prepare_context(&master_key_ctx)
})
.collect();
prep_result?;
} else {
for (_, cryptor) in cryptors.iter_mut() {
cryptor.enc_algo = enc_key.alg.clone();
cryptor.prepare_context(&master_key_ctx)?;
}
}
Ok(Some(Self {
cryptors,
signer,
enc_key,
sign_key,
policy,
version: version.to_string(),
chunk_size,
}))
}
pub fn silent_decrypt_from_file(
&self,
tensor_name: &str,
file: &std::fs::File,
file_offset: u64,
len: usize,
) -> Result<(), CryptoTensorsError> {
match self.get(tensor_name) {
Some(cryptor) => {
cryptor.decrypt_from_file(file, file_offset, len, self.chunk_size)?;
Ok(())
}
None => Ok(()),
}
}
pub fn silent_decrypt<'a>(
&'a self,
tensor_name: &str,
data: &'a [u8],
) -> Result<&'a [u8], CryptoTensorsError> {
match self.get(tensor_name) {
Some(cryptor) => {
cryptor.decrypt(data, self.chunk_size)?;
cryptor.buffer.get().map(|b| b.as_slice()).ok_or_else(|| {
CryptoTensorsError::Decryption("Decrypted buffer not available".to_string())
})
}
None => Ok(data),
}
}
pub fn encrypt_tensors<'a>(
&self,
get_data: impl Fn(&str) -> Option<Cow<'a, [u8]>>,
) -> Result<(), CryptoTensorsError> {
if self.cryptors.is_empty() {
return Ok(());
}
let master_key = Zeroizing::new(self.enc_key.get_master_key_bytes()?);
let master_key_ctx = prepare_key_context(&master_key, &self.enc_key.alg)?;
for (name, cryptor) in &self.cryptors {
if let Some(data) = get_data(name) {
cryptor.encrypt(&data, &master_key_ctx, self.chunk_size)?;
}
}
Ok(())
}
pub fn get_buffer(&self, tensor_name: &str) -> Option<Arc<Vec<u8>>> {
match self.get(tensor_name) {
Some(cryptor) => cryptor.buffer.get().cloned(),
None => None,
}
}
pub fn rewrap(&mut self, new_config: &SerializeCryptoConfig) -> Result<(), CryptoTensorsError> {
let new_enc_key =
Self::resolve_key_from_serialize_config(new_config, SerializeKeyKind::Enc)?;
let new_sign_key =
Self::resolve_key_from_serialize_config(new_config, SerializeKeyKind::Sign)?;
new_enc_key.validate(ValidateMode::ForCreation)?;
new_sign_key.validate(ValidateMode::ForCreation)?;
let cryptor_names: Vec<String> = self.cryptors.keys().cloned().collect();
for name in cryptor_names {
let old_cryptor = self.cryptors.get(&name).ok_or_else(|| {
CryptoTensorsError::KeyUnwrap(format!("Cryptor {} not found", name))
})?;
let new_cryptor = old_cryptor.with_new_key(&new_enc_key)?;
self.cryptors.insert(name, new_cryptor);
}
self.enc_key = new_enc_key;
self.sign_key = new_sign_key;
self.signer = HeaderSigner::new(&self.sign_key)?;
if let Some(policy) = &new_config.policy {
self.policy = policy.clone();
}
Ok(())
}
}