use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use http::{HeaderMap, Method, StatusCode};
use serde::{Deserialize, Serialize};
use url::Url;
use uuid::Uuid;
use crate::errors::{DialectError, SignerError};
use crate::traits::{Signer, UpstreamDialect};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Principal {
pub id: String,
pub kind: PrincipalKind,
pub claims: serde_json::Map<String, serde_json::Value>,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PrincipalKind {
ApiKey,
OAuthSubject,
InternalKey,
WorkloadIdentity,
SubscriptionBearer,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PluginSlot {
Router,
ObservabilityHook,
Shape,
}
impl PluginSlot {
pub fn as_str(self) -> &'static str {
match self {
Self::Router => "router",
Self::ObservabilityHook => "observability_hook",
Self::Shape => "shape",
}
}
pub fn parse(value: &str) -> Option<Self> {
match value {
"router" => Some(Self::Router),
"observability_hook" => Some(Self::ObservabilityHook),
"shape" => Some(Self::Shape),
_ => None,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "kind")]
pub enum Upstream {
AnthropicDirect {
#[serde(default, skip_serializing_if = "Option::is_none")]
base_url: Option<Url>,
},
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum UpstreamKind {
AnthropicApiKey,
AnthropicOauth,
}
impl UpstreamKind {
pub fn as_str(self) -> &'static str {
match self {
Self::AnthropicApiKey => "anthropic_api_key",
Self::AnthropicOauth => "anthropic_oauth",
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RateLimitKind {
Requests,
Tokens,
InputTokens,
OutputTokens,
}
impl RateLimitKind {
pub fn as_str(self) -> &'static str {
match self {
Self::Requests => "requests",
Self::Tokens => "tokens",
Self::InputTokens => "input_tokens",
Self::OutputTokens => "output_tokens",
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct RateLimitObservation {
pub kind: RateLimitKind,
pub window: String,
pub limit: Option<u64>,
pub remaining: Option<u64>,
pub reset: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SubscriptionQuotaDataState {
Fresh,
Stale,
Missing,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SubscriptionQuotaCandidateSnapshot {
pub window: String,
pub state: SubscriptionQuotaDataState,
pub source: Option<String>,
pub utilization: Option<f64>,
pub status: Option<String>,
pub resets_at_unix_secs: Option<u64>,
pub surpassed_threshold: Option<bool>,
pub representative_claim: Option<String>,
pub disabled_reason: Option<String>,
pub extra_usage_enabled: Option<bool>,
pub extra_usage_monthly_limit: Option<f64>,
pub extra_usage_used_credits: Option<f64>,
pub observed_at_unix_millis: Option<u64>,
pub max_staleness_secs: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
pub enum TtlClass {
#[default]
Ephemeral5m,
Ephemeral1h,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
pub enum BreakpointOrigin {
Explicit,
AutoCacheInferred,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[allow(dead_code)]
pub enum CacheBreakpointSource {
Tools,
System,
Message,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[allow(dead_code)]
pub struct CacheBreakpoint {
pub block_index: u32,
pub source: CacheBreakpointSource,
pub path: String,
pub message_index: Option<u32>,
pub prefix_hash: String,
pub prefix_token_count: u64,
pub requested_ttl: TtlClass,
pub origin: BreakpointOrigin,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[allow(dead_code)]
pub struct WarmCacheEntry {
pub prefix_hash: String,
pub expires_at_unix_secs: u64,
pub ttl_class: TtlClass,
pub last_observed_at_unix_secs: u64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[allow(dead_code)]
pub struct CacheScore {
pub predicted_cache_read_tokens: u32,
pub predicted_cache_creation_tokens_5m: u32,
pub predicted_cache_creation_tokens_1h: u32,
pub predicted_uncached_input_tokens: u32,
pub predicted_expires_at_unix_secs: Option<u64>,
pub matched_breakpoint_index: Option<u32>,
pub confidence: f32,
pub ambiguity_reason: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct UpstreamCandidate {
pub upstream_id: Uuid,
pub name: String,
pub kind: UpstreamKind,
pub observed_rate_limits: Vec<RateLimitObservation>,
#[serde(default)]
pub subscription_quotas: Vec<SubscriptionQuotaCandidateSnapshot>,
pub observed_at_unix_secs: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_score: Option<CacheScore>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub base_url: Option<String>,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CredentialStrategy {
ApiKey,
OAuth,
InternalForwarded,
}
#[derive(Clone, Debug)]
pub struct RequestContext {
pub request_id: String,
pub downstream_headers: HeaderMap,
pub method: Method,
pub path: String,
pub query: Option<String>,
pub body_bytes: Bytes,
pub cache_breakpoints: Vec<CacheBreakpoint>,
pub canonical_model_id: String,
}
#[derive(Clone, Debug)]
pub struct ShapedRequest {
url: Url,
method: Method,
headers: HeaderMap,
body: Bytes,
_seal: crate::private::Seal,
}
impl ShapedRequest {
pub fn url(&self) -> &Url {
&self.url
}
pub fn set_url(&mut self, url: Url) {
self.url = url;
}
pub fn method(&self) -> &Method {
&self.method
}
pub fn set_method(&mut self, method: Method) {
self.method = method;
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
}
pub fn body(&self) -> &Bytes {
&self.body
}
pub fn set_body(&mut self, body: Bytes) {
self.body = body;
}
}
#[derive(Debug)]
pub struct ShapedRequestBuilder {
_seal: crate::private::Seal,
}
impl ShapedRequestBuilder {
pub fn shaped_request(
&mut self,
url: Url,
method: Method,
headers: HeaderMap,
body: Bytes,
) -> ShapedRequest {
ShapedRequest {
url,
method,
headers,
body,
_seal: crate::private::Seal,
}
}
}
pub fn shape_request(
dialect: &dyn UpstreamDialect,
ctx: &RequestContext,
upstream: &Upstream,
principal: &Principal,
) -> Result<ShapedRequest, DialectError> {
let mut builder = ShapedRequestBuilder {
_seal: crate::private::Seal,
};
dialect.shape(ctx, upstream, principal, &mut builder)
}
#[derive(Clone, Debug)]
pub struct SignedRequest {
url: Url,
method: Method,
headers: HeaderMap,
body: Bytes,
_seal: crate::private::Seal,
}
impl SignedRequest {
pub fn from_shaped(shaped: ShapedRequest, _capability: &mut SigningCapability) -> Self {
Self {
url: shaped.url,
method: shaped.method,
headers: shaped.headers,
body: shaped.body,
_seal: crate::private::Seal,
}
}
pub fn url(&self) -> &Url {
&self.url
}
pub fn method(&self) -> &Method {
&self.method
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn body(&self) -> &Bytes {
&self.body
}
pub fn into_parts(self) -> (Url, Method, HeaderMap, Bytes) {
(self.url, self.method, self.headers, self.body)
}
}
#[derive(Debug)]
pub struct SigningCapability {
_seal: crate::private::Seal,
}
pub async fn sign_request(
signer: &dyn Signer,
shaped: ShapedRequest,
) -> Result<SignedRequest, SignerError> {
let mut capability = SigningCapability {
_seal: crate::private::Seal,
};
signer.sign(shaped, &mut capability).await
}
pub struct RouteDecision {
pub upstream_id: Option<Uuid>,
pub upstream: Upstream,
pub dialect: Arc<dyn UpstreamDialect>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PrincipalQuotas {
pub requests_per_window: u64,
pub input_tokens_per_window: u64,
pub output_tokens_per_window: u64,
pub window: Duration,
pub allowed_models: Vec<String>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ObserveEvent {
RequestStarted {
request_id: String,
downstream_user_agent: Option<String>,
},
AuthnComplete {
principal_id: String,
kind: PrincipalKind,
},
UpstreamChosen {
upstream: Upstream,
},
Chunk {
batch_index: u64,
event_count: usize,
total_bytes: usize,
},
RequestFinished {
status: StatusCode,
input_tokens: Option<u64>,
output_tokens: Option<u64>,
cache_creation_input_tokens: Option<u64>,
cache_read_input_tokens: Option<u64>,
duration_ms: u64,
},
Error {
code: String,
message: String,
source: String,
},
}
#[derive(Clone)]
pub enum RetryDecision {
Refresh {
new_signer: Arc<dyn Signer>,
},
Fail,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct PluginManifest {
pub name: String,
pub artifact: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub wire_version: Option<u8>,
pub config: serde_json::Value,
#[serde(default)]
pub metadata: BTreeMap<String, serde_json::Value>,
}
pub const MAX_ROUTING_TRACE_STAGES: usize = 100;
pub const MAX_STAGE_NAME_LEN: usize = 256;
pub const MAX_ERROR_MESSAGE_LEN: usize = 1024;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PassthroughCause {
HealthyUpstream,
NoAlternative,
PluginDecision,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PerCandidateReason {
RateLimited,
InsufficientQuota,
Unhealthy,
RejectedByPlugin,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum TerminalStrategy {
#[default]
FirstPick,
Random,
RoundRobin,
LeastConnections,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct StageDecision {
pub stage_name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub upstream_id: Option<Uuid>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
#[serde(default, skip_serializing_if = "is_zero_u64")]
pub duration_us: u64,
}
fn is_zero_u64(value: &u64) -> bool {
*value == 0
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct TerminalDecision {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub upstream_id: Option<Uuid>,
pub strategy: TerminalStrategy,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct RoutingTrace {
#[serde(default)]
pub stages: Vec<StageDecision>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub terminal_decision: Option<TerminalDecision>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum InternalErrorStage {
Authn,
#[default]
Router,
RouterFilter,
Shape,
Signer,
Relay,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum InternalErrorKind {
#[default]
PluginError,
InvalidOutput,
Trap,
ConfigError,
Timeout,
Unavailable,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct InternalError {
pub stage: InternalErrorStage,
pub kind: InternalErrorKind,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shaped_request_accessors_and_mutators_preserve_parts() {
let mut builder = ShapedRequestBuilder {
_seal: crate::private::Seal,
};
let mut headers = HeaderMap::new();
headers.insert("x-test", "one".parse().unwrap());
let mut shaped = builder.shaped_request(
"https://example.test/v1/messages".parse().unwrap(),
Method::POST,
headers,
Bytes::from_static(b"first"),
);
assert_eq!(shaped.url().as_str(), "https://example.test/v1/messages");
assert_eq!(shaped.method(), Method::POST);
assert_eq!(shaped.headers()["x-test"], "one");
assert_eq!(shaped.body(), &Bytes::from_static(b"first"));
shaped.set_url("https://example.test/v1/complete".parse().unwrap());
shaped.set_method(Method::PUT);
shaped
.headers_mut()
.insert("x-test", "two".parse().unwrap());
shaped.set_body(Bytes::from_static(b"second"));
assert_eq!(shaped.url().path(), "/v1/complete");
assert_eq!(shaped.method(), Method::PUT);
assert_eq!(shaped.headers()["x-test"], "two");
assert_eq!(shaped.body(), &Bytes::from_static(b"second"));
}
#[test]
fn signed_request_exposes_and_consumes_signed_parts() {
let mut builder = ShapedRequestBuilder {
_seal: crate::private::Seal,
};
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer token".parse().unwrap());
let shaped = builder.shaped_request(
"https://api.example.test/v1/messages".parse().unwrap(),
Method::POST,
headers,
Bytes::from_static(b"{}"),
);
let mut capability = SigningCapability {
_seal: crate::private::Seal,
};
let signed = SignedRequest::from_shaped(shaped, &mut capability);
assert_eq!(signed.url().host_str(), Some("api.example.test"));
assert_eq!(signed.method(), Method::POST);
assert_eq!(signed.headers()["authorization"], "Bearer token");
assert_eq!(signed.body(), &Bytes::from_static(b"{}"));
let (url, method, headers, body) = signed.into_parts();
assert_eq!(url.as_str(), "https://api.example.test/v1/messages");
assert_eq!(method, Method::POST);
assert_eq!(headers["authorization"], "Bearer token");
assert_eq!(body, Bytes::from_static(b"{}"));
}
#[test]
fn upstream_and_manifest_serde_round_trip() {
let upstreams = vec![Upstream::AnthropicDirect { base_url: None }];
for upstream in upstreams {
let json = serde_json::to_string(&upstream).unwrap();
let decoded: Upstream = serde_json::from_str(&json).unwrap();
assert_eq!(decoded, upstream);
}
let manifest: PluginManifest = serde_json::from_value(serde_json::json!({
"name": "authn",
"artifact": "plugin.wasm",
"config": {"enabled": true}
}))
.unwrap();
assert_eq!(manifest.name, "authn");
assert_eq!(manifest.wire_version, None);
assert!(manifest.metadata.is_empty());
let manifest: PluginManifest = serde_json::from_value(serde_json::json!({
"name": "cache-aware",
"artifact": "plugin.wasm",
"wire_version": 2,
"config": {}
}))
.unwrap();
assert_eq!(manifest.wire_version, Some(2));
}
#[test]
fn public_enums_cover_all_current_variants() {
let principal_kinds = [
PrincipalKind::ApiKey,
PrincipalKind::OAuthSubject,
PrincipalKind::InternalKey,
PrincipalKind::WorkloadIdentity,
PrincipalKind::SubscriptionBearer,
];
assert_eq!(principal_kinds.len(), 5);
let strategies = [
CredentialStrategy::ApiKey,
CredentialStrategy::OAuth,
CredentialStrategy::InternalForwarded,
];
assert_eq!(strategies.len(), 3);
}
#[test]
fn observe_event_variants_are_equatable() {
let events = vec![
ObserveEvent::RequestStarted {
request_id: "req".to_owned(),
downstream_user_agent: Some("ua".to_owned()),
},
ObserveEvent::AuthnComplete {
principal_id: "principal".to_owned(),
kind: PrincipalKind::InternalKey,
},
ObserveEvent::UpstreamChosen {
upstream: Upstream::AnthropicDirect { base_url: None },
},
ObserveEvent::Chunk {
batch_index: 1,
event_count: 2,
total_bytes: 3,
},
ObserveEvent::RequestFinished {
status: StatusCode::OK,
input_tokens: Some(4),
output_tokens: Some(5),
cache_creation_input_tokens: Some(6),
cache_read_input_tokens: Some(7),
duration_ms: 8,
},
ObserveEvent::Error {
code: "E".to_owned(),
message: "redacted".to_owned(),
source: "plugin".to_owned(),
},
];
assert_eq!(events, events.clone());
}
#[test]
fn cache_types_roundtrip() {
let ttl_class = TtlClass::Ephemeral1h;
let json = serde_json::to_string(&ttl_class).unwrap();
let decoded: TtlClass = serde_json::from_str(&json).unwrap();
assert_eq!(decoded, ttl_class);
let origin = BreakpointOrigin::AutoCacheInferred;
let json = serde_json::to_string(&origin).unwrap();
let decoded: BreakpointOrigin = serde_json::from_str(&json).unwrap();
assert_eq!(decoded, origin);
let source = CacheBreakpointSource::System;
let json = serde_json::to_string(&source).unwrap();
let decoded: CacheBreakpointSource = serde_json::from_str(&json).unwrap();
assert_eq!(decoded, source);
let breakpoint = CacheBreakpoint {
block_index: 0,
source: CacheBreakpointSource::Message,
path: "messages.0.content.1".to_owned(),
message_index: Some(0),
prefix_hash: "abc123".to_owned(),
prefix_token_count: 100,
requested_ttl: TtlClass::Ephemeral5m,
origin: BreakpointOrigin::Explicit,
};
let json = serde_json::to_string(&breakpoint).unwrap();
let decoded: CacheBreakpoint = serde_json::from_str(&json).unwrap();
assert_eq!(decoded, breakpoint);
let warm_entry = WarmCacheEntry {
prefix_hash: "def456".to_owned(),
expires_at_unix_secs: 1700000000,
ttl_class: TtlClass::Ephemeral1h,
last_observed_at_unix_secs: 1699999000,
};
let json = serde_json::to_string(&warm_entry).unwrap();
let decoded: WarmCacheEntry = serde_json::from_str(&json).unwrap();
assert_eq!(decoded, warm_entry);
let cache_score = CacheScore {
predicted_cache_read_tokens: 50,
predicted_cache_creation_tokens_5m: 100,
predicted_cache_creation_tokens_1h: 200,
predicted_uncached_input_tokens: 25,
predicted_expires_at_unix_secs: Some(1700000000),
matched_breakpoint_index: Some(0),
confidence: 0.95,
ambiguity_reason: None,
};
let json = serde_json::to_string(&cache_score).unwrap();
let decoded: CacheScore = serde_json::from_str(&json).unwrap();
assert_eq!(
decoded.predicted_cache_read_tokens,
cache_score.predicted_cache_read_tokens
);
assert_eq!(
decoded.predicted_cache_creation_tokens_5m,
cache_score.predicted_cache_creation_tokens_5m
);
assert_eq!(
decoded.predicted_cache_creation_tokens_1h,
cache_score.predicted_cache_creation_tokens_1h
);
assert_eq!(
decoded.predicted_uncached_input_tokens,
cache_score.predicted_uncached_input_tokens
);
assert_eq!(
decoded.predicted_expires_at_unix_secs,
cache_score.predicted_expires_at_unix_secs
);
assert_eq!(
decoded.matched_breakpoint_index,
cache_score.matched_breakpoint_index
);
assert!((decoded.confidence - cache_score.confidence).abs() < 0.0001);
assert_eq!(decoded.ambiguity_reason, cache_score.ambiguity_reason);
}
#[test]
fn upstream_candidate_cache_score_roundtrip() {
let candidate_no_cache = UpstreamCandidate {
upstream_id: Uuid::new_v4(),
name: "test-upstream".to_owned(),
kind: UpstreamKind::AnthropicApiKey,
observed_rate_limits: vec![],
subscription_quotas: vec![],
observed_at_unix_secs: 1700000000,
cache_score: None,
base_url: None,
};
let json = serde_json::to_string(&candidate_no_cache).unwrap();
let decoded: UpstreamCandidate = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.upstream_id, candidate_no_cache.upstream_id);
assert_eq!(decoded.name, candidate_no_cache.name);
assert_eq!(decoded.cache_score, None);
let cache_score = CacheScore {
predicted_cache_read_tokens: 50,
predicted_cache_creation_tokens_5m: 100,
predicted_cache_creation_tokens_1h: 200,
predicted_uncached_input_tokens: 25,
predicted_expires_at_unix_secs: Some(1700000000),
matched_breakpoint_index: Some(0),
confidence: 0.95,
ambiguity_reason: None,
};
let candidate_with_cache = UpstreamCandidate {
upstream_id: Uuid::new_v4(),
name: "test-upstream-cached".to_owned(),
kind: UpstreamKind::AnthropicApiKey,
observed_rate_limits: vec![],
subscription_quotas: vec![],
observed_at_unix_secs: 1700000000,
cache_score: Some(cache_score),
base_url: None,
};
let json = serde_json::to_string(&candidate_with_cache).unwrap();
let decoded: UpstreamCandidate = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.upstream_id, candidate_with_cache.upstream_id);
assert_eq!(decoded.name, candidate_with_cache.name);
assert!(decoded.cache_score.is_some());
assert_eq!(decoded.cache_score.unwrap().predicted_cache_read_tokens, 50);
}
#[test]
fn request_context_cache_fields_roundtrip() {
let ctx_empty = RequestContext {
request_id: "req-1".to_owned(),
downstream_headers: HeaderMap::new(),
method: Method::POST,
path: "/v1/messages".to_owned(),
query: None,
body_bytes: Bytes::new(),
cache_breakpoints: Vec::new(),
canonical_model_id: String::new(),
};
assert_eq!(ctx_empty.cache_breakpoints.len(), 0);
assert_eq!(ctx_empty.canonical_model_id, "");
let breakpoint = CacheBreakpoint {
block_index: 1,
source: CacheBreakpointSource::Message,
path: "messages.0.content.0".to_owned(),
message_index: Some(0),
prefix_hash: "hash123".to_owned(),
prefix_token_count: 150,
requested_ttl: TtlClass::Ephemeral1h,
origin: BreakpointOrigin::Explicit,
};
let ctx_populated = RequestContext {
request_id: "req-2".to_owned(),
downstream_headers: HeaderMap::new(),
method: Method::POST,
path: "/v1/messages".to_owned(),
query: Some("param=value".to_owned()),
body_bytes: Bytes::from_static(b"test"),
cache_breakpoints: vec![breakpoint],
canonical_model_id: "claude-sonnet-4-5-20250929".to_owned(),
};
assert_eq!(ctx_populated.cache_breakpoints.len(), 1);
assert_eq!(
ctx_populated.canonical_model_id,
"claude-sonnet-4-5-20250929"
);
}
#[test]
fn upstream_candidate_deserialize_without_cache_score_field() {
let json = r#"{
"upstream_id": "00000000-0000-0000-0000-000000000001",
"name": "legacy-upstream",
"kind": "anthropic_api_key",
"observed_rate_limits": [],
"subscription_quotas": [],
"observed_at_unix_secs": 1700000000
}"#;
let candidate: UpstreamCandidate = serde_json::from_str(json).unwrap();
assert!(candidate.cache_score.is_none());
assert_eq!(candidate.name, "legacy-upstream");
}
#[test]
fn request_context_cache_breakpoints_default_on_missing_fields() {
let ctx = RequestContext {
request_id: "test".to_owned(),
downstream_headers: HeaderMap::new(),
method: Method::GET,
path: "/test".to_owned(),
query: None,
body_bytes: Bytes::new(),
cache_breakpoints: Vec::new(),
canonical_model_id: String::new(),
};
assert!(ctx.cache_breakpoints.is_empty());
assert!(ctx.canonical_model_id.is_empty());
}
}