use crate::zerokms::vitur_client::connection::ZeroKMSConnectionInit;
use async_mutex::Mutex;
use zerokms_protocol::{ViturRequest, ViturRequestError};
use super::connection::ZeroKMSConnection;
type EffectHandlers = Vec<(String, Box<dyn FnOnce(&str) + Send>)>;
type RequestHandlers = Vec<(String, Result<String, ViturRequestError>)>;
pub struct TestConnectionBuilder {
handlers: RequestHandlers,
effects: EffectHandlers,
}
impl TestConnectionBuilder {
pub fn new() -> Self {
Self {
handlers: vec![],
effects: vec![],
}
}
pub fn add_success_response<R: ViturRequest>(mut self, response: R::Response) -> Self {
self.handlers.push((
R::ENDPOINT.to_string(),
Ok(serde_json::to_string(&response)
.expect("Failed to serialise success response. This shouldn't happen.")),
));
self
}
pub fn add_failed_response<R: ViturRequest>(mut self, error: ViturRequestError) -> Self {
self.handlers.push((R::ENDPOINT.to_string(), Err(error)));
self
}
pub fn add_effect<R: ViturRequest, H: FnOnce(R) + Send + 'static>(
mut self,
handler: H,
) -> Self {
let endpoint = R::ENDPOINT;
self.effects.push((
endpoint.to_string(),
Box::new(move |message| {
handler(serde_json::from_str(message).expect(
"Failed to parse request from message in test effect. This shouldn't happen.",
))
}),
));
self
}
pub fn build(self) -> TestConnection {
TestConnection {
handlers: Mutex::new(self.handlers),
effects: Mutex::new(self.effects),
}
}
}
impl Default for TestConnectionBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct TestConnection {
handlers: Mutex<RequestHandlers>,
effects: Mutex<EffectHandlers>,
}
impl TestConnection {
pub fn builder() -> TestConnectionBuilder {
TestConnectionBuilder::new()
}
pub fn empty() -> Self {
Self::builder().build()
}
}
impl ZeroKMSConnectionInit for TestConnection {
type ConnectionOpts = TestConnectionBuilder;
type Error = std::convert::Infallible;
fn init(builder: Self::ConnectionOpts) -> Result<Self, Self::Error> {
Ok(builder.build())
}
}
impl ZeroKMSConnection for TestConnection {
async fn send<Request: ViturRequest>(
&self,
request: Request,
_access_token: &str,
) -> Result<Request::Response, ViturRequestError> {
let endpoint = Request::ENDPOINT;
let mut effect_guard = self.effects.lock().await;
let effect_position = effect_guard.iter().position(|(x, _)| x == endpoint);
let body = serde_json::to_string(&request)
.expect("Failed to serialise request body in test connection");
if let Some(index) = effect_position {
let (_, effect) = effect_guard.remove(index);
effect(&body);
}
let mut handler_guard = self.handlers.lock().await;
let index = handler_guard
.iter()
.position(|(x, _)| x == endpoint)
.unwrap_or_else(|| panic!("No handler defined for request: {endpoint}"));
let (_, body) = handler_guard.remove(index);
body.map(|x| {
serde_json::from_str(&x)
.expect("Failed to parse response body from handler in test connection")
})
}
}
mod tests {
use super::*;
use uuid::uuid;
use zerokms_protocol::{GenerateKeyRequest, GenerateKeyResponse};
#[tokio::test]
async fn test_success_response() {
let conn = TestConnection::builder()
.add_success_response::<GenerateKeyRequest>(GenerateKeyResponse { keys: [].into() })
.build();
let res = conn
.send(
GenerateKeyRequest {
keys: vec![].into(),
client_id: uuid!("00000000-0000-0000-0000-000000000000"),
keyset_id: None,
unverified_context: Default::default(),
},
"access-token",
)
.await
.unwrap();
assert_eq!(res.keys.len(), 0);
}
#[tokio::test]
async fn test_failure_response() {
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
struct MyError(&'static str);
let conn = TestConnection::builder()
.add_failed_response::<GenerateKeyRequest>(ViturRequestError::other(
"Oh no",
MyError("Oops"),
))
.build();
let err = conn
.send(
GenerateKeyRequest {
keys: vec![].into(),
client_id: uuid!("00000000-0000-0000-0000-000000000000"),
keyset_id: None,
unverified_context: Default::default(),
},
"access-token",
)
.await
.expect_err("Expected request to fail");
assert_eq!(err.to_string(), "Other: Oh no: Oops");
}
#[tokio::test]
#[should_panic]
async fn test_panic_if_no_handler() {
let conn = TestConnection::empty();
let _ = conn
.send(
GenerateKeyRequest {
keys: vec![].into(),
client_id: uuid!("00000000-0000-0000-0000-000000000000"),
keyset_id: None,
unverified_context: Default::default(),
},
"access-token",
)
.await;
}
#[tokio::test]
#[should_panic]
async fn test_assert_in_effect_fails_test() {
let conn = TestConnection::builder()
.add_success_response::<GenerateKeyRequest>(GenerateKeyResponse { keys: [].into() })
.add_effect(|_: GenerateKeyRequest| panic!("Effect should fail!"))
.build();
let _ = conn
.send(
GenerateKeyRequest {
keys: vec![].into(),
client_id: uuid!("00000000-0000-0000-0000-000000000000"),
keyset_id: None,
unverified_context: Default::default(),
},
"access-token",
)
.await;
}
}