use crate::types::{AccessControl, ByteBuf, KeyManagerConfig, KeyName, TransportKey};
use candid::Principal;
use ic_stable_structures::memory_manager::VirtualMemory;
use ic_stable_structures::storable::Blob;
use ic_stable_structures::{DefaultMemoryImpl, StableBTreeMap, StableCell, Storable};
use std::future::Future;
use ic_cdk::management_canister::{VetKDDeriveKeyArgs, VetKDKeyId, VetKDPublicKeyArgs};
pub type VetKeyVerificationKey = ByteBuf;
pub type VetKey = ByteBuf;
pub type Owner = Principal;
pub type Caller = Principal;
pub type KeyId = (Owner, KeyName);
type Memory = VirtualMemory<DefaultMemoryImpl>;
pub struct KeyManager<T: AccessControl> {
pub config: StableCell<KeyManagerConfig, Memory>,
pub access_control: StableBTreeMap<(Principal, KeyId), T, Memory>,
pub shared_keys: StableBTreeMap<(KeyId, Principal), (), Memory>,
}
impl<T: AccessControl> KeyManager<T> {
pub fn init(
domain_separator: &str,
key_id: VetKDKeyId,
memory_key_manager_config: Memory,
memory_access_control: Memory,
memory_shared_keys: Memory,
) -> Self {
let config = StableCell::init(
memory_key_manager_config,
KeyManagerConfig {
domain_separator: domain_separator.to_string(),
key_id: key_id.clone(),
},
)
.expect("failed to initialize key manager config");
KeyManager {
config,
access_control: StableBTreeMap::init(memory_access_control),
shared_keys: StableBTreeMap::init(memory_shared_keys),
}
}
pub fn get_accessible_shared_key_ids(&self, caller: Principal) -> Vec<KeyId> {
self.access_control
.range((caller, (Principal::management_canister(), Blob::default()))..)
.take_while(|((p, _), _)| p == &caller)
.map(|((_, key_id), _)| key_id)
.collect()
}
pub fn get_shared_user_access_for_key(
&self,
caller: Principal,
key_id: KeyId,
) -> Result<Vec<(Principal, T)>, String> {
self.ensure_user_can_get_user_rights(caller, key_id)?;
let users: Vec<_> = self
.shared_keys
.range((key_id, Principal::management_canister())..)
.take_while(|((k, _), _)| k == &key_id)
.map(|((_, user), _)| user)
.collect();
users
.into_iter()
.map(|user| {
self.get_user_rights(caller, key_id, user)
.map(|opt_user_rights| {
(user, opt_user_rights.expect("always some access rights"))
})
})
.collect::<Result<Vec<_>, _>>()
}
pub fn get_vetkey_verification_key(
&self,
) -> impl Future<Output = VetKeyVerificationKey> + Send + Sync {
use futures::future::FutureExt;
let domain_separator = self.config.get().domain_separator.clone();
let key_id = self.config.get().key_id.clone();
let future = async move {
let request = VetKDPublicKeyArgs {
canister_id: None,
context: domain_separator.to_bytes().to_vec(),
key_id,
};
ic_cdk::management_canister::vetkd_public_key(&request).await
};
future.map(|call_result| {
let reply = call_result.expect("call to vetkd_public_key failed");
VetKeyVerificationKey::from(reply.public_key)
})
}
pub fn get_encrypted_vetkey(
&self,
caller: Principal,
subkey_key_id: KeyId,
transport_key: TransportKey,
) -> Result<impl Future<Output = VetKey> + Send + Sync, String> {
use futures::future::FutureExt;
self.ensure_user_can_read(caller, subkey_key_id)?;
let domain_separator = self.config.get().domain_separator.clone();
let vetkd_key_id = self.config.get().key_id.clone();
let future = async move {
let request = VetKDDeriveKeyArgs {
input: key_id_to_vetkd_input(subkey_key_id.0, subkey_key_id.1.as_ref()),
context: domain_separator.to_bytes().to_vec(),
key_id: vetkd_key_id,
transport_public_key: transport_key.into(),
};
ic_cdk::management_canister::vetkd_derive_key(&request).await
};
Ok(future.map(|call_result| {
let reply = call_result.expect("call to vetkd_derive_key failed");
VetKey::from(reply.encrypted_key)
}))
}
pub fn get_user_rights(
&self,
caller: Principal,
key_id: KeyId,
user: Principal,
) -> Result<Option<T>, String> {
self.ensure_user_can_get_user_rights(caller, key_id)?;
Ok(self.ensure_user_can_read(user, key_id).ok())
}
pub fn set_user_rights(
&mut self,
caller: Principal,
key_id: KeyId,
user: Principal,
access_rights: T,
) -> Result<Option<T>, String> {
self.ensure_user_can_set_user_rights(caller, key_id)?;
if caller == key_id.0 && caller == user {
return Err("cannot change key owner's user rights".to_string());
}
self.shared_keys.insert((key_id, user), ());
Ok(self.access_control.insert((user, key_id), access_rights))
}
pub fn remove_user(
&mut self,
caller: Principal,
key_id: KeyId,
user: Principal,
) -> Result<Option<T>, String> {
self.ensure_user_can_set_user_rights(caller, key_id)?;
if caller == user && caller == key_id.0 {
return Err("cannot remove key owner".to_string());
}
self.shared_keys.remove(&(key_id, user));
Ok(self.access_control.remove(&(user, key_id)))
}
pub fn ensure_user_can_read(&self, user: Principal, key_id: KeyId) -> Result<T, String> {
let is_owner = user == key_id.0;
if is_owner {
return Ok(T::owner_rights());
}
let has_shared_access = self.access_control.get(&(user, key_id));
match has_shared_access {
Some(access_rights) if access_rights.can_read() => Ok(access_rights),
_ => Err("unauthorized".to_string()),
}
}
pub fn ensure_user_can_write(&self, user: Principal, key_id: KeyId) -> Result<T, String> {
let is_owner = user == key_id.0;
if is_owner {
return Ok(T::owner_rights());
}
let has_shared_access = self.access_control.get(&(user, key_id));
match has_shared_access {
Some(access_rights) if access_rights.can_write() => Ok(access_rights),
_ => Err("unauthorized".to_string()),
}
}
pub fn ensure_user_can_get_user_rights(
&self,
user: Principal,
key_id: KeyId,
) -> Result<T, String> {
let is_owner = user == key_id.0;
if is_owner {
return Ok(T::owner_rights());
}
let has_shared_access = self.access_control.get(&(user, key_id));
match has_shared_access {
Some(access_rights) if access_rights.can_get_user_rights() => Ok(access_rights),
_ => Err("unauthorized".to_string()),
}
}
pub fn ensure_user_can_set_user_rights(
&self,
user: Principal,
key_id: KeyId,
) -> Result<T, String> {
let is_owner = user == key_id.0;
if is_owner {
return Ok(T::owner_rights());
}
let has_shared_access = self.access_control.get(&(user, key_id));
match has_shared_access {
Some(access_rights) if access_rights.can_set_user_rights() => Ok(access_rights),
_ => Err("unauthorized".to_string()),
}
}
}
pub fn key_id_to_vetkd_input(principal: Principal, key_name: &[u8]) -> Vec<u8> {
let mut vetkd_input = Vec::with_capacity(principal.as_slice().len() + 1 + key_name.len());
vetkd_input.push(principal.as_slice().len() as u8);
vetkd_input.extend(principal.as_slice());
vetkd_input.extend(key_name);
vetkd_input
}