use crate::cryptotensors::CryptoTensorsError;
use crate::encryption::EncryptionAlgorithm;
use crate::registry;
use crate::signing::SignatureAlgorithm;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use once_cell::sync::OnceCell;
use ring::rand::{self, SecureRandom};
use ring::signature::{Ed25519KeyPair, KeyPair};
use serde::{de::Error, Deserialize, Deserializer, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ValidateMode {
ForCreation,
FromHeader,
FromJwk,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum JwkKeyType {
Oct,
Okp,
}
mod key_material_serde {
use super::*;
pub fn deserialize<'de, D>(deserializer: D) -> Result<OnceCell<Option<String>>, D::Error>
where
D: Deserializer<'de>,
{
let value: Option<String> = Option::deserialize(deserializer)?;
let cell = OnceCell::new();
if let Some(v) = value {
cell.set(Some(v))
.map_err(|_| D::Error::custom("Failed to set OnceCell value"))?;
}
Ok(cell)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyMaterial {
#[serde(rename = "kty")]
pub key_type: JwkKeyType,
pub alg: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub kid: Option<String>,
#[serde(skip_serializing, default)]
#[serde(with = "key_material_serde")]
pub k: OnceCell<Option<String>>,
#[serde(skip_serializing, default)]
#[serde(with = "key_material_serde")]
#[serde(rename = "x")]
pub x_pub: OnceCell<Option<String>>,
#[serde(skip_serializing, default)]
#[serde(with = "key_material_serde")]
#[serde(rename = "d")]
pub d_priv: OnceCell<Option<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jku: Option<String>,
}
impl KeyMaterial {
fn new_internal(
key_type: JwkKeyType,
alg: String,
kid: Option<String>,
jku: Option<String>,
k: Option<Vec<u8>>,
x_pub: Option<Vec<u8>>,
d_priv: Option<Vec<u8>>,
) -> Result<Self, CryptoTensorsError> {
let key_material = Self {
key_type,
alg,
kid,
k: OnceCell::new(),
x_pub: OnceCell::new(),
d_priv: OnceCell::new(),
jku,
};
if let Some(k) = k {
key_material
.k
.set(Some(BASE64.encode(&k)))
.map_err(|_| CryptoTensorsError::KeyCreation("Failed to set key".to_string()))?;
}
if let Some(x_pub) = x_pub {
key_material
.x_pub
.set(Some(BASE64.encode(&x_pub)))
.map_err(|_| {
CryptoTensorsError::KeyCreation("Failed to set public key".to_string())
})?;
}
if let Some(d_priv) = d_priv {
key_material
.d_priv
.set(Some(BASE64.encode(&d_priv)))
.map_err(|_| {
CryptoTensorsError::KeyCreation("Failed to set private key".to_string())
})?;
}
Ok(key_material)
}
pub fn load_key(&self) -> Result<(), CryptoTensorsError> {
let jwk = match self.key_type {
JwkKeyType::Oct => registry::get_master_key(self.jku.as_deref(), self.kid.as_deref())?,
JwkKeyType::Okp => registry::get_verify_key(self.jku.as_deref(), self.kid.as_deref())?,
};
let key = Self::from_jwk(&jwk, false)?;
self.update_from_key(&key)
}
pub(crate) fn update_from_key(&self, key: &KeyMaterial) -> Result<(), CryptoTensorsError> {
match self.key_type {
JwkKeyType::Oct => {
if let Some(Some(k)) = key.k.get() {
self.k
.set(Some(k.clone()))
.map_err(|_| CryptoTensorsError::MissingMasterKey)?;
} else {
return Err(CryptoTensorsError::MissingMasterKey);
}
}
JwkKeyType::Okp => {
if let Some(Some(x_pub)) = key.x_pub.get() {
self.x_pub
.set(Some(x_pub.clone()))
.map_err(|_| CryptoTensorsError::MissingVerificationKey)?;
} else {
return Err(CryptoTensorsError::MissingVerificationKey);
}
}
}
Ok(())
}
pub fn validate(&self, mode: ValidateMode) -> Result<(), CryptoTensorsError> {
if self.key_type != JwkKeyType::Oct && self.key_type != JwkKeyType::Okp {
return Err(CryptoTensorsError::InvalidKey(
"Invalid key type".to_string(),
));
}
match mode {
ValidateMode::ForCreation | ValidateMode::FromHeader => {
if self.alg.is_empty() {
return Err(CryptoTensorsError::InvalidAlgorithm(
"Missing alg field".to_string(),
));
}
}
ValidateMode::FromJwk => {
}
}
match mode {
ValidateMode::ForCreation => {
match self.key_type {
JwkKeyType::Oct => {
if self.k.get().and_then(|k| k.as_ref()).is_none() {
return Err(CryptoTensorsError::MissingMasterKey);
}
}
JwkKeyType::Okp => {
if self.d_priv.get().and_then(|k| k.as_ref()).is_none() {
return Err(CryptoTensorsError::MissingSigningKey);
}
if self.x_pub.get().and_then(|k| k.as_ref()).is_none() {
return Err(CryptoTensorsError::MissingVerificationKey);
}
}
}
}
ValidateMode::FromJwk => match self.key_type {
JwkKeyType::Oct => {
if self.k.get().and_then(|k| k.as_ref()).is_none() {
return Err(CryptoTensorsError::MissingMasterKey);
}
}
JwkKeyType::Okp => {
if self.x_pub.get().and_then(|k| k.as_ref()).is_none() {
return Err(CryptoTensorsError::MissingVerificationKey);
}
}
},
ValidateMode::FromHeader => {
}
}
if !self.alg.is_empty() {
match self.key_type {
JwkKeyType::Oct => {
if self.alg.parse::<EncryptionAlgorithm>().is_err() {
return Err(CryptoTensorsError::InvalidAlgorithm(
"Invalid encryption algorithm".to_string(),
));
}
}
JwkKeyType::Okp => {
if self.alg.parse::<SignatureAlgorithm>().is_err() {
return Err(CryptoTensorsError::InvalidAlgorithm(
"Invalid signature algorithm".to_string(),
));
}
}
}
}
Ok(())
}
pub fn new_enc_key(
key_b64: Option<String>,
alg: Option<String>,
kid: Option<String>,
jku: Option<String>,
) -> Result<Self, CryptoTensorsError> {
let alg = alg.unwrap_or_else(|| "aes256gcm".to_string());
let enc_alg = alg
.parse::<EncryptionAlgorithm>()
.map_err(|_| CryptoTensorsError::InvalidAlgorithm(alg.clone()))?;
let key_bytes = if let Some(ref b64_str) = key_b64 {
let bytes = BASE64.decode(b64_str).map_err(|e| {
CryptoTensorsError::InvalidKey(format!("Invalid base64 key: {}", e))
})?;
if bytes.len() != enc_alg.key_len() {
return Err(CryptoTensorsError::InvalidKeyLength {
expected: enc_alg.key_len(),
actual: bytes.len(),
});
}
bytes
} else {
let mut key = vec![0u8; enc_alg.key_len()];
let rng = rand::SystemRandom::new();
rng.fill(&mut key)
.map_err(|e| CryptoTensorsError::RandomGeneration(e.to_string()))?;
key
};
KeyMaterial::new_internal(JwkKeyType::Oct, alg, kid, jku, Some(key_bytes), None, None)
}
pub fn new_sign_key(
public_b64: Option<String>,
private_b64: Option<String>,
alg: Option<String>,
kid: Option<String>,
jku: Option<String>,
) -> Result<Self, CryptoTensorsError> {
let alg = alg.unwrap_or_else(|| "ed25519".to_string());
let sig_alg = alg
.parse::<SignatureAlgorithm>()
.map_err(|_| CryptoTensorsError::InvalidAlgorithm(alg.clone()))?;
match sig_alg {
SignatureAlgorithm::Ed25519 => {
let public = if let Some(pub_b64) = public_b64 {
let pub_bytes = BASE64.decode(&pub_b64).map_err(|e| {
CryptoTensorsError::InvalidKey(format!("Invalid base64 public key: {}", e))
})?;
if pub_bytes.len() != 32 {
return Err(CryptoTensorsError::InvalidKeyLength {
expected: 32,
actual: pub_bytes.len(),
});
}
Some(pub_bytes)
} else {
None
};
let private = if let Some(priv_b64) = private_b64 {
let priv_bytes = BASE64.decode(&priv_b64).map_err(|e| {
CryptoTensorsError::InvalidKey(format!("Invalid base64 private key: {}", e))
})?;
if priv_bytes.len() != 32 {
return Err(CryptoTensorsError::InvalidKeyLength {
expected: 32,
actual: priv_bytes.len(),
});
}
Some(priv_bytes)
} else {
None
};
let (public, private) = if public.is_none() && private.is_none() {
let rng = rand::SystemRandom::new();
let mut private_key = [0u8; 32];
rng.fill(&mut private_key)
.map_err(|e| CryptoTensorsError::RandomGeneration(e.to_string()))?;
let key_pair = Ed25519KeyPair::from_seed_unchecked(&private_key)
.map_err(|e| CryptoTensorsError::KeyCreation(e.to_string()))?;
(
Some(key_pair.public_key().as_ref().to_vec()),
Some(private_key.to_vec()),
)
} else if public.is_some() && private.is_some() {
Ed25519KeyPair::from_seed_and_public_key(
private.clone().unwrap().as_slice(),
public.clone().unwrap().as_slice(),
)
.map_err(|e| CryptoTensorsError::KeyCreation(e.to_string()))?;
(public, private)
} else if private.is_some() && public.is_none() {
let private_key = private.clone().unwrap();
let key_pair = Ed25519KeyPair::from_seed_unchecked(&private_key)
.map_err(|e| CryptoTensorsError::KeyCreation(e.to_string()))?;
(Some(key_pair.public_key().as_ref().to_vec()), private)
} else {
(public, private)
};
KeyMaterial::new_internal(JwkKeyType::Okp, alg, kid, jku, None, public, private)
}
}
}
pub fn to_jwk(&self) -> Result<String, CryptoTensorsError> {
#[derive(Serialize)]
struct JwkOut<'a> {
#[serde(rename = "kty")]
key_type: &'a JwkKeyType,
alg: &'a String,
#[serde(skip_serializing_if = "Option::is_none")]
kid: &'a Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
jku: &'a Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
k: Option<&'a String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "x")]
x_pub: Option<&'a String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "d")]
d_priv: Option<&'a String>,
}
let jwk = JwkOut {
key_type: &self.key_type,
alg: &self.alg,
kid: &self.kid,
jku: &self.jku,
k: self.k.get().and_then(|v| v.as_ref()),
x_pub: self.x_pub.get().and_then(|v| v.as_ref()),
d_priv: self.d_priv.get().and_then(|v| v.as_ref()),
};
serde_json::to_string(&jwk)
.map_err(|e| CryptoTensorsError::KeyCreation(format!("Failed to serialize JWK: {}", e)))
}
pub fn get_master_key_bytes(&self) -> Result<Vec<u8>, CryptoTensorsError> {
let k = self
.k
.get()
.and_then(|v| v.as_ref())
.ok_or(CryptoTensorsError::MissingMasterKey)?;
BASE64.decode(k).map_err(|e| {
CryptoTensorsError::InvalidKey(format!("Invalid base64 master key: {}", e))
})
}
pub fn get_public_key_bytes(&self) -> Result<Vec<u8>, CryptoTensorsError> {
let x = self
.x_pub
.get()
.and_then(|v| v.as_ref())
.ok_or(CryptoTensorsError::MissingVerificationKey)?;
BASE64.decode(x).map_err(|e| {
CryptoTensorsError::InvalidKey(format!("Invalid base64 public key: {}", e))
})
}
pub fn get_private_key_bytes(&self) -> Result<Vec<u8>, CryptoTensorsError> {
let d = self
.d_priv
.get()
.and_then(|v| v.as_ref())
.ok_or(CryptoTensorsError::MissingSigningKey)?;
BASE64.decode(d).map_err(|e| {
CryptoTensorsError::InvalidKey(format!("Invalid base64 private key: {}", e))
})
}
pub fn from_header(header: &serde_json::Value) -> Result<Self, CryptoTensorsError> {
let key: KeyMaterial = serde_json::from_value(header.clone()).map_err(|e| {
CryptoTensorsError::InvalidKey(format!("Failed to parse key material: {}", e))
})?;
key.validate(ValidateMode::FromHeader)?;
Ok(key)
}
pub fn from_jwk(
jwk_str: &serde_json::Value,
is_for_save: bool,
) -> Result<Self, CryptoTensorsError> {
let key_material: KeyMaterial = serde_json::from_value(jwk_str.clone())
.map_err(|e| CryptoTensorsError::InvalidKey(format!("Failed to parse JWK: {}", e)))?;
key_material.validate(if is_for_save {
ValidateMode::ForCreation
} else {
ValidateMode::FromJwk
})?;
Ok(key_material)
}
pub fn from_jwk_str(jwk_json: &str, is_for_save: bool) -> Result<Self, CryptoTensorsError> {
let jwk_value: serde_json::Value = serde_json::from_str(jwk_json)
.map_err(|e| CryptoTensorsError::InvalidKey(format!("Invalid JSON: {}", e)))?;
Self::from_jwk(&jwk_value, is_for_save)
}
pub fn kid(&self) -> Option<&str> {
self.kid.as_deref()
}
pub fn jku(&self) -> Option<&str> {
self.jku.as_deref()
}
pub fn set_kid(&mut self, kid: &str) {
self.kid = Some(kid.to_string());
}
pub fn set_jku(&mut self, jku: &str) {
self.jku = Some(jku.to_string());
}
}
#[derive(Clone, Copy)]
pub enum KeyRole {
Master,
Signing,
Verify,
}
impl KeyRole {
fn missing_key_message(&self) -> &'static str {
match self {
KeyRole::Master => "encryption key required: provide enc_key or use registry (enc_kid)",
KeyRole::Signing => "signing key required: provide sign_key or use registry (sign_kid)",
KeyRole::Verify => {
"verification key required: provide sign_key or use registry (sign_kid)"
}
}
}
}
pub struct KeyLookupParams<'a> {
pub direct: Option<&'a KeyMaterial>,
pub jku: Option<&'a str>,
pub kid: Option<&'a str>,
pub registry_allowed: bool,
}
pub fn resolve_key(
role: KeyRole,
params: &KeyLookupParams<'_>,
for_creation: bool,
) -> Result<KeyMaterial, CryptoTensorsError> {
if let Some(k) = params.direct {
return Ok(k.clone());
}
let jwk = if params.registry_allowed {
match role {
KeyRole::Master => registry::get_master_key(params.jku, params.kid)?,
KeyRole::Signing => registry::get_signing_key(params.jku, params.kid)?,
KeyRole::Verify => registry::get_verify_key(params.jku, params.kid)?,
}
} else {
return Err(CryptoTensorsError::KeyLoad(
role.missing_key_message().to_string(),
));
};
KeyMaterial::from_jwk(&jwk, for_creation)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keymaterial_convenience_methods() {
let jwk_json = r#"{"kty":"oct","alg":"aes256gcm","kid":"test","k":"dGVzdC1rZXktMzItYnl0ZXMtbG9uZy1lbmNyeXB0aW9u"}"#;
let key = KeyMaterial::from_jwk_str(jwk_json, false).unwrap();
assert_eq!(key.kid(), Some("test"));
assert_eq!(key.alg, "aes256gcm");
let mut key = KeyMaterial::new_enc_key(None, None, None, None).unwrap();
key.set_kid("new-kid");
key.set_jku("file:///new/path.jwk");
assert_eq!(key.kid(), Some("new-kid"));
assert_eq!(key.jku(), Some("file:///new/path.jwk"));
}
}