use aes_gcm_siv::aead::{Aead, Payload};
use aes_gcm_siv::{Aes256GcmSiv, Key as AesKey, KeyInit, Nonce};
use itertools::Itertools;
use log::trace;
use recipher::key::Key;
use std::borrow::Cow;
use std::collections::HashMap;
use url::Url;
use uuid::Uuid;
use zerokms_protocol::{
Context, CreateClientRequest, CreateClientResponse, CreateClientSpec, CreateKeysetRequest,
CreateKeysetResponse, DeleteClientRequest, DeleteClientResponse, DisableKeysetRequest,
EnableKeysetRequest, GenerateKeyRequest, GenerateKeySpec, GeneratedKey, GrantKeysetRequest,
IdentifiedBy, KeyId, Keyset, KeysetClient, ListClientRequest, ListKeysetRequest,
LoadKeysetRequest, LoadKeysetResponse, ModifyKeysetRequest, RetrieveKeyRequest,
RetrieveKeyRequestFallible, RetrieveKeySpec, RetrievedKey, RevokeKeysetRequest,
UnverifiedContext,
};
pub mod connection;
pub mod encrypted_record;
mod encryption_target;
pub mod errors;
mod futures;
pub mod key;
use encrypted_record::{Decryptable, EncryptedRecord, WithContext};
use futures::map_async_chunked;
pub use connection::{HttpConnection, ZeroKMSConnection, ZeroKMSConnectionInit};
pub use encryption_target::EncryptionTarget;
pub use errors::*;
pub use key::{ClientKey, DataKey, DataKeyWithTag, IndexKey};
pub use recipher::key::{GenRandom, Iv};
use crate::zerokms::vitur_client::connection::HttpConnectionOpts;
pub(super) const DEFAULT_KEYS_PER_REQ: usize = 500;
pub(super) const DEFAULT_CONCURRENT_REQS: usize = 5;
#[cfg(test)]
pub mod test_connection;
pub struct ClientOpts<CONNOPTS> {
pub max_keys_per_req: usize,
pub max_concurrent_reqs: usize,
pub connection_opts: CONNOPTS,
}
pub struct Client<C = HttpConnection> {
connection: C,
max_keys_per_req: usize,
max_concurrent_reqs: usize,
}
#[derive(Debug, Clone)]
pub struct EncryptPayload<'a> {
pub msg: &'a [u8],
pub descriptor: &'a str,
pub context: Cow<'a, [Context]>,
}
impl<'a> EncryptPayload<'a> {
pub fn new(msg: &'a [u8]) -> Self {
Self {
msg,
descriptor: "",
context: Default::default(),
}
}
pub fn set_context(mut self, context: Cow<'a, [Context]>) -> Self {
self.context = context;
self
}
pub fn new_with_descriptor(msg: &'a [u8], descriptor: &'a str) -> Self {
Self {
msg,
descriptor,
context: Default::default(),
}
}
}
impl<'a> From<&'a EncryptPayload<'a>> for GenerateKeyPayload<'a> {
fn from(
EncryptPayload {
descriptor,
context,
..
}: &'a EncryptPayload<'a>,
) -> Self {
Self::new(descriptor, context.clone())
}
}
#[derive(Clone)]
pub struct GenerateKeyPayload<'a> {
pub descriptor: &'a str,
context: Cow<'a, [Context]>,
}
impl<'a> GenerateKeyPayload<'a> {
pub fn new(descriptor: &'a str, context: Cow<'a, [Context]>) -> Self {
Self {
descriptor,
context,
}
}
}
pub struct RetrieveKeyPayload<'a> {
pub iv: KeyId,
pub descriptor: &'a str,
pub tag: &'a [u8],
pub context: Cow<'a, [Context]>,
}
impl<'a> RetrieveKeyPayload<'a> {
pub fn new(iv: Iv, descriptor: &'a str, tag: &'a [u8]) -> Self {
Self {
iv: KeyId::from(iv),
descriptor,
tag,
context: Default::default(),
}
}
pub fn with_context(mut self, context: Cow<'a, [Context]>) -> Self {
self.context = context;
self
}
}
impl<'a> From<RetrieveKeyPayload<'a>> for RetrieveKeySpec<'a> {
fn from(
RetrieveKeyPayload {
iv,
descriptor,
tag,
context,
}: RetrieveKeyPayload<'a>,
) -> Self {
Self::new(iv, tag, descriptor).with_context(context)
}
}
impl<'a> From<&'a EncryptedRecord> for RetrieveKeyPayload<'a> {
fn from(record: &'a EncryptedRecord) -> Self {
Self {
iv: KeyId::from(record.iv),
descriptor: &record.descriptor,
tag: &record.tag,
context: Default::default(),
}
}
}
impl<'a, 'context> From<&'a WithContext<'context>> for RetrieveKeyPayload<'a> {
fn from(with_context: &'a WithContext) -> Self {
Self {
iv: KeyId::from(with_context.record.iv),
descriptor: &with_context.record.descriptor,
tag: &with_context.record.tag,
context: with_context.context.clone(),
}
}
}
impl Client<HttpConnection> {
#[deprecated(
note = "Use Client::init_opts with ClientOpts instead",
since = "0.32.2"
)]
pub fn init(host: String) -> Self {
let host = Url::parse(&host).expect("Invalid Vitur host URL");
let conn_opts = HttpConnectionOpts::new(Some(host));
Self::init_opts(ClientOpts {
max_keys_per_req: DEFAULT_KEYS_PER_REQ,
max_concurrent_reqs: DEFAULT_CONCURRENT_REQS,
connection_opts: conn_opts,
})
.expect("Failed to initialize HTTP connection")
}
}
pub type FallibleDataKeyVec = Vec<Result<DataKey, RetrieveKeyError>>;
impl<C> Client<C> {
pub(crate) fn connection(&self) -> &C {
&self.connection
}
}
impl<C: ZeroKMSConnection + Send + Sync> Client<C> {
pub fn init_opts(opts: ClientOpts<C::ConnectionOpts>) -> Result<Self, C::Error> {
let connection = C::init(opts.connection_opts)?;
Ok(Self {
connection,
max_keys_per_req: opts.max_keys_per_req,
max_concurrent_reqs: opts.max_concurrent_reqs,
})
}
pub async fn create_keyset(
&self,
name: &str,
description: &str,
access_token: &str,
) -> Result<Keyset, CreateKeysetError> {
let req = CreateKeysetRequest {
name: name.into(),
description: description.into(),
client: None,
};
let response = self.connection.send(req, access_token).await?;
Ok(response.keyset)
}
pub async fn create_keyset_with_client(
&self,
name: &str,
description: &str,
client_spec: CreateClientSpec<'_>,
access_token: &str,
) -> Result<CreateKeysetResponse, CreateKeysetError> {
let req = CreateKeysetRequest {
name: name.into(),
description: description.into(),
client: Some(client_spec),
};
Ok(self.connection.send(req, access_token).await?)
}
pub async fn grant_keyset(
&self,
client_id: Uuid,
keyset_id: Uuid,
access_token: &str,
) -> Result<(), GrantKeysetError> {
let req = GrantKeysetRequest {
client_id,
keyset_id: keyset_id.into(),
};
self.connection.send(req, access_token).await?;
Ok(())
}
pub async fn revoke_keyset(
&self,
client_id: Uuid,
keyset_id: Uuid,
access_token: &str,
) -> Result<(), RevokeKeysetError> {
let req = RevokeKeysetRequest {
client_id,
keyset_id: keyset_id.into(),
};
self.connection.send(req, access_token).await?;
Ok(())
}
pub async fn list_keysets(
&self,
access_token: &str,
show_disabled: bool,
) -> Result<Vec<Keyset>, ListKeysetError> {
let response = self
.connection
.send(ListKeysetRequest { show_disabled }, access_token)
.await?;
Ok(response)
}
pub async fn enable_keyset(
&self,
keyset_id: Uuid,
access_token: &str,
) -> Result<(), EnableKeysetError> {
self.connection
.send(
EnableKeysetRequest {
keyset_id: keyset_id.into(),
},
access_token,
)
.await?;
Ok(())
}
pub async fn disable_keyset(
&self,
keyset_id: Uuid,
access_token: &str,
) -> Result<(), DisableKeysetError> {
self.connection
.send(
DisableKeysetRequest {
keyset_id: keyset_id.into(),
},
access_token,
)
.await?;
Ok(())
}
pub async fn modify_keyset(
&self,
keyset_id: Uuid,
name: Option<&str>,
description: Option<&str>,
access_token: &str,
) -> Result<(), ModifyKeysetError> {
self.connection
.send(
ModifyKeysetRequest {
keyset_id: keyset_id.into(),
name: name.map(Cow::Borrowed),
description: description.map(Cow::Borrowed),
},
access_token,
)
.await?;
Ok(())
}
pub async fn create_client(
&self,
name: &str,
description: &str,
keyset_id: Option<Uuid>,
access_token: &str,
) -> Result<CreateClientResponse, CreateClientError> {
let req = CreateClientRequest {
name: name.into(),
description: description.into(),
keyset_id: keyset_id.map(IdentifiedBy::from),
};
let response = self.connection.send(req, access_token).await?;
Ok(response)
}
pub async fn list_clients(
&self,
access_token: &str,
) -> Result<Vec<KeysetClient>, ListClientError> {
let response = self
.connection
.send(ListClientRequest, access_token)
.await?;
Ok(response)
}
pub async fn delete_client(
&self,
client_id: Uuid,
access_token: &str,
) -> Result<DeleteClientResponse, RevokeClientError> {
let req = DeleteClientRequest { client_id };
let response = self.connection.send(req, access_token).await?;
Ok(response)
}
pub async fn retrieve_keys(
&self,
keys: impl IntoIterator<Item = RetrieveKeyPayload<'_>>,
key: &ClientKey,
keyset_id: Option<Uuid>,
access_token: &str,
unverified_context: Option<&UnverifiedContext>,
) -> Result<Vec<DataKey>, RetrieveKeyError> {
trace!(target: "vitur_client::retrieve_keys", "preparing payloads");
let keys = keys
.into_iter()
.map(RetrieveKeySpec::from)
.collect::<Vec<_>>();
tracing::trace!(target: "vitur_client::retrieve_keys", max_keys_per_req = self.max_keys_per_req, max_parallel_reqs = self.max_concurrent_reqs);
let result = map_async_chunked(
&keys,
|keys| async {
let req = RetrieveKeyRequest {
keys: keys.into(),
keyset_id: keyset_id.map(Into::into),
client_id: key.key_id,
unverified_context: unverified_context.cloned().unwrap_or_default(),
};
trace!(target: "vitur_client::retrieve_keys", "sending request with {} keys", keys.len());
self.connection
.send(req, access_token)
.await
.map_err(RetrieveKeyError::RequestFailed)
.and_then(|res| {
if res.keys.len() != keys.len() {
return Err(RetrieveKeyError::InvalidNumberOfKeys {
expected: keys.len(),
received: res.keys.len(),
});
}
trace!(target: "vitur_client::retrieve_keys", "retrieved keys - creating data keys");
Ok(keys
.iter()
.zip(res.keys)
.map(
|(RetrieveKeySpec { iv, .. }, RetrievedKey { key_material })| {
DataKey::from_key_material(key, iv.into_inner(), &key_material)
},
)
.collect())
})
},
self.max_keys_per_req,
self.max_concurrent_reqs,
)
.await;
match &result {
Err(x) => {
trace!(target: "vitur_client::retrieve_keys", "failed with error: {x}");
}
Ok(x) => {
trace!(target: "vitur_client::retrieve_keys", "successfully generated {} keys", x.len());
}
}
result
}
pub async fn retrieve_keys_fallible<'a>(
&self,
keys: impl IntoIterator<Item = RetrieveKeyPayload<'_>>,
client_key: &ClientKey,
keyset_id: Option<Uuid>,
access_token: &str,
unverified_context: Option<Cow<'a, UnverifiedContext>>,
) -> Result<FallibleDataKeyVec, RetrieveKeyError> {
trace!(target: "vitur_client::retrieve_keys", "preparing payloads");
let keys = keys
.into_iter()
.map(RetrieveKeySpec::from)
.collect::<Vec<_>>();
tracing::trace!(target: "vitur_client::retrieve_keys", max_keys_per_req = self.max_keys_per_req, max_parallel_reqs = self.max_concurrent_reqs);
let result = map_async_chunked(
&keys,
|keys| async {
let req = RetrieveKeyRequestFallible {
keys: keys.into(),
keyset_id: keyset_id.map(Into::into),
client_id: client_key.key_id,
unverified_context: unverified_context.clone().unwrap_or_default(),
};
trace!(target: "vitur_client::retrieve_keys", "sending request with {} keys", keys.len());
self.connection
.send(req, access_token)
.await
.map_err(RetrieveKeyError::RequestFailed)
.and_then(|res| {
if res.keys.len() != keys.len() {
return Err(RetrieveKeyError::InvalidNumberOfKeys {
expected: keys.len(),
received: res.keys.len(),
});
}
trace!(target: "vitur_client::retrieve_keys", "retrieved keys - creating data keys");
Ok(keys
.iter()
.zip(res.keys)
.map(|(RetrieveKeySpec { iv, .. }, result)| {
result
.map(|key| {
DataKey::from_key_material(client_key, iv.into_inner(), &key.key_material)
})
.map_err(RetrieveKeyError::FailedRetrieval)
})
.collect())
})
},
self.max_keys_per_req,
self.max_concurrent_reqs,
)
.await;
match &result {
Err(x) => {
trace!(target: "vitur_client::retrieve_keys", "failed with error: {x}");
}
Ok(x) => {
trace!(target: "vitur_client::retrieve_keys", "successfully generated {} keys", x.len());
}
}
result
}
pub async fn generate_keys<'a>(
&self,
keys: impl IntoIterator<Item = GenerateKeyPayload<'_>>,
client_key: &ClientKey,
keyset_id: Option<Uuid>,
access_token: &str,
unverified_context: Option<Cow<'a, UnverifiedContext>>,
) -> Result<Vec<DataKeyWithTag>, GenerateKeyError> {
let keys = {
let mut rng = rand::thread_rng();
keys.into_iter()
.map(
|GenerateKeyPayload {
descriptor,
context,
}| {
GenRandom::gen_random(&mut rng)
.map(|iv: Iv| {
GenerateKeySpec::new_with_context(iv, descriptor, context.clone())
})
.map_err(GenerateKeyError::GenerateIv)
},
)
.collect::<Result<Vec<_>, _>>()?
};
trace!(target: "vitur_client::generate_keys", "generated {} key payloads", keys.len());
tracing::trace!(target: "vitur_client::retrieve_keys", max_keys_per_req = self.max_keys_per_req, max_parallel_reqs = self.max_concurrent_reqs);
let result = map_async_chunked(
&keys,
|keys| async {
let req = GenerateKeyRequest {
keys: keys.into(),
keyset_id: keyset_id.map(Into::into),
client_id: client_key.key_id,
unverified_context: unverified_context.clone().unwrap_or_default(),
};
trace!(target: "vitur_client::generate_keys", "sending request with {} keys", keys.len());
self.connection
.send(req, access_token)
.await
.map_err(GenerateKeyError::from)
.and_then(|res| {
if res.keys.len() != keys.len() {
return Err(GenerateKeyError::InvalidNumberOfKeys {
expected: keys.len(),
received: res.keys.len(),
});
}
trace!(target: "vitur_client::generate_keys", "sending request with {} keys", keys.len());
Ok(keys
.iter()
.zip(res.keys)
.map(
|(
GenerateKeySpec { iv, .. },
GeneratedKey { key_material, tag },
)| {
DataKeyWithTag::from_key_material(client_key, iv.into_inner(), &key_material, tag)
},
)
.collect())
})
},
self.max_keys_per_req,
self.max_concurrent_reqs,
)
.await;
match &result {
Err(x) => {
trace!(target: "vitur_client::generate_keys", "failed with error: {x}");
}
Ok(x) => {
trace!(target: "vitur_client::generate_keys", "successfully generated {} keys", x.len());
}
}
result
}
pub async fn encrypt(
&self,
payloads: impl IntoIterator<Item = EncryptPayload<'_>>,
key: &ClientKey,
keyset_id: Option<Uuid>,
access_token: &str,
) -> Result<Vec<EncryptedRecord>, EncryptError> {
let payloads = payloads.into_iter().collect::<Vec<_>>();
trace!(target: "vitur_client::encrypt", "generating {} keys", payloads.len());
let keys = self
.generate_keys(
payloads.iter().map(GenerateKeyPayload::from),
key,
keyset_id,
access_token,
None,
)
.await?;
trace!(target: "vitur_client::encrypt", "generated {} keys - encrypting records", keys.len());
let output = payloads
.into_iter()
.zip(keys)
.map(|(payload, datakey_with_tag)| {
EncryptionTarget::new(payload, datakey_with_tag, keyset_id)
})
.map(encrypt)
.collect::<Result<Vec<_>, _>>()?;
trace!(target: "vitur_client::encrypt", "success - encrypted {} records", output.len());
Ok(output)
}
pub async fn encrypt_single(
&self,
payload: EncryptPayload<'_>,
key: &ClientKey,
keyset_id: Option<Uuid>,
access_token: &str,
) -> Result<EncryptedRecord, EncryptError> {
let mut vec = self
.encrypt([payload], key, keyset_id, access_token)
.await?;
debug_assert_eq!(vec.len(), 1);
Ok(vec.remove(0))
}
pub async fn decrypt<P>(
&self,
payloads: impl IntoIterator<Item = P>,
key: &ClientKey,
keyset_id: Option<Uuid>,
access_token: &str,
unverified_context: Option<&UnverifiedContext>,
) -> Result<Vec<Vec<u8>>, DecryptError>
where
P: Decryptable,
{
let payloads = payloads.into_iter().collect::<Vec<_>>();
let mut output: Vec<Option<Vec<u8>>> = vec![None; payloads.len()];
trace!(target: "vitur_client::decrypt", "retrieving keys");
let mut grouped_payloads = HashMap::<Option<Uuid>, Vec<(usize, P)>>::new();
for (index, record) in payloads.into_iter().enumerate() {
grouped_payloads
.entry(record.keyset_id())
.or_default()
.push((index, record))
}
for (payload_keyset_id, records) in grouped_payloads.into_iter() {
let payload_keyset_id = if keyset_id.is_some() {
keyset_id
} else {
payload_keyset_id
};
let payloads: Vec<RetrieveKeyPayload<'_>> = records
.iter()
.map(|(_, record)| record.retrieve_key_payload())
.try_collect()
.map_err(|e| DecryptError::Internal(e.to_string()))?;
let keys = self
.retrieve_keys(
payloads,
key,
payload_keyset_id,
access_token,
unverified_context,
)
.await
.inspect_err(|_err| {
trace!(target: "vitur_client::decrypt", "failed to retrieve keys");
})?;
trace!(target: "vitur_client::decrypt", "retrieved keys - decrypting records");
let keys_len = keys.len();
for (record, data_key) in records.into_iter().zip(keys) {
let (index, record) = record;
let target = record
.into_encrypted_record()
.map_err(|e| DecryptError::Internal(e.to_string()))
.map(|record| DecryptionTarget::new(record, data_key))?;
let plaintext = decrypt(target)?;
output[index] = Some(plaintext);
}
trace!(target: "vitur_client::decrypt", "decrypted {keys_len} records");
}
output
.into_iter()
.map(|record| {
record.ok_or_else(|| {
DecryptError::Internal(
"Record was None but everything should be Some(_) as this point"
.to_string(),
)
})
})
.collect()
}
pub async fn decrypt_fallible<'a, P>(
&self,
payloads: impl IntoIterator<Item = P>,
key: &ClientKey,
access_token: &str,
unverified_context: Option<Cow<'a, UnverifiedContext>>,
) -> Result<Vec<Result<Vec<u8>, RecordDecryptError>>, DecryptError>
where
P: Decryptable,
{
let payloads = payloads.into_iter().collect::<Vec<_>>();
let mut output: Vec<Option<Result<Vec<u8>, RecordDecryptError>>> =
vec![None; payloads.len()];
trace!(target: "vitur_client::decrypt", "retrieving keys");
let mut grouped_payloads = HashMap::<Option<Uuid>, Vec<(usize, P)>>::new();
for (index, record) in payloads.into_iter().enumerate() {
grouped_payloads
.entry(record.keyset_id())
.or_default()
.push((index, record))
}
for (keyset_id, records) in grouped_payloads.into_iter() {
let payloads: Vec<RetrieveKeyPayload<'_>> = records
.iter()
.map(|(_, record)| record.retrieve_key_payload())
.try_collect()
.map_err(|e| DecryptError::Internal(e.to_string()))?;
let keys = self
.retrieve_keys_fallible(
payloads,
key,
keyset_id,
access_token,
unverified_context.clone(),
)
.await
.inspect_err(|_| {
trace!(target: "vitur_client::decrypt", "failed to retrieve keys");
})?;
trace!(target: "vitur_client::decrypt", "retrieved keys - decrypting records");
let keys_len = keys.len();
for (record, data_key) in records.into_iter().zip(keys) {
let (index, record) = record;
let plaintext = record
.into_encrypted_record()
.map_err(|e| RecordDecryptError {
reason: e.to_string(),
})
.and_then(|record| {
data_key
.map(|data_key| {
DecryptionTarget::new(record, data_key)
})
.map_err(|e| RecordDecryptError {
reason: e.to_string(),
})
})
.and_then(decrypt);
output[index] = Some(plaintext);
}
trace!(target: "vitur_client::decrypt", "decrypted {keys_len} records");
}
output
.into_iter()
.map(|record| {
record.ok_or_else(|| {
DecryptError::Internal(
"Record was None but everything should be Some(_) as this point"
.to_string(),
)
})
})
.collect()
}
pub async fn decrypt_single<P>(
&self,
payload: P,
key: &ClientKey,
keyset_id: Option<Uuid>,
access_token: &str,
unverified_context: Option<&UnverifiedContext>,
) -> Result<Vec<u8>, DecryptError>
where
P: Decryptable,
{
let mut vec = self
.decrypt([payload], key, keyset_id, access_token, unverified_context)
.await?;
debug_assert_eq!(vec.len(), 1);
Ok(vec.remove(0))
}
pub(crate) async fn load_keyset(
&self,
client_key: &ClientKey,
keyset_id: Option<IdentifiedBy>,
access_token: &str,
) -> Result<(Keyset, IndexKey), LoadKeysetError> {
let req = LoadKeysetRequest {
client_id: client_key.key_id,
keyset_id,
};
let LoadKeysetResponse {
keyset,
partial_index_key,
} = self.connection.send(req, access_token).await?;
let data_key = IndexKey::from_key_material(client_key, &partial_index_key.key_material);
Ok((keyset, data_key))
}
}
pub struct DecryptionTarget {
record: EncryptedRecord,
key: DataKey,
aad: Vec<u8>,
}
impl DecryptionTarget {
pub(super) fn new(record: EncryptedRecord, key: DataKey) -> Self {
let descriptor_bytes = record.descriptor.as_bytes();
let mut aad = Vec::with_capacity(descriptor_bytes.len() + record.tag.len());
aad.extend_from_slice(descriptor_bytes);
aad.extend_from_slice(&record.tag);
Self { record, key, aad }
}
fn nonce(&self) -> &Nonce {
Nonce::from_slice(&self.key.iv[..12])
}
fn key(&self) -> &Key {
self.key.key()
}
}
impl<'e> From<&'e DecryptionTarget> for Payload<'e, 'e> {
fn from(target: &'e DecryptionTarget) -> Self {
Payload {
msg: &target.record.ciphertext,
aad: &target.aad,
}
}
}
pub fn encrypt(target: EncryptionTarget<'_>) -> Result<EncryptedRecord, EncryptError> {
let aes_key = AesKey::<Aes256GcmSiv>::from_slice(target.key());
Aes256GcmSiv::new(aes_key)
.encrypt(target.nonce(), &target)
.map_err(EncryptError::FailedToEncrypt)
.map(|ciphertext| {
let (iv, tag, descriptor, keyset_id) = target.into_meta();
EncryptedRecord {
iv,
ciphertext,
tag,
descriptor,
keyset_id,
}
})
}
pub fn decrypt(target: DecryptionTarget) -> Result<Vec<u8>, RecordDecryptError> {
let aes_key = AesKey::<Aes256GcmSiv>::from_slice(target.key());
Aes256GcmSiv::new(aes_key)
.decrypt(target.nonce(), &target)
.map_err(|e| RecordDecryptError {
reason: e.to_string(),
})
}
#[cfg(test)]
mod tests {
use super::{key::V1KeySet, *};
use fake::{Fake, Faker};
use recipher::keyset::{EncryptionKeySet, ProxyKeySet};
use test_connection::*;
use thiserror::Error;
use uuid::uuid;
use zerokms_protocol::*;
fn random_client_key() -> ClientKey {
let domain_key = EncryptionKeySet::generate().unwrap();
let authority_key = EncryptionKeySet::generate().unwrap();
let keyset = ProxyKeySet::generate(&authority_key, &domain_key);
ClientKey {
key_id: uuid!("00000000-0000-0000-0000-000000000000"),
keyset: V1KeySet(keyset),
}
}
fn build_client(
callback: impl FnOnce(TestConnectionBuilder) -> TestConnectionBuilder,
) -> Client<TestConnection> {
let builder = callback(TestConnectionBuilder::new());
let client_opts = ClientOpts {
max_keys_per_req: 10,
max_concurrent_reqs: 5,
connection_opts: builder,
};
Client::init_opts(client_opts).expect("Failed to initialize test client")
}
#[derive(Error, Debug)]
#[error("{0}")]
struct TestConnectionError(String);
impl PartialEq for IndexKey {
fn eq(&self, other: &Self) -> bool {
self.key() == other.key()
}
}
#[tokio::test]
async fn test_load_keyset() {
let keyset: Keyset = Faker.fake();
let keyset_id = keyset.id;
let client_key = random_client_key();
let key_bytes = fake::vec![u8; 528];
let key_client = build_client(|builder| {
builder.add_success_response::<LoadKeysetRequest>(LoadKeysetResponse {
keyset,
partial_index_key: RetrievedKey {
key_material: key_bytes.clone().into(),
},
})
});
let (keyset, data_key) = key_client
.load_keyset(&client_key, Some(keyset_id.into()), "token")
.await
.expect("Failed to load keyset");
let index_key_check = IndexKey::from_key_material(&client_key, &key_bytes.into());
assert_eq!(keyset.id, keyset_id);
assert_eq!(data_key, index_key_check);
}
#[tokio::test]
async fn test_load_keyset_missing_key() {
let keyset: Keyset = Faker.fake();
let keyset_id = keyset.id;
let client_key = random_client_key();
let key_client = build_client(|builder| {
builder.add_failed_response::<LoadKeysetRequest>(ViturRequestError::response(
"Keyset not found",
TestConnectionError("Status: 404, Body: Not Found".into()),
))
});
let result = key_client
.load_keyset(&client_key, Some(keyset_id.into()), "token")
.await;
assert!(matches!(result, Err(LoadKeysetError::RequestFailed(_))));
}
}