use std::sync::Arc;
use cc_lb_plugin_api::{
FilterError, FilterOutput, FilterPlugin, PerCandidateReason, Principal, RequestContext,
SlotKey, UpstreamCandidate,
};
use cc_lb_plugin_wire::schema::{HookKind, WireVersion};
use cc_lb_plugin_wire::{
ArchivedFilterResponse, ClaimRef, FilterRequestRef, HeaderRef, PrincipalRef, QueryRef,
ShapeRequestRef, UpstreamCandidateRef, UpstreamRef,
};
use rkyv::rancor::Error as RkyvError;
use rkyv::util::AlignedVec;
use uuid::Uuid;
use crate::cache::call_filter_hook;
use crate::cell::{PluginCell, PluginSlot};
use crate::error::WasmtimeRuntimeError;
pub struct WasmtimeFilterPlugin {
slot_key: SlotKey,
cell: Arc<PluginCell>,
plugin_id: Uuid,
plugin_name: String,
runtime_config: Arc<crate::HotEngineConfig>,
}
impl WasmtimeFilterPlugin {
pub fn new(
slot: Arc<PluginSlot>,
slot_key: SlotKey,
plugin_id: Uuid,
plugin_name: impl Into<String>,
runtime_config: Arc<crate::HotEngineConfig>,
) -> Self {
let cell = slot.current.load_full();
Self {
slot_key,
cell,
plugin_id,
plugin_name: plugin_name.into(),
runtime_config,
}
}
}
impl FilterPlugin for WasmtimeFilterPlugin {
fn filter(
&self,
ctx: &RequestContext,
principal: &Principal,
candidates: &[UpstreamCandidate],
) -> Result<FilterOutput, FilterError> {
let in_bytes = host_to_wire_request(
ctx,
principal,
candidates,
self.runtime_config.cookie_redaction,
)
.map_err(|e| FilterError::Runtime {
reason: format!("rkyv encode request: {e}"),
})?;
let out_bytes = match self
.cell
.metadata
.hooks
.get(HookKind::Filter.as_str())
.and_then(|m| WireVersion::from_u8(m.wire_version))
{
Some(WireVersion::V1) => call_filter_hook(&self.cell, in_bytes.as_slice()),
None => unreachable!("filter slot has filter metadata"),
}
.map_err(runtime_error_to_filter)?;
let bound = self.runtime_config.wire_bounds.output_body_bytes;
if out_bytes.len() as u64 > bound {
return Err(FilterError::Runtime {
reason: format!(
"filter output {} bytes exceeds wire_bounds.output_body_bytes ({})",
out_bytes.len(),
bound
),
});
}
let mut aligned = AlignedVec::<16>::with_capacity(out_bytes.len());
aligned.extend_from_slice(&out_bytes);
let archived =
rkyv::access::<ArchivedFilterResponse, RkyvError>(&aligned).map_err(|e| {
FilterError::Runtime {
reason: format!("rkyv access response: {e}"),
}
})?;
wire_to_host_output(
archived,
self.runtime_config.wire_bounds.reason_bytes as usize,
)
}
fn plugin_id(&self) -> Uuid {
self.plugin_id
}
fn plugin_name(&self) -> &str {
&self.plugin_name
}
fn slot_key(&self) -> SlotKey {
self.slot_key.clone()
}
}
fn runtime_error_to_filter(err: WasmtimeRuntimeError) -> FilterError {
match err {
WasmtimeRuntimeError::GuestTrap { phase, source } => FilterError::Trap {
reason: format!("{phase}: {source}"),
},
other => FilterError::Runtime {
reason: other.to_string(),
},
}
}
fn host_to_wire_request(
ctx: &RequestContext,
principal: &Principal,
candidates: &[UpstreamCandidate],
cookie_redaction: bool,
) -> Result<AlignedVec<16>, RkyvError> {
let principal_kind_str = principal_kind_to_wire(principal);
let claim_bufs: Vec<(&str, Vec<u8>)> = principal
.claims
.iter()
.filter_map(|(k, v)| serde_json::to_vec(v).ok().map(|bytes| (k.as_str(), bytes)))
.collect();
let claim_refs: Vec<ClaimRef<'_>> = claim_bufs
.iter()
.map(|(k, v)| ClaimRef {
key: k,
value: v.as_slice(),
})
.collect();
let header_refs: Vec<HeaderRef<'_>> = ctx
.downstream_headers
.iter()
.filter(|(name, _)| !is_stripped_downstream_header(name.as_str(), cookie_redaction))
.map(|(name, value)| HeaderRef {
name: name.as_str(),
value: value.as_bytes(),
})
.collect();
let candidate_id_bufs: Vec<String> = candidates
.iter()
.map(|c| c.upstream_id.to_string())
.collect();
let candidate_refs: Vec<UpstreamCandidateRef<'_>> = candidates
.iter()
.zip(candidate_id_bufs.iter())
.map(|(c, id_str)| UpstreamCandidateRef {
upstream_id: id_str.as_str(),
name: c.name.as_str(),
kind: c.kind.as_str(),
observed_at_unix_secs: c.observed_at_unix_secs,
predicted_cache_read_tokens: c
.cache_score
.as_ref()
.map(|s| s.predicted_cache_read_tokens)
.unwrap_or(0),
})
.collect();
let query_ref = ctx.query.as_deref().map(|s| QueryRef { value: s });
let request = FilterRequestRef {
request_id: ctx.request_id.as_str(),
method: ctx.method.as_str(),
path: ctx.path.as_str(),
query: query_ref,
headers: &header_refs,
body: ctx.body_bytes.as_ref(),
principal: PrincipalRef {
id: principal.id.as_str(),
kind: principal_kind_str.as_str(),
claims: &claim_refs,
},
candidates: &candidate_refs,
};
rkyv::to_bytes::<RkyvError>(&request)
}
fn principal_kind_to_wire(principal: &Principal) -> String {
serde_json::to_value(&principal.kind)
.ok()
.and_then(|value| value.as_str().map(ToOwned::to_owned))
.unwrap_or_else(|| "unknown".to_owned())
}
fn is_stripped_downstream_header(name: &str, cookie_redaction: bool) -> bool {
let lower = name.to_ascii_lowercase();
if matches!(
lower.as_str(),
"authorization" | "x-api-key" | "host" | "proxy-authorization"
) {
return true;
}
cookie_redaction && lower.as_str() == "cookie"
}
fn is_stripped_shape_output_header(name: &str) -> bool {
let lower = name.to_ascii_lowercase();
matches!(
lower.as_str(),
"connection"
| "keep-alive"
| "proxy-authenticate"
| "proxy-authorization"
| "te"
| "trailer"
| "transfer-encoding"
| "upgrade"
| "host"
| "content-length"
| "authorization"
| "x-api-key"
) || lower.starts_with("x-anthropic-")
}
fn wire_to_host_output(
archived: &ArchivedFilterResponse,
reason_cap: usize,
) -> Result<FilterOutput, FilterError> {
let mut kept_upstream_ids = Vec::new();
let mut per_candidate_reasons = Vec::new();
let mut reasons = Vec::new();
for result in archived.results.iter() {
let upstream_id_str: &str = &result.upstream_id;
let decision_str: &str = &result.decision;
let reason_str: &str = &result.reason;
let upstream_id =
Uuid::parse_str(upstream_id_str).map_err(|source| FilterError::Runtime {
reason: format!(
"plugin returned invalid upstream_id `{upstream_id_str}`: {source}"
),
})?;
if decision_str == "accept" {
kept_upstream_ids.push(upstream_id);
} else {
per_candidate_reasons.push(per_candidate_reason_from_label(decision_str, reason_str));
}
if !reason_str.is_empty() {
let truncated = truncate_reason(reason_str, reason_cap);
reasons.push(format!("{upstream_id_str}: {truncated}"));
}
}
Ok(FilterOutput {
kept_upstream_ids,
reason: reasons.join("; "),
per_candidate_reasons,
})
}
fn truncate_reason(reason: &str, cap: usize) -> std::borrow::Cow<'_, str> {
if reason.len() <= cap {
return std::borrow::Cow::Borrowed(reason);
}
let mut end = cap;
while end > 0 && !reason.is_char_boundary(end) {
end -= 1;
}
std::borrow::Cow::Owned(reason[..end].to_owned())
}
fn per_candidate_reason_from_label(decision: &str, reason: &str) -> PerCandidateReason {
let label = if decision == "accept" {
reason
} else {
decision
};
let label = label.replace('-', "_").to_ascii_lowercase();
if label.contains("rate_limit") {
PerCandidateReason::RateLimited
} else if label.contains("quota") {
PerCandidateReason::InsufficientQuota
} else if label.contains("unhealthy") {
PerCandidateReason::Unhealthy
} else {
PerCandidateReason::RejectedByPlugin
}
}
pub struct WasmtimeUpstreamDialect {
cell: Arc<PluginCell>,
runtime_config: Arc<crate::HotEngineConfig>,
}
impl WasmtimeUpstreamDialect {
pub fn new(slot: Arc<PluginSlot>, runtime_config: Arc<crate::HotEngineConfig>) -> Self {
let cell = slot.current.load_full();
Self {
cell,
runtime_config,
}
}
}
impl cc_lb_plugin_api::UpstreamDialect for WasmtimeUpstreamDialect {
fn shape(
&self,
ctx: &RequestContext,
upstream: &cc_lb_plugin_api::Upstream,
principal: &Principal,
builder: &mut cc_lb_plugin_api::ShapedRequestBuilder,
) -> Result<cc_lb_plugin_api::ShapedRequest, cc_lb_plugin_api::DialectError> {
let in_bytes = host_to_wire_shape_request(
ctx,
upstream,
principal,
self.runtime_config.cookie_redaction,
)
.map_err(|e| cc_lb_plugin_api::DialectError::UnsupportedRequest {
reason: format!("rkyv encode ShapeRequest: {e}"),
})?;
let out_bytes = match self
.cell
.metadata
.hooks
.get(HookKind::Shape.as_str())
.and_then(|m| WireVersion::from_u8(m.wire_version))
{
Some(WireVersion::V1) => crate::cache::call_shape_hook(&self.cell, in_bytes.as_slice()),
None => unreachable!("shape slot has shape metadata"),
}
.map_err(runtime_error_to_dialect)?;
let out_bound = self.runtime_config.wire_bounds.output_body_bytes;
if out_bytes.len() as u64 > out_bound {
return Err(cc_lb_plugin_api::DialectError::UnsupportedRequest {
reason: format!(
"shape output {} bytes exceeds wire_bounds.output_body_bytes ({})",
out_bytes.len(),
out_bound
),
});
}
let mut aligned = AlignedVec::<16>::with_capacity(out_bytes.len());
aligned.extend_from_slice(&out_bytes);
let archived = rkyv::access::<cc_lb_plugin_wire::ArchivedShapeResponse, RkyvError>(
&aligned,
)
.map_err(|e| cc_lb_plugin_api::DialectError::UnsupportedRequest {
reason: format!("rkyv access ShapeResponse: {e}"),
})?;
wire_to_host_shaped_request(
builder,
archived,
upstream,
self.runtime_config.shape_origin_policy,
&self.runtime_config.wire_bounds,
)
}
}
fn host_upstream_to_wire(upstream: &cc_lb_plugin_api::Upstream) -> cc_lb_plugin_wire::Upstream {
match upstream {
cc_lb_plugin_api::Upstream::AnthropicDirect { base_url } => {
cc_lb_plugin_wire::Upstream::AnthropicDirect {
base_url: base_url.as_ref().map(|u| u.to_string().into_boxed_str()),
}
}
}
}
fn runtime_error_to_dialect(err: WasmtimeRuntimeError) -> cc_lb_plugin_api::DialectError {
match err {
WasmtimeRuntimeError::GuestTrap { phase, source } => {
cc_lb_plugin_api::DialectError::UnsupportedRequest {
reason: format!("{phase}: {source}"),
}
}
other => cc_lb_plugin_api::DialectError::UnsupportedRequest {
reason: other.to_string(),
},
}
}
fn host_to_wire_shape_request(
ctx: &RequestContext,
upstream: &cc_lb_plugin_api::Upstream,
principal: &Principal,
cookie_redaction: bool,
) -> Result<AlignedVec<16>, RkyvError> {
let principal_kind_str = principal_kind_to_wire(principal);
let claim_bufs: Vec<(&str, Vec<u8>)> = principal
.claims
.iter()
.filter_map(|(k, v)| serde_json::to_vec(v).ok().map(|bytes| (k.as_str(), bytes)))
.collect();
let claim_refs: Vec<ClaimRef<'_>> = claim_bufs
.iter()
.map(|(k, v)| ClaimRef {
key: k,
value: v.as_slice(),
})
.collect();
let header_refs: Vec<HeaderRef<'_>> = ctx
.downstream_headers
.iter()
.filter(|(name, _)| !is_stripped_downstream_header(name.as_str(), cookie_redaction))
.map(|(name, value)| HeaderRef {
name: name.as_str(),
value: value.as_bytes(),
})
.collect();
let base_url_str = match upstream {
cc_lb_plugin_api::Upstream::AnthropicDirect { base_url } => {
base_url.as_ref().map(|u| u.to_string())
}
};
let upstream_ref = UpstreamRef::AnthropicDirect {
base_url: base_url_str.as_deref().map(|s| QueryRef { value: s }),
};
let query_ref = ctx.query.as_deref().map(|s| QueryRef { value: s });
let request = ShapeRequestRef {
request_id: ctx.request_id.as_str(),
method: ctx.method.as_str(),
path: ctx.path.as_str(),
query: query_ref,
headers: &header_refs,
body: ctx.body_bytes.as_ref(),
principal: PrincipalRef {
id: principal.id.as_str(),
kind: principal_kind_str.as_str(),
claims: &claim_refs,
},
upstream: upstream_ref,
};
rkyv::to_bytes::<RkyvError>(&request)
}
fn upstream_base_url(upstream: &cc_lb_plugin_api::Upstream) -> Option<url::Url> {
match upstream {
cc_lb_plugin_api::Upstream::AnthropicDirect { base_url } => base_url.clone(),
}
}
fn wire_to_host_shaped_request(
builder: &mut cc_lb_plugin_api::ShapedRequestBuilder,
archived: &cc_lb_plugin_wire::ArchivedShapeResponse,
upstream: &cc_lb_plugin_api::Upstream,
origin_policy: crate::policy::ShapeOriginPolicy,
wire_bounds: &crate::policy::PluginWireBounds,
) -> Result<cc_lb_plugin_api::ShapedRequest, cc_lb_plugin_api::DialectError> {
if archived.headers.len() as u32 > wire_bounds.max_headers {
return Err(cc_lb_plugin_api::DialectError::UnsupportedRequest {
reason: format!(
"shape plugin returned {} headers, exceeds wire_bounds.max_headers ({})",
archived.headers.len(),
wire_bounds.max_headers,
),
});
}
let url_str: &str = &archived.url;
let url = url::Url::parse(url_str)?;
if matches!(
origin_policy,
crate::policy::ShapeOriginPolicy::SelectedUpstreamOrigin
) && let Some(expected) = upstream_base_url(upstream)
{
let expected_origin = expected.origin();
let actual_origin = url.origin();
if expected_origin != actual_origin {
return Err(cc_lb_plugin_api::DialectError::UnsupportedRequest {
reason: format!(
"shape plugin returned URL origin `{}` but selected upstream requires `{}`",
actual_origin.ascii_serialization(),
expected_origin.ascii_serialization(),
),
});
}
}
let method_str: &str = &archived.method;
let method = http::Method::from_bytes(method_str.as_bytes()).map_err(|e| {
cc_lb_plugin_api::DialectError::UnsupportedRequest {
reason: format!("plugin returned invalid method `{method_str}`: {e}"),
}
})?;
let mut headers = http::HeaderMap::new();
for h in archived.headers.iter() {
let h_name: &str = &h.name;
let h_value: &[u8] = &h.value;
if is_stripped_shape_output_header(h_name) {
tracing::debug!(header = %h_name, "dropping shape-plugin output header per hop-by-hop/signer contract");
continue;
}
if h_value.len() as u32 > wire_bounds.max_header_value_bytes {
return Err(cc_lb_plugin_api::DialectError::UnsupportedRequest {
reason: format!(
"shape plugin header `{h_name}` value {} bytes exceeds wire_bounds.max_header_value_bytes ({})",
h_value.len(),
wire_bounds.max_header_value_bytes,
),
});
}
let name = http::HeaderName::from_bytes(h_name.as_bytes()).map_err(|e| {
cc_lb_plugin_api::DialectError::UnsupportedRequest {
reason: format!("plugin returned invalid header name `{h_name}`: {e}"),
}
})?;
let value = http::HeaderValue::from_bytes(h_value).map_err(|e| {
cc_lb_plugin_api::DialectError::UnsupportedRequest {
reason: format!("plugin returned invalid header value for `{h_name}`: {e}"),
}
})?;
headers.append(name, value);
}
let body: &[u8] = &archived.body;
Ok(builder.shaped_request(url, method, headers, bytes::Bytes::copy_from_slice(body)))
}
pub struct WasmtimeObservabilityHookPlugin {
cell: Arc<PluginCell>,
}
impl WasmtimeObservabilityHookPlugin {
pub fn new(slot: Arc<PluginSlot>, runtime_config: Arc<crate::HotEngineConfig>) -> Self {
let _ = runtime_config;
let cell = slot.current.load_full();
Self { cell }
}
}
impl cc_lb_plugin_api::ObservabilityHook for WasmtimeObservabilityHookPlugin {
fn observe(
&self,
event: cc_lb_plugin_api::ObserveEvent,
) -> Result<(), cc_lb_plugin_api::ObservabilityError> {
let wire = host_observe_event_to_wire(event);
let in_bytes = rkyv::to_bytes::<RkyvError>(&wire).map_err(|e| {
cc_lb_plugin_api::ObservabilityError::Dropped {
reason: format!("rkyv encode ObserveEvent: {e}"),
}
})?;
match self
.cell
.metadata
.hooks
.get(HookKind::Observe.as_str())
.and_then(|m| WireVersion::from_u8(m.wire_version))
{
Some(WireVersion::V1) => {
crate::cache::call_observe_hook(&self.cell, in_bytes.as_slice())
}
None => unreachable!("observe slot has observe metadata"),
}
.map_err(|e| cc_lb_plugin_api::ObservabilityError::Dropped {
reason: e.to_string(),
})?;
Ok(())
}
}
fn host_observe_event_to_wire(
event: cc_lb_plugin_api::ObserveEvent,
) -> cc_lb_plugin_wire::ObserveEvent {
use cc_lb_plugin_api::ObserveEvent as Host;
use cc_lb_plugin_wire::ObserveEvent as Wire;
match event {
Host::RequestStarted {
request_id,
downstream_user_agent,
} => Wire::RequestStarted {
request_id: request_id.into_boxed_str(),
downstream_user_agent: downstream_user_agent.map(String::into_boxed_str),
},
Host::AuthnComplete { principal_id, kind } => Wire::AuthnComplete {
principal_id: principal_id.into_boxed_str(),
principal_kind: serde_json::to_value(&kind)
.ok()
.and_then(|v| v.as_str().map(str::to_owned))
.unwrap_or_else(|| "unknown".to_owned())
.into_boxed_str(),
},
Host::UpstreamChosen { upstream } => Wire::UpstreamChosen {
upstream: host_upstream_to_wire(&upstream),
},
Host::Chunk {
batch_index,
event_count,
total_bytes,
} => Wire::Chunk {
batch_index,
event_count: event_count as u64,
total_bytes: total_bytes as u64,
},
Host::RequestFinished {
status,
input_tokens,
output_tokens,
cache_creation_input_tokens,
cache_read_input_tokens,
duration_ms,
} => Wire::RequestFinished {
status: status.as_u16(),
input_tokens,
output_tokens,
cache_creation_input_tokens,
cache_read_input_tokens,
duration_ms,
},
Host::Error {
code,
message,
source,
} => Wire::Error {
code: code.into_boxed_str(),
message: message.into_boxed_str(),
source: source.into_boxed_str(),
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use cc_lb_plugin_api::PrincipalKind;
use cc_lb_plugin_wire::FilterResponse as WireFilterResponse;
use cc_lb_plugin_wire::PerCandidateReason as WirePerCandidateReason;
fn fixture_principal() -> Principal {
let mut claims = serde_json::Map::new();
claims.insert("scope".to_owned(), serde_json::Value::from("inference"));
Principal {
id: "tenant-a".to_owned(),
kind: PrincipalKind::ApiKey,
claims,
}
}
fn fixture_request() -> RequestContext {
let mut headers = http::HeaderMap::new();
headers.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/json"),
);
headers.insert(
http::header::AUTHORIZATION,
http::HeaderValue::from_static("Bearer secret"),
);
RequestContext {
request_id: "req-123".to_owned(),
downstream_headers: headers,
method: http::Method::POST,
path: "/v1/messages".to_owned(),
query: None,
body_bytes: bytes::Bytes::from_static(b"{\"msg\":\"hi\"}"),
cache_breakpoints: Vec::new(),
canonical_model_id: "claude-fixture".to_owned(),
}
}
#[test]
fn host_to_wire_strips_auth_headers() {
let principal = fixture_principal();
let ctx = fixture_request();
let bytes = host_to_wire_request(&ctx, &principal, &[], false).expect("encode");
let archived = rkyv::access::<cc_lb_plugin_wire::ArchivedFilterRequest, RkyvError>(&bytes)
.expect("archived");
let request_id: &str = &archived.request_id;
let method: &str = &archived.method;
let path: &str = &archived.path;
let principal_id: &str = &archived.principal.id;
let principal_kind: &str = &archived.principal.kind;
assert_eq!(request_id, "req-123");
assert_eq!(method, "POST");
assert_eq!(path, "/v1/messages");
assert_eq!(
archived.headers.len(),
1,
"authorization must be filtered out"
);
let header_name: &str = &archived.headers[0].name;
assert_eq!(header_name, "content-type");
assert_eq!(principal_id, "tenant-a");
assert_eq!(principal_kind, "api_key");
assert!(archived.principal.claims.iter().any(|entry| {
let key: &str = &entry.key;
key == "scope"
}));
}
#[test]
fn wire_to_host_splits_kept_and_rejected() {
let response = WireFilterResponse {
results: Box::new([
WirePerCandidateReason {
upstream_id: Box::from("11111111-1111-1111-1111-111111111111"),
decision: Box::from("accept"),
reason: Box::from("top-K"),
},
WirePerCandidateReason {
upstream_id: Box::from("22222222-2222-2222-2222-222222222222"),
decision: Box::from("rate-limit"),
reason: Box::from("burst exceeded"),
},
]),
};
let bytes = rkyv::to_bytes::<RkyvError>(&response).expect("encode");
let mut aligned = AlignedVec::<16>::with_capacity(bytes.len());
aligned.extend_from_slice(&bytes);
let archived =
rkyv::access::<ArchivedFilterResponse, RkyvError>(&aligned).expect("archived view");
let out = wire_to_host_output(archived, 256).expect("conversion must succeed");
assert_eq!(out.kept_upstream_ids.len(), 1);
assert_eq!(
out.kept_upstream_ids[0],
Uuid::parse_str("11111111-1111-1111-1111-111111111111").unwrap()
);
assert_eq!(out.per_candidate_reasons.len(), 1);
assert_eq!(
out.per_candidate_reasons[0],
PerCandidateReason::RateLimited
);
}
}