use crate::cryptotensors::CryptoTensorsError;
use crate::key::KeyMaterial;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use ring::signature;
use serde_json::Value;
use std::collections::HashMap;
use std::fs::File;
use std::io::Read;
use std::path::Path;
use std::sync::{OnceLock, RwLock};
const PROVIDER_PUBLIC_KEYS: &[(&str, &str)] = &[
(
"koalavault-vllm",
"vM5cRuHaIyKt3RAELcqc4+nXbSbCh53ABYt2/lOGqw8=",
),
];
pub const PRIORITY_DIRECT: i32 = 100;
pub const PRIORITY_TEMP: i32 = PRIORITY_DIRECT;
pub const PRIORITY_NATIVE: i32 = 50;
pub const PRIORITY_FILE: i32 = 10;
pub const PRIORITY_ENV: i32 = 0;
#[cfg(feature = "provider-file")]
#[derive(Debug, Clone)]
pub struct FileJwkPath {
pub path: String,
}
#[cfg(feature = "provider-file")]
impl FileJwkPath {
pub fn parse(url: &str) -> Result<Self, CryptoTensorsError> {
if let Some((scheme, rest)) = url.split_once("://") {
match scheme.to_lowercase().as_str() {
"file" => (),
_ => return Err(CryptoTensorsError::InvalidJwkUrl(url.to_string())),
};
let path = if let Some(path_str) = rest.strip_prefix('/') {
#[cfg(windows)]
{
if let Some(colon_pos) = path_str.find(':') {
if colon_pos == 1
&& path_str
.chars()
.next()
.map(|c| c.is_ascii_alphabetic())
.unwrap_or(false)
{
let drive = &path_str[..2]; let rest_path = &path_str[2..];
let normalized = rest_path.replace('\\', "/");
format!("{}:{}", drive, normalized)
} else {
rest.to_string()
}
} else {
rest.to_string()
}
}
#[cfg(not(windows))]
{
format!("/{}", path_str)
}
} else if let Some(stripped) = rest.strip_prefix('~') {
let home = if cfg!(windows) {
std::env::var("USERPROFILE")
.or_else(|_| std::env::var("HOME"))
.map_err(|e| {
CryptoTensorsError::KeyLoad(format!(
"Failed to get home directory: {}",
e
))
})?
} else {
std::env::var("HOME").map_err(|e| {
CryptoTensorsError::KeyLoad(format!("Failed to get HOME: {}", e))
})?
};
format!("{}{}", home, stripped)
} else {
rest.to_string()
};
Ok(Self { path })
} else {
Err(CryptoTensorsError::InvalidJwkUrl(
"Missing URL scheme".to_string(),
))
}
}
}
#[cfg(feature = "provider-file")]
pub fn load_jwks_from_file(path: &str) -> Result<Vec<serde_json::Value>, CryptoTensorsError> {
let content = std::fs::read_to_string(path)
.map_err(|e| CryptoTensorsError::KeyLoad(format!("Failed to read file {}: {}", path, e)))?;
if let Ok(jwk) = serde_json::from_str::<serde_json::Value>(&content) {
if jwk.get("kty").is_some() {
return Ok(vec![jwk]);
}
if let Some(keys) = jwk.get("keys").and_then(|k| k.as_array()) {
return Ok(keys.clone());
}
}
Err(CryptoTensorsError::KeyLoad(format!(
"Failed to parse JWK from {}",
path
)))
}
pub fn select_jwk(
keys: &[serde_json::Value],
kty: &str,
alg: Option<&str>,
kid: Option<&str>,
) -> Option<serde_json::Value> {
let matching: Vec<_> = keys
.iter()
.filter(|k| k.get("kty").and_then(|v| v.as_str()) == Some(kty))
.collect();
if matching.is_empty() {
return None;
}
let matching: Vec<_> = if let Some(alg_filter) = alg {
matching
.into_iter()
.filter(|k| {
let key_alg = k.get("alg").and_then(|v| v.as_str());
key_alg.is_none() || key_alg == Some(alg_filter)
})
.collect()
} else {
matching
};
if matching.is_empty() {
return None;
}
if let Some(kid_filter) = kid {
if let Some(key) = matching
.iter()
.find(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid_filter))
{
return Some((*key).clone());
}
return None;
}
matching.first().map(|k| (*k).clone())
}
#[derive(Clone)]
pub struct DirectKeyProvider {
enc_keys: HashMap<String, Value>,
sign_keys: HashMap<String, Value>,
default_enc_kid: Option<String>,
default_sign_kid: Option<String>,
}
impl DirectKeyProvider {
pub fn new() -> Self {
Self {
enc_keys: HashMap::new(),
sign_keys: HashMap::new(),
default_enc_kid: None,
default_sign_kid: None,
}
}
pub fn from_single_keys(enc_key: Value, sign_key: Value) -> Self {
let enc_kid = enc_key
.get("kid")
.and_then(|v| v.as_str())
.map(String::from);
let sign_kid = sign_key
.get("kid")
.and_then(|v| v.as_str())
.map(String::from);
let mut provider = Self::new();
if let Some(kid) = &enc_kid {
provider.enc_keys.insert(kid.clone(), enc_key);
provider.default_enc_kid = Some(kid.clone());
}
if let Some(kid) = &sign_kid {
provider.sign_keys.insert(kid.clone(), sign_key);
provider.default_sign_kid = Some(kid.clone());
}
provider
}
pub fn add_enc_key(&mut self, kid: &str, key: Value) -> &mut Self {
self.enc_keys.insert(kid.to_string(), key);
if self.default_enc_kid.is_none() {
self.default_enc_kid = Some(kid.to_string());
}
self
}
pub fn add_sign_key(&mut self, kid: &str, key: Value) -> &mut Self {
self.sign_keys.insert(kid.to_string(), key);
if self.default_sign_kid.is_none() {
self.default_sign_kid = Some(kid.to_string());
}
self
}
pub fn set_default_enc_kid(&mut self, kid: &str) -> &mut Self {
self.default_enc_kid = Some(kid.to_string());
self
}
pub fn set_default_sign_kid(&mut self, kid: &str) -> &mut Self {
self.default_sign_kid = Some(kid.to_string());
self
}
pub fn from_key_material(
enc_key: &KeyMaterial,
sign_key: &KeyMaterial,
) -> Result<Self, CryptoTensorsError> {
let enc_jwk = enc_key.to_jwk()?;
let sign_jwk = sign_key.to_jwk()?;
let enc_val: Value = serde_json::from_str(&enc_jwk)?;
let sign_val: Value = serde_json::from_str(&sign_jwk)?;
Ok(Self::from_single_keys(enc_val, sign_val))
}
}
impl Default for DirectKeyProvider {
fn default() -> Self {
Self::new()
}
}
impl KeyProvider for DirectKeyProvider {
fn name(&self) -> &str {
"DirectKeyProvider"
}
fn is_ready(&self) -> bool {
!self.enc_keys.is_empty() || !self.sign_keys.is_empty()
}
fn matches(&self, _jku: Option<&str>, kid: Option<&str>) -> bool {
match kid {
Some(k) => self.enc_keys.contains_key(k) || self.sign_keys.contains_key(k),
None => {
self.default_enc_kid.is_some() || self.default_sign_kid.is_some()
}
}
}
fn get_master_key(
&self,
_jku: Option<&str>,
kid: Option<&str>,
) -> Result<Value, CryptoTensorsError> {
let kid = kid.or(self.default_enc_kid.as_deref()).ok_or_else(|| {
CryptoTensorsError::InvalidKey("No enc_kid specified and no default".to_string())
})?;
self.enc_keys
.get(kid)
.cloned()
.ok_or_else(|| CryptoTensorsError::InvalidKey(format!("Key not found: {}", kid)))
}
fn get_verify_key(
&self,
_jku: Option<&str>,
kid: Option<&str>,
) -> Result<Value, CryptoTensorsError> {
let kid = kid.or(self.default_sign_kid.as_deref()).ok_or_else(|| {
CryptoTensorsError::InvalidKey("No sign_kid specified and no default".to_string())
})?;
self.sign_keys
.get(kid)
.cloned()
.ok_or_else(|| CryptoTensorsError::InvalidKey(format!("Key not found: {}", kid)))
}
fn get_signing_key(
&self,
_jku: Option<&str>,
kid: Option<&str>,
) -> Result<Value, CryptoTensorsError> {
let kid = kid.or(self.default_sign_kid.as_deref()).ok_or_else(|| {
CryptoTensorsError::InvalidKey("No sign_kid specified and no default".to_string())
})?;
self.sign_keys
.get(kid)
.cloned()
.ok_or_else(|| CryptoTensorsError::InvalidKey(format!("Key not found: {}", kid)))
}
}
#[cfg(feature = "provider-file")]
pub struct FileKeyProvider {
search_paths: Vec<String>,
}
#[cfg(feature = "provider-file")]
impl FileKeyProvider {
pub fn new(search_paths: Vec<String>) -> Self {
let mut normalized_paths = Vec::new();
let mut seen = std::collections::HashSet::new();
for path in search_paths {
let normalized = Self::normalize_path(&path);
if seen.insert(normalized.clone()) {
normalized_paths.push(normalized);
}
}
Self {
search_paths: normalized_paths,
}
}
fn normalize_path(path: &str) -> String {
let expanded = if let Some(stripped) = path.strip_prefix('~') {
let home = if cfg!(windows) {
std::env::var("USERPROFILE")
.or_else(|_| std::env::var("HOME"))
.unwrap_or_else(|_| "~".to_string())
} else {
std::env::var("HOME").unwrap_or_else(|_| "~".to_string())
};
format!("{}{}", home, stripped)
} else {
path.to_string()
};
if Path::new(&expanded).is_absolute() {
if let Ok(canonical) = std::fs::canonicalize(&expanded) {
return canonical.to_string_lossy().to_string();
}
}
expanded
}
fn resolve_path(&self, file_path: &str) -> Vec<String> {
let path = Path::new(file_path);
if path.is_absolute() {
vec![file_path.to_string()]
} else {
self.search_paths
.iter()
.map(|search_path| {
Path::new(search_path)
.join(file_path)
.to_string_lossy()
.to_string()
})
.collect()
}
}
}
#[cfg(feature = "provider-file")]
impl KeyProvider for FileKeyProvider {
fn is_ready(&self) -> bool {
if self.search_paths.is_empty() {
true
} else {
self.search_paths.iter().any(|p| {
let path = Path::new(p);
path.exists() && path.is_dir()
})
}
}
fn matches(&self, jku: Option<&str>, _kid: Option<&str>) -> bool {
jku.map(|j| j.starts_with("file://")).unwrap_or(false)
}
fn get_master_key(
&self,
jku: Option<&str>,
kid: Option<&str>,
) -> Result<serde_json::Value, CryptoTensorsError> {
let paths = if let Some(jku_str) = jku {
let parsed = FileJwkPath::parse(jku_str)?;
self.resolve_path(&parsed.path)
} else {
let mut paths = Vec::new();
for search_path in &self.search_paths {
paths.push(
Path::new(search_path)
.join("keys.jwk")
.to_string_lossy()
.to_string(),
);
}
paths
};
for path in paths {
if let Ok(keys) = load_jwks_from_file(&path) {
if let Some(key) = select_jwk(&keys, "oct", None, kid) {
return Ok(key);
}
}
}
Err(CryptoTensorsError::NoSuitableKey)
}
fn get_verify_key(
&self,
jku: Option<&str>,
kid: Option<&str>,
) -> Result<serde_json::Value, CryptoTensorsError> {
let paths = if let Some(jku_str) = jku {
let parsed = FileJwkPath::parse(jku_str)?;
self.resolve_path(&parsed.path)
} else {
let mut paths = Vec::new();
for search_path in &self.search_paths {
paths.push(
Path::new(search_path)
.join("keys.jwk")
.to_string_lossy()
.to_string(),
);
}
paths
};
for path in paths {
if let Ok(keys) = load_jwks_from_file(&path) {
if let Some(key) = select_jwk(&keys, "okp", None, kid) {
return Ok(key);
}
}
}
Err(CryptoTensorsError::NoSuitableKey)
}
fn name(&self) -> &str {
"file"
}
}
#[cfg(feature = "provider-env")]
pub struct EnvKeyProvider {
env_var: String,
}
#[cfg(feature = "provider-env")]
impl Default for EnvKeyProvider {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "provider-env")]
impl EnvKeyProvider {
pub fn new() -> Self {
Self {
env_var: "CRYPTOTENSOR_KEYS".to_string(),
}
}
fn get_keys(&self) -> Result<Vec<serde_json::Value>, CryptoTensorsError> {
let val = std::env::var(&self.env_var)
.map_err(|_| CryptoTensorsError::Registry(format!("{} not set", self.env_var)))?;
let jwk_set: serde_json::Value = serde_json::from_str(&val).map_err(|e| {
CryptoTensorsError::KeyLoad(format!(
"Failed to parse JWK Set from {}: {}",
self.env_var, e
))
})?;
if let Some(keys) = jwk_set.get("keys").and_then(|k| k.as_array()) {
Ok(keys.clone())
} else if jwk_set.get("kty").is_some() {
Ok(vec![jwk_set])
} else {
Err(CryptoTensorsError::KeyLoad(format!(
"Invalid JWK format in {}",
self.env_var
)))
}
}
}
#[cfg(feature = "provider-env")]
impl KeyProvider for EnvKeyProvider {
fn is_ready(&self) -> bool {
std::env::var(&self.env_var).is_ok()
}
fn matches(&self, _jku: Option<&str>, _kid: Option<&str>) -> bool {
true
}
fn get_master_key(
&self,
_jku: Option<&str>,
kid: Option<&str>,
) -> Result<serde_json::Value, CryptoTensorsError> {
let keys = self.get_keys()?;
select_jwk(&keys, "oct", None, kid).ok_or(CryptoTensorsError::NoSuitableKey)
}
fn get_verify_key(
&self,
_jku: Option<&str>,
kid: Option<&str>,
) -> Result<serde_json::Value, CryptoTensorsError> {
let keys = self.get_keys()?;
select_jwk(&keys, "okp", None, kid).ok_or(CryptoTensorsError::NoSuitableKey)
}
fn name(&self) -> &str {
"env"
}
}
pub trait KeyProvider: Send + Sync {
fn name(&self) -> &str;
fn initialize(&mut self, _config_json: &str) -> Result<(), CryptoTensorsError> {
Ok(())
}
fn is_ready(&self) -> bool;
fn matches(&self, jku: Option<&str>, kid: Option<&str>) -> bool;
fn get_master_key(
&self,
jku: Option<&str>,
kid: Option<&str>,
) -> Result<serde_json::Value, CryptoTensorsError>;
fn get_verify_key(
&self,
jku: Option<&str>,
kid: Option<&str>,
) -> Result<serde_json::Value, CryptoTensorsError>;
fn get_signing_key(
&self,
_jku: Option<&str>,
_kid: Option<&str>,
) -> Result<serde_json::Value, CryptoTensorsError> {
Err(CryptoTensorsError::Registry(
"get_signing_key not implemented for this provider".to_string(),
))
}
}
struct ProviderEntry {
provider: Box<dyn KeyProvider>,
priority: i32,
enabled: bool,
_lib: Option<std::sync::Arc<libloading::Library>>,
}
static PROVIDERS: OnceLock<RwLock<Vec<ProviderEntry>>> = OnceLock::new();
fn get_providers() -> &'static RwLock<Vec<ProviderEntry>> {
PROVIDERS.get_or_init(|| {
let entries = vec![
#[cfg(feature = "provider-env")]
ProviderEntry {
provider: Box::new(EnvKeyProvider::new()),
priority: PRIORITY_ENV,
enabled: true,
_lib: None,
},
#[cfg(feature = "provider-file")]
ProviderEntry {
provider: Box::new(FileKeyProvider::new(Vec::new())),
priority: PRIORITY_FILE,
enabled: true,
_lib: None,
},
];
RwLock::new(entries)
})
}
pub fn register_provider_with_priority(provider: Box<dyn KeyProvider>, priority: i32) {
register_provider_full(provider, priority, None);
}
fn register_provider_full(
provider: Box<dyn KeyProvider>,
priority: i32,
lib: Option<std::sync::Arc<libloading::Library>>,
) {
let providers = get_providers();
if let Ok(mut guard) = providers.write() {
let provider_name = provider.name();
guard.retain(|entry| entry.provider.name() != provider_name);
guard.push(ProviderEntry {
provider,
priority,
enabled: true,
_lib: lib,
});
guard.sort_by(|a, b| b.priority.cmp(&a.priority));
}
}
pub fn register_provider(provider: Box<dyn KeyProvider>) {
register_provider_with_priority(provider, 0);
}
pub fn disable_provider(name: &str) {
let providers = get_providers();
if let Ok(mut guard) = providers.write() {
guard.retain(|entry| entry.provider.name() != name);
}
}
pub fn enable_provider(name: &str) {
let providers = get_providers();
if let Ok(mut guard) = providers.write() {
for entry in guard.iter_mut() {
if entry.provider.name() == name {
entry.enabled = true;
}
}
}
}
pub fn clear_providers() {
let providers = get_providers();
if let Ok(mut guard) = providers.write() {
guard.clear();
}
}
#[allow(improper_ctypes_definitions)]
pub type CreateProviderFn = unsafe extern "C" fn() -> *mut dyn KeyProvider;
fn verify_library_signature(provider_name: &str, lib_path: &str) -> Result<(), CryptoTensorsError> {
let sig_path = format!("{}.sig", lib_path);
let public_key_str = PROVIDER_PUBLIC_KEYS
.iter()
.find(|(name, _)| *name == provider_name)
.map(|(_, key)| *key)
.ok_or_else(|| {
CryptoTensorsError::Registry(format!(
"No trusted public key configured for provider: {}",
provider_name
))
})?;
let mut lib_file = File::open(lib_path)
.map_err(|e| CryptoTensorsError::Registry(format!("Failed to open library: {}", e)))?;
let mut lib_data = Vec::new();
lib_file
.read_to_end(&mut lib_data)
.map_err(|e| CryptoTensorsError::Registry(format!("Failed to read library: {}", e)))?;
let sig_path_obj = std::path::Path::new(&sig_path);
if !sig_path_obj.exists() {
return Err(CryptoTensorsError::Registry(format!(
"Signature file missing: {}",
sig_path
)));
}
let mut sig_file = File::open(&sig_path).map_err(|e| {
CryptoTensorsError::Registry(format!("Failed to open signature file: {}", e))
})?;
let mut sig_base64 = String::new();
sig_file
.read_to_string(&mut sig_base64)
.map_err(|e| CryptoTensorsError::Registry(format!("Failed to read signature: {}", e)))?;
let signature = BASE64
.decode(sig_base64.trim())
.map_err(|e| CryptoTensorsError::Registry(format!("Invalid signature encoding: {}", e)))?;
let public_key_bytes = BASE64
.decode(public_key_str)
.map_err(|e| CryptoTensorsError::Registry(format!("Invalid public key encoding: {}", e)))?;
let peer_public_key = signature::UnparsedPublicKey::new(&signature::ED25519, public_key_bytes);
peer_public_key.verify(&lib_data, &signature).map_err(|e| {
CryptoTensorsError::Registry(format!("Signature verification failed: {}", e))
})?;
Ok(())
}
pub fn load_provider_native(
name: &str,
lib_path: &str,
config_json: &str,
) -> Result<(), CryptoTensorsError> {
verify_library_signature(name, lib_path)?;
let lib = unsafe {
libloading::Library::new(lib_path)
.map_err(|e| CryptoTensorsError::Registry(format!("Failed to load library: {}", e)))?
};
let lib = std::sync::Arc::new(lib);
let create_fn: libloading::Symbol<CreateProviderFn> = unsafe {
lib.get(b"cryptotensors_create_provider").map_err(|e| {
CryptoTensorsError::Registry(format!("Missing symbol in provider library: {}", e))
})?
};
let mut provider = unsafe { Box::from_raw(create_fn()) };
provider.initialize(config_json)?;
register_provider_full(provider, PRIORITY_NATIVE, Some(lib));
Ok(())
}
pub fn provider_count() -> usize {
let providers = get_providers();
if let Ok(guard) = providers.read() {
guard.len()
} else {
0
}
}
pub fn get_master_key(
jku: Option<&str>,
kid: Option<&str>,
) -> Result<serde_json::Value, CryptoTensorsError> {
let providers = get_providers();
let guard = providers
.read()
.map_err(|_| CryptoTensorsError::Registry("Failed to acquire read lock".to_string()))?;
for entry in guard.iter() {
if entry.enabled && entry.provider.is_ready() && entry.provider.matches(jku, kid) {
match entry.provider.get_master_key(jku, kid) {
Ok(key) => return Ok(key),
Err(_) => continue,
}
}
}
Err(CryptoTensorsError::Registry(
"No suitable provider found for master key".to_string(),
))
}
pub fn get_verify_key(
jku: Option<&str>,
kid: Option<&str>,
) -> Result<serde_json::Value, CryptoTensorsError> {
let providers = get_providers();
let guard = providers
.read()
.map_err(|_| CryptoTensorsError::Registry("Failed to acquire read lock".to_string()))?;
for entry in guard.iter() {
if entry.enabled && entry.provider.is_ready() && entry.provider.matches(jku, kid) {
match entry.provider.get_verify_key(jku, kid) {
Ok(key) => {
return Ok(key);
}
Err(_) => continue,
}
}
}
Err(CryptoTensorsError::Registry(format!(
"No suitable provider found for verify key (jku={:?}, kid={:?})",
jku, kid
)))
}
pub fn get_signing_key(
jku: Option<&str>,
kid: Option<&str>,
) -> Result<serde_json::Value, CryptoTensorsError> {
let providers = get_providers();
let guard = providers
.read()
.map_err(|_| CryptoTensorsError::Registry("Failed to acquire read lock".to_string()))?;
for entry in guard.iter() {
if entry.enabled && entry.provider.is_ready() && entry.provider.matches(jku, kid) {
match entry.provider.get_signing_key(jku, kid) {
Ok(key) => return Ok(key),
Err(_) => continue,
}
}
}
Err(CryptoTensorsError::Registry(format!(
"No suitable provider found for signing key (jku={:?}, kid={:?})",
jku, kid
)))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static TEST_MUTEX: Mutex<()> = Mutex::new(());
struct TestProvider {
name: String,
ready: bool,
}
impl KeyProvider for TestProvider {
fn is_ready(&self) -> bool {
self.ready
}
fn matches(&self, _jku: Option<&str>, _kid: Option<&str>) -> bool {
true
}
fn name(&self) -> &str {
&self.name
}
fn get_master_key(
&self,
_jku: Option<&str>,
_kid: Option<&str>,
) -> Result<serde_json::Value, CryptoTensorsError> {
Ok(serde_json::json!({"kty": "oct", "k": "AAA"}))
}
fn get_verify_key(
&self,
_jku: Option<&str>,
_kid: Option<&str>,
) -> Result<serde_json::Value, CryptoTensorsError> {
Ok(serde_json::json!({"kty": "okp", "x": "AAA"}))
}
}
#[test]
fn test_priority_and_disable() {
let _guard = TEST_MUTEX.lock().unwrap();
clear_providers();
register_provider_with_priority(
Box::new(TestProvider {
name: "low".into(),
ready: true,
}),
0,
);
register_provider_with_priority(
Box::new(TestProvider {
name: "high".into(),
ready: true,
}),
10,
);
{
let providers = get_providers();
let guard = providers.read().unwrap();
assert_eq!(guard[0].provider.name(), "high");
assert_eq!(guard[1].provider.name(), "low");
}
disable_provider("high");
let providers = get_providers();
let guard = providers.read().unwrap();
assert_eq!(guard.len(), 1);
assert_eq!(guard[0].provider.name(), "low");
}
#[test]
fn test_direct_key_provider_single_keys() {
let enc_key =
KeyMaterial::new_enc_key(None, None, Some("test-enc".to_string()), None).unwrap();
let sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("test-sign".to_string()), None)
.unwrap();
let enc_jwk: serde_json::Value = serde_json::from_str(&enc_key.to_jwk().unwrap()).unwrap();
let sign_jwk: serde_json::Value =
serde_json::from_str(&sign_key.to_jwk().unwrap()).unwrap();
let provider = DirectKeyProvider::from_single_keys(enc_jwk.clone(), sign_jwk.clone());
assert!(provider.is_ready());
assert!(provider.matches(None, Some("test-enc")));
assert!(provider.matches(None, Some("test-sign")));
let retrieved_enc = provider.get_master_key(None, Some("test-enc")).unwrap();
assert_eq!(
retrieved_enc.get("kid").and_then(|v| v.as_str()),
Some("test-enc")
);
let retrieved_sign = provider.get_signing_key(None, Some("test-sign")).unwrap();
assert_eq!(
retrieved_sign.get("kid").and_then(|v| v.as_str()),
Some("test-sign")
);
}
#[test]
fn test_direct_key_provider_multiple_keys() {
let mut provider = DirectKeyProvider::new();
let enc1 =
KeyMaterial::new_enc_key(None, None, Some("user1-enc".to_string()), None).unwrap();
let enc2 =
KeyMaterial::new_enc_key(None, None, Some("user2-enc".to_string()), None).unwrap();
let sign1 =
KeyMaterial::new_sign_key(None, None, None, Some("user1-sign".to_string()), None)
.unwrap();
let sign2 =
KeyMaterial::new_sign_key(None, None, None, Some("user2-sign".to_string()), None)
.unwrap();
provider
.add_enc_key(
"user1-enc",
serde_json::from_str(&enc1.to_jwk().unwrap()).unwrap(),
)
.add_enc_key(
"user2-enc",
serde_json::from_str(&enc2.to_jwk().unwrap()).unwrap(),
)
.add_sign_key(
"user1-sign",
serde_json::from_str(&sign1.to_jwk().unwrap()).unwrap(),
)
.add_sign_key(
"user2-sign",
serde_json::from_str(&sign2.to_jwk().unwrap()).unwrap(),
);
assert!(provider.is_ready());
let user1_enc = provider.get_master_key(None, Some("user1-enc")).unwrap();
assert_eq!(
user1_enc.get("kid").and_then(|v| v.as_str()),
Some("user1-enc")
);
let user2_sign = provider.get_signing_key(None, Some("user2-sign")).unwrap();
assert_eq!(
user2_sign.get("kid").and_then(|v| v.as_str()),
Some("user2-sign")
);
}
#[test]
fn test_direct_key_provider_default_keys() {
let enc1 = KeyMaterial::new_enc_key(None, None, Some("enc1".to_string()), None).unwrap();
let enc2 = KeyMaterial::new_enc_key(None, None, Some("enc2".to_string()), None).unwrap();
let sign1 =
KeyMaterial::new_sign_key(None, None, None, Some("sign1".to_string()), None).unwrap();
let mut provider = DirectKeyProvider::new();
provider
.add_enc_key(
"enc1",
serde_json::from_str(&enc1.to_jwk().unwrap()).unwrap(),
)
.add_enc_key(
"enc2",
serde_json::from_str(&enc2.to_jwk().unwrap()).unwrap(),
)
.add_sign_key(
"sign1",
serde_json::from_str(&sign1.to_jwk().unwrap()).unwrap(),
)
.set_default_enc_kid("enc2")
.set_default_sign_kid("sign1");
let key = provider.get_master_key(None, None).unwrap();
assert_eq!(key.get("kid").and_then(|v| v.as_str()), Some("enc2"));
let key = provider.get_signing_key(None, None).unwrap();
assert_eq!(key.get("kid").and_then(|v| v.as_str()), Some("sign1"));
let key = provider.get_master_key(None, Some("enc1")).unwrap();
assert_eq!(key.get("kid").and_then(|v| v.as_str()), Some("enc1"));
}
}