use std::collections::HashMap;
use std::time::Duration;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::ports::signing::{SigningError, SigningInput, SigningOutput, SigningPort};
#[cfg(test)]
use crate::ports::signing::ErasedSigningPort;
pub struct NoopSigningAdapter;
impl SigningPort for NoopSigningAdapter {
async fn sign(&self, _input: SigningInput) -> Result<SigningOutput, SigningError> {
Ok(SigningOutput::default())
}
}
#[derive(Debug, Clone)]
pub struct HttpSigningConfig {
pub endpoint: String,
pub timeout: Duration,
pub bearer_token: Option<String>,
pub extra_headers: HashMap<String, String>,
}
impl Default for HttpSigningConfig {
fn default() -> Self {
Self {
endpoint: "http://localhost:27042/sign".to_string(),
timeout: Duration::from_secs(10),
bearer_token: None,
extra_headers: HashMap::new(),
}
}
}
#[derive(Debug, Serialize)]
struct SignRequest {
method: String,
url: String,
headers: HashMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
body_b64: Option<String>,
context: serde_json::Value,
}
#[derive(Debug, Deserialize)]
struct SignResponse {
#[serde(default)]
headers: HashMap<String, String>,
#[serde(default)]
query_params: Vec<(String, String)>,
#[serde(default)]
body_b64: Option<String>,
}
pub struct HttpSigningAdapter {
config: HttpSigningConfig,
client: Client,
}
impl HttpSigningAdapter {
pub fn new(config: HttpSigningConfig) -> Self {
let client = Client::builder()
.timeout(config.timeout)
.build()
.unwrap_or_default();
Self { config, client }
}
}
impl SigningPort for HttpSigningAdapter {
async fn sign(&self, input: SigningInput) -> Result<SigningOutput, SigningError> {
let body_b64 = input.body.as_deref().map(base64_encode);
let req_body = SignRequest {
method: input.method,
url: input.url,
headers: input.headers,
body_b64,
context: input.context,
};
let mut req = self.client.post(&self.config.endpoint).json(&req_body);
if let Some(token) = &self.config.bearer_token {
req = req.bearer_auth(token);
}
for (k, v) in &self.config.extra_headers {
req = req.header(k, v);
}
let response = req.send().await.map_err(|e| {
if e.is_timeout() {
SigningError::Timeout(
self.config
.timeout
.as_millis()
.try_into()
.unwrap_or(u64::MAX),
)
} else {
SigningError::BackendUnavailable(e.to_string())
}
})?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.text().await.unwrap_or_default();
return Err(SigningError::InvalidResponse(format!(
"sidecar returned HTTP {status}: {body}"
)));
}
let sign_resp: SignResponse = response
.json()
.await
.map_err(|e| SigningError::InvalidResponse(e.to_string()))?;
let body_override = sign_resp
.body_b64
.map(|b64| base64_decode(&b64))
.transpose()
.map_err(|e| SigningError::InvalidResponse(format!("base64 decode failed: {e}")))?;
Ok(SigningOutput {
headers: sign_resp.headers,
query_params: sign_resp.query_params,
body_override,
})
}
}
fn base64_encode(input: &[u8]) -> String {
use std::fmt::Write;
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
for chunk in input.chunks(3) {
let b0 = usize::from(*chunk.first().unwrap_or(&0));
let b1 = if chunk.len() > 1 {
usize::from(*chunk.get(1).unwrap_or(&0))
} else {
0
};
let b2 = if chunk.len() > 2 {
usize::from(*chunk.get(2).unwrap_or(&0))
} else {
0
};
let first = TABLE.get(b0 >> 2).copied().unwrap_or_default();
let second = TABLE
.get(((b0 & 3) << 4) | (b1 >> 4))
.copied()
.unwrap_or_default();
let _ = write!(out, "{}", char::from(first));
let _ = write!(out, "{}", char::from(second));
if chunk.len() > 1 {
let third = TABLE
.get(((b1 & 0xf) << 2) | (b2 >> 6))
.copied()
.unwrap_or_default();
let _ = write!(out, "{}", char::from(third));
} else {
out.push('=');
}
if chunk.len() > 2 {
let fourth = TABLE.get(b2 & 0x3f).copied().unwrap_or_default();
let _ = write!(out, "{}", char::from(fourth));
} else {
out.push('=');
}
}
out
}
fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
let input = input.trim_end_matches('=');
let mut out = Vec::with_capacity(input.len() * 3 / 4 + 1);
let decode_char = |c: u8| -> Result<u8, String> {
match c {
b'A'..=b'Z' => Ok(c - b'A'),
b'a'..=b'z' => Ok(c - b'a' + 26),
b'0'..=b'9' => Ok(c - b'0' + 52),
b'+' => Ok(62),
b'/' => Ok(63),
_ => Err(format!("invalid base64 char: {c}")),
}
};
let bytes = input.as_bytes();
let mut i = 0;
while i + 1 < bytes.len() {
let v0 = decode_char(*bytes.get(i).unwrap_or(&0))?;
let v1 = decode_char(*bytes.get(i + 1).unwrap_or(&0))?;
out.push((v0 << 2) | (v1 >> 4));
if i + 2 < bytes.len() {
let v2 = decode_char(*bytes.get(i + 2).unwrap_or(&0))?;
out.push(((v1 & 0xf) << 4) | (v2 >> 2));
if i + 3 < bytes.len() {
let v3 = decode_char(*bytes.get(i + 3).unwrap_or(&0))?;
out.push(((v2 & 3) << 6) | v3);
}
}
i += 4;
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn noop_returns_empty_output() -> std::result::Result<(), Box<dyn std::error::Error>> {
let signer = NoopSigningAdapter;
let output = signer
.sign(SigningInput {
method: "GET".to_string(),
url: "https://example.com".to_string(),
headers: HashMap::new(),
body: None,
context: json!({}),
})
.await?;
assert!(output.headers.is_empty());
assert!(output.query_params.is_empty());
assert!(output.body_override.is_none());
Ok(())
}
#[tokio::test]
async fn noop_is_erased_signing_port() -> std::result::Result<(), Box<dyn std::error::Error>> {
let signer: std::sync::Arc<dyn ErasedSigningPort> = std::sync::Arc::new(NoopSigningAdapter);
let output = signer
.erased_sign(SigningInput {
method: "POST".to_string(),
url: "https://api.example.com/data".to_string(),
headers: HashMap::new(),
body: Some(b"{\"key\":\"val\"}".to_vec()),
context: json!({"session": "abc"}),
})
.await?;
assert!(output.headers.is_empty());
Ok(())
}
#[test]
fn base64_roundtrip() -> std::result::Result<(), Box<dyn std::error::Error>> {
let input = b"Hello, Stygian signing!";
let encoded = base64_encode(input);
let decoded = base64_decode(&encoded)
.map_err(|e| std::io::Error::other(format!("base64 decode failed: {e}")))?;
assert_eq!(decoded, input);
Ok(())
}
#[test]
fn base64_encode_known_value() {
assert_eq!(base64_encode(b"Man"), "TWFu");
assert_eq!(base64_encode(b"Ma"), "TWE=");
assert_eq!(base64_encode(b"M"), "TQ==");
}
}