use std::sync::Arc;
use std::time::Duration;
use osproxy_core::{Clock, IndexName, PartitionId, PrincipalId};
use osproxy_observe::{DiagLevel, DiagnosticsDirective, DirectiveMatch, DirectiveVerifier};
use serde_json::Value;
#[cfg(all(feature = "fips", feature = "non-fips"))]
compile_error!(
"features `fips` and `non-fips` are mutually exclusive; build with \
`--no-default-features --features fips` for a FIPS artifact"
);
#[cfg(not(any(feature = "fips", feature = "non-fips")))]
compile_error!("enable exactly one crypto provider feature: `fips` or `non-fips`");
#[cfg(feature = "fips")]
use aws_lc_rs::hmac;
#[cfg(feature = "non-fips")]
use ring::hmac;
pub struct HmacDirectiveVerifier {
key: hmac::Key,
clock: Arc<dyn Clock>,
}
impl std::fmt::Debug for HmacDirectiveVerifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HmacDirectiveVerifier")
.finish_non_exhaustive()
}
}
impl HmacDirectiveVerifier {
#[must_use]
pub fn new(secret: &[u8], clock: Arc<dyn Clock>) -> Self {
Self {
key: hmac::Key::new(hmac::HMAC_SHA256, secret),
clock,
}
}
fn to_directive(&self, payload: &[u8]) -> Option<DiagnosticsDirective> {
let v: Value = serde_json::from_slice(payload).ok()?;
let level = parse_level(v.get("level")?.as_str()?)?;
let exp = v.get("exp")?.as_u64()?;
let now_secs = self.clock.unix_nanos() / 1_000_000_000;
let remaining = exp.checked_sub(now_secs)?;
if remaining == 0 {
return None;
}
let expires_at = self
.clock
.now()
.saturating_add(Duration::from_secs(remaining));
let mut match_ = DirectiveMatch::all();
if let Some(t) = v.get("tenant").and_then(Value::as_str) {
match_ = match_.for_tenant(PartitionId::from(t));
}
if let Some(i) = v.get("index").and_then(Value::as_str) {
match_ = match_.for_index(IndexName::from(i));
}
if let Some(p) = v.get("principal").and_then(Value::as_str) {
match_ = match_.for_principal(PrincipalId::from(p));
}
let sample_per_mille = match v.get("sample_per_mille") {
None => 1000,
Some(n) => match n.as_u64() {
Some(n) if n <= 1000 => u16::try_from(n).unwrap_or(1000),
_ => return None,
},
};
Some(DiagnosticsDirective {
id: "x-debug-header".to_owned(),
match_,
level,
sample_per_mille,
expires_at,
ring_buffer: v
.get("ring_buffer")
.and_then(Value::as_bool)
.unwrap_or(false),
capture: v.get("capture").and_then(Value::as_bool).unwrap_or(false),
})
}
}
impl DirectiveVerifier for HmacDirectiveVerifier {
fn verify(&self, header_value: &str) -> Option<DiagnosticsDirective> {
let (payload_hex, sig_hex) = header_value.split_once('.')?;
let payload = decode_hex(payload_hex)?;
let sig = decode_hex(sig_hex)?;
hmac::verify(&self.key, &payload, &sig).ok()?;
self.to_directive(&payload)
}
}
pub(crate) fn parse_level(name: &str) -> Option<DiagLevel> {
DiagLevel::from_name(name)
}
fn decode_hex(s: &str) -> Option<Vec<u8>> {
if !s.len().is_multiple_of(2) {
return None;
}
let bytes = s.as_bytes();
let mut out = Vec::with_capacity(s.len() / 2);
let mut i = 0;
while i < bytes.len() {
let hi = (bytes[i] as char).to_digit(16)?;
let lo = (bytes[i + 1] as char).to_digit(16)?;
out.push(u8::try_from(hi * 16 + lo).ok()?);
i += 2;
}
Some(out)
}
#[cfg(test)]
#[must_use]
pub(crate) fn encode_hex(bytes: &[u8]) -> String {
let mut out = String::with_capacity(bytes.len() * 2);
for &b in bytes {
out.push(char::from_digit(u32::from(b >> 4), 16).unwrap_or('0'));
out.push(char::from_digit(u32::from(b & 0x0f), 16).unwrap_or('0'));
}
out
}
#[cfg(test)]
#[must_use]
pub(crate) fn sign_token(secret: &[u8], payload: &[u8]) -> String {
let key = hmac::Key::new(hmac::HMAC_SHA256, secret);
let tag = hmac::sign(&key, payload);
format!("{}.{}", encode_hex(payload), encode_hex(tag.as_ref()))
}
#[cfg(test)]
#[path = "directive_tests.rs"]
mod tests;