1use std::collections::BTreeMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use bytes::Bytes;
8use http::{HeaderMap, Method, StatusCode};
9use serde::{Deserialize, Serialize};
10use url::Url;
11use uuid::Uuid;
12
13use crate::errors::{DialectError, SignerError};
14use crate::traits::{Signer, UpstreamDialect};
15
16#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
18pub struct Principal {
19 pub id: String,
21 pub kind: PrincipalKind,
23 pub claims: serde_json::Map<String, serde_json::Value>,
25}
26
27#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
29#[serde(rename_all = "snake_case")]
30pub enum PrincipalKind {
31 ApiKey,
33 OAuthSubject,
35 InternalKey,
37 WorkloadIdentity,
39 SubscriptionBearer,
41}
42
43#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
45#[serde(rename_all = "snake_case")]
46pub enum PluginSlot {
47 Router,
49 ObservabilityHook,
51 Shape,
53}
54
55impl PluginSlot {
56 pub fn as_str(self) -> &'static str {
58 match self {
59 Self::Router => "router",
60 Self::ObservabilityHook => "observability_hook",
61 Self::Shape => "shape",
62 }
63 }
64
65 pub fn parse(value: &str) -> Option<Self> {
67 match value {
68 "router" => Some(Self::Router),
69 "observability_hook" => Some(Self::ObservabilityHook),
70 "shape" => Some(Self::Shape),
71 _ => None,
72 }
73 }
74}
75
76#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
78#[serde(rename_all = "snake_case", tag = "kind")]
79pub enum Upstream {
80 AnthropicDirect {
82 #[serde(default, skip_serializing_if = "Option::is_none")]
86 base_url: Option<Url>,
87 },
88}
89
90#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
92#[serde(rename_all = "snake_case")]
93pub enum UpstreamKind {
94 AnthropicApiKey,
96 AnthropicOauth,
98}
99
100impl UpstreamKind {
101 pub fn as_str(self) -> &'static str {
103 match self {
104 Self::AnthropicApiKey => "anthropic_api_key",
105 Self::AnthropicOauth => "anthropic_oauth",
106 }
107 }
108}
109
110#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
112#[serde(rename_all = "snake_case")]
113pub enum RateLimitKind {
114 Requests,
116 Tokens,
118 InputTokens,
120 OutputTokens,
122}
123
124impl RateLimitKind {
125 pub fn as_str(self) -> &'static str {
127 match self {
128 Self::Requests => "requests",
129 Self::Tokens => "tokens",
130 Self::InputTokens => "input_tokens",
131 Self::OutputTokens => "output_tokens",
132 }
133 }
134}
135
136#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
138pub struct RateLimitObservation {
139 pub kind: RateLimitKind,
141 pub window: String,
143 pub limit: Option<u64>,
145 pub remaining: Option<u64>,
147 pub reset: Option<String>,
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
153#[serde(rename_all = "snake_case")]
154pub enum SubscriptionQuotaDataState {
155 Fresh,
157 Stale,
159 Missing,
161}
162
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
165pub struct SubscriptionQuotaCandidateSnapshot {
166 pub window: String,
168 pub state: SubscriptionQuotaDataState,
170 pub source: Option<String>,
172 pub utilization: Option<f64>,
174 pub status: Option<String>,
176 pub resets_at_unix_secs: Option<u64>,
178 pub surpassed_threshold: Option<f64>,
182 pub representative_claim: Option<String>,
184 pub disabled_reason: Option<String>,
186 pub extra_usage_enabled: Option<bool>,
188 pub extra_usage_monthly_limit: Option<f64>,
190 pub extra_usage_used_credits: Option<f64>,
192 pub observed_at_unix_millis: Option<u64>,
194 pub max_staleness_secs: u64,
196 #[serde(default, skip_serializing_if = "Option::is_none")]
199 pub fallback_available: Option<bool>,
200 #[serde(default, skip_serializing_if = "Option::is_none")]
203 pub overage_in_use: Option<bool>,
204 #[serde(default, skip_serializing_if = "Option::is_none")]
207 pub overage_period_monthly_utilization: Option<f64>,
208 #[serde(default, skip_serializing_if = "Option::is_none")]
210 pub upgrade_paths: Option<Vec<String>>,
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
215#[serde(rename_all = "snake_case")]
216#[allow(dead_code)]
217pub enum TtlClass {
218 #[default]
220 Ephemeral5m,
221 Ephemeral1h,
223}
224
225#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
227#[serde(rename_all = "snake_case")]
228#[allow(dead_code)]
229pub enum BreakpointOrigin {
230 Explicit,
232 AutoCacheInferred,
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
238#[serde(rename_all = "snake_case")]
239#[allow(dead_code)]
240pub enum CacheBreakpointSource {
241 Tools,
243 System,
245 Message,
247}
248
249#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
251#[allow(dead_code)]
252pub struct CacheBreakpoint {
253 pub block_index: u32,
255 pub source: CacheBreakpointSource,
257 pub path: String,
259 pub message_index: Option<u32>,
261 pub prefix_hash: String,
263 pub prefix_token_count: u64,
265 pub requested_ttl: TtlClass,
267 pub origin: BreakpointOrigin,
269}
270
271#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
273#[allow(dead_code)]
274pub struct WarmCacheEntry {
275 pub prefix_hash: String,
277 pub expires_at_unix_secs: u64,
279 pub ttl_class: TtlClass,
281 pub last_observed_at_unix_secs: u64,
283}
284
285#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
287#[allow(dead_code)]
288pub struct CacheScore {
289 pub predicted_cache_read_tokens: u32,
291 pub predicted_cache_creation_tokens_5m: u32,
293 pub predicted_cache_creation_tokens_1h: u32,
295 pub predicted_uncached_input_tokens: u32,
297 pub predicted_expires_at_unix_secs: Option<u64>,
299 pub matched_breakpoint_index: Option<u32>,
301 pub confidence: f32,
303 pub ambiguity_reason: Option<String>,
305}
306
307#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
313pub struct UpstreamCandidate {
314 pub upstream_id: Uuid,
316 pub name: String,
318 pub kind: UpstreamKind,
320 pub observed_rate_limits: Vec<RateLimitObservation>,
322 #[serde(default)]
324 pub subscription_quotas: Vec<SubscriptionQuotaCandidateSnapshot>,
325 pub observed_at_unix_secs: u64,
327 #[serde(default, skip_serializing_if = "Option::is_none")]
329 pub cache_score: Option<CacheScore>,
330 #[serde(default, skip_serializing_if = "Option::is_none")]
334 pub base_url: Option<String>,
335}
336
337#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
339#[serde(rename_all = "snake_case")]
340pub enum CredentialStrategy {
341 ApiKey,
343 OAuth,
345 InternalForwarded,
347}
348
349#[derive(Clone, Debug)]
351pub struct RequestContext {
352 pub request_id: String,
354 pub downstream_headers: HeaderMap,
356 pub method: Method,
358 pub path: String,
360 pub query: Option<String>,
362 pub body_bytes: Bytes,
364 pub cache_breakpoints: Vec<CacheBreakpoint>,
366 pub canonical_model_id: String,
368}
369
370#[derive(Clone, Debug)]
372pub struct ShapedRequest {
373 url: Url,
374 method: Method,
375 headers: HeaderMap,
376 body: Bytes,
377 _seal: crate::private::Seal,
378}
379
380impl ShapedRequest {
381 pub fn url(&self) -> &Url {
383 &self.url
384 }
385
386 pub fn set_url(&mut self, url: Url) {
388 self.url = url;
389 }
390
391 pub fn method(&self) -> &Method {
393 &self.method
394 }
395
396 pub fn set_method(&mut self, method: Method) {
398 self.method = method;
399 }
400
401 pub fn headers(&self) -> &HeaderMap {
403 &self.headers
404 }
405
406 pub fn headers_mut(&mut self) -> &mut HeaderMap {
408 &mut self.headers
409 }
410
411 pub fn body(&self) -> &Bytes {
413 &self.body
414 }
415
416 pub fn set_body(&mut self, body: Bytes) {
418 self.body = body;
419 }
420}
421
422#[derive(Debug)]
429pub struct ShapedRequestBuilder {
430 _seal: crate::private::Seal,
431}
432
433impl ShapedRequestBuilder {
434 pub fn shaped_request(
436 &mut self,
437 url: Url,
438 method: Method,
439 headers: HeaderMap,
440 body: Bytes,
441 ) -> ShapedRequest {
442 ShapedRequest {
443 url,
444 method,
445 headers,
446 body,
447 _seal: crate::private::Seal,
448 }
449 }
450}
451
452pub fn shape_request(
454 dialect: &dyn UpstreamDialect,
455 ctx: &RequestContext,
456 upstream: &Upstream,
457 principal: &Principal,
458) -> Result<ShapedRequest, DialectError> {
459 let mut builder = ShapedRequestBuilder {
460 _seal: crate::private::Seal,
461 };
462 dialect.shape(ctx, upstream, principal, &mut builder)
463}
464
465#[derive(Clone, Debug)]
467pub struct SignedRequest {
468 url: Url,
469 method: Method,
470 headers: HeaderMap,
471 body: Bytes,
472 _seal: crate::private::Seal,
473}
474
475impl SignedRequest {
476 pub fn from_shaped(shaped: ShapedRequest, _capability: &mut SigningCapability) -> Self {
478 Self {
479 url: shaped.url,
480 method: shaped.method,
481 headers: shaped.headers,
482 body: shaped.body,
483 _seal: crate::private::Seal,
484 }
485 }
486
487 pub fn url(&self) -> &Url {
489 &self.url
490 }
491
492 pub fn method(&self) -> &Method {
494 &self.method
495 }
496
497 pub fn headers(&self) -> &HeaderMap {
499 &self.headers
500 }
501
502 pub fn body(&self) -> &Bytes {
504 &self.body
505 }
506
507 pub fn into_parts(self) -> (Url, Method, HeaderMap, Bytes) {
509 (self.url, self.method, self.headers, self.body)
510 }
511}
512
513#[derive(Debug)]
520pub struct SigningCapability {
521 _seal: crate::private::Seal,
522}
523
524pub async fn sign_request(
526 signer: &dyn Signer,
527 shaped: ShapedRequest,
528) -> Result<SignedRequest, SignerError> {
529 let mut capability = SigningCapability {
530 _seal: crate::private::Seal,
531 };
532 signer.sign(shaped, &mut capability).await
533}
534
535pub struct RouteDecision {
537 pub upstream_id: Option<Uuid>,
539 pub upstream: Upstream,
541 pub dialect: Arc<dyn UpstreamDialect>,
543}
544
545#[derive(Clone, Debug, Eq, PartialEq)]
547pub struct PrincipalQuotas {
548 pub requests_per_window: u64,
550 pub input_tokens_per_window: u64,
552 pub output_tokens_per_window: u64,
554 pub window: Duration,
556 pub allowed_models: Vec<String>,
558}
559
560#[derive(Clone, Debug, Eq, PartialEq)]
562pub enum ObserveEvent {
563 RequestStarted {
565 request_id: String,
567 downstream_user_agent: Option<String>,
569 },
570 AuthnComplete {
572 principal_id: String,
574 kind: PrincipalKind,
576 },
577 UpstreamChosen {
579 upstream: Upstream,
581 },
582 Chunk {
584 batch_index: u64,
586 event_count: usize,
588 total_bytes: usize,
590 },
591 RequestFinished {
593 status: StatusCode,
595 input_tokens: Option<u64>,
597 output_tokens: Option<u64>,
599 cache_creation_input_tokens: Option<u64>,
601 cache_read_input_tokens: Option<u64>,
603 duration_ms: u64,
605 },
606 Error {
608 code: String,
610 message: String,
612 source: String,
614 },
615}
616
617#[derive(Clone)]
619pub enum RetryDecision {
620 Refresh {
622 new_signer: Arc<dyn Signer>,
624 },
625 Fail,
627}
628
629#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
631pub struct PluginManifest {
632 pub name: String,
634 pub artifact: String,
636 #[serde(default, skip_serializing_if = "Option::is_none")]
638 pub wire_version: Option<u8>,
639 pub config: serde_json::Value,
641 #[serde(default)]
643 pub metadata: BTreeMap<String, serde_json::Value>,
644}
645
646pub const MAX_ROUTING_TRACE_STAGES: usize = 100;
649pub const MAX_STAGE_NAME_LEN: usize = 256;
651pub const MAX_ERROR_MESSAGE_LEN: usize = 1024;
653
654#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
656#[serde(rename_all = "snake_case")]
657pub enum PassthroughCause {
658 HealthyUpstream,
660 NoAlternative,
662 PluginDecision,
664}
665
666#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
668#[serde(rename_all = "snake_case")]
669pub enum PerCandidateReason {
670 RateLimited,
672 InsufficientQuota,
674 Unhealthy,
676 RejectedByPlugin,
678}
679
680#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
682#[serde(rename_all = "kebab-case")]
683pub enum TerminalStrategy {
684 #[default]
686 FirstPick,
687 Random,
689 RoundRobin,
691 LeastConnections,
693}
694
695#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
697pub struct StageDecision {
698 pub stage_name: String,
700 #[serde(default, skip_serializing_if = "Option::is_none")]
702 pub upstream_id: Option<Uuid>,
703 #[serde(default, skip_serializing_if = "Option::is_none")]
705 pub reason: Option<String>,
706 #[serde(default, skip_serializing_if = "is_zero_u64")]
708 pub duration_us: u64,
709}
710
711fn is_zero_u64(value: &u64) -> bool {
712 *value == 0
713}
714
715#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
717pub struct TerminalDecision {
718 #[serde(default, skip_serializing_if = "Option::is_none")]
720 pub upstream_id: Option<Uuid>,
721 pub strategy: TerminalStrategy,
723}
724
725#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
727pub struct RoutingTrace {
728 #[serde(default)]
730 pub stages: Vec<StageDecision>,
731 #[serde(default, skip_serializing_if = "Option::is_none")]
733 pub terminal_decision: Option<TerminalDecision>,
734}
735
736#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
738#[serde(rename_all = "snake_case")]
739pub enum InternalErrorStage {
740 Authn,
742 #[default]
744 Router,
745 RouterFilter,
747 Shape,
749 Signer,
751 Relay,
753}
754
755#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
757#[serde(rename_all = "snake_case")]
758pub enum InternalErrorKind {
759 #[default]
761 PluginError,
762 InvalidOutput,
764 Trap,
766 ConfigError,
768 Timeout,
770 Unavailable,
772}
773
774#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
776pub struct InternalError {
777 pub stage: InternalErrorStage,
779 pub kind: InternalErrorKind,
781 #[serde(default, skip_serializing_if = "Option::is_none")]
783 pub message: Option<String>,
784}
785
786#[cfg(test)]
787mod tests {
788 use super::*;
789
790 #[test]
791 fn shaped_request_accessors_and_mutators_preserve_parts() {
792 let mut builder = ShapedRequestBuilder {
793 _seal: crate::private::Seal,
794 };
795 let mut headers = HeaderMap::new();
796 headers.insert("x-test", "one".parse().unwrap());
797 let mut shaped = builder.shaped_request(
798 "https://example.test/v1/messages".parse().unwrap(),
799 Method::POST,
800 headers,
801 Bytes::from_static(b"first"),
802 );
803
804 assert_eq!(shaped.url().as_str(), "https://example.test/v1/messages");
805 assert_eq!(shaped.method(), Method::POST);
806 assert_eq!(shaped.headers()["x-test"], "one");
807 assert_eq!(shaped.body(), &Bytes::from_static(b"first"));
808
809 shaped.set_url("https://example.test/v1/complete".parse().unwrap());
810 shaped.set_method(Method::PUT);
811 shaped
812 .headers_mut()
813 .insert("x-test", "two".parse().unwrap());
814 shaped.set_body(Bytes::from_static(b"second"));
815
816 assert_eq!(shaped.url().path(), "/v1/complete");
817 assert_eq!(shaped.method(), Method::PUT);
818 assert_eq!(shaped.headers()["x-test"], "two");
819 assert_eq!(shaped.body(), &Bytes::from_static(b"second"));
820 }
821
822 #[test]
823 fn signed_request_exposes_and_consumes_signed_parts() {
824 let mut builder = ShapedRequestBuilder {
825 _seal: crate::private::Seal,
826 };
827 let mut headers = HeaderMap::new();
828 headers.insert("authorization", "Bearer token".parse().unwrap());
829 let shaped = builder.shaped_request(
830 "https://api.example.test/v1/messages".parse().unwrap(),
831 Method::POST,
832 headers,
833 Bytes::from_static(b"{}"),
834 );
835 let mut capability = SigningCapability {
836 _seal: crate::private::Seal,
837 };
838
839 let signed = SignedRequest::from_shaped(shaped, &mut capability);
840 assert_eq!(signed.url().host_str(), Some("api.example.test"));
841 assert_eq!(signed.method(), Method::POST);
842 assert_eq!(signed.headers()["authorization"], "Bearer token");
843 assert_eq!(signed.body(), &Bytes::from_static(b"{}"));
844
845 let (url, method, headers, body) = signed.into_parts();
846 assert_eq!(url.as_str(), "https://api.example.test/v1/messages");
847 assert_eq!(method, Method::POST);
848 assert_eq!(headers["authorization"], "Bearer token");
849 assert_eq!(body, Bytes::from_static(b"{}"));
850 }
851
852 #[test]
853 fn upstream_and_manifest_serde_round_trip() {
854 let upstreams = vec![Upstream::AnthropicDirect { base_url: None }];
855
856 for upstream in upstreams {
857 let json = serde_json::to_string(&upstream).unwrap();
858 let decoded: Upstream = serde_json::from_str(&json).unwrap();
859 assert_eq!(decoded, upstream);
860 }
861
862 let manifest: PluginManifest = serde_json::from_value(serde_json::json!({
863 "name": "authn",
864 "artifact": "plugin.wasm",
865 "config": {"enabled": true}
866 }))
867 .unwrap();
868 assert_eq!(manifest.name, "authn");
869 assert_eq!(manifest.wire_version, None);
870 assert!(manifest.metadata.is_empty());
871
872 let manifest: PluginManifest = serde_json::from_value(serde_json::json!({
873 "name": "cache-aware",
874 "artifact": "plugin.wasm",
875 "wire_version": 2,
876 "config": {}
877 }))
878 .unwrap();
879 assert_eq!(manifest.wire_version, Some(2));
880 }
881
882 #[test]
883 fn public_enums_cover_all_current_variants() {
884 let principal_kinds = [
885 PrincipalKind::ApiKey,
886 PrincipalKind::OAuthSubject,
887 PrincipalKind::InternalKey,
888 PrincipalKind::WorkloadIdentity,
889 PrincipalKind::SubscriptionBearer,
890 ];
891 assert_eq!(principal_kinds.len(), 5);
892
893 let strategies = [
894 CredentialStrategy::ApiKey,
895 CredentialStrategy::OAuth,
896 CredentialStrategy::InternalForwarded,
897 ];
898 assert_eq!(strategies.len(), 3);
899 }
900
901 #[test]
902 fn observe_event_variants_are_equatable() {
903 let events = vec![
904 ObserveEvent::RequestStarted {
905 request_id: "req".to_owned(),
906 downstream_user_agent: Some("ua".to_owned()),
907 },
908 ObserveEvent::AuthnComplete {
909 principal_id: "principal".to_owned(),
910 kind: PrincipalKind::InternalKey,
911 },
912 ObserveEvent::UpstreamChosen {
913 upstream: Upstream::AnthropicDirect { base_url: None },
914 },
915 ObserveEvent::Chunk {
916 batch_index: 1,
917 event_count: 2,
918 total_bytes: 3,
919 },
920 ObserveEvent::RequestFinished {
921 status: StatusCode::OK,
922 input_tokens: Some(4),
923 output_tokens: Some(5),
924 cache_creation_input_tokens: Some(6),
925 cache_read_input_tokens: Some(7),
926 duration_ms: 8,
927 },
928 ObserveEvent::Error {
929 code: "E".to_owned(),
930 message: "redacted".to_owned(),
931 source: "plugin".to_owned(),
932 },
933 ];
934
935 assert_eq!(events, events.clone());
936 }
937
938 #[test]
939 fn cache_types_roundtrip() {
940 let ttl_class = TtlClass::Ephemeral1h;
941 let json = serde_json::to_string(&ttl_class).unwrap();
942 let decoded: TtlClass = serde_json::from_str(&json).unwrap();
943 assert_eq!(decoded, ttl_class);
944
945 let origin = BreakpointOrigin::AutoCacheInferred;
946 let json = serde_json::to_string(&origin).unwrap();
947 let decoded: BreakpointOrigin = serde_json::from_str(&json).unwrap();
948 assert_eq!(decoded, origin);
949
950 let source = CacheBreakpointSource::System;
951 let json = serde_json::to_string(&source).unwrap();
952 let decoded: CacheBreakpointSource = serde_json::from_str(&json).unwrap();
953 assert_eq!(decoded, source);
954
955 let breakpoint = CacheBreakpoint {
956 block_index: 0,
957 source: CacheBreakpointSource::Message,
958 path: "messages.0.content.1".to_owned(),
959 message_index: Some(0),
960 prefix_hash: "abc123".to_owned(),
961 prefix_token_count: 100,
962 requested_ttl: TtlClass::Ephemeral5m,
963 origin: BreakpointOrigin::Explicit,
964 };
965 let json = serde_json::to_string(&breakpoint).unwrap();
966 let decoded: CacheBreakpoint = serde_json::from_str(&json).unwrap();
967 assert_eq!(decoded, breakpoint);
968
969 let warm_entry = WarmCacheEntry {
970 prefix_hash: "def456".to_owned(),
971 expires_at_unix_secs: 1700000000,
972 ttl_class: TtlClass::Ephemeral1h,
973 last_observed_at_unix_secs: 1699999000,
974 };
975 let json = serde_json::to_string(&warm_entry).unwrap();
976 let decoded: WarmCacheEntry = serde_json::from_str(&json).unwrap();
977 assert_eq!(decoded, warm_entry);
978
979 let cache_score = CacheScore {
980 predicted_cache_read_tokens: 50,
981 predicted_cache_creation_tokens_5m: 100,
982 predicted_cache_creation_tokens_1h: 200,
983 predicted_uncached_input_tokens: 25,
984 predicted_expires_at_unix_secs: Some(1700000000),
985 matched_breakpoint_index: Some(0),
986 confidence: 0.95,
987 ambiguity_reason: None,
988 };
989 let json = serde_json::to_string(&cache_score).unwrap();
990 let decoded: CacheScore = serde_json::from_str(&json).unwrap();
991 assert_eq!(
992 decoded.predicted_cache_read_tokens,
993 cache_score.predicted_cache_read_tokens
994 );
995 assert_eq!(
996 decoded.predicted_cache_creation_tokens_5m,
997 cache_score.predicted_cache_creation_tokens_5m
998 );
999 assert_eq!(
1000 decoded.predicted_cache_creation_tokens_1h,
1001 cache_score.predicted_cache_creation_tokens_1h
1002 );
1003 assert_eq!(
1004 decoded.predicted_uncached_input_tokens,
1005 cache_score.predicted_uncached_input_tokens
1006 );
1007 assert_eq!(
1008 decoded.predicted_expires_at_unix_secs,
1009 cache_score.predicted_expires_at_unix_secs
1010 );
1011 assert_eq!(
1012 decoded.matched_breakpoint_index,
1013 cache_score.matched_breakpoint_index
1014 );
1015 assert!((decoded.confidence - cache_score.confidence).abs() < 0.0001);
1016 assert_eq!(decoded.ambiguity_reason, cache_score.ambiguity_reason);
1017 }
1018
1019 #[test]
1020 fn upstream_candidate_cache_score_roundtrip() {
1021 let candidate_no_cache = UpstreamCandidate {
1022 upstream_id: Uuid::new_v4(),
1023 name: "test-upstream".to_owned(),
1024 kind: UpstreamKind::AnthropicApiKey,
1025 observed_rate_limits: vec![],
1026 subscription_quotas: vec![],
1027 observed_at_unix_secs: 1700000000,
1028 cache_score: None,
1029 base_url: None,
1030 };
1031 let json = serde_json::to_string(&candidate_no_cache).unwrap();
1032 let decoded: UpstreamCandidate = serde_json::from_str(&json).unwrap();
1033 assert_eq!(decoded.upstream_id, candidate_no_cache.upstream_id);
1034 assert_eq!(decoded.name, candidate_no_cache.name);
1035 assert_eq!(decoded.cache_score, None);
1036
1037 let cache_score = CacheScore {
1038 predicted_cache_read_tokens: 50,
1039 predicted_cache_creation_tokens_5m: 100,
1040 predicted_cache_creation_tokens_1h: 200,
1041 predicted_uncached_input_tokens: 25,
1042 predicted_expires_at_unix_secs: Some(1700000000),
1043 matched_breakpoint_index: Some(0),
1044 confidence: 0.95,
1045 ambiguity_reason: None,
1046 };
1047 let candidate_with_cache = UpstreamCandidate {
1048 upstream_id: Uuid::new_v4(),
1049 name: "test-upstream-cached".to_owned(),
1050 kind: UpstreamKind::AnthropicApiKey,
1051 observed_rate_limits: vec![],
1052 subscription_quotas: vec![],
1053 observed_at_unix_secs: 1700000000,
1054 cache_score: Some(cache_score),
1055 base_url: None,
1056 };
1057 let json = serde_json::to_string(&candidate_with_cache).unwrap();
1058 let decoded: UpstreamCandidate = serde_json::from_str(&json).unwrap();
1059 assert_eq!(decoded.upstream_id, candidate_with_cache.upstream_id);
1060 assert_eq!(decoded.name, candidate_with_cache.name);
1061 assert!(decoded.cache_score.is_some());
1062 assert_eq!(decoded.cache_score.unwrap().predicted_cache_read_tokens, 50);
1063 }
1064
1065 #[test]
1066 fn request_context_cache_fields_roundtrip() {
1067 let ctx_empty = RequestContext {
1068 request_id: "req-1".to_owned(),
1069 downstream_headers: HeaderMap::new(),
1070 method: Method::POST,
1071 path: "/v1/messages".to_owned(),
1072 query: None,
1073 body_bytes: Bytes::new(),
1074 cache_breakpoints: Vec::new(),
1075 canonical_model_id: String::new(),
1076 };
1077 assert_eq!(ctx_empty.cache_breakpoints.len(), 0);
1078 assert_eq!(ctx_empty.canonical_model_id, "");
1079
1080 let breakpoint = CacheBreakpoint {
1081 block_index: 1,
1082 source: CacheBreakpointSource::Message,
1083 path: "messages.0.content.0".to_owned(),
1084 message_index: Some(0),
1085 prefix_hash: "hash123".to_owned(),
1086 prefix_token_count: 150,
1087 requested_ttl: TtlClass::Ephemeral1h,
1088 origin: BreakpointOrigin::Explicit,
1089 };
1090 let ctx_populated = RequestContext {
1091 request_id: "req-2".to_owned(),
1092 downstream_headers: HeaderMap::new(),
1093 method: Method::POST,
1094 path: "/v1/messages".to_owned(),
1095 query: Some("param=value".to_owned()),
1096 body_bytes: Bytes::from_static(b"test"),
1097 cache_breakpoints: vec![breakpoint],
1098 canonical_model_id: "claude-sonnet-4-5-20250929".to_owned(),
1099 };
1100 assert_eq!(ctx_populated.cache_breakpoints.len(), 1);
1101 assert_eq!(
1102 ctx_populated.canonical_model_id,
1103 "claude-sonnet-4-5-20250929"
1104 );
1105 }
1106
1107 #[test]
1108 fn upstream_candidate_deserialize_without_cache_score_field() {
1109 let json = r#"{
1110 "upstream_id": "00000000-0000-0000-0000-000000000001",
1111 "name": "legacy-upstream",
1112 "kind": "anthropic_api_key",
1113 "observed_rate_limits": [],
1114 "subscription_quotas": [],
1115 "observed_at_unix_secs": 1700000000
1116 }"#;
1117 let candidate: UpstreamCandidate = serde_json::from_str(json).unwrap();
1118 assert!(candidate.cache_score.is_none());
1119 assert_eq!(candidate.name, "legacy-upstream");
1120 }
1121
1122 #[test]
1123 fn request_context_cache_breakpoints_default_on_missing_fields() {
1124 let ctx = RequestContext {
1125 request_id: "test".to_owned(),
1126 downstream_headers: HeaderMap::new(),
1127 method: Method::GET,
1128 path: "/test".to_owned(),
1129 query: None,
1130 body_bytes: Bytes::new(),
1131 cache_breakpoints: Vec::new(),
1132 canonical_model_id: String::new(),
1133 };
1134 assert!(ctx.cache_breakpoints.is_empty());
1135 assert!(ctx.canonical_model_id.is_empty());
1136 }
1137}