pub mod h2;
pub mod http;
pub mod sni;
pub mod spawn;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use cellos_core::{CdnProvider, CloudEventV1, ExecutionCellSpec};
use serde_json::{json, Map, Value};
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream};
pub const PEEK_BUF_LEN: usize = 16 * 1024;
#[derive(Debug, Clone)]
pub struct SniProxyConfig {
pub bind_addr: SocketAddr,
pub upstream_addr: SocketAddr,
pub hostname_allowlist: Vec<String>,
pub cdn_providers: Vec<CdnProvider>,
pub cell_id: String,
pub run_id: String,
pub policy_digest: Option<String>,
pub keyset_id: Option<String>,
pub issuer_kid: Option<String>,
pub correlation_id: Option<String>,
pub upstream_resolver_id: String,
pub peek_timeout: Duration,
}
pub trait L7DecisionEmitter: Send + Sync + 'static {
fn emit(&self, event: CloudEventV1);
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub struct ProxyStats {
pub connections_total: u64,
pub connections_allowed: u64,
pub connections_denied: u64,
pub peek_timeouts: u64,
pub upstream_failures: u64,
}
mod reason_code {
pub const SNI_ALLOWLIST_MATCH: &str = "l7_sni_allowlist_match";
pub const SNI_ALLOWLIST_MISS: &str = "l7_sni_allowlist_miss";
pub const SNI_MISSING: &str = "l7_sni_missing";
pub const HTTP_HOST_ALLOWLIST_MATCH: &str = "l7_http_host_allowlist_match";
pub const HTTP_HOST_ALLOWLIST_MISS: &str = "l7_http_host_allowlist_miss";
pub const HTTP_HOST_MISSING: &str = "l7_http_host_missing";
pub const UNKNOWN_PROTOCOL: &str = "l7_unknown_protocol";
pub const PEEK_TIMEOUT: &str = "l7_peek_timeout";
pub const H2_AUTHORITY_ALLOWLIST_MATCH: &str = "l7_h2_authority_allowlist_match";
pub const H2_AUTHORITY_ALLOWLIST_MISS: &str = "l7_h2_authority_allowlist_miss";
pub const H2_AUTHORITY_MISSING: &str = "l7_h2_authority_missing";
pub const H2_UNPARSEABLE_HEADERS: &str = "l7_h2_unparseable_headers";
pub const H2_AUTHORITY_ALLOWLIST_MATCH_HUFFMAN: &str =
"l7_h2_authority_allowlist_match_huffman";
pub const H2_AUTHORITY_ALLOWLIST_MISS_HUFFMAN: &str = "l7_h2_authority_allowlist_miss_huffman";
pub const H2_AUTHORITY_ALLOWLIST_MATCH_DYNAMIC_INDEXED: &str =
"l7_h2_authority_allowlist_match_dynamic_indexed";
pub const H2_AUTHORITY_ALLOWLIST_MISS_DYNAMIC_INDEXED: &str =
"l7_h2_authority_allowlist_miss_dynamic_indexed";
}
fn h2_reason_codes_for(provenance: h2::AuthorityProvenance) -> (&'static str, &'static str) {
match provenance {
h2::AuthorityProvenance::StaticIndexed | h2::AuthorityProvenance::StaticLiteral => (
reason_code::H2_AUTHORITY_ALLOWLIST_MATCH,
reason_code::H2_AUTHORITY_ALLOWLIST_MISS,
),
h2::AuthorityProvenance::DynamicIndexed => (
reason_code::H2_AUTHORITY_ALLOWLIST_MATCH_DYNAMIC_INDEXED,
reason_code::H2_AUTHORITY_ALLOWLIST_MISS_DYNAMIC_INDEXED,
),
h2::AuthorityProvenance::Huffman => (
reason_code::H2_AUTHORITY_ALLOWLIST_MATCH_HUFFMAN,
reason_code::H2_AUTHORITY_ALLOWLIST_MISS_HUFFMAN,
),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ProtocolGuess {
Tls,
H2c,
Http1,
Unknown,
}
fn guess_protocol(buf: &[u8]) -> ProtocolGuess {
if buf.is_empty() {
return ProtocolGuess::Unknown;
}
if buf[0] == 22 {
return ProtocolGuess::Tls;
}
if h2::is_h2c_preface(buf) {
return ProtocolGuess::H2c;
}
const METHODS: &[&[u8]] = &[
b"GET ",
b"POST ",
b"HEAD ",
b"PUT ",
b"DELETE ",
b"OPTIONS ",
b"PATCH ",
b"CONNECT ",
];
for m in METHODS {
if buf.len() >= m.len() && &buf[..m.len()] == *m {
return ProtocolGuess::Http1;
}
}
ProtocolGuess::Unknown
}
pub async fn run_one_shot(
cfg: &SniProxyConfig,
listener: TcpListener,
emitter: Arc<dyn L7DecisionEmitter>,
shutdown: Arc<AtomicBool>,
) -> std::io::Result<ProxyStats> {
let total = Arc::new(std::sync::atomic::AtomicU64::new(0));
let allowed = Arc::new(std::sync::atomic::AtomicU64::new(0));
let denied = Arc::new(std::sync::atomic::AtomicU64::new(0));
let timeouts = Arc::new(std::sync::atomic::AtomicU64::new(0));
let upstream_failures = Arc::new(std::sync::atomic::AtomicU64::new(0));
let event_spec = build_event_spec(cfg);
while !shutdown.load(Ordering::SeqCst) {
let (stream, peer) = match listener.accept().await {
Ok(t) => t,
Err(e) => {
if shutdown.load(Ordering::SeqCst) {
break;
}
tracing::warn!(
target: "cellos.supervisor.sni_proxy",
error = %e,
"accept() failed"
);
continue;
}
};
if shutdown.load(Ordering::SeqCst) {
drop(stream);
break;
}
total.fetch_add(1, Ordering::SeqCst);
let cfg = cfg.clone();
let emitter = emitter.clone();
let event_spec = event_spec.clone();
let allowed = allowed.clone();
let denied = denied.clone();
let timeouts = timeouts.clone();
let upstream_failures = upstream_failures.clone();
tokio::spawn(async move {
handle_connection(
stream,
peer,
cfg,
emitter,
event_spec,
allowed,
denied,
timeouts,
upstream_failures,
)
.await;
});
}
Ok(ProxyStats {
connections_total: total.load(Ordering::SeqCst),
connections_allowed: allowed.load(Ordering::SeqCst),
connections_denied: denied.load(Ordering::SeqCst),
peek_timeouts: timeouts.load(Ordering::SeqCst),
upstream_failures: upstream_failures.load(Ordering::SeqCst),
})
}
#[allow(clippy::too_many_arguments)]
async fn handle_connection(
mut stream: TcpStream,
_peer: SocketAddr,
cfg: SniProxyConfig,
emitter: Arc<dyn L7DecisionEmitter>,
event_spec: ExecutionCellSpec,
allowed: Arc<std::sync::atomic::AtomicU64>,
denied: Arc<std::sync::atomic::AtomicU64>,
timeouts: Arc<std::sync::atomic::AtomicU64>,
upstream_failures: Arc<std::sync::atomic::AtomicU64>,
) {
let mut buf = vec![0u8; PEEK_BUF_LEN];
let peek_result = tokio::time::timeout(cfg.peek_timeout, stream.peek(&mut buf)).await;
let n = match peek_result {
Err(_elapsed) => {
timeouts.fetch_add(1, Ordering::SeqCst);
denied.fetch_add(1, Ordering::SeqCst);
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
"",
reason_code::PEEK_TIMEOUT,
None,
None,
);
return;
}
Ok(Err(e)) => {
tracing::debug!(
target: "cellos.supervisor.sni_proxy",
error = %e,
"peek() error before any bytes"
);
denied.fetch_add(1, Ordering::SeqCst);
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
"",
reason_code::UNKNOWN_PROTOCOL,
None,
None,
);
return;
}
Ok(Ok(0)) => {
denied.fetch_add(1, Ordering::SeqCst);
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
"",
reason_code::UNKNOWN_PROTOCOL,
None,
None,
);
return;
}
Ok(Ok(n)) => n,
};
let preamble = &buf[..n];
let guess = guess_protocol(preamble);
let (host_opt, allow_reason, miss_reason, missing_reason, deny_response): (
Option<String>,
&'static str,
&'static str,
&'static str,
DenyResponse,
) = match guess {
ProtocolGuess::Tls => match sni::extract_sni(preamble) {
Ok(opt) => (
opt,
reason_code::SNI_ALLOWLIST_MATCH,
reason_code::SNI_ALLOWLIST_MISS,
reason_code::SNI_MISSING,
DenyResponse::Drop,
),
Err(_e) => {
denied.fetch_add(1, Ordering::SeqCst);
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
"",
reason_code::UNKNOWN_PROTOCOL,
None,
None,
);
return;
}
},
ProtocolGuess::H2c => {
handle_h2c_connection(
stream,
preamble.to_vec(),
cfg.clone(),
emitter.clone(),
event_spec.clone(),
allowed.clone(),
denied.clone(),
upstream_failures.clone(),
)
.await;
return;
}
ProtocolGuess::Http1 => match http::extract_http_host(preamble) {
Ok(opt) => (
opt,
reason_code::HTTP_HOST_ALLOWLIST_MATCH,
reason_code::HTTP_HOST_ALLOWLIST_MISS,
reason_code::HTTP_HOST_MISSING,
DenyResponse::Http403,
),
Err(_e) => {
denied.fetch_add(1, Ordering::SeqCst);
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
"",
reason_code::UNKNOWN_PROTOCOL,
None,
None,
);
return;
}
},
ProtocolGuess::Unknown => {
denied.fetch_add(1, Ordering::SeqCst);
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
"",
reason_code::UNKNOWN_PROTOCOL,
None,
None,
);
return;
}
};
let host = match host_opt {
Some(h) if !h.is_empty() => h,
_ => {
denied.fetch_add(1, Ordering::SeqCst);
send_deny_response(&mut stream, deny_response).await;
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
"",
missing_reason,
None,
None,
);
return;
}
};
if !cellos_core::hostname_allowlist::matches_allowlist(&host, &cfg.hostname_allowlist) {
denied.fetch_add(1, Ordering::SeqCst);
send_deny_response(&mut stream, deny_response).await;
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
host.as_str(),
miss_reason,
None,
None,
);
return;
}
let upstream = match TcpStream::connect(cfg.upstream_addr).await {
Ok(s) => s,
Err(e) => {
tracing::warn!(
target: "cellos.supervisor.sni_proxy",
error = %e,
upstream = %cfg.upstream_addr,
"upstream connect failed"
);
upstream_failures.fetch_add(1, Ordering::SeqCst);
denied.fetch_add(1, Ordering::SeqCst);
send_deny_response(&mut stream, deny_response).await;
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
host.as_str(),
reason_code::UNKNOWN_PROTOCOL,
None,
None,
);
return;
}
};
allowed.fetch_add(1, Ordering::SeqCst);
emit_decision(
&emitter,
&cfg,
&event_spec,
"allow",
host.as_str(),
allow_reason,
None,
None,
);
let mut client = stream;
let mut up = upstream;
if let Err(e) = tokio::io::copy_bidirectional(&mut client, &mut up).await {
tracing::debug!(
target: "cellos.supervisor.sni_proxy",
error = %e,
host = %host,
"copy_bidirectional ended with error"
);
}
}
#[allow(clippy::too_many_arguments)]
enum UpstreamSink {
Pending { addr: SocketAddr, buffer: Vec<u8> },
Open {
rd: tokio::net::tcp::OwnedReadHalf,
wr: tokio::net::tcp::OwnedWriteHalf,
},
}
impl UpstreamSink {
fn pending(addr: SocketAddr) -> Self {
Self::Pending {
addr,
buffer: Vec::new(),
}
}
async fn write(&mut self, bytes: &[u8]) -> std::io::Result<()> {
match self {
Self::Pending { buffer, .. } => {
buffer.extend_from_slice(bytes);
Ok(())
}
Self::Open { wr, .. } => wr.write_all(bytes).await,
}
}
async fn commit(&mut self) -> std::io::Result<()> {
if let Self::Pending { addr, buffer } = self {
let stream = TcpStream::connect(*addr).await?;
let (rd, mut wr) = stream.into_split();
if !buffer.is_empty() {
wr.write_all(buffer).await?;
}
*self = Self::Open { rd, wr };
}
Ok(())
}
async fn shutdown(&mut self) {
if let Self::Open { wr, .. } = self {
let _ = wr.shutdown().await;
}
}
}
async fn read_upstream(sink: &mut UpstreamSink, buf: &mut [u8]) -> std::io::Result<usize> {
use tokio::io::AsyncReadExt;
match sink {
UpstreamSink::Open { rd, .. } => rd.read(buf).await,
UpstreamSink::Pending { .. } => std::future::pending().await,
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_h2c_connection(
stream: TcpStream,
preamble: Vec<u8>,
cfg: SniProxyConfig,
emitter: Arc<dyn L7DecisionEmitter>,
event_spec: ExecutionCellSpec,
allowed: Arc<std::sync::atomic::AtomicU64>,
denied: Arc<std::sync::atomic::AtomicU64>,
upstream_failures: Arc<std::sync::atomic::AtomicU64>,
) {
use tokio::io::AsyncReadExt;
let mut upstream = UpstreamSink::pending(cfg.upstream_addr);
let mut client = stream;
let mut consumed = vec![0u8; preamble.len()];
if let Err(e) = client.read_exact(&mut consumed).await {
tracing::debug!(
target: "cellos.supervisor.sni_proxy",
error = %e,
"h2c read_exact(preamble) failed"
);
return;
}
debug_assert_eq!(
consumed, preamble,
"kernel returned different bytes than peek"
);
let (mut client_rd, mut client_wr) = client.into_split();
if let Err(e) = upstream.write(&preamble[..h2::HTTP2_PREFACE.len()]).await {
tracing::debug!(
target: "cellos.supervisor.sni_proxy",
error = %e,
"h2c buffer(preface) failed"
);
return;
}
let mut decoder = h2::H2ConnectionDecoder::new();
let mut denied_streams: std::collections::HashSet<u32> = std::collections::HashSet::new();
let mut pending_headers: std::collections::HashMap<u32, Vec<u8>> =
std::collections::HashMap::new();
let mut frame_buf: Vec<u8> = preamble[h2::HTTP2_PREFACE.len()..].to_vec();
let mut read_chunk = vec![0u8; 16 * 1024];
let mut up_read_chunk = vec![0u8; 16 * 1024];
'outer: loop {
loop {
match h2::frame::parse_one_frame(&frame_buf) {
Ok(Some((header, payload, _rest))) => {
let frame_total = 9 + header.length as usize;
let frame_bytes = frame_buf[..frame_total].to_vec();
let payload_owned = payload.to_vec();
let frame_type = header.frame_type;
let stream_id = header.stream_id;
let is_stream_bound = stream_id != 0;
let is_headers_or_continuation = frame_type == h2::frame::FRAME_TYPE_HEADERS
|| frame_type == h2::frame::FRAME_TYPE_CONTINUATION;
if is_headers_or_continuation {
pending_headers
.entry(stream_id)
.or_default()
.extend_from_slice(&frame_bytes);
match decoder.feed_frame(&header, &payload_owned) {
Ok(Some(decoded)) => {
let sid = decoded.stream_id;
let host_norm = decoded.authority.clone();
let provenance = if decoded.via_huffman {
h2::AuthorityProvenance::Huffman
} else if decoded.via_dynamic_table {
h2::AuthorityProvenance::DynamicIndexed
} else {
h2::AuthorityProvenance::StaticLiteral
};
let (allow_r, miss_r) = h2_reason_codes_for(provenance);
let allow = match host_norm.as_deref() {
Some(h) if !h.is_empty() => {
cellos_core::hostname_allowlist::matches_allowlist(
h,
&cfg.hostname_allowlist,
)
}
_ => false,
};
let pending = pending_headers.remove(&sid).unwrap_or_default();
if host_norm.is_none() {
denied.fetch_add(1, Ordering::SeqCst);
denied_streams.insert(sid);
let rst = build_rst_stream_refused(sid);
let _ = client_wr.write_all(&rst).await;
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
"",
reason_code::H2_AUTHORITY_MISSING,
None,
Some(sid),
);
} else if allow {
allowed.fetch_add(1, Ordering::SeqCst);
if let Err(e) = upstream.commit().await {
tracing::warn!(
target: "cellos.supervisor.sni_proxy",
error = %e,
upstream = %cfg.upstream_addr,
"h2c upstream connect failed (deferred)"
);
upstream_failures.fetch_add(1, Ordering::SeqCst);
denied.fetch_add(1, Ordering::SeqCst);
let _ = client_wr.write_all(H2_GOAWAY_PROTOCOL_ERROR).await;
let _ = client_wr.shutdown().await;
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
"",
reason_code::UNKNOWN_PROTOCOL,
None,
None,
);
break 'outer;
}
if let Err(e) = upstream.write(&pending).await {
tracing::debug!(
target: "cellos.supervisor.sni_proxy",
error = %e,
"h2c forward(allowed HEADERS block) failed"
);
break 'outer;
}
emit_decision(
&emitter,
&cfg,
&event_spec,
"allow",
host_norm.as_deref().unwrap_or(""),
allow_r,
None,
Some(sid),
);
} else {
denied.fetch_add(1, Ordering::SeqCst);
denied_streams.insert(sid);
let rst = build_rst_stream_refused(sid);
let _ = client_wr.write_all(&rst).await;
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
host_norm.as_deref().unwrap_or(""),
miss_r,
None,
Some(sid),
);
}
}
Ok(None) => {
}
Err(_e) => {
denied.fetch_add(1, Ordering::SeqCst);
denied_streams.insert(stream_id);
pending_headers.remove(&stream_id);
let rst = build_rst_stream_refused(stream_id);
let _ = client_wr.write_all(&rst).await;
emit_decision(
&emitter,
&cfg,
&event_spec,
"deny",
"",
reason_code::H2_UNPARSEABLE_HEADERS,
None,
Some(stream_id),
);
}
}
} else if is_stream_bound && denied_streams.contains(&stream_id) {
} else {
if let Err(e) = upstream.write(&frame_bytes).await {
tracing::debug!(
target: "cellos.supervisor.sni_proxy",
error = %e,
"h2c forward(frame) failed"
);
break 'outer;
}
}
frame_buf.drain(..frame_total);
continue;
}
Ok(None) => break, Err(_e) => {
let _ = client_wr.write_all(H2_GOAWAY_PROTOCOL_ERROR).await;
upstream.shutdown().await;
break 'outer;
}
}
}
tokio::select! {
biased;
r = client_rd.read(&mut read_chunk) => match r {
Ok(0) => break 'outer,
Ok(n) => frame_buf.extend_from_slice(&read_chunk[..n]),
Err(e) => {
tracing::debug!(
target: "cellos.supervisor.sni_proxy",
error = %e,
"h2c client read error"
);
break 'outer;
}
},
r = read_upstream(&mut upstream, &mut up_read_chunk) => match r {
Ok(0) => break 'outer,
Ok(n) => {
if let Err(e) = client_wr.write_all(&up_read_chunk[..n]).await {
tracing::debug!(
target: "cellos.supervisor.sni_proxy",
error = %e,
"h2c upstream→client write failed"
);
break 'outer;
}
}
Err(e) => {
tracing::debug!(
target: "cellos.supervisor.sni_proxy",
error = %e,
"h2c upstream read error"
);
break 'outer;
}
},
}
}
upstream.shutdown().await;
let _ = client_wr.shutdown().await;
}
fn build_rst_stream_refused(stream_id: u32) -> [u8; 13] {
let mut out = [0u8; 13];
out[0] = 0x00;
out[1] = 0x00;
out[2] = 0x04; out[3] = 0x03; out[4] = 0x00; let sid = stream_id & 0x7FFF_FFFF;
out[5..9].copy_from_slice(&sid.to_be_bytes());
out[9..13].copy_from_slice(&7u32.to_be_bytes()); out
}
#[derive(Debug, Clone, Copy)]
enum DenyResponse {
Drop,
Http403,
#[allow(dead_code)]
H2Goaway,
}
const H2_GOAWAY_PROTOCOL_ERROR: &[u8; 17] = &[
0x00, 0x00, 0x08, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, ];
async fn send_deny_response(stream: &mut TcpStream, mode: DenyResponse) {
match mode {
DenyResponse::Http403 => {
const RESP: &[u8] =
b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\nConnection: close\r\n\r\n";
if let Err(e) = stream.write_all(RESP).await {
tracing::debug!(
target: "cellos.supervisor.sni_proxy",
error = %e,
"writing 403 response failed"
);
}
let _ = stream.shutdown().await;
}
DenyResponse::H2Goaway => {
if let Err(e) = stream.write_all(H2_GOAWAY_PROTOCOL_ERROR).await {
tracing::debug!(
target: "cellos.supervisor.sni_proxy",
error = %e,
"writing h2 GOAWAY failed"
);
}
let _ = stream.shutdown().await;
}
DenyResponse::Drop => {
}
}
}
fn build_event_spec(cfg: &SniProxyConfig) -> ExecutionCellSpec {
use cellos_core::{AuthorityBundle, Correlation, Lifetime};
let correlation = cfg.correlation_id.as_ref().map(|c| Correlation {
correlation_id: Some(c.clone()),
..Default::default()
});
ExecutionCellSpec {
id: format!("sni-proxy/{}/{}", cfg.cell_id, cfg.run_id),
correlation,
ingress: None,
environment: None,
placement: None,
policy: None,
identity: None,
run: None,
authority: AuthorityBundle::default(),
lifetime: Lifetime { ttl_seconds: 0 },
export: None,
telemetry: None,
}
}
#[allow(clippy::too_many_arguments)]
fn emit_decision(
emitter: &Arc<dyn L7DecisionEmitter>,
cfg: &SniProxyConfig,
spec: &ExecutionCellSpec,
action: &str,
sni_host: &str,
reason_code_str: &str,
rule_ref: Option<&str>,
stream_id: Option<u32>,
) {
let decision_id = uuid::Uuid::new_v4().to_string();
let policy_digest = cfg.policy_digest.clone().unwrap_or_default();
let keyset_id = cfg.keyset_id.clone().unwrap_or_default();
let issuer_kid = cfg.issuer_kid.clone().unwrap_or_default();
let data = match cellos_core::observability_l7_egress_decision_data_v1(
spec,
cfg.cell_id.as_str(),
Some(cfg.run_id.as_str()),
decision_id.as_str(),
action,
if sni_host.is_empty() {
"(unknown)"
} else {
sni_host
},
policy_digest.as_str(),
keyset_id.as_str(),
issuer_kid.as_str(),
reason_code_str,
rule_ref,
stream_id,
) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
target: "cellos.supervisor.sni_proxy",
error = %e,
"build l7_egress_decision data failed"
);
return;
}
};
let observed_at = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true);
let event = CloudEventV1 {
specversion: "1.0".into(),
id: uuid::Uuid::new_v4().to_string(),
source: "cellos-sni-proxy".into(),
ty: "dev.cellos.events.cell.observability.v1.l7_egress_decision".into(),
datacontenttype: Some("application/json".into()),
data: Some(data),
time: Some(observed_at),
traceparent: None,
};
emitter.emit(event);
}
#[allow(dead_code)]
fn event_data_get<'a>(event: &'a CloudEventV1, key: &str) -> Option<&'a Value> {
event.data.as_ref()?.as_object()?.get(key)
}
#[allow(dead_code)]
fn empty_data_map() -> Map<String, Value> {
let mut m = Map::new();
m.insert("decisionId".into(), json!(uuid::Uuid::new_v4().to_string()));
m
}
#[cfg(test)]
pub(crate) mod test_helpers {
pub fn build_client_hello(snis: &[&str]) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&[0x03, 0x03]); body.extend_from_slice(&[0u8; 32]); body.push(0); body.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); body.extend_from_slice(&[0x01, 0x00]);
let mut ext_section = Vec::new();
if !snis.is_empty() {
let mut sn_body = Vec::new();
let mut inner = Vec::new();
for s in snis {
inner.push(0u8); inner.extend_from_slice(&(s.len() as u16).to_be_bytes());
inner.extend_from_slice(s.as_bytes());
}
sn_body.extend_from_slice(&(inner.len() as u16).to_be_bytes());
sn_body.extend_from_slice(&inner);
ext_section.extend_from_slice(&[0x00, 0x00]);
ext_section.extend_from_slice(&(sn_body.len() as u16).to_be_bytes());
ext_section.extend_from_slice(&sn_body);
}
body.extend_from_slice(&(ext_section.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_section);
let mut hs = Vec::new();
hs.push(1);
let body_len_bytes = (body.len() as u32).to_be_bytes();
hs.extend_from_slice(&body_len_bytes[1..]);
hs.extend_from_slice(&body);
let mut rec = Vec::new();
rec.push(22);
rec.extend_from_slice(&[0x03, 0x01]);
rec.extend_from_slice(&(hs.len() as u16).to_be_bytes());
rec.extend_from_slice(&hs);
rec
}
}
#[cfg(test)]
mod tests {
use super::test_helpers::build_client_hello;
use super::*;
use std::sync::Mutex;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
#[derive(Default)]
struct MemEmitter {
events: Mutex<Vec<CloudEventV1>>,
}
impl L7DecisionEmitter for MemEmitter {
fn emit(&self, event: CloudEventV1) {
self.events.lock().unwrap().push(event);
}
}
fn cfg_with(allowlist: &[&str], upstream: SocketAddr, peek_ms: u64) -> SniProxyConfig {
SniProxyConfig {
bind_addr: "127.0.0.1:0".parse().unwrap(),
upstream_addr: upstream,
hostname_allowlist: allowlist.iter().map(|s| s.to_string()).collect(),
cdn_providers: vec![],
cell_id: "test-cell".into(),
run_id: "test-run".into(),
policy_digest: Some("digest-test".into()),
keyset_id: Some("keyset-test".into()),
issuer_kid: Some("kid-test".into()),
correlation_id: None,
upstream_resolver_id: "sni-proxy-test".into(),
peek_timeout: Duration::from_millis(peek_ms),
}
}
async fn spawn_echo_upstream() -> (SocketAddr, tokio::task::JoinHandle<Vec<u8>>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let h = tokio::spawn(async move {
let (mut s, _) = listener.accept().await.unwrap();
let mut buf = Vec::new();
let mut tmp = [0u8; 4096];
for _ in 0..32 {
match s.read(&mut tmp).await {
Ok(0) => break,
Ok(n) => buf.extend_from_slice(&tmp[..n]),
Err(_) => break,
}
}
buf
});
(addr, h)
}
async fn spawn_proxy(
cfg: SniProxyConfig,
emitter: Arc<MemEmitter>,
) -> (
SocketAddr,
Arc<AtomicBool>,
tokio::task::JoinHandle<std::io::Result<ProxyStats>>,
) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown2 = shutdown.clone();
let h = tokio::spawn(async move {
run_one_shot(
&cfg,
listener,
emitter as Arc<dyn L7DecisionEmitter>,
shutdown2,
)
.await
});
(addr, shutdown, h)
}
fn poke_shutdown(addr: SocketAddr) {
let _ = std::net::TcpStream::connect_timeout(&addr, Duration::from_millis(200));
}
#[tokio::test]
async fn proxy_allows_tls_with_matching_sni() {
let (upstream, upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let ch = build_client_hello(&["api.example.com"]);
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&ch).await.unwrap();
s.shutdown().await.ok();
let upstream_bytes = tokio::time::timeout(Duration::from_secs(2), upstream_h)
.await
.expect("upstream task")
.expect("upstream join");
assert_eq!(
upstream_bytes, ch,
"upstream did not receive forwarded ClientHello bytes verbatim"
);
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1, "expected exactly one event");
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "allow");
assert_eq!(data["reasonCode"], "l7_sni_allowlist_match");
assert_eq!(data["sniHost"], "api.example.com");
}
#[tokio::test]
async fn proxy_denies_tls_with_unmatched_sni() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let ch = build_client_hello(&["evil.example.com"]);
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&ch).await.unwrap();
let mut sink = Vec::new();
let read_timeout =
tokio::time::timeout(Duration::from_millis(800), s.read_to_end(&mut sink)).await;
assert!(
read_timeout.is_ok(),
"TLS deny should close the stream promptly; got peek timeout"
);
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1);
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "deny");
assert_eq!(data["reasonCode"], "l7_sni_allowlist_miss");
assert_eq!(data["sniHost"], "evil.example.com");
}
#[tokio::test]
async fn proxy_denies_tls_without_sni() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let ch = build_client_hello(&[]); let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&ch).await.unwrap();
let mut sink = Vec::new();
let _ = tokio::time::timeout(Duration::from_millis(800), s.read_to_end(&mut sink)).await;
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1);
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "deny");
assert_eq!(data["reasonCode"], "l7_sni_missing");
}
#[tokio::test]
async fn proxy_allows_http_with_matching_host() {
let (upstream, upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let req = b"GET / HTTP/1.1\r\nHost: api.example.com\r\nConnection: close\r\n\r\n";
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(req).await.unwrap();
s.shutdown().await.ok();
let upstream_bytes = tokio::time::timeout(Duration::from_secs(2), upstream_h)
.await
.expect("upstream task")
.expect("upstream join");
assert_eq!(upstream_bytes, req.to_vec());
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1);
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "allow");
assert_eq!(data["reasonCode"], "l7_http_host_allowlist_match");
assert_eq!(data["sniHost"], "api.example.com");
}
#[tokio::test]
async fn proxy_denies_http_with_unmatched_host_and_returns_403() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let req = b"GET / HTTP/1.1\r\nHost: evil.example.com\r\nConnection: close\r\n\r\n";
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(req).await.unwrap();
let mut response = Vec::new();
let read_result =
tokio::time::timeout(Duration::from_secs(2), s.read_to_end(&mut response))
.await
.expect("read deadline");
match read_result {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => {}
Err(e) => panic!("unexpected read error: {e}"),
}
let resp_str = String::from_utf8_lossy(&response);
assert!(
resp_str.starts_with("HTTP/1.1 403 Forbidden\r\n"),
"expected 403 response, got: {resp_str:?}"
);
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1);
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "deny");
assert_eq!(data["reasonCode"], "l7_http_host_allowlist_miss");
assert_eq!(data["sniHost"], "evil.example.com");
}
#[tokio::test]
async fn proxy_emits_one_event_per_connection() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
for ch in [build_client_hello(&["api.example.com"])] {
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&ch).await.unwrap();
s.shutdown().await.ok();
tokio::time::sleep(Duration::from_millis(80)).await;
}
for ch in [build_client_hello(&["evil.example.com"])] {
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&ch).await.unwrap();
let mut sink = Vec::new();
let _ =
tokio::time::timeout(Duration::from_millis(400), s.read_to_end(&mut sink)).await;
}
let req = b"GET / HTTP/1.1\r\nHost: blocked.example.com\r\n\r\n";
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(req).await.unwrap();
let mut sink = Vec::new();
let _ = tokio::time::timeout(Duration::from_millis(400), s.read_to_end(&mut sink)).await;
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(
evs.len(),
3,
"expected exactly one event per connection, got {}: {:#?}",
evs.len(),
evs.iter().map(|e| &e.data).collect::<Vec<_>>()
);
}
#[tokio::test]
async fn proxy_returns_peek_timeout_when_client_silent() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 50); let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let s = TcpStream::connect(listen).await.unwrap();
tokio::time::sleep(Duration::from_millis(250)).await;
drop(s);
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert!(!evs.is_empty(), "expected at least one peek_timeout event");
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "deny");
assert_eq!(data["reasonCode"], "l7_peek_timeout");
}
#[test]
fn guess_protocol_classifies_correctly() {
assert_eq!(guess_protocol(&[22, 0x03, 0x03]), ProtocolGuess::Tls);
assert_eq!(guess_protocol(b"GET / HTTP/1.1\r\n"), ProtocolGuess::Http1);
assert_eq!(guess_protocol(b"POST / HTTP/1.1"), ProtocolGuess::Http1);
assert_eq!(guess_protocol(b"\x00\x00\x00\x00"), ProtocolGuess::Unknown);
assert_eq!(guess_protocol(b""), ProtocolGuess::Unknown);
assert_eq!(guess_protocol(h2::HTTP2_PREFACE), ProtocolGuess::H2c);
}
fn build_h2c_stream(authority: &str) -> Vec<u8> {
let mut out = h2::HTTP2_PREFACE.to_vec();
let block = h2::test_helpers::hpack_literal_indexed_name(1, authority);
out.extend_from_slice(&h2::test_helpers::settings_then_headers(&block));
out
}
fn build_h2c_stream_no_authority() -> Vec<u8> {
let mut out = h2::HTTP2_PREFACE.to_vec();
let block = h2::test_helpers::hpack_indexed(2);
out.extend_from_slice(&h2::test_helpers::settings_then_headers(&block));
out
}
fn build_h2c_stream_continuation_reassemblable(authority: &str) -> Vec<u8> {
let mut out = h2::HTTP2_PREFACE.to_vec();
let block = h2::test_helpers::hpack_literal_indexed_name(1, authority);
let mid = block.len() / 2;
let (a, b) = block.split_at(mid);
let mut sequence = h2::test_helpers::empty_settings_frame();
sequence.extend_from_slice(&h2::test_helpers::continuation_fragmented_headers(a));
sequence.extend_from_slice(&h2::test_helpers::continuation_frame(b, true));
out.extend_from_slice(&sequence);
out
}
fn build_h2c_stream_unparseable() -> Vec<u8> {
let mut out = h2::HTTP2_PREFACE.to_vec();
let block = h2::test_helpers::hpack_indexed(200);
out.extend_from_slice(&h2::test_helpers::settings_then_headers(&block));
out
}
fn build_h2c_stream_huffman(authority: &str) -> Vec<u8> {
let mut out = h2::HTTP2_PREFACE.to_vec();
let block = h2::test_helpers::hpack_literal_indexed_name_huffman(1, authority);
out.extend_from_slice(&h2::test_helpers::settings_then_headers(&block));
out
}
fn build_h2c_stream_dynamic_indexed_name(authority: &str) -> Vec<u8> {
let mut out = h2::HTTP2_PREFACE.to_vec();
let mut block: Vec<u8> = Vec::new();
block.extend_from_slice(&h2::test_helpers::hpack_literal_indexed_name(1, ""));
block.extend_from_slice(&h2::test_helpers::hpack_literal_indexed_name(6, authority));
block.push(0x40 | 0x3F); block.push(0x00); h2::test_helpers::encode_literal_string(&mut block, authority);
out.extend_from_slice(&h2::test_helpers::settings_then_headers(&block));
out
}
#[tokio::test]
async fn proxy_allows_h2c_with_matching_authority() {
let (upstream, upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let stream_bytes = build_h2c_stream("api.example.com");
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
s.shutdown().await.ok();
let upstream_bytes = tokio::time::timeout(Duration::from_secs(2), upstream_h)
.await
.expect("upstream task")
.expect("upstream join");
assert_eq!(
upstream_bytes, stream_bytes,
"upstream did not receive the forwarded h2c bytes verbatim"
);
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1, "expected exactly one event");
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "allow");
assert_eq!(data["reasonCode"], "l7_h2_authority_allowlist_match");
assert_eq!(data["sniHost"], "api.example.com");
}
#[tokio::test]
async fn proxy_denies_h2c_with_unmatched_authority() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let stream_bytes = build_h2c_stream("evil.example.com");
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
let mut sink = Vec::new();
let _ = tokio::time::timeout(Duration::from_millis(800), s.read_to_end(&mut sink)).await;
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1);
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "deny");
assert_eq!(data["reasonCode"], "l7_h2_authority_allowlist_miss");
assert_eq!(data["sniHost"], "evil.example.com");
}
#[tokio::test]
async fn proxy_emits_event_for_h2_unparseable_headers() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let stream_bytes = build_h2c_stream_unparseable();
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
let mut sink = Vec::new();
let _ = tokio::time::timeout(Duration::from_millis(800), s.read_to_end(&mut sink)).await;
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1);
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "deny");
assert_eq!(data["reasonCode"], "l7_h2_unparseable_headers");
}
#[tokio::test]
async fn proxy_allows_h2c_with_continuation_fragmented_authority() {
let (upstream, upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let stream_bytes = build_h2c_stream_continuation_reassemblable("api.example.com");
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
s.shutdown().await.ok();
let upstream_bytes = tokio::time::timeout(Duration::from_secs(2), upstream_h)
.await
.expect("upstream task")
.expect("upstream join");
assert_eq!(upstream_bytes, stream_bytes);
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1);
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "allow");
assert_eq!(data["reasonCode"], "l7_h2_authority_allowlist_match");
assert_eq!(data["sniHost"], "api.example.com");
}
#[tokio::test]
async fn proxy_allows_h2c_with_huffman_encoded_authority() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let stream_bytes = build_h2c_stream_huffman("api.example.com");
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
s.shutdown().await.ok();
tokio::time::sleep(Duration::from_millis(200)).await;
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1);
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "allow");
assert_eq!(
data["reasonCode"], "l7_h2_authority_allowlist_match_huffman",
"Huffman provenance must surface as the differentiated reason code"
);
assert_eq!(data["sniHost"], "api.example.com");
}
#[tokio::test]
async fn proxy_allows_h2c_with_dynamic_table_indexed_authority() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let stream_bytes = build_h2c_stream_dynamic_indexed_name("api.example.com");
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
s.shutdown().await.ok();
tokio::time::sleep(Duration::from_millis(200)).await;
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1, "expected one event, got {evs:#?}");
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "allow");
assert_eq!(
data["reasonCode"], "l7_h2_authority_allowlist_match_dynamic_indexed",
"dynamic-table-name provenance must surface as the differentiated reason code"
);
}
#[tokio::test]
async fn proxy_denies_h2c_with_authority_extracted_from_huffman_literal_not_in_allowlist() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let stream_bytes = build_h2c_stream_huffman("evil.example.com");
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
let mut sink = Vec::new();
let _ = tokio::time::timeout(Duration::from_millis(800), s.read_to_end(&mut sink)).await;
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1);
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "deny");
assert_eq!(
data["reasonCode"], "l7_h2_authority_allowlist_miss_huffman",
"Huffman provenance must surface in the deny reason code"
);
assert_eq!(data["sniHost"], "evil.example.com");
}
#[tokio::test]
async fn proxy_emits_event_with_extracted_authority_when_huffman_used() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["my-service.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let stream_bytes = build_h2c_stream_huffman("my-service.example.com");
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
s.shutdown().await.ok();
tokio::time::sleep(Duration::from_millis(200)).await;
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1);
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["sniHost"], "my-service.example.com");
assert_eq!(data["action"], "allow");
assert_eq!(
data["reasonCode"],
"l7_h2_authority_allowlist_match_huffman",
);
}
#[tokio::test]
async fn proxy_emits_event_for_h2_missing_authority() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let stream_bytes = build_h2c_stream_no_authority();
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
let mut sink = Vec::new();
let _ = tokio::time::timeout(Duration::from_millis(800), s.read_to_end(&mut sink)).await;
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 1);
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "deny");
assert_eq!(data["reasonCode"], "l7_h2_authority_missing");
}
#[tokio::test]
async fn proxy_h2c_deny_responds_with_rst_stream_refused() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let stream_bytes = build_h2c_stream("evil.example.com");
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
s.shutdown().await.ok();
let mut response = Vec::new();
let read_result =
tokio::time::timeout(Duration::from_secs(2), s.read_to_end(&mut response))
.await
.expect("read deadline");
match read_result {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => {}
Err(e) => panic!("unexpected read error: {e}"),
}
assert!(
response.len() >= 13,
"expected at least one full RST_STREAM frame, got {} bytes: {:02x?}",
response.len(),
response
);
assert_eq!(&response[0..3], &[0x00, 0x00, 0x04], "RST_STREAM length");
assert_eq!(response[3], 0x03, "RST_STREAM frame type");
assert_eq!(response[4], 0x00, "RST_STREAM flags");
assert_eq!(&response[5..9], &[0x00, 0x00, 0x00, 0x01], "stream id 1");
assert_eq!(
&response[9..13],
&[0x00, 0x00, 0x00, 0x07],
"RST_STREAM error code = REFUSED_STREAM"
);
{
let evs = emitter.events.lock().unwrap();
assert!(!evs.is_empty(), "expected at least one deny event");
let data = evs[0].data.as_ref().unwrap();
assert_eq!(data["action"], "deny");
assert_eq!(data["streamId"], 1);
}
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
}
#[allow(dead_code)]
fn _legacy_goaway_assertions(response: &[u8]) {
assert_eq!(&response[0..3], &[0x00, 0x00, 0x08], "GOAWAY length");
assert_eq!(response[3], 0x07, "GOAWAY frame type");
assert_eq!(response[4], 0x00, "GOAWAY flags");
assert_eq!(&response[5..9], &[0x00, 0x00, 0x00, 0x00], "stream id 0");
assert_eq!(
&response[13..17],
&[0x00, 0x00, 0x00, 0x01],
"GOAWAY error code = PROTOCOL_ERROR"
);
}
fn build_h2c_two_streams(authority_a: &str, authority_b: &str) -> Vec<u8> {
let mut out = h2::HTTP2_PREFACE.to_vec();
out.extend_from_slice(&h2::test_helpers::empty_settings_frame());
let block_a = h2::test_helpers::hpack_literal_indexed_name(1, authority_a);
out.extend_from_slice(&h2::test_helpers::headers_frame_on_stream(&block_a, 1));
let block_b = h2::test_helpers::hpack_literal_indexed_name(1, authority_b);
out.extend_from_slice(&h2::test_helpers::headers_frame_on_stream(&block_b, 3));
out
}
fn build_h2c_denied_stream_then_data(
authority_allowed: &str,
authority_denied: &str,
data_payload: &[u8],
) -> Vec<u8> {
let mut out = h2::HTTP2_PREFACE.to_vec();
out.extend_from_slice(&h2::test_helpers::empty_settings_frame());
let block_d = h2::test_helpers::hpack_literal_indexed_name(1, authority_denied);
out.extend_from_slice(&h2::test_helpers::headers_frame_on_stream(&block_d, 1));
out.extend_from_slice(&h2::test_helpers::data_frame_on_stream(data_payload, 1));
let block_a = h2::test_helpers::hpack_literal_indexed_name(1, authority_allowed);
out.extend_from_slice(&h2::test_helpers::headers_frame_on_stream(&block_a, 3));
out
}
#[tokio::test]
async fn proxy_allows_two_streams_with_matching_authority_on_same_connection() {
let (upstream, upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com", "api.other.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let stream_bytes = build_h2c_two_streams("api.example.com", "api.other.com");
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
s.shutdown().await.ok();
let upstream_bytes = tokio::time::timeout(Duration::from_secs(2), upstream_h)
.await
.expect("upstream task")
.expect("upstream join");
assert_eq!(upstream_bytes, stream_bytes);
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(
evs.len(),
2,
"expected one allow event per stream, got {evs:#?}"
);
let mut stream_ids: Vec<i64> = evs
.iter()
.map(|e| {
e.data
.as_ref()
.unwrap()
.get("streamId")
.and_then(|v| v.as_i64())
.expect("event must carry streamId")
})
.collect();
stream_ids.sort_unstable();
assert_eq!(stream_ids, vec![1, 3]);
for ev in evs.iter() {
let data = ev.data.as_ref().unwrap();
assert_eq!(data["action"], "allow");
}
}
#[tokio::test]
async fn proxy_denies_one_stream_keeps_connection_open_for_other_streams() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let stream_bytes = build_h2c_two_streams("evil.example.com", "api.example.com");
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
let mut response = vec![0u8; 13];
let read_n = tokio::time::timeout(Duration::from_secs(2), s.read_exact(&mut response))
.await
.expect("read deadline");
assert!(read_n.is_ok(), "expected to read 13 bytes (RST_STREAM)");
assert_eq!(&response[0..3], &[0x00, 0x00, 0x04], "RST_STREAM length");
assert_eq!(response[3], 0x03, "RST_STREAM type");
assert_eq!(&response[5..9], &[0x00, 0x00, 0x00, 0x01], "stream id 1");
assert_eq!(
&response[9..13],
&[0x00, 0x00, 0x00, 0x07],
"REFUSED_STREAM"
);
s.shutdown().await.ok();
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(evs.len(), 2, "expected one event per stream");
let mut by_stream: std::collections::HashMap<i64, &CloudEventV1> =
std::collections::HashMap::new();
for ev in evs.iter() {
let sid = ev
.data
.as_ref()
.unwrap()
.get("streamId")
.and_then(|v| v.as_i64())
.unwrap();
by_stream.insert(sid, ev);
}
assert_eq!(by_stream[&1].data.as_ref().unwrap()["action"], "deny");
assert_eq!(by_stream[&3].data.as_ref().unwrap()["action"], "allow");
}
#[tokio::test]
async fn proxy_emits_one_event_per_stream_decision() {
let (upstream, _upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let mut bytes = h2::HTTP2_PREFACE.to_vec();
bytes.extend_from_slice(&h2::test_helpers::empty_settings_frame());
bytes.extend_from_slice(&h2::test_helpers::headers_frame_on_stream(
&h2::test_helpers::hpack_literal_indexed_name(1, "api.example.com"),
1,
));
bytes.extend_from_slice(&h2::test_helpers::headers_frame_on_stream(
&h2::test_helpers::hpack_literal_indexed_name(1, "evil1.example.com"),
3,
));
bytes.extend_from_slice(&h2::test_helpers::headers_frame_on_stream(
&h2::test_helpers::hpack_literal_indexed_name(1, "evil2.example.com"),
5,
));
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&bytes).await.unwrap();
let mut sink = vec![0u8; 64];
let _ = tokio::time::timeout(Duration::from_millis(400), s.read(&mut sink)).await;
s.shutdown().await.ok();
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
let evs = emitter.events.lock().unwrap();
assert_eq!(
evs.len(),
3,
"expected exactly one event per stream, got {evs:#?}"
);
let mut per_action: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
for ev in evs.iter() {
let action = ev.data.as_ref().unwrap()["action"]
.as_str()
.unwrap()
.to_string();
*per_action.entry(action).or_default() += 1;
assert!(
ev.data
.as_ref()
.unwrap()
.get("streamId")
.and_then(|v| v.as_i64())
.is_some(),
"event missing streamId: {ev:?}"
);
}
assert_eq!(per_action.get("allow").copied(), Some(1));
assert_eq!(per_action.get("deny").copied(), Some(2));
}
#[tokio::test]
async fn proxy_drops_data_frames_on_denied_stream() {
let (upstream, upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let secret = b"do-not-exfiltrate";
let stream_bytes =
build_h2c_denied_stream_then_data("api.example.com", "evil.example.com", secret);
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&stream_bytes).await.unwrap();
s.shutdown().await.ok();
let upstream_bytes = tokio::time::timeout(Duration::from_secs(2), upstream_h)
.await
.expect("upstream task")
.expect("upstream join");
let upstream_str = String::from_utf8_lossy(&upstream_bytes);
assert!(
!upstream_str.contains("do-not-exfiltrate"),
"denied-stream DATA leaked to upstream: {upstream_str:?}"
);
assert!(!upstream_bytes.is_empty(), "upstream got no bytes");
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
}
#[tokio::test]
async fn proxy_forwards_settings_and_window_update_verbatim() {
let (upstream, upstream_h) = spawn_echo_upstream().await;
let emitter = Arc::new(MemEmitter::default());
let cfg = cfg_with(&["api.example.com"], upstream, 500);
let (listen, shutdown, h) = spawn_proxy(cfg, emitter.clone()).await;
let mut bytes = h2::HTTP2_PREFACE.to_vec();
bytes.extend_from_slice(&h2::test_helpers::empty_settings_frame());
bytes.extend_from_slice(&h2::test_helpers::headers_frame_on_stream(
&h2::test_helpers::hpack_literal_indexed_name(1, "api.example.com"),
1,
));
let mut wu = h2::test_helpers::frame_header(4, 0x08, 0x00, 0);
wu.extend_from_slice(&1024u32.to_be_bytes());
bytes.extend_from_slice(&wu);
let mut s = TcpStream::connect(listen).await.unwrap();
s.write_all(&bytes).await.unwrap();
s.shutdown().await.ok();
let upstream_bytes = tokio::time::timeout(Duration::from_secs(2), upstream_h)
.await
.expect("upstream task")
.expect("upstream join");
assert_eq!(upstream_bytes, bytes);
shutdown.store(true, Ordering::SeqCst);
poke_shutdown(listen);
let _ = tokio::time::timeout(Duration::from_secs(2), h).await;
}
}