1use std::collections::BTreeMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use bytes::Bytes;
8use http::{HeaderMap, Method, StatusCode};
9use serde::{Deserialize, Serialize};
10use url::Url;
11use uuid::Uuid;
12
13use crate::errors::{DialectError, SignerError};
14use crate::traits::{Signer, UpstreamDialect};
15
16#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
18pub struct Principal {
19 pub id: String,
21 pub kind: PrincipalKind,
23 pub claims: serde_json::Map<String, serde_json::Value>,
25}
26
27#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
29#[serde(rename_all = "snake_case")]
30pub enum PrincipalKind {
31 ApiKey,
33 OAuthSubject,
35 InternalKey,
37 WorkloadIdentity,
39 SubscriptionBearer,
41}
42
43#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
45#[serde(rename_all = "snake_case", tag = "kind")]
46pub enum Upstream {
47 AnthropicDirect {
49 #[serde(default, skip_serializing_if = "Option::is_none")]
53 base_url: Option<Url>,
54 },
55}
56
57#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum UpstreamKind {
61 AnthropicApiKey,
63 AnthropicOauth,
65}
66
67impl UpstreamKind {
68 pub fn as_str(self) -> &'static str {
70 match self {
71 Self::AnthropicApiKey => "anthropic_api_key",
72 Self::AnthropicOauth => "anthropic_oauth",
73 }
74 }
75}
76
77#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
79#[serde(rename_all = "snake_case")]
80pub enum RateLimitKind {
81 Requests,
83 Tokens,
85 InputTokens,
87 OutputTokens,
89}
90
91impl RateLimitKind {
92 pub fn as_str(self) -> &'static str {
94 match self {
95 Self::Requests => "requests",
96 Self::Tokens => "tokens",
97 Self::InputTokens => "input_tokens",
98 Self::OutputTokens => "output_tokens",
99 }
100 }
101}
102
103#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
105pub struct RateLimitObservation {
106 pub kind: RateLimitKind,
108 pub window: String,
110 pub limit: Option<u64>,
112 pub remaining: Option<u64>,
114 pub reset: Option<String>,
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
120#[serde(rename_all = "snake_case")]
121pub enum SubscriptionQuotaDataState {
122 Fresh,
124 Stale,
126 Missing,
128}
129
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
132pub struct SubscriptionQuotaCandidateSnapshot {
133 pub window: String,
135 pub state: SubscriptionQuotaDataState,
137 pub source: Option<String>,
139 pub utilization: Option<f64>,
141 pub status: Option<String>,
143 pub resets_at_unix_secs: Option<u64>,
145 pub surpassed_threshold: Option<bool>,
147 pub representative_claim: Option<String>,
149 pub disabled_reason: Option<String>,
151 pub extra_usage_enabled: Option<bool>,
153 pub extra_usage_monthly_limit: Option<f64>,
155 pub extra_usage_used_credits: Option<f64>,
157 pub observed_at_unix_millis: Option<u64>,
159 pub max_staleness_secs: u64,
161}
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
165#[serde(rename_all = "snake_case")]
166#[allow(dead_code)]
167pub enum TtlClass {
168 #[default]
170 Ephemeral5m,
171 Ephemeral1h,
173}
174
175#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
177#[serde(rename_all = "snake_case")]
178#[allow(dead_code)]
179pub enum BreakpointOrigin {
180 Explicit,
182 AutoCacheInferred,
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
188#[serde(rename_all = "snake_case")]
189#[allow(dead_code)]
190pub enum CacheBreakpointSource {
191 Tools,
193 System,
195 Message,
197}
198
199#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
201#[allow(dead_code)]
202pub struct CacheBreakpoint {
203 pub block_index: u32,
205 pub source: CacheBreakpointSource,
207 pub path: String,
209 pub message_index: Option<u32>,
211 pub prefix_hash: String,
213 pub prefix_token_count: u64,
215 pub requested_ttl: TtlClass,
217 pub origin: BreakpointOrigin,
219}
220
221#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
223#[allow(dead_code)]
224pub struct WarmCacheEntry {
225 pub prefix_hash: String,
227 pub expires_at_unix_secs: u64,
229 pub ttl_class: TtlClass,
231 pub last_observed_at_unix_secs: u64,
233}
234
235#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
237#[allow(dead_code)]
238pub struct CacheScore {
239 pub predicted_cache_read_tokens: u32,
241 pub predicted_cache_creation_tokens_5m: u32,
243 pub predicted_cache_creation_tokens_1h: u32,
245 pub predicted_uncached_input_tokens: u32,
247 pub predicted_expires_at_unix_secs: Option<u64>,
249 pub matched_breakpoint_index: Option<u32>,
251 pub confidence: f32,
253 pub ambiguity_reason: Option<String>,
255}
256
257#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
263pub struct UpstreamCandidate {
264 pub upstream_id: Uuid,
266 pub name: String,
268 pub kind: UpstreamKind,
270 pub observed_rate_limits: Vec<RateLimitObservation>,
272 #[serde(default)]
274 pub subscription_quotas: Vec<SubscriptionQuotaCandidateSnapshot>,
275 pub observed_at_unix_secs: u64,
277 #[serde(default, skip_serializing_if = "Option::is_none")]
279 pub cache_score: Option<CacheScore>,
280 #[serde(default, skip_serializing_if = "Option::is_none")]
284 pub base_url: Option<String>,
285}
286
287#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
289#[serde(rename_all = "snake_case")]
290pub enum CredentialStrategy {
291 ApiKey,
293 OAuth,
295 InternalForwarded,
297}
298
299#[derive(Clone, Debug)]
301pub struct RequestContext {
302 pub request_id: String,
304 pub downstream_headers: HeaderMap,
306 pub method: Method,
308 pub path: String,
310 pub query: Option<String>,
312 pub body_bytes: Bytes,
314 pub cache_breakpoints: Vec<CacheBreakpoint>,
316 pub canonical_model_id: String,
318}
319
320#[derive(Clone, Debug)]
322pub struct ShapedRequest {
323 url: Url,
324 method: Method,
325 headers: HeaderMap,
326 body: Bytes,
327 _seal: crate::private::Seal,
328}
329
330impl ShapedRequest {
331 pub fn url(&self) -> &Url {
333 &self.url
334 }
335
336 pub fn set_url(&mut self, url: Url) {
338 self.url = url;
339 }
340
341 pub fn method(&self) -> &Method {
343 &self.method
344 }
345
346 pub fn set_method(&mut self, method: Method) {
348 self.method = method;
349 }
350
351 pub fn headers(&self) -> &HeaderMap {
353 &self.headers
354 }
355
356 pub fn headers_mut(&mut self) -> &mut HeaderMap {
358 &mut self.headers
359 }
360
361 pub fn body(&self) -> &Bytes {
363 &self.body
364 }
365
366 pub fn set_body(&mut self, body: Bytes) {
368 self.body = body;
369 }
370}
371
372#[derive(Debug)]
379pub struct ShapedRequestBuilder {
380 _seal: crate::private::Seal,
381}
382
383impl ShapedRequestBuilder {
384 pub fn shaped_request(
386 &mut self,
387 url: Url,
388 method: Method,
389 headers: HeaderMap,
390 body: Bytes,
391 ) -> ShapedRequest {
392 ShapedRequest {
393 url,
394 method,
395 headers,
396 body,
397 _seal: crate::private::Seal,
398 }
399 }
400}
401
402pub fn shape_request(
404 dialect: &dyn UpstreamDialect,
405 ctx: &RequestContext,
406 upstream: &Upstream,
407 principal: &Principal,
408) -> Result<ShapedRequest, DialectError> {
409 let mut builder = ShapedRequestBuilder {
410 _seal: crate::private::Seal,
411 };
412 dialect.shape(ctx, upstream, principal, &mut builder)
413}
414
415#[derive(Clone, Debug)]
417pub struct SignedRequest {
418 url: Url,
419 method: Method,
420 headers: HeaderMap,
421 body: Bytes,
422 _seal: crate::private::Seal,
423}
424
425impl SignedRequest {
426 pub fn from_shaped(shaped: ShapedRequest, _capability: &mut SigningCapability) -> Self {
428 Self {
429 url: shaped.url,
430 method: shaped.method,
431 headers: shaped.headers,
432 body: shaped.body,
433 _seal: crate::private::Seal,
434 }
435 }
436
437 pub fn url(&self) -> &Url {
439 &self.url
440 }
441
442 pub fn method(&self) -> &Method {
444 &self.method
445 }
446
447 pub fn headers(&self) -> &HeaderMap {
449 &self.headers
450 }
451
452 pub fn body(&self) -> &Bytes {
454 &self.body
455 }
456
457 pub fn into_parts(self) -> (Url, Method, HeaderMap, Bytes) {
459 (self.url, self.method, self.headers, self.body)
460 }
461}
462
463#[derive(Debug)]
470pub struct SigningCapability {
471 _seal: crate::private::Seal,
472}
473
474pub async fn sign_request(
476 signer: &dyn Signer,
477 shaped: ShapedRequest,
478) -> Result<SignedRequest, SignerError> {
479 let mut capability = SigningCapability {
480 _seal: crate::private::Seal,
481 };
482 signer.sign(shaped, &mut capability).await
483}
484
485pub struct RouteDecision {
487 pub upstream_id: Option<Uuid>,
489 pub upstream: Upstream,
491 pub dialect: Arc<dyn UpstreamDialect>,
493}
494
495#[derive(Clone, Debug, Eq, PartialEq)]
497pub struct PrincipalQuotas {
498 pub requests_per_window: u64,
500 pub input_tokens_per_window: u64,
502 pub output_tokens_per_window: u64,
504 pub window: Duration,
506 pub allowed_models: Vec<String>,
508}
509
510#[derive(Clone, Debug, Eq, PartialEq)]
512pub enum ObserveEvent {
513 RequestStarted {
515 request_id: String,
517 downstream_user_agent: Option<String>,
519 },
520 AuthnComplete {
522 principal_id: String,
524 kind: PrincipalKind,
526 },
527 UpstreamChosen {
529 upstream: Upstream,
531 },
532 Chunk {
534 batch_index: u64,
536 event_count: usize,
538 total_bytes: usize,
540 },
541 RequestFinished {
543 status: StatusCode,
545 input_tokens: Option<u64>,
547 output_tokens: Option<u64>,
549 cache_creation_input_tokens: Option<u64>,
551 cache_read_input_tokens: Option<u64>,
553 duration_ms: u64,
555 },
556 Error {
558 code: String,
560 message: String,
562 source: String,
564 },
565}
566
567#[derive(Clone)]
569pub enum RetryDecision {
570 Refresh {
572 new_signer: Arc<dyn Signer>,
574 },
575 Fail,
577}
578
579#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
581pub struct PluginManifest {
582 pub name: String,
584 pub artifact: String,
586 #[serde(default, skip_serializing_if = "Option::is_none")]
588 pub wire_version: Option<u8>,
589 pub config: serde_json::Value,
591 #[serde(default)]
593 pub metadata: BTreeMap<String, serde_json::Value>,
594}
595
596pub const MAX_ROUTING_TRACE_STAGES: usize = 100;
599pub const MAX_STAGE_NAME_LEN: usize = 256;
601pub const MAX_ERROR_MESSAGE_LEN: usize = 1024;
603
604#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
606#[serde(rename_all = "snake_case")]
607pub enum PassthroughCause {
608 HealthyUpstream,
610 NoAlternative,
612 PluginDecision,
614}
615
616#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
618#[serde(rename_all = "snake_case")]
619pub enum PerCandidateReason {
620 RateLimited,
622 InsufficientQuota,
624 Unhealthy,
626 RejectedByPlugin,
628}
629
630#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
632#[serde(rename_all = "kebab-case")]
633pub enum TerminalStrategy {
634 #[default]
636 FirstPick,
637 Random,
639 RoundRobin,
641 LeastConnections,
643}
644
645#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
647pub struct StageDecision {
648 pub stage_name: String,
650 #[serde(default, skip_serializing_if = "Option::is_none")]
652 pub upstream_id: Option<Uuid>,
653 #[serde(default, skip_serializing_if = "Option::is_none")]
655 pub reason: Option<String>,
656 #[serde(default, skip_serializing_if = "is_zero_u64")]
658 pub duration_us: u64,
659}
660
661fn is_zero_u64(value: &u64) -> bool {
662 *value == 0
663}
664
665#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
667pub struct TerminalDecision {
668 #[serde(default, skip_serializing_if = "Option::is_none")]
670 pub upstream_id: Option<Uuid>,
671 pub strategy: TerminalStrategy,
673}
674
675#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
677pub struct RoutingTrace {
678 #[serde(default)]
680 pub stages: Vec<StageDecision>,
681 #[serde(default, skip_serializing_if = "Option::is_none")]
683 pub terminal_decision: Option<TerminalDecision>,
684}
685
686#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
688#[serde(rename_all = "snake_case")]
689pub enum InternalErrorStage {
690 Authn,
692 #[default]
694 Router,
695 RouterFilter,
697 Shape,
699 Signer,
701 Relay,
703}
704
705#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
707#[serde(rename_all = "snake_case")]
708pub enum InternalErrorKind {
709 #[default]
711 PluginError,
712 InvalidOutput,
714 Trap,
716 ConfigError,
718 Timeout,
720 Unavailable,
722}
723
724#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
726pub struct InternalError {
727 pub stage: InternalErrorStage,
729 pub kind: InternalErrorKind,
731 #[serde(default, skip_serializing_if = "Option::is_none")]
733 pub message: Option<String>,
734}
735
736#[cfg(test)]
737mod tests {
738 use super::*;
739
740 #[test]
741 fn shaped_request_accessors_and_mutators_preserve_parts() {
742 let mut builder = ShapedRequestBuilder {
743 _seal: crate::private::Seal,
744 };
745 let mut headers = HeaderMap::new();
746 headers.insert("x-test", "one".parse().unwrap());
747 let mut shaped = builder.shaped_request(
748 "https://example.test/v1/messages".parse().unwrap(),
749 Method::POST,
750 headers,
751 Bytes::from_static(b"first"),
752 );
753
754 assert_eq!(shaped.url().as_str(), "https://example.test/v1/messages");
755 assert_eq!(shaped.method(), Method::POST);
756 assert_eq!(shaped.headers()["x-test"], "one");
757 assert_eq!(shaped.body(), &Bytes::from_static(b"first"));
758
759 shaped.set_url("https://example.test/v1/complete".parse().unwrap());
760 shaped.set_method(Method::PUT);
761 shaped
762 .headers_mut()
763 .insert("x-test", "two".parse().unwrap());
764 shaped.set_body(Bytes::from_static(b"second"));
765
766 assert_eq!(shaped.url().path(), "/v1/complete");
767 assert_eq!(shaped.method(), Method::PUT);
768 assert_eq!(shaped.headers()["x-test"], "two");
769 assert_eq!(shaped.body(), &Bytes::from_static(b"second"));
770 }
771
772 #[test]
773 fn signed_request_exposes_and_consumes_signed_parts() {
774 let mut builder = ShapedRequestBuilder {
775 _seal: crate::private::Seal,
776 };
777 let mut headers = HeaderMap::new();
778 headers.insert("authorization", "Bearer token".parse().unwrap());
779 let shaped = builder.shaped_request(
780 "https://api.example.test/v1/messages".parse().unwrap(),
781 Method::POST,
782 headers,
783 Bytes::from_static(b"{}"),
784 );
785 let mut capability = SigningCapability {
786 _seal: crate::private::Seal,
787 };
788
789 let signed = SignedRequest::from_shaped(shaped, &mut capability);
790 assert_eq!(signed.url().host_str(), Some("api.example.test"));
791 assert_eq!(signed.method(), Method::POST);
792 assert_eq!(signed.headers()["authorization"], "Bearer token");
793 assert_eq!(signed.body(), &Bytes::from_static(b"{}"));
794
795 let (url, method, headers, body) = signed.into_parts();
796 assert_eq!(url.as_str(), "https://api.example.test/v1/messages");
797 assert_eq!(method, Method::POST);
798 assert_eq!(headers["authorization"], "Bearer token");
799 assert_eq!(body, Bytes::from_static(b"{}"));
800 }
801
802 #[test]
803 fn upstream_and_manifest_serde_round_trip() {
804 let upstreams = vec![Upstream::AnthropicDirect { base_url: None }];
805
806 for upstream in upstreams {
807 let json = serde_json::to_string(&upstream).unwrap();
808 let decoded: Upstream = serde_json::from_str(&json).unwrap();
809 assert_eq!(decoded, upstream);
810 }
811
812 let manifest: PluginManifest = serde_json::from_value(serde_json::json!({
813 "name": "authn",
814 "artifact": "plugin.wasm",
815 "config": {"enabled": true}
816 }))
817 .unwrap();
818 assert_eq!(manifest.name, "authn");
819 assert_eq!(manifest.wire_version, None);
820 assert!(manifest.metadata.is_empty());
821
822 let manifest: PluginManifest = serde_json::from_value(serde_json::json!({
823 "name": "cache-aware",
824 "artifact": "plugin.wasm",
825 "wire_version": 2,
826 "config": {}
827 }))
828 .unwrap();
829 assert_eq!(manifest.wire_version, Some(2));
830 }
831
832 #[test]
833 fn public_enums_cover_all_current_variants() {
834 let principal_kinds = [
835 PrincipalKind::ApiKey,
836 PrincipalKind::OAuthSubject,
837 PrincipalKind::InternalKey,
838 PrincipalKind::WorkloadIdentity,
839 PrincipalKind::SubscriptionBearer,
840 ];
841 assert_eq!(principal_kinds.len(), 5);
842
843 let strategies = [
844 CredentialStrategy::ApiKey,
845 CredentialStrategy::OAuth,
846 CredentialStrategy::InternalForwarded,
847 ];
848 assert_eq!(strategies.len(), 3);
849 }
850
851 #[test]
852 fn observe_event_variants_are_equatable() {
853 let events = vec![
854 ObserveEvent::RequestStarted {
855 request_id: "req".to_owned(),
856 downstream_user_agent: Some("ua".to_owned()),
857 },
858 ObserveEvent::AuthnComplete {
859 principal_id: "principal".to_owned(),
860 kind: PrincipalKind::InternalKey,
861 },
862 ObserveEvent::UpstreamChosen {
863 upstream: Upstream::AnthropicDirect { base_url: None },
864 },
865 ObserveEvent::Chunk {
866 batch_index: 1,
867 event_count: 2,
868 total_bytes: 3,
869 },
870 ObserveEvent::RequestFinished {
871 status: StatusCode::OK,
872 input_tokens: Some(4),
873 output_tokens: Some(5),
874 cache_creation_input_tokens: Some(6),
875 cache_read_input_tokens: Some(7),
876 duration_ms: 8,
877 },
878 ObserveEvent::Error {
879 code: "E".to_owned(),
880 message: "redacted".to_owned(),
881 source: "plugin".to_owned(),
882 },
883 ];
884
885 assert_eq!(events, events.clone());
886 }
887
888 #[test]
889 fn cache_types_roundtrip() {
890 let ttl_class = TtlClass::Ephemeral1h;
891 let json = serde_json::to_string(&ttl_class).unwrap();
892 let decoded: TtlClass = serde_json::from_str(&json).unwrap();
893 assert_eq!(decoded, ttl_class);
894
895 let origin = BreakpointOrigin::AutoCacheInferred;
896 let json = serde_json::to_string(&origin).unwrap();
897 let decoded: BreakpointOrigin = serde_json::from_str(&json).unwrap();
898 assert_eq!(decoded, origin);
899
900 let source = CacheBreakpointSource::System;
901 let json = serde_json::to_string(&source).unwrap();
902 let decoded: CacheBreakpointSource = serde_json::from_str(&json).unwrap();
903 assert_eq!(decoded, source);
904
905 let breakpoint = CacheBreakpoint {
906 block_index: 0,
907 source: CacheBreakpointSource::Message,
908 path: "messages.0.content.1".to_owned(),
909 message_index: Some(0),
910 prefix_hash: "abc123".to_owned(),
911 prefix_token_count: 100,
912 requested_ttl: TtlClass::Ephemeral5m,
913 origin: BreakpointOrigin::Explicit,
914 };
915 let json = serde_json::to_string(&breakpoint).unwrap();
916 let decoded: CacheBreakpoint = serde_json::from_str(&json).unwrap();
917 assert_eq!(decoded, breakpoint);
918
919 let warm_entry = WarmCacheEntry {
920 prefix_hash: "def456".to_owned(),
921 expires_at_unix_secs: 1700000000,
922 ttl_class: TtlClass::Ephemeral1h,
923 last_observed_at_unix_secs: 1699999000,
924 };
925 let json = serde_json::to_string(&warm_entry).unwrap();
926 let decoded: WarmCacheEntry = serde_json::from_str(&json).unwrap();
927 assert_eq!(decoded, warm_entry);
928
929 let cache_score = CacheScore {
930 predicted_cache_read_tokens: 50,
931 predicted_cache_creation_tokens_5m: 100,
932 predicted_cache_creation_tokens_1h: 200,
933 predicted_uncached_input_tokens: 25,
934 predicted_expires_at_unix_secs: Some(1700000000),
935 matched_breakpoint_index: Some(0),
936 confidence: 0.95,
937 ambiguity_reason: None,
938 };
939 let json = serde_json::to_string(&cache_score).unwrap();
940 let decoded: CacheScore = serde_json::from_str(&json).unwrap();
941 assert_eq!(
942 decoded.predicted_cache_read_tokens,
943 cache_score.predicted_cache_read_tokens
944 );
945 assert_eq!(
946 decoded.predicted_cache_creation_tokens_5m,
947 cache_score.predicted_cache_creation_tokens_5m
948 );
949 assert_eq!(
950 decoded.predicted_cache_creation_tokens_1h,
951 cache_score.predicted_cache_creation_tokens_1h
952 );
953 assert_eq!(
954 decoded.predicted_uncached_input_tokens,
955 cache_score.predicted_uncached_input_tokens
956 );
957 assert_eq!(
958 decoded.predicted_expires_at_unix_secs,
959 cache_score.predicted_expires_at_unix_secs
960 );
961 assert_eq!(
962 decoded.matched_breakpoint_index,
963 cache_score.matched_breakpoint_index
964 );
965 assert!((decoded.confidence - cache_score.confidence).abs() < 0.0001);
966 assert_eq!(decoded.ambiguity_reason, cache_score.ambiguity_reason);
967 }
968
969 #[test]
970 fn upstream_candidate_cache_score_roundtrip() {
971 let candidate_no_cache = UpstreamCandidate {
972 upstream_id: Uuid::new_v4(),
973 name: "test-upstream".to_owned(),
974 kind: UpstreamKind::AnthropicApiKey,
975 observed_rate_limits: vec![],
976 subscription_quotas: vec![],
977 observed_at_unix_secs: 1700000000,
978 cache_score: None,
979 base_url: None,
980 };
981 let json = serde_json::to_string(&candidate_no_cache).unwrap();
982 let decoded: UpstreamCandidate = serde_json::from_str(&json).unwrap();
983 assert_eq!(decoded.upstream_id, candidate_no_cache.upstream_id);
984 assert_eq!(decoded.name, candidate_no_cache.name);
985 assert_eq!(decoded.cache_score, None);
986
987 let cache_score = CacheScore {
988 predicted_cache_read_tokens: 50,
989 predicted_cache_creation_tokens_5m: 100,
990 predicted_cache_creation_tokens_1h: 200,
991 predicted_uncached_input_tokens: 25,
992 predicted_expires_at_unix_secs: Some(1700000000),
993 matched_breakpoint_index: Some(0),
994 confidence: 0.95,
995 ambiguity_reason: None,
996 };
997 let candidate_with_cache = UpstreamCandidate {
998 upstream_id: Uuid::new_v4(),
999 name: "test-upstream-cached".to_owned(),
1000 kind: UpstreamKind::AnthropicApiKey,
1001 observed_rate_limits: vec![],
1002 subscription_quotas: vec![],
1003 observed_at_unix_secs: 1700000000,
1004 cache_score: Some(cache_score),
1005 base_url: None,
1006 };
1007 let json = serde_json::to_string(&candidate_with_cache).unwrap();
1008 let decoded: UpstreamCandidate = serde_json::from_str(&json).unwrap();
1009 assert_eq!(decoded.upstream_id, candidate_with_cache.upstream_id);
1010 assert_eq!(decoded.name, candidate_with_cache.name);
1011 assert!(decoded.cache_score.is_some());
1012 assert_eq!(decoded.cache_score.unwrap().predicted_cache_read_tokens, 50);
1013 }
1014
1015 #[test]
1016 fn request_context_cache_fields_roundtrip() {
1017 let ctx_empty = RequestContext {
1018 request_id: "req-1".to_owned(),
1019 downstream_headers: HeaderMap::new(),
1020 method: Method::POST,
1021 path: "/v1/messages".to_owned(),
1022 query: None,
1023 body_bytes: Bytes::new(),
1024 cache_breakpoints: Vec::new(),
1025 canonical_model_id: String::new(),
1026 };
1027 assert_eq!(ctx_empty.cache_breakpoints.len(), 0);
1028 assert_eq!(ctx_empty.canonical_model_id, "");
1029
1030 let breakpoint = CacheBreakpoint {
1031 block_index: 1,
1032 source: CacheBreakpointSource::Message,
1033 path: "messages.0.content.0".to_owned(),
1034 message_index: Some(0),
1035 prefix_hash: "hash123".to_owned(),
1036 prefix_token_count: 150,
1037 requested_ttl: TtlClass::Ephemeral1h,
1038 origin: BreakpointOrigin::Explicit,
1039 };
1040 let ctx_populated = RequestContext {
1041 request_id: "req-2".to_owned(),
1042 downstream_headers: HeaderMap::new(),
1043 method: Method::POST,
1044 path: "/v1/messages".to_owned(),
1045 query: Some("param=value".to_owned()),
1046 body_bytes: Bytes::from_static(b"test"),
1047 cache_breakpoints: vec![breakpoint],
1048 canonical_model_id: "claude-sonnet-4-5-20250929".to_owned(),
1049 };
1050 assert_eq!(ctx_populated.cache_breakpoints.len(), 1);
1051 assert_eq!(
1052 ctx_populated.canonical_model_id,
1053 "claude-sonnet-4-5-20250929"
1054 );
1055 }
1056
1057 #[test]
1058 fn upstream_candidate_deserialize_without_cache_score_field() {
1059 let json = r#"{
1060 "upstream_id": "00000000-0000-0000-0000-000000000001",
1061 "name": "legacy-upstream",
1062 "kind": "anthropic_api_key",
1063 "observed_rate_limits": [],
1064 "subscription_quotas": [],
1065 "observed_at_unix_secs": 1700000000
1066 }"#;
1067 let candidate: UpstreamCandidate = serde_json::from_str(json).unwrap();
1068 assert!(candidate.cache_score.is_none());
1069 assert_eq!(candidate.name, "legacy-upstream");
1070 }
1071
1072 #[test]
1073 fn request_context_cache_breakpoints_default_on_missing_fields() {
1074 let ctx = RequestContext {
1075 request_id: "test".to_owned(),
1076 downstream_headers: HeaderMap::new(),
1077 method: Method::GET,
1078 path: "/test".to_owned(),
1079 query: None,
1080 body_bytes: Bytes::new(),
1081 cache_breakpoints: Vec::new(),
1082 canonical_model_id: String::new(),
1083 };
1084 assert!(ctx.cache_breakpoints.is_empty());
1085 assert!(ctx.canonical_model_id.is_empty());
1086 }
1087}