1use std::collections::BTreeMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use bytes::Bytes;
8use http::{HeaderMap, Method, StatusCode};
9use serde::{Deserialize, Serialize};
10use url::Url;
11use uuid::Uuid;
12
13use crate::errors::{DialectError, SignerError};
14use crate::traits::{Signer, UpstreamDialect};
15
16#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
18pub struct Principal {
19 pub id: String,
21 pub kind: PrincipalKind,
23 pub claims: serde_json::Map<String, serde_json::Value>,
25}
26
27#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
29#[serde(rename_all = "snake_case")]
30pub enum PrincipalKind {
31 ApiKey,
33 OAuthSubject,
35 InternalKey,
37 WorkloadIdentity,
39 SubscriptionBearer,
41}
42
43#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
45#[serde(rename_all = "snake_case")]
46pub enum PluginSlot {
47 Router,
49 ObservabilityHook,
51 Shape,
53}
54
55impl PluginSlot {
56 pub fn as_str(self) -> &'static str {
58 match self {
59 Self::Router => "router",
60 Self::ObservabilityHook => "observability_hook",
61 Self::Shape => "shape",
62 }
63 }
64
65 pub fn parse(value: &str) -> Option<Self> {
67 match value {
68 "router" => Some(Self::Router),
69 "observability_hook" => Some(Self::ObservabilityHook),
70 "shape" => Some(Self::Shape),
71 _ => None,
72 }
73 }
74}
75
76pub const GLOBAL_PRINCIPAL: &str = "__global__";
78
79#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
85pub struct SlotKey {
86 pub principal: String,
89 pub plugin: String,
91}
92
93impl SlotKey {
94 pub fn new(principal: impl Into<String>, plugin: impl Into<String>) -> Self {
96 Self {
97 principal: principal.into(),
98 plugin: plugin.into(),
99 }
100 }
101
102 pub fn global(plugin: impl Into<String>) -> Self {
104 Self::new(GLOBAL_PRINCIPAL, plugin)
105 }
106}
107
108#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
110#[serde(rename_all = "snake_case", tag = "kind")]
111pub enum Upstream {
112 AnthropicDirect {
114 #[serde(default, skip_serializing_if = "Option::is_none")]
118 base_url: Option<Url>,
119 },
120}
121
122#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
124#[serde(rename_all = "snake_case")]
125pub enum UpstreamKind {
126 AnthropicApiKey,
128 AnthropicOauth,
130}
131
132impl UpstreamKind {
133 pub fn as_str(self) -> &'static str {
135 match self {
136 Self::AnthropicApiKey => "anthropic_api_key",
137 Self::AnthropicOauth => "anthropic_oauth",
138 }
139 }
140}
141
142#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
144#[serde(rename_all = "snake_case")]
145pub enum RateLimitKind {
146 Requests,
148 Tokens,
150 InputTokens,
152 OutputTokens,
154}
155
156impl RateLimitKind {
157 pub fn as_str(self) -> &'static str {
159 match self {
160 Self::Requests => "requests",
161 Self::Tokens => "tokens",
162 Self::InputTokens => "input_tokens",
163 Self::OutputTokens => "output_tokens",
164 }
165 }
166}
167
168#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
170pub struct RateLimitObservation {
171 pub kind: RateLimitKind,
173 pub window: String,
175 pub limit: Option<u64>,
177 pub remaining: Option<u64>,
179 pub reset: Option<String>,
181}
182
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
185#[serde(rename_all = "snake_case")]
186pub enum SubscriptionQuotaDataState {
187 Fresh,
189 Stale,
191 Missing,
193}
194
195#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
197pub struct SubscriptionQuotaCandidateSnapshot {
198 pub window: String,
200 pub state: SubscriptionQuotaDataState,
202 pub source: Option<String>,
204 pub utilization: Option<f64>,
206 pub status: Option<String>,
208 pub resets_at_unix_secs: Option<u64>,
210 pub surpassed_threshold: Option<f64>,
214 pub representative_claim: Option<String>,
216 pub disabled_reason: Option<String>,
218 pub extra_usage_enabled: Option<bool>,
220 pub extra_usage_monthly_limit: Option<f64>,
222 pub extra_usage_used_credits: Option<f64>,
224 pub observed_at_unix_millis: Option<u64>,
226 pub max_staleness_secs: u64,
228 #[serde(default, skip_serializing_if = "Option::is_none")]
231 pub fallback_available: Option<bool>,
232 #[serde(default, skip_serializing_if = "Option::is_none")]
235 pub overage_in_use: Option<bool>,
236 #[serde(default, skip_serializing_if = "Option::is_none")]
239 pub overage_period_monthly_utilization: Option<f64>,
240 #[serde(default, skip_serializing_if = "Option::is_none")]
242 pub upgrade_paths: Option<Vec<String>>,
243}
244
245#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
247#[serde(rename_all = "snake_case")]
248#[allow(dead_code)]
249pub enum TtlClass {
250 #[default]
252 Ephemeral5m,
253 Ephemeral1h,
255}
256
257#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
259#[serde(rename_all = "snake_case")]
260#[allow(dead_code)]
261pub enum BreakpointOrigin {
262 Explicit,
264 AutoCacheInferred,
266}
267
268#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
270#[serde(rename_all = "snake_case")]
271#[allow(dead_code)]
272pub enum CacheBreakpointSource {
273 Tools,
275 System,
277 Message,
279}
280
281#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
283#[allow(dead_code)]
284pub struct CacheBreakpoint {
285 pub block_index: u32,
287 pub source: CacheBreakpointSource,
289 pub path: String,
291 pub message_index: Option<u32>,
293 pub prefix_hash: String,
295 pub prefix_token_count: u64,
297 pub requested_ttl: TtlClass,
299 pub origin: BreakpointOrigin,
301}
302
303#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
305#[allow(dead_code)]
306pub struct WarmCacheEntry {
307 pub prefix_hash: String,
309 pub expires_at_unix_secs: u64,
311 pub ttl_class: TtlClass,
313 pub last_observed_at_unix_secs: u64,
315}
316
317#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
319#[allow(dead_code)]
320pub struct CacheScore {
321 pub predicted_cache_read_tokens: u32,
323 pub predicted_cache_creation_tokens_5m: u32,
325 pub predicted_cache_creation_tokens_1h: u32,
327 pub predicted_uncached_input_tokens: u32,
329 pub predicted_expires_at_unix_secs: Option<u64>,
331 pub matched_breakpoint_index: Option<u32>,
333 pub confidence: f32,
335 pub ambiguity_reason: Option<String>,
337}
338
339#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
345pub struct UpstreamCandidate {
346 pub upstream_id: Uuid,
348 pub name: String,
350 pub kind: UpstreamKind,
352 pub observed_rate_limits: Vec<RateLimitObservation>,
354 #[serde(default)]
356 pub subscription_quotas: Vec<SubscriptionQuotaCandidateSnapshot>,
357 pub observed_at_unix_secs: u64,
359 #[serde(default, skip_serializing_if = "Option::is_none")]
361 pub cache_score: Option<CacheScore>,
362 #[serde(default, skip_serializing_if = "Option::is_none")]
366 pub base_url: Option<String>,
367}
368
369#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
371#[serde(rename_all = "snake_case")]
372pub enum CredentialStrategy {
373 ApiKey,
375 OAuth,
377 InternalForwarded,
379}
380
381#[derive(Clone, Debug)]
383pub struct RequestContext {
384 pub request_id: String,
386 pub downstream_headers: HeaderMap,
388 pub method: Method,
390 pub path: String,
392 pub query: Option<String>,
394 pub body_bytes: Bytes,
396 pub cache_breakpoints: Vec<CacheBreakpoint>,
398 pub canonical_model_id: String,
400}
401
402#[derive(Clone, Debug)]
404pub struct ShapedRequest {
405 url: Url,
406 method: Method,
407 headers: HeaderMap,
408 body: Bytes,
409 _seal: crate::private::Seal,
410}
411
412impl ShapedRequest {
413 pub fn url(&self) -> &Url {
415 &self.url
416 }
417
418 pub fn set_url(&mut self, url: Url) {
420 self.url = url;
421 }
422
423 pub fn method(&self) -> &Method {
425 &self.method
426 }
427
428 pub fn set_method(&mut self, method: Method) {
430 self.method = method;
431 }
432
433 pub fn headers(&self) -> &HeaderMap {
435 &self.headers
436 }
437
438 pub fn headers_mut(&mut self) -> &mut HeaderMap {
440 &mut self.headers
441 }
442
443 pub fn body(&self) -> &Bytes {
445 &self.body
446 }
447
448 pub fn set_body(&mut self, body: Bytes) {
450 self.body = body;
451 }
452}
453
454#[derive(Debug)]
461pub struct ShapedRequestBuilder {
462 _seal: crate::private::Seal,
463}
464
465impl ShapedRequestBuilder {
466 pub fn shaped_request(
468 &mut self,
469 url: Url,
470 method: Method,
471 headers: HeaderMap,
472 body: Bytes,
473 ) -> ShapedRequest {
474 ShapedRequest {
475 url,
476 method,
477 headers,
478 body,
479 _seal: crate::private::Seal,
480 }
481 }
482}
483
484pub fn shape_request(
486 dialect: &dyn UpstreamDialect,
487 ctx: &RequestContext,
488 upstream: &Upstream,
489 principal: &Principal,
490) -> Result<ShapedRequest, DialectError> {
491 let mut builder = ShapedRequestBuilder {
492 _seal: crate::private::Seal,
493 };
494 dialect.shape(ctx, upstream, principal, &mut builder)
495}
496
497#[derive(Clone, Debug)]
499pub struct SignedRequest {
500 url: Url,
501 method: Method,
502 headers: HeaderMap,
503 body: Bytes,
504 _seal: crate::private::Seal,
505}
506
507impl SignedRequest {
508 pub fn from_shaped(shaped: ShapedRequest, _capability: &mut SigningCapability) -> Self {
510 Self {
511 url: shaped.url,
512 method: shaped.method,
513 headers: shaped.headers,
514 body: shaped.body,
515 _seal: crate::private::Seal,
516 }
517 }
518
519 pub fn url(&self) -> &Url {
521 &self.url
522 }
523
524 pub fn method(&self) -> &Method {
526 &self.method
527 }
528
529 pub fn headers(&self) -> &HeaderMap {
531 &self.headers
532 }
533
534 pub fn body(&self) -> &Bytes {
536 &self.body
537 }
538
539 pub fn into_parts(self) -> (Url, Method, HeaderMap, Bytes) {
541 (self.url, self.method, self.headers, self.body)
542 }
543}
544
545#[derive(Debug)]
552pub struct SigningCapability {
553 _seal: crate::private::Seal,
554}
555
556pub async fn sign_request(
558 signer: &dyn Signer,
559 shaped: ShapedRequest,
560) -> Result<SignedRequest, SignerError> {
561 let mut capability = SigningCapability {
562 _seal: crate::private::Seal,
563 };
564 signer.sign(shaped, &mut capability).await
565}
566
567pub struct RouteDecision {
569 pub upstream_id: Option<Uuid>,
571 pub upstream: Upstream,
573 pub dialect: Arc<dyn UpstreamDialect>,
575}
576
577#[derive(Clone, Debug, Eq, PartialEq)]
579pub struct PrincipalQuotas {
580 pub requests_per_window: u64,
582 pub input_tokens_per_window: u64,
584 pub output_tokens_per_window: u64,
586 pub window: Duration,
588 pub allowed_models: Vec<String>,
590}
591
592#[derive(Clone, Debug, Eq, PartialEq)]
594pub enum ObserveEvent {
595 RequestStarted {
597 request_id: String,
599 downstream_user_agent: Option<String>,
601 },
602 AuthnComplete {
604 principal_id: String,
606 kind: PrincipalKind,
608 },
609 UpstreamChosen {
611 upstream: Upstream,
613 },
614 Chunk {
616 batch_index: u64,
618 event_count: usize,
620 total_bytes: usize,
622 },
623 RequestFinished {
625 status: StatusCode,
627 input_tokens: Option<u64>,
629 output_tokens: Option<u64>,
631 cache_creation_input_tokens: Option<u64>,
633 cache_read_input_tokens: Option<u64>,
635 duration_ms: u64,
637 },
638 Error {
640 code: String,
642 message: String,
644 source: String,
646 },
647}
648
649#[derive(Clone)]
651pub enum RetryDecision {
652 Refresh {
654 new_signer: Arc<dyn Signer>,
656 },
657 Fail,
659}
660
661#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
663pub struct PluginManifest {
664 pub name: String,
666 pub artifact: String,
668 #[serde(default, skip_serializing_if = "Option::is_none")]
670 pub wire_version: Option<u8>,
671 pub config: serde_json::Value,
673 #[serde(default)]
675 pub metadata: BTreeMap<String, serde_json::Value>,
676 #[serde(default = "default_pure")]
682 pub pure: bool,
683}
684
685pub fn default_pure() -> bool {
688 true
689}
690
691pub const MAX_ROUTING_TRACE_STAGES: usize = 100;
694pub const MAX_STAGE_NAME_LEN: usize = 256;
696pub const MAX_ERROR_MESSAGE_LEN: usize = 1024;
698
699#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
701#[serde(rename_all = "snake_case")]
702pub enum PassthroughCause {
703 HealthyUpstream,
705 NoAlternative,
707 PluginDecision,
709}
710
711#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
713#[serde(rename_all = "snake_case")]
714pub enum PerCandidateReason {
715 RateLimited,
717 InsufficientQuota,
719 Unhealthy,
721 RejectedByPlugin,
723}
724
725#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
727#[serde(rename_all = "kebab-case")]
728pub enum TerminalStrategy {
729 #[default]
731 FirstPick,
732 Random,
734 RoundRobin,
736 LeastConnections,
738}
739
740#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
742pub struct StageDecision {
743 pub stage_name: String,
745 #[serde(default, skip_serializing_if = "Option::is_none")]
747 pub upstream_id: Option<Uuid>,
748 #[serde(default, skip_serializing_if = "Option::is_none")]
750 pub reason: Option<String>,
751 #[serde(default, skip_serializing_if = "is_zero_u64")]
753 pub duration_us: u64,
754}
755
756fn is_zero_u64(value: &u64) -> bool {
757 *value == 0
758}
759
760#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
762pub struct TerminalDecision {
763 #[serde(default, skip_serializing_if = "Option::is_none")]
765 pub upstream_id: Option<Uuid>,
766 pub strategy: TerminalStrategy,
768}
769
770#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
772pub struct RoutingTrace {
773 #[serde(default)]
775 pub stages: Vec<StageDecision>,
776 #[serde(default, skip_serializing_if = "Option::is_none")]
778 pub terminal_decision: Option<TerminalDecision>,
779}
780
781#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
783#[serde(rename_all = "snake_case")]
784pub enum InternalErrorStage {
785 Authn,
787 #[default]
789 Router,
790 RouterFilter,
792 Shape,
794 Signer,
796 Relay,
798}
799
800#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
802#[serde(rename_all = "snake_case")]
803pub enum InternalErrorKind {
804 #[default]
806 PluginError,
807 InvalidOutput,
809 Trap,
811 ConfigError,
813 Timeout,
815 Unavailable,
817}
818
819#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
821pub struct InternalError {
822 pub stage: InternalErrorStage,
824 pub kind: InternalErrorKind,
826 #[serde(default, skip_serializing_if = "Option::is_none")]
828 pub message: Option<String>,
829}
830
831#[cfg(test)]
832mod tests {
833 use super::*;
834
835 #[test]
836 fn shaped_request_accessors_and_mutators_preserve_parts() {
837 let mut builder = ShapedRequestBuilder {
838 _seal: crate::private::Seal,
839 };
840 let mut headers = HeaderMap::new();
841 headers.insert("x-test", "one".parse().unwrap());
842 let mut shaped = builder.shaped_request(
843 "https://example.test/v1/messages".parse().unwrap(),
844 Method::POST,
845 headers,
846 Bytes::from_static(b"first"),
847 );
848
849 assert_eq!(shaped.url().as_str(), "https://example.test/v1/messages");
850 assert_eq!(shaped.method(), Method::POST);
851 assert_eq!(shaped.headers()["x-test"], "one");
852 assert_eq!(shaped.body(), &Bytes::from_static(b"first"));
853
854 shaped.set_url("https://example.test/v1/complete".parse().unwrap());
855 shaped.set_method(Method::PUT);
856 shaped
857 .headers_mut()
858 .insert("x-test", "two".parse().unwrap());
859 shaped.set_body(Bytes::from_static(b"second"));
860
861 assert_eq!(shaped.url().path(), "/v1/complete");
862 assert_eq!(shaped.method(), Method::PUT);
863 assert_eq!(shaped.headers()["x-test"], "two");
864 assert_eq!(shaped.body(), &Bytes::from_static(b"second"));
865 }
866
867 #[test]
868 fn signed_request_exposes_and_consumes_signed_parts() {
869 let mut builder = ShapedRequestBuilder {
870 _seal: crate::private::Seal,
871 };
872 let mut headers = HeaderMap::new();
873 headers.insert("authorization", "Bearer token".parse().unwrap());
874 let shaped = builder.shaped_request(
875 "https://api.example.test/v1/messages".parse().unwrap(),
876 Method::POST,
877 headers,
878 Bytes::from_static(b"{}"),
879 );
880 let mut capability = SigningCapability {
881 _seal: crate::private::Seal,
882 };
883
884 let signed = SignedRequest::from_shaped(shaped, &mut capability);
885 assert_eq!(signed.url().host_str(), Some("api.example.test"));
886 assert_eq!(signed.method(), Method::POST);
887 assert_eq!(signed.headers()["authorization"], "Bearer token");
888 assert_eq!(signed.body(), &Bytes::from_static(b"{}"));
889
890 let (url, method, headers, body) = signed.into_parts();
891 assert_eq!(url.as_str(), "https://api.example.test/v1/messages");
892 assert_eq!(method, Method::POST);
893 assert_eq!(headers["authorization"], "Bearer token");
894 assert_eq!(body, Bytes::from_static(b"{}"));
895 }
896
897 #[test]
898 fn upstream_and_manifest_serde_round_trip() {
899 let upstreams = vec![Upstream::AnthropicDirect { base_url: None }];
900
901 for upstream in upstreams {
902 let json = serde_json::to_string(&upstream).unwrap();
903 let decoded: Upstream = serde_json::from_str(&json).unwrap();
904 assert_eq!(decoded, upstream);
905 }
906
907 let manifest: PluginManifest = serde_json::from_value(serde_json::json!({
908 "name": "authn",
909 "artifact": "plugin.wasm",
910 "config": {"enabled": true}
911 }))
912 .unwrap();
913 assert_eq!(manifest.name, "authn");
914 assert_eq!(manifest.wire_version, None);
915 assert!(manifest.metadata.is_empty());
916
917 let manifest: PluginManifest = serde_json::from_value(serde_json::json!({
918 "name": "cache-aware",
919 "artifact": "plugin.wasm",
920 "wire_version": 2,
921 "config": {}
922 }))
923 .unwrap();
924 assert_eq!(manifest.wire_version, Some(2));
925 }
926
927 #[test]
928 fn public_enums_cover_all_current_variants() {
929 let principal_kinds = [
930 PrincipalKind::ApiKey,
931 PrincipalKind::OAuthSubject,
932 PrincipalKind::InternalKey,
933 PrincipalKind::WorkloadIdentity,
934 PrincipalKind::SubscriptionBearer,
935 ];
936 assert_eq!(principal_kinds.len(), 5);
937
938 let strategies = [
939 CredentialStrategy::ApiKey,
940 CredentialStrategy::OAuth,
941 CredentialStrategy::InternalForwarded,
942 ];
943 assert_eq!(strategies.len(), 3);
944 }
945
946 #[test]
947 fn observe_event_variants_are_equatable() {
948 let events = vec![
949 ObserveEvent::RequestStarted {
950 request_id: "req".to_owned(),
951 downstream_user_agent: Some("ua".to_owned()),
952 },
953 ObserveEvent::AuthnComplete {
954 principal_id: "principal".to_owned(),
955 kind: PrincipalKind::InternalKey,
956 },
957 ObserveEvent::UpstreamChosen {
958 upstream: Upstream::AnthropicDirect { base_url: None },
959 },
960 ObserveEvent::Chunk {
961 batch_index: 1,
962 event_count: 2,
963 total_bytes: 3,
964 },
965 ObserveEvent::RequestFinished {
966 status: StatusCode::OK,
967 input_tokens: Some(4),
968 output_tokens: Some(5),
969 cache_creation_input_tokens: Some(6),
970 cache_read_input_tokens: Some(7),
971 duration_ms: 8,
972 },
973 ObserveEvent::Error {
974 code: "E".to_owned(),
975 message: "redacted".to_owned(),
976 source: "plugin".to_owned(),
977 },
978 ];
979
980 assert_eq!(events, events.clone());
981 }
982
983 #[test]
984 fn cache_types_roundtrip() {
985 let ttl_class = TtlClass::Ephemeral1h;
986 let json = serde_json::to_string(&ttl_class).unwrap();
987 let decoded: TtlClass = serde_json::from_str(&json).unwrap();
988 assert_eq!(decoded, ttl_class);
989
990 let origin = BreakpointOrigin::AutoCacheInferred;
991 let json = serde_json::to_string(&origin).unwrap();
992 let decoded: BreakpointOrigin = serde_json::from_str(&json).unwrap();
993 assert_eq!(decoded, origin);
994
995 let source = CacheBreakpointSource::System;
996 let json = serde_json::to_string(&source).unwrap();
997 let decoded: CacheBreakpointSource = serde_json::from_str(&json).unwrap();
998 assert_eq!(decoded, source);
999
1000 let breakpoint = CacheBreakpoint {
1001 block_index: 0,
1002 source: CacheBreakpointSource::Message,
1003 path: "messages.0.content.1".to_owned(),
1004 message_index: Some(0),
1005 prefix_hash: "abc123".to_owned(),
1006 prefix_token_count: 100,
1007 requested_ttl: TtlClass::Ephemeral5m,
1008 origin: BreakpointOrigin::Explicit,
1009 };
1010 let json = serde_json::to_string(&breakpoint).unwrap();
1011 let decoded: CacheBreakpoint = serde_json::from_str(&json).unwrap();
1012 assert_eq!(decoded, breakpoint);
1013
1014 let warm_entry = WarmCacheEntry {
1015 prefix_hash: "def456".to_owned(),
1016 expires_at_unix_secs: 1700000000,
1017 ttl_class: TtlClass::Ephemeral1h,
1018 last_observed_at_unix_secs: 1699999000,
1019 };
1020 let json = serde_json::to_string(&warm_entry).unwrap();
1021 let decoded: WarmCacheEntry = serde_json::from_str(&json).unwrap();
1022 assert_eq!(decoded, warm_entry);
1023
1024 let cache_score = CacheScore {
1025 predicted_cache_read_tokens: 50,
1026 predicted_cache_creation_tokens_5m: 100,
1027 predicted_cache_creation_tokens_1h: 200,
1028 predicted_uncached_input_tokens: 25,
1029 predicted_expires_at_unix_secs: Some(1700000000),
1030 matched_breakpoint_index: Some(0),
1031 confidence: 0.95,
1032 ambiguity_reason: None,
1033 };
1034 let json = serde_json::to_string(&cache_score).unwrap();
1035 let decoded: CacheScore = serde_json::from_str(&json).unwrap();
1036 assert_eq!(
1037 decoded.predicted_cache_read_tokens,
1038 cache_score.predicted_cache_read_tokens
1039 );
1040 assert_eq!(
1041 decoded.predicted_cache_creation_tokens_5m,
1042 cache_score.predicted_cache_creation_tokens_5m
1043 );
1044 assert_eq!(
1045 decoded.predicted_cache_creation_tokens_1h,
1046 cache_score.predicted_cache_creation_tokens_1h
1047 );
1048 assert_eq!(
1049 decoded.predicted_uncached_input_tokens,
1050 cache_score.predicted_uncached_input_tokens
1051 );
1052 assert_eq!(
1053 decoded.predicted_expires_at_unix_secs,
1054 cache_score.predicted_expires_at_unix_secs
1055 );
1056 assert_eq!(
1057 decoded.matched_breakpoint_index,
1058 cache_score.matched_breakpoint_index
1059 );
1060 assert!((decoded.confidence - cache_score.confidence).abs() < 0.0001);
1061 assert_eq!(decoded.ambiguity_reason, cache_score.ambiguity_reason);
1062 }
1063
1064 #[test]
1065 fn upstream_candidate_cache_score_roundtrip() {
1066 let candidate_no_cache = UpstreamCandidate {
1067 upstream_id: Uuid::new_v4(),
1068 name: "test-upstream".to_owned(),
1069 kind: UpstreamKind::AnthropicApiKey,
1070 observed_rate_limits: vec![],
1071 subscription_quotas: vec![],
1072 observed_at_unix_secs: 1700000000,
1073 cache_score: None,
1074 base_url: None,
1075 };
1076 let json = serde_json::to_string(&candidate_no_cache).unwrap();
1077 let decoded: UpstreamCandidate = serde_json::from_str(&json).unwrap();
1078 assert_eq!(decoded.upstream_id, candidate_no_cache.upstream_id);
1079 assert_eq!(decoded.name, candidate_no_cache.name);
1080 assert_eq!(decoded.cache_score, None);
1081
1082 let cache_score = CacheScore {
1083 predicted_cache_read_tokens: 50,
1084 predicted_cache_creation_tokens_5m: 100,
1085 predicted_cache_creation_tokens_1h: 200,
1086 predicted_uncached_input_tokens: 25,
1087 predicted_expires_at_unix_secs: Some(1700000000),
1088 matched_breakpoint_index: Some(0),
1089 confidence: 0.95,
1090 ambiguity_reason: None,
1091 };
1092 let candidate_with_cache = UpstreamCandidate {
1093 upstream_id: Uuid::new_v4(),
1094 name: "test-upstream-cached".to_owned(),
1095 kind: UpstreamKind::AnthropicApiKey,
1096 observed_rate_limits: vec![],
1097 subscription_quotas: vec![],
1098 observed_at_unix_secs: 1700000000,
1099 cache_score: Some(cache_score),
1100 base_url: None,
1101 };
1102 let json = serde_json::to_string(&candidate_with_cache).unwrap();
1103 let decoded: UpstreamCandidate = serde_json::from_str(&json).unwrap();
1104 assert_eq!(decoded.upstream_id, candidate_with_cache.upstream_id);
1105 assert_eq!(decoded.name, candidate_with_cache.name);
1106 assert!(decoded.cache_score.is_some());
1107 assert_eq!(decoded.cache_score.unwrap().predicted_cache_read_tokens, 50);
1108 }
1109
1110 #[test]
1111 fn request_context_cache_fields_roundtrip() {
1112 let ctx_empty = RequestContext {
1113 request_id: "req-1".to_owned(),
1114 downstream_headers: HeaderMap::new(),
1115 method: Method::POST,
1116 path: "/v1/messages".to_owned(),
1117 query: None,
1118 body_bytes: Bytes::new(),
1119 cache_breakpoints: Vec::new(),
1120 canonical_model_id: String::new(),
1121 };
1122 assert_eq!(ctx_empty.cache_breakpoints.len(), 0);
1123 assert_eq!(ctx_empty.canonical_model_id, "");
1124
1125 let breakpoint = CacheBreakpoint {
1126 block_index: 1,
1127 source: CacheBreakpointSource::Message,
1128 path: "messages.0.content.0".to_owned(),
1129 message_index: Some(0),
1130 prefix_hash: "hash123".to_owned(),
1131 prefix_token_count: 150,
1132 requested_ttl: TtlClass::Ephemeral1h,
1133 origin: BreakpointOrigin::Explicit,
1134 };
1135 let ctx_populated = RequestContext {
1136 request_id: "req-2".to_owned(),
1137 downstream_headers: HeaderMap::new(),
1138 method: Method::POST,
1139 path: "/v1/messages".to_owned(),
1140 query: Some("param=value".to_owned()),
1141 body_bytes: Bytes::from_static(b"test"),
1142 cache_breakpoints: vec![breakpoint],
1143 canonical_model_id: "claude-sonnet-4-5-20250929".to_owned(),
1144 };
1145 assert_eq!(ctx_populated.cache_breakpoints.len(), 1);
1146 assert_eq!(
1147 ctx_populated.canonical_model_id,
1148 "claude-sonnet-4-5-20250929"
1149 );
1150 }
1151
1152 #[test]
1153 fn upstream_candidate_deserialize_without_cache_score_field() {
1154 let json = r#"{
1155 "upstream_id": "00000000-0000-0000-0000-000000000001",
1156 "name": "legacy-upstream",
1157 "kind": "anthropic_api_key",
1158 "observed_rate_limits": [],
1159 "subscription_quotas": [],
1160 "observed_at_unix_secs": 1700000000
1161 }"#;
1162 let candidate: UpstreamCandidate = serde_json::from_str(json).unwrap();
1163 assert!(candidate.cache_score.is_none());
1164 assert_eq!(candidate.name, "legacy-upstream");
1165 }
1166
1167 #[test]
1168 fn request_context_cache_breakpoints_default_on_missing_fields() {
1169 let ctx = RequestContext {
1170 request_id: "test".to_owned(),
1171 downstream_headers: HeaderMap::new(),
1172 method: Method::GET,
1173 path: "/test".to_owned(),
1174 query: None,
1175 body_bytes: Bytes::new(),
1176 cache_breakpoints: Vec::new(),
1177 canonical_model_id: String::new(),
1178 };
1179 assert!(ctx.cache_breakpoints.is_empty());
1180 assert!(ctx.canonical_model_id.is_empty());
1181 }
1182}