use base64::Engine as _;
use bhttp::{Message, Mode};
use ohttp::ClientRequest;
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use crate::AuthenticatorError;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct OhttpClientConfig {
pub relay_url: String,
pub key_config_base64: String,
}
impl OhttpClientConfig {
pub fn new(relay_url: String, key_config_base64: String) -> Self {
Self {
relay_url,
key_config_base64,
}
}
}
#[derive(Debug)]
pub struct OhttpResponse {
pub status: StatusCode,
pub body: Vec<u8>,
}
#[derive(Clone, Debug)]
pub struct OhttpClient {
client: Client,
relay_url: String,
target_scheme: String,
target_authority: String,
encoded_config_list: Vec<u8>,
}
impl OhttpClient {
pub fn new(
client: Client,
config_scope: &str,
target_url: &str,
config: OhttpClientConfig,
) -> Result<Self, AuthenticatorError> {
let (target_scheme, target_authority) =
target_url
.split_once("://")
.ok_or_else(|| AuthenticatorError::InvalidConfig {
attribute: format!("{config_scope}.target_url"),
reason: format!("expected scheme://authority, got {:?}", target_url),
})?;
let target_scheme = target_scheme.to_owned();
let target_authority = target_authority.trim_end_matches('/').to_owned();
let attribute = format!("{config_scope}.key_config_base64");
let encoded_config_list = base64::engine::general_purpose::STANDARD
.decode(&config.key_config_base64)
.map_err(|err| AuthenticatorError::InvalidConfig {
attribute: attribute.clone(),
reason: format!("invalid base64: {err}"),
})?;
ClientRequest::from_encoded_config_list(&encoded_config_list).map_err(|err| {
AuthenticatorError::InvalidConfig {
attribute,
reason: format!("invalid application/ohttp-keys payload: {err}"),
}
})?;
Ok(Self {
client,
relay_url: config.relay_url,
target_scheme,
target_authority,
encoded_config_list,
})
}
pub async fn post_json<T: serde::Serialize>(
&self,
path: &str,
body: &T,
) -> Result<OhttpResponse, AuthenticatorError> {
let body = serde_json::to_vec(body).map_err(|e| {
AuthenticatorError::Generic(format!("failed to serialize request body: {e}"))
})?;
self.request(b"POST", path, Some(&body)).await
}
pub async fn get(&self, path: &str) -> Result<OhttpResponse, AuthenticatorError> {
self.request(b"GET", path, None).await
}
async fn request(
&self,
method: &[u8],
path: &str,
body: Option<&[u8]>,
) -> Result<OhttpResponse, AuthenticatorError> {
let mut msg = Message::request(
method.to_vec(),
self.target_scheme.as_bytes().to_vec(),
self.target_authority.as_bytes().to_vec(),
path.as_bytes().to_vec(),
);
if let Some(body) = body {
msg.put_header("content-type", "application/json");
msg.write_content(body);
}
let mut bhttp_buf = Vec::new();
msg.write_bhttp(Mode::KnownLength, &mut bhttp_buf)?;
let ohttp_req = ClientRequest::from_encoded_config_list(&self.encoded_config_list)?;
let (enc_request, ohttp_resp_ctx) = ohttp_req.encapsulate(&bhttp_buf)?;
let resp = self
.client
.post(&self.relay_url)
.header("content-type", "message/ohttp-req")
.body(enc_request)
.send()
.await?;
if !resp.status().is_success() {
return Err(AuthenticatorError::OhttpRelayError {
status: resp.status(),
body: resp.text().await.unwrap_or_default(),
});
}
let enc_response = resp.bytes().await?;
let response_buf = ohttp_resp_ctx.decapsulate(&enc_response)?;
let response_msg = Message::read_bhttp(&mut std::io::Cursor::new(&response_buf))?;
let status_code = response_msg
.control()
.status()
.map(|s| s.code())
.ok_or_else(|| {
AuthenticatorError::Generic("OHTTP response missing HTTP status line".into())
})?;
let status = StatusCode::from_u16(status_code).map_err(|_| bhttp::Error::InvalidStatus)?;
Ok(OhttpResponse {
status,
body: response_msg.content().to_vec(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::AuthenticatorError;
#[test]
fn invalid_base64_key_config_returns_invalid_config() {
let config = OhttpClientConfig::new(
"http://localhost:1234".into(),
"not valid base64 !!!".into(),
);
let result = OhttpClient::new(
reqwest::Client::new(),
"test_scope",
"https://localhost:9999",
config,
);
match result {
Err(AuthenticatorError::InvalidConfig { attribute, reason }) => {
assert_eq!(attribute, "test_scope.key_config_base64");
assert!(
reason.contains("invalid base64"),
"unexpected reason: {reason}"
);
}
other => panic!("expected InvalidConfig, got: {other:?}"),
}
}
#[test]
fn invalid_ohttp_keys_payload_returns_invalid_config() {
let config = OhttpClientConfig::new(
"http://localhost:1234".into(),
base64::engine::general_purpose::STANDARD
.encode(b"definitely not an ohttp-keys payload"),
);
let result = OhttpClient::new(
reqwest::Client::new(),
"my_scope",
"https://localhost:9999",
config,
);
match result {
Err(AuthenticatorError::InvalidConfig { attribute, reason }) => {
assert_eq!(attribute, "my_scope.key_config_base64");
assert!(
reason.contains("invalid application/ohttp-keys payload"),
"unexpected reason: {reason}"
);
}
other => panic!("expected InvalidConfig, got: {other:?}"),
}
}
#[test]
fn garbage_ohttp_keys_bytes_returns_invalid_config() {
let config = OhttpClientConfig::new(
"http://127.0.0.1:0/does-not-exist".into(),
base64::engine::general_purpose::STANDARD.encode(b"not-a-valid-ohttp-keys"),
);
let result = OhttpClient::new(
reqwest::Client::new(),
"test",
"http://localhost:1234",
config,
);
assert!(
matches!(result, Err(AuthenticatorError::InvalidConfig { .. })),
"expected InvalidConfig for garbage key config, got: {result:?}"
);
}
#[test]
fn missing_scheme_in_target_url_returns_invalid_config() {
let config = OhttpClientConfig::new(
"http://localhost:1234".into(),
base64::engine::general_purpose::STANDARD.encode(b"irrelevant"),
);
let result = OhttpClient::new(
reqwest::Client::new(),
"test_scope",
"localhost:9999",
config,
);
match result {
Err(AuthenticatorError::InvalidConfig { attribute, reason }) => {
assert_eq!(attribute, "test_scope.target_url");
assert!(
reason.contains("expected scheme://authority"),
"unexpected reason: {reason}"
);
}
other => panic!("expected InvalidConfig, got: {other:?}"),
}
}
#[test]
fn empty_key_config_returns_invalid_config() {
let config = OhttpClientConfig::new(
"http://localhost:1234".into(),
base64::engine::general_purpose::STANDARD.encode(b""),
);
let result = OhttpClient::new(
reqwest::Client::new(),
"test_scope",
"https://localhost:9999",
config,
);
assert!(
matches!(result, Err(AuthenticatorError::InvalidConfig { .. })),
"expected InvalidConfig for empty key config, got: {result:?}"
);
}
}