1use std::collections::BTreeMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use bytes::Bytes;
8use http::{HeaderMap, Method, StatusCode};
9use serde::{Deserialize, Serialize};
10use url::Url;
11use uuid::Uuid;
12
13use crate::errors::{DialectError, SignerError};
14use crate::traits::{Signer, UpstreamDialect};
15
16#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
18pub struct Principal {
19 pub id: String,
21 pub kind: PrincipalKind,
23 pub claims: serde_json::Map<String, serde_json::Value>,
25}
26
27#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
29#[serde(rename_all = "snake_case")]
30pub enum PrincipalKind {
31 ApiKey,
33 OAuthSubject,
35 InternalKey,
37 WorkloadIdentity,
39 SubscriptionBearer,
41}
42
43#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
45#[serde(rename_all = "snake_case")]
46pub enum PluginSlot {
47 Router,
49 ObservabilityHook,
51 Shape,
53}
54
55impl PluginSlot {
56 pub fn as_str(self) -> &'static str {
58 match self {
59 Self::Router => "router",
60 Self::ObservabilityHook => "observability_hook",
61 Self::Shape => "shape",
62 }
63 }
64
65 pub fn parse(value: &str) -> Option<Self> {
67 match value {
68 "router" => Some(Self::Router),
69 "observability_hook" => Some(Self::ObservabilityHook),
70 "shape" => Some(Self::Shape),
71 _ => None,
72 }
73 }
74}
75
76#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
78#[serde(rename_all = "snake_case", tag = "kind")]
79pub enum Upstream {
80 AnthropicDirect {
82 #[serde(default, skip_serializing_if = "Option::is_none")]
86 base_url: Option<Url>,
87 },
88}
89
90#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
92#[serde(rename_all = "snake_case")]
93pub enum UpstreamKind {
94 AnthropicApiKey,
96 AnthropicOauth,
98}
99
100impl UpstreamKind {
101 pub fn as_str(self) -> &'static str {
103 match self {
104 Self::AnthropicApiKey => "anthropic_api_key",
105 Self::AnthropicOauth => "anthropic_oauth",
106 }
107 }
108}
109
110#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
112#[serde(rename_all = "snake_case")]
113pub enum RateLimitKind {
114 Requests,
116 Tokens,
118 InputTokens,
120 OutputTokens,
122}
123
124impl RateLimitKind {
125 pub fn as_str(self) -> &'static str {
127 match self {
128 Self::Requests => "requests",
129 Self::Tokens => "tokens",
130 Self::InputTokens => "input_tokens",
131 Self::OutputTokens => "output_tokens",
132 }
133 }
134}
135
136#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
138pub struct RateLimitObservation {
139 pub kind: RateLimitKind,
141 pub window: String,
143 pub limit: Option<u64>,
145 pub remaining: Option<u64>,
147 pub reset: Option<String>,
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
153#[serde(rename_all = "snake_case")]
154pub enum SubscriptionQuotaDataState {
155 Fresh,
157 Stale,
159 Missing,
161}
162
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
165pub struct SubscriptionQuotaCandidateSnapshot {
166 pub window: String,
168 pub state: SubscriptionQuotaDataState,
170 pub source: Option<String>,
172 pub utilization: Option<f64>,
174 pub status: Option<String>,
176 pub resets_at_unix_secs: Option<u64>,
178 pub surpassed_threshold: Option<bool>,
180 pub representative_claim: Option<String>,
182 pub disabled_reason: Option<String>,
184 pub extra_usage_enabled: Option<bool>,
186 pub extra_usage_monthly_limit: Option<f64>,
188 pub extra_usage_used_credits: Option<f64>,
190 pub observed_at_unix_millis: Option<u64>,
192 pub max_staleness_secs: u64,
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
198#[serde(rename_all = "snake_case")]
199#[allow(dead_code)]
200pub enum TtlClass {
201 #[default]
203 Ephemeral5m,
204 Ephemeral1h,
206}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
210#[serde(rename_all = "snake_case")]
211#[allow(dead_code)]
212pub enum BreakpointOrigin {
213 Explicit,
215 AutoCacheInferred,
217}
218
219#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
221#[serde(rename_all = "snake_case")]
222#[allow(dead_code)]
223pub enum CacheBreakpointSource {
224 Tools,
226 System,
228 Message,
230}
231
232#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
234#[allow(dead_code)]
235pub struct CacheBreakpoint {
236 pub block_index: u32,
238 pub source: CacheBreakpointSource,
240 pub path: String,
242 pub message_index: Option<u32>,
244 pub prefix_hash: String,
246 pub prefix_token_count: u64,
248 pub requested_ttl: TtlClass,
250 pub origin: BreakpointOrigin,
252}
253
254#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
256#[allow(dead_code)]
257pub struct WarmCacheEntry {
258 pub prefix_hash: String,
260 pub expires_at_unix_secs: u64,
262 pub ttl_class: TtlClass,
264 pub last_observed_at_unix_secs: u64,
266}
267
268#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
270#[allow(dead_code)]
271pub struct CacheScore {
272 pub predicted_cache_read_tokens: u32,
274 pub predicted_cache_creation_tokens_5m: u32,
276 pub predicted_cache_creation_tokens_1h: u32,
278 pub predicted_uncached_input_tokens: u32,
280 pub predicted_expires_at_unix_secs: Option<u64>,
282 pub matched_breakpoint_index: Option<u32>,
284 pub confidence: f32,
286 pub ambiguity_reason: Option<String>,
288}
289
290#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
296pub struct UpstreamCandidate {
297 pub upstream_id: Uuid,
299 pub name: String,
301 pub kind: UpstreamKind,
303 pub observed_rate_limits: Vec<RateLimitObservation>,
305 #[serde(default)]
307 pub subscription_quotas: Vec<SubscriptionQuotaCandidateSnapshot>,
308 pub observed_at_unix_secs: u64,
310 #[serde(default, skip_serializing_if = "Option::is_none")]
312 pub cache_score: Option<CacheScore>,
313 #[serde(default, skip_serializing_if = "Option::is_none")]
317 pub base_url: Option<String>,
318}
319
320#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
322#[serde(rename_all = "snake_case")]
323pub enum CredentialStrategy {
324 ApiKey,
326 OAuth,
328 InternalForwarded,
330}
331
332#[derive(Clone, Debug)]
334pub struct RequestContext {
335 pub request_id: String,
337 pub downstream_headers: HeaderMap,
339 pub method: Method,
341 pub path: String,
343 pub query: Option<String>,
345 pub body_bytes: Bytes,
347 pub cache_breakpoints: Vec<CacheBreakpoint>,
349 pub canonical_model_id: String,
351}
352
353#[derive(Clone, Debug)]
355pub struct ShapedRequest {
356 url: Url,
357 method: Method,
358 headers: HeaderMap,
359 body: Bytes,
360 _seal: crate::private::Seal,
361}
362
363impl ShapedRequest {
364 pub fn url(&self) -> &Url {
366 &self.url
367 }
368
369 pub fn set_url(&mut self, url: Url) {
371 self.url = url;
372 }
373
374 pub fn method(&self) -> &Method {
376 &self.method
377 }
378
379 pub fn set_method(&mut self, method: Method) {
381 self.method = method;
382 }
383
384 pub fn headers(&self) -> &HeaderMap {
386 &self.headers
387 }
388
389 pub fn headers_mut(&mut self) -> &mut HeaderMap {
391 &mut self.headers
392 }
393
394 pub fn body(&self) -> &Bytes {
396 &self.body
397 }
398
399 pub fn set_body(&mut self, body: Bytes) {
401 self.body = body;
402 }
403}
404
405#[derive(Debug)]
412pub struct ShapedRequestBuilder {
413 _seal: crate::private::Seal,
414}
415
416impl ShapedRequestBuilder {
417 pub fn shaped_request(
419 &mut self,
420 url: Url,
421 method: Method,
422 headers: HeaderMap,
423 body: Bytes,
424 ) -> ShapedRequest {
425 ShapedRequest {
426 url,
427 method,
428 headers,
429 body,
430 _seal: crate::private::Seal,
431 }
432 }
433}
434
435pub fn shape_request(
437 dialect: &dyn UpstreamDialect,
438 ctx: &RequestContext,
439 upstream: &Upstream,
440 principal: &Principal,
441) -> Result<ShapedRequest, DialectError> {
442 let mut builder = ShapedRequestBuilder {
443 _seal: crate::private::Seal,
444 };
445 dialect.shape(ctx, upstream, principal, &mut builder)
446}
447
448#[derive(Clone, Debug)]
450pub struct SignedRequest {
451 url: Url,
452 method: Method,
453 headers: HeaderMap,
454 body: Bytes,
455 _seal: crate::private::Seal,
456}
457
458impl SignedRequest {
459 pub fn from_shaped(shaped: ShapedRequest, _capability: &mut SigningCapability) -> Self {
461 Self {
462 url: shaped.url,
463 method: shaped.method,
464 headers: shaped.headers,
465 body: shaped.body,
466 _seal: crate::private::Seal,
467 }
468 }
469
470 pub fn url(&self) -> &Url {
472 &self.url
473 }
474
475 pub fn method(&self) -> &Method {
477 &self.method
478 }
479
480 pub fn headers(&self) -> &HeaderMap {
482 &self.headers
483 }
484
485 pub fn body(&self) -> &Bytes {
487 &self.body
488 }
489
490 pub fn into_parts(self) -> (Url, Method, HeaderMap, Bytes) {
492 (self.url, self.method, self.headers, self.body)
493 }
494}
495
496#[derive(Debug)]
503pub struct SigningCapability {
504 _seal: crate::private::Seal,
505}
506
507pub async fn sign_request(
509 signer: &dyn Signer,
510 shaped: ShapedRequest,
511) -> Result<SignedRequest, SignerError> {
512 let mut capability = SigningCapability {
513 _seal: crate::private::Seal,
514 };
515 signer.sign(shaped, &mut capability).await
516}
517
518pub struct RouteDecision {
520 pub upstream_id: Option<Uuid>,
522 pub upstream: Upstream,
524 pub dialect: Arc<dyn UpstreamDialect>,
526}
527
528#[derive(Clone, Debug, Eq, PartialEq)]
530pub struct PrincipalQuotas {
531 pub requests_per_window: u64,
533 pub input_tokens_per_window: u64,
535 pub output_tokens_per_window: u64,
537 pub window: Duration,
539 pub allowed_models: Vec<String>,
541}
542
543#[derive(Clone, Debug, Eq, PartialEq)]
545pub enum ObserveEvent {
546 RequestStarted {
548 request_id: String,
550 downstream_user_agent: Option<String>,
552 },
553 AuthnComplete {
555 principal_id: String,
557 kind: PrincipalKind,
559 },
560 UpstreamChosen {
562 upstream: Upstream,
564 },
565 Chunk {
567 batch_index: u64,
569 event_count: usize,
571 total_bytes: usize,
573 },
574 RequestFinished {
576 status: StatusCode,
578 input_tokens: Option<u64>,
580 output_tokens: Option<u64>,
582 cache_creation_input_tokens: Option<u64>,
584 cache_read_input_tokens: Option<u64>,
586 duration_ms: u64,
588 },
589 Error {
591 code: String,
593 message: String,
595 source: String,
597 },
598}
599
600#[derive(Clone)]
602pub enum RetryDecision {
603 Refresh {
605 new_signer: Arc<dyn Signer>,
607 },
608 Fail,
610}
611
612#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
614pub struct PluginManifest {
615 pub name: String,
617 pub artifact: String,
619 #[serde(default, skip_serializing_if = "Option::is_none")]
621 pub wire_version: Option<u8>,
622 pub config: serde_json::Value,
624 #[serde(default)]
626 pub metadata: BTreeMap<String, serde_json::Value>,
627}
628
629pub const MAX_ROUTING_TRACE_STAGES: usize = 100;
632pub const MAX_STAGE_NAME_LEN: usize = 256;
634pub const MAX_ERROR_MESSAGE_LEN: usize = 1024;
636
637#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
639#[serde(rename_all = "snake_case")]
640pub enum PassthroughCause {
641 HealthyUpstream,
643 NoAlternative,
645 PluginDecision,
647}
648
649#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
651#[serde(rename_all = "snake_case")]
652pub enum PerCandidateReason {
653 RateLimited,
655 InsufficientQuota,
657 Unhealthy,
659 RejectedByPlugin,
661}
662
663#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
665#[serde(rename_all = "kebab-case")]
666pub enum TerminalStrategy {
667 #[default]
669 FirstPick,
670 Random,
672 RoundRobin,
674 LeastConnections,
676}
677
678#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
680pub struct StageDecision {
681 pub stage_name: String,
683 #[serde(default, skip_serializing_if = "Option::is_none")]
685 pub upstream_id: Option<Uuid>,
686 #[serde(default, skip_serializing_if = "Option::is_none")]
688 pub reason: Option<String>,
689 #[serde(default, skip_serializing_if = "is_zero_u64")]
691 pub duration_us: u64,
692}
693
694fn is_zero_u64(value: &u64) -> bool {
695 *value == 0
696}
697
698#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
700pub struct TerminalDecision {
701 #[serde(default, skip_serializing_if = "Option::is_none")]
703 pub upstream_id: Option<Uuid>,
704 pub strategy: TerminalStrategy,
706}
707
708#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
710pub struct RoutingTrace {
711 #[serde(default)]
713 pub stages: Vec<StageDecision>,
714 #[serde(default, skip_serializing_if = "Option::is_none")]
716 pub terminal_decision: Option<TerminalDecision>,
717}
718
719#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
721#[serde(rename_all = "snake_case")]
722pub enum InternalErrorStage {
723 Authn,
725 #[default]
727 Router,
728 RouterFilter,
730 Shape,
732 Signer,
734 Relay,
736}
737
738#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
740#[serde(rename_all = "snake_case")]
741pub enum InternalErrorKind {
742 #[default]
744 PluginError,
745 InvalidOutput,
747 Trap,
749 ConfigError,
751 Timeout,
753 Unavailable,
755}
756
757#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
759pub struct InternalError {
760 pub stage: InternalErrorStage,
762 pub kind: InternalErrorKind,
764 #[serde(default, skip_serializing_if = "Option::is_none")]
766 pub message: Option<String>,
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772
773 #[test]
774 fn shaped_request_accessors_and_mutators_preserve_parts() {
775 let mut builder = ShapedRequestBuilder {
776 _seal: crate::private::Seal,
777 };
778 let mut headers = HeaderMap::new();
779 headers.insert("x-test", "one".parse().unwrap());
780 let mut shaped = builder.shaped_request(
781 "https://example.test/v1/messages".parse().unwrap(),
782 Method::POST,
783 headers,
784 Bytes::from_static(b"first"),
785 );
786
787 assert_eq!(shaped.url().as_str(), "https://example.test/v1/messages");
788 assert_eq!(shaped.method(), Method::POST);
789 assert_eq!(shaped.headers()["x-test"], "one");
790 assert_eq!(shaped.body(), &Bytes::from_static(b"first"));
791
792 shaped.set_url("https://example.test/v1/complete".parse().unwrap());
793 shaped.set_method(Method::PUT);
794 shaped
795 .headers_mut()
796 .insert("x-test", "two".parse().unwrap());
797 shaped.set_body(Bytes::from_static(b"second"));
798
799 assert_eq!(shaped.url().path(), "/v1/complete");
800 assert_eq!(shaped.method(), Method::PUT);
801 assert_eq!(shaped.headers()["x-test"], "two");
802 assert_eq!(shaped.body(), &Bytes::from_static(b"second"));
803 }
804
805 #[test]
806 fn signed_request_exposes_and_consumes_signed_parts() {
807 let mut builder = ShapedRequestBuilder {
808 _seal: crate::private::Seal,
809 };
810 let mut headers = HeaderMap::new();
811 headers.insert("authorization", "Bearer token".parse().unwrap());
812 let shaped = builder.shaped_request(
813 "https://api.example.test/v1/messages".parse().unwrap(),
814 Method::POST,
815 headers,
816 Bytes::from_static(b"{}"),
817 );
818 let mut capability = SigningCapability {
819 _seal: crate::private::Seal,
820 };
821
822 let signed = SignedRequest::from_shaped(shaped, &mut capability);
823 assert_eq!(signed.url().host_str(), Some("api.example.test"));
824 assert_eq!(signed.method(), Method::POST);
825 assert_eq!(signed.headers()["authorization"], "Bearer token");
826 assert_eq!(signed.body(), &Bytes::from_static(b"{}"));
827
828 let (url, method, headers, body) = signed.into_parts();
829 assert_eq!(url.as_str(), "https://api.example.test/v1/messages");
830 assert_eq!(method, Method::POST);
831 assert_eq!(headers["authorization"], "Bearer token");
832 assert_eq!(body, Bytes::from_static(b"{}"));
833 }
834
835 #[test]
836 fn upstream_and_manifest_serde_round_trip() {
837 let upstreams = vec![Upstream::AnthropicDirect { base_url: None }];
838
839 for upstream in upstreams {
840 let json = serde_json::to_string(&upstream).unwrap();
841 let decoded: Upstream = serde_json::from_str(&json).unwrap();
842 assert_eq!(decoded, upstream);
843 }
844
845 let manifest: PluginManifest = serde_json::from_value(serde_json::json!({
846 "name": "authn",
847 "artifact": "plugin.wasm",
848 "config": {"enabled": true}
849 }))
850 .unwrap();
851 assert_eq!(manifest.name, "authn");
852 assert_eq!(manifest.wire_version, None);
853 assert!(manifest.metadata.is_empty());
854
855 let manifest: PluginManifest = serde_json::from_value(serde_json::json!({
856 "name": "cache-aware",
857 "artifact": "plugin.wasm",
858 "wire_version": 2,
859 "config": {}
860 }))
861 .unwrap();
862 assert_eq!(manifest.wire_version, Some(2));
863 }
864
865 #[test]
866 fn public_enums_cover_all_current_variants() {
867 let principal_kinds = [
868 PrincipalKind::ApiKey,
869 PrincipalKind::OAuthSubject,
870 PrincipalKind::InternalKey,
871 PrincipalKind::WorkloadIdentity,
872 PrincipalKind::SubscriptionBearer,
873 ];
874 assert_eq!(principal_kinds.len(), 5);
875
876 let strategies = [
877 CredentialStrategy::ApiKey,
878 CredentialStrategy::OAuth,
879 CredentialStrategy::InternalForwarded,
880 ];
881 assert_eq!(strategies.len(), 3);
882 }
883
884 #[test]
885 fn observe_event_variants_are_equatable() {
886 let events = vec![
887 ObserveEvent::RequestStarted {
888 request_id: "req".to_owned(),
889 downstream_user_agent: Some("ua".to_owned()),
890 },
891 ObserveEvent::AuthnComplete {
892 principal_id: "principal".to_owned(),
893 kind: PrincipalKind::InternalKey,
894 },
895 ObserveEvent::UpstreamChosen {
896 upstream: Upstream::AnthropicDirect { base_url: None },
897 },
898 ObserveEvent::Chunk {
899 batch_index: 1,
900 event_count: 2,
901 total_bytes: 3,
902 },
903 ObserveEvent::RequestFinished {
904 status: StatusCode::OK,
905 input_tokens: Some(4),
906 output_tokens: Some(5),
907 cache_creation_input_tokens: Some(6),
908 cache_read_input_tokens: Some(7),
909 duration_ms: 8,
910 },
911 ObserveEvent::Error {
912 code: "E".to_owned(),
913 message: "redacted".to_owned(),
914 source: "plugin".to_owned(),
915 },
916 ];
917
918 assert_eq!(events, events.clone());
919 }
920
921 #[test]
922 fn cache_types_roundtrip() {
923 let ttl_class = TtlClass::Ephemeral1h;
924 let json = serde_json::to_string(&ttl_class).unwrap();
925 let decoded: TtlClass = serde_json::from_str(&json).unwrap();
926 assert_eq!(decoded, ttl_class);
927
928 let origin = BreakpointOrigin::AutoCacheInferred;
929 let json = serde_json::to_string(&origin).unwrap();
930 let decoded: BreakpointOrigin = serde_json::from_str(&json).unwrap();
931 assert_eq!(decoded, origin);
932
933 let source = CacheBreakpointSource::System;
934 let json = serde_json::to_string(&source).unwrap();
935 let decoded: CacheBreakpointSource = serde_json::from_str(&json).unwrap();
936 assert_eq!(decoded, source);
937
938 let breakpoint = CacheBreakpoint {
939 block_index: 0,
940 source: CacheBreakpointSource::Message,
941 path: "messages.0.content.1".to_owned(),
942 message_index: Some(0),
943 prefix_hash: "abc123".to_owned(),
944 prefix_token_count: 100,
945 requested_ttl: TtlClass::Ephemeral5m,
946 origin: BreakpointOrigin::Explicit,
947 };
948 let json = serde_json::to_string(&breakpoint).unwrap();
949 let decoded: CacheBreakpoint = serde_json::from_str(&json).unwrap();
950 assert_eq!(decoded, breakpoint);
951
952 let warm_entry = WarmCacheEntry {
953 prefix_hash: "def456".to_owned(),
954 expires_at_unix_secs: 1700000000,
955 ttl_class: TtlClass::Ephemeral1h,
956 last_observed_at_unix_secs: 1699999000,
957 };
958 let json = serde_json::to_string(&warm_entry).unwrap();
959 let decoded: WarmCacheEntry = serde_json::from_str(&json).unwrap();
960 assert_eq!(decoded, warm_entry);
961
962 let cache_score = CacheScore {
963 predicted_cache_read_tokens: 50,
964 predicted_cache_creation_tokens_5m: 100,
965 predicted_cache_creation_tokens_1h: 200,
966 predicted_uncached_input_tokens: 25,
967 predicted_expires_at_unix_secs: Some(1700000000),
968 matched_breakpoint_index: Some(0),
969 confidence: 0.95,
970 ambiguity_reason: None,
971 };
972 let json = serde_json::to_string(&cache_score).unwrap();
973 let decoded: CacheScore = serde_json::from_str(&json).unwrap();
974 assert_eq!(
975 decoded.predicted_cache_read_tokens,
976 cache_score.predicted_cache_read_tokens
977 );
978 assert_eq!(
979 decoded.predicted_cache_creation_tokens_5m,
980 cache_score.predicted_cache_creation_tokens_5m
981 );
982 assert_eq!(
983 decoded.predicted_cache_creation_tokens_1h,
984 cache_score.predicted_cache_creation_tokens_1h
985 );
986 assert_eq!(
987 decoded.predicted_uncached_input_tokens,
988 cache_score.predicted_uncached_input_tokens
989 );
990 assert_eq!(
991 decoded.predicted_expires_at_unix_secs,
992 cache_score.predicted_expires_at_unix_secs
993 );
994 assert_eq!(
995 decoded.matched_breakpoint_index,
996 cache_score.matched_breakpoint_index
997 );
998 assert!((decoded.confidence - cache_score.confidence).abs() < 0.0001);
999 assert_eq!(decoded.ambiguity_reason, cache_score.ambiguity_reason);
1000 }
1001
1002 #[test]
1003 fn upstream_candidate_cache_score_roundtrip() {
1004 let candidate_no_cache = UpstreamCandidate {
1005 upstream_id: Uuid::new_v4(),
1006 name: "test-upstream".to_owned(),
1007 kind: UpstreamKind::AnthropicApiKey,
1008 observed_rate_limits: vec![],
1009 subscription_quotas: vec![],
1010 observed_at_unix_secs: 1700000000,
1011 cache_score: None,
1012 base_url: None,
1013 };
1014 let json = serde_json::to_string(&candidate_no_cache).unwrap();
1015 let decoded: UpstreamCandidate = serde_json::from_str(&json).unwrap();
1016 assert_eq!(decoded.upstream_id, candidate_no_cache.upstream_id);
1017 assert_eq!(decoded.name, candidate_no_cache.name);
1018 assert_eq!(decoded.cache_score, None);
1019
1020 let cache_score = CacheScore {
1021 predicted_cache_read_tokens: 50,
1022 predicted_cache_creation_tokens_5m: 100,
1023 predicted_cache_creation_tokens_1h: 200,
1024 predicted_uncached_input_tokens: 25,
1025 predicted_expires_at_unix_secs: Some(1700000000),
1026 matched_breakpoint_index: Some(0),
1027 confidence: 0.95,
1028 ambiguity_reason: None,
1029 };
1030 let candidate_with_cache = UpstreamCandidate {
1031 upstream_id: Uuid::new_v4(),
1032 name: "test-upstream-cached".to_owned(),
1033 kind: UpstreamKind::AnthropicApiKey,
1034 observed_rate_limits: vec![],
1035 subscription_quotas: vec![],
1036 observed_at_unix_secs: 1700000000,
1037 cache_score: Some(cache_score),
1038 base_url: None,
1039 };
1040 let json = serde_json::to_string(&candidate_with_cache).unwrap();
1041 let decoded: UpstreamCandidate = serde_json::from_str(&json).unwrap();
1042 assert_eq!(decoded.upstream_id, candidate_with_cache.upstream_id);
1043 assert_eq!(decoded.name, candidate_with_cache.name);
1044 assert!(decoded.cache_score.is_some());
1045 assert_eq!(decoded.cache_score.unwrap().predicted_cache_read_tokens, 50);
1046 }
1047
1048 #[test]
1049 fn request_context_cache_fields_roundtrip() {
1050 let ctx_empty = RequestContext {
1051 request_id: "req-1".to_owned(),
1052 downstream_headers: HeaderMap::new(),
1053 method: Method::POST,
1054 path: "/v1/messages".to_owned(),
1055 query: None,
1056 body_bytes: Bytes::new(),
1057 cache_breakpoints: Vec::new(),
1058 canonical_model_id: String::new(),
1059 };
1060 assert_eq!(ctx_empty.cache_breakpoints.len(), 0);
1061 assert_eq!(ctx_empty.canonical_model_id, "");
1062
1063 let breakpoint = CacheBreakpoint {
1064 block_index: 1,
1065 source: CacheBreakpointSource::Message,
1066 path: "messages.0.content.0".to_owned(),
1067 message_index: Some(0),
1068 prefix_hash: "hash123".to_owned(),
1069 prefix_token_count: 150,
1070 requested_ttl: TtlClass::Ephemeral1h,
1071 origin: BreakpointOrigin::Explicit,
1072 };
1073 let ctx_populated = RequestContext {
1074 request_id: "req-2".to_owned(),
1075 downstream_headers: HeaderMap::new(),
1076 method: Method::POST,
1077 path: "/v1/messages".to_owned(),
1078 query: Some("param=value".to_owned()),
1079 body_bytes: Bytes::from_static(b"test"),
1080 cache_breakpoints: vec![breakpoint],
1081 canonical_model_id: "claude-sonnet-4-5-20250929".to_owned(),
1082 };
1083 assert_eq!(ctx_populated.cache_breakpoints.len(), 1);
1084 assert_eq!(
1085 ctx_populated.canonical_model_id,
1086 "claude-sonnet-4-5-20250929"
1087 );
1088 }
1089
1090 #[test]
1091 fn upstream_candidate_deserialize_without_cache_score_field() {
1092 let json = r#"{
1093 "upstream_id": "00000000-0000-0000-0000-000000000001",
1094 "name": "legacy-upstream",
1095 "kind": "anthropic_api_key",
1096 "observed_rate_limits": [],
1097 "subscription_quotas": [],
1098 "observed_at_unix_secs": 1700000000
1099 }"#;
1100 let candidate: UpstreamCandidate = serde_json::from_str(json).unwrap();
1101 assert!(candidate.cache_score.is_none());
1102 assert_eq!(candidate.name, "legacy-upstream");
1103 }
1104
1105 #[test]
1106 fn request_context_cache_breakpoints_default_on_missing_fields() {
1107 let ctx = RequestContext {
1108 request_id: "test".to_owned(),
1109 downstream_headers: HeaderMap::new(),
1110 method: Method::GET,
1111 path: "/test".to_owned(),
1112 query: None,
1113 body_bytes: Bytes::new(),
1114 cache_breakpoints: Vec::new(),
1115 canonical_model_id: String::new(),
1116 };
1117 assert!(ctx.cache_breakpoints.is_empty());
1118 assert!(ctx.canonical_model_id.is_empty());
1119 }
1120}