Skip to main content

cc_lb_plugin_wire/v2/
common.rs

1//! v2 is a fork of v1 to allow additive cache-related fields. v1 is FROZEN.
2
3extern crate alloc;
4
5use alloc::{string::String, vec::Vec};
6use serde::{Deserialize, Serialize};
7
8#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
9#[serde(deny_unknown_fields)]
10pub struct HeaderWire {
11    pub name: String,
12    pub value_base64: String,
13}
14
15impl HeaderWire {
16    pub fn dry_run_sample() -> Self {
17        Self {
18            name: String::from("x-cc-lb-dry-run"),
19            value_base64: String::new(),
20        }
21    }
22}
23
24#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
25#[serde(deny_unknown_fields)]
26pub struct Principal {
27    pub id: String,
28    pub kind: String,
29    pub claims: serde_json::Map<String, serde_json::Value>,
30}
31
32impl Principal {
33    pub fn dry_run_sample() -> Self {
34        Self {
35            id: String::from("dry-run-principal"),
36            kind: String::from("api_key"),
37            claims: serde_json::Map::new(),
38        }
39    }
40}
41
42#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
43#[serde(deny_unknown_fields)]
44pub struct RequestWire {
45    pub request_id: String,
46    pub headers: Vec<HeaderWire>,
47    pub method: String,
48    pub path: String,
49    pub query: Option<String>,
50    pub body_base64: String,
51}
52
53impl RequestWire {
54    pub fn dry_run_sample() -> Self {
55        Self {
56            request_id: String::from("dry-run-request"),
57            headers: Vec::new(),
58            method: String::from("POST"),
59            path: String::from("/v1/messages"),
60            query: None,
61            body_base64: String::new(),
62        }
63    }
64}
65
66#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
67#[serde(deny_unknown_fields)]
68pub struct RateLimitObservationWire {
69    pub kind: String,
70    pub window: String,
71    pub limit: Option<u64>,
72    pub remaining: Option<u64>,
73    pub reset: Option<String>,
74}
75
76impl RateLimitObservationWire {
77    pub fn dry_run_sample() -> Self {
78        Self {
79            kind: String::from("requests"),
80            window: String::from("dry-run"),
81            limit: None,
82            remaining: None,
83            reset: None,
84        }
85    }
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
89#[serde(rename_all = "snake_case")]
90#[serde(deny_unknown_fields)]
91pub struct SubscriptionQuotaCandidateSnapshotWire {
92    pub window: String,
93    pub source: String,
94    pub data_state: String,
95    pub utilization: Option<f64>,
96    pub status: Option<String>,
97    pub resets_at_unix_secs: Option<u64>,
98    pub surpassed_threshold: Option<bool>,
99    pub representative_claim: Option<String>,
100    pub disabled_reason: Option<String>,
101    pub observed_at_unix_millis: Option<u64>,
102    pub age_secs: Option<u64>,
103}
104
105impl SubscriptionQuotaCandidateSnapshotWire {
106    pub fn dry_run_sample() -> Self {
107        Self {
108            window: String::from("5h"),
109            source: String::from("missing"),
110            data_state: String::from("missing"),
111            utilization: None,
112            status: None,
113            resets_at_unix_secs: None,
114            surpassed_threshold: None,
115            representative_claim: None,
116            disabled_reason: None,
117            observed_at_unix_millis: None,
118            age_secs: None,
119        }
120    }
121}
122
123#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
124#[serde(deny_unknown_fields)]
125pub struct CandidateWire {
126    pub upstream_id: String,
127    pub name: String,
128    pub kind: String,
129    pub observed_rate_limits: Vec<RateLimitObservationWire>,
130    pub subscription_quotas: Vec<SubscriptionQuotaCandidateSnapshotWire>,
131    pub observed_at_unix_secs: u64,
132    #[serde(default, skip_serializing_if = "Option::is_none")]
133    pub cache_score: Option<CacheScoreWire>,
134}
135
136impl CandidateWire {
137    pub fn dry_run_sample() -> Self {
138        Self {
139            upstream_id: String::from("00000000-0000-0000-0000-000000000000"),
140            name: String::from("dry-run-upstream"),
141            kind: String::from("anthropic_api_key"),
142            observed_rate_limits: Vec::new(),
143            subscription_quotas: Vec::new(),
144            observed_at_unix_secs: 0,
145            cache_score: None,
146        }
147    }
148}
149
150#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
151#[serde(rename_all = "snake_case", tag = "kind")]
152#[serde(deny_unknown_fields)]
153pub enum UpstreamWire {
154    AnthropicDirect,
155}
156
157impl UpstreamWire {
158    pub fn dry_run_sample() -> Self {
159        Self::AnthropicDirect
160    }
161}
162
163#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
164#[serde(rename_all = "snake_case", tag = "kind")]
165#[serde(deny_unknown_fields)]
166pub enum DialectBinding {
167    #[serde(rename = "self")]
168    SelfReferenced,
169}
170
171impl DialectBinding {
172    pub fn dry_run_sample() -> Self {
173        Self::SelfReferenced
174    }
175}
176
177#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
178#[serde(deny_unknown_fields)]
179pub struct ShapedRequestWire {
180    pub url: String,
181    pub method: String,
182    pub headers: Vec<HeaderWire>,
183    pub body_base64: String,
184}
185
186impl ShapedRequestWire {
187    pub fn dry_run_sample() -> Self {
188        Self {
189            url: String::from("https://api.anthropic.com/v1/messages"),
190            method: String::from("POST"),
191            headers: Vec::new(),
192            body_base64: String::new(),
193        }
194    }
195}
196
197#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
198#[serde(rename_all = "snake_case")]
199#[serde(deny_unknown_fields)]
200pub enum UpstreamErrorCategory {
201    Unauthorized,
202    Retryable,
203    Failed,
204}
205
206impl UpstreamErrorCategory {
207    pub fn dry_run_sample() -> Self {
208        Self::Unauthorized
209    }
210}
211
212#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
213#[serde(deny_unknown_fields)]
214pub struct UpstreamErrorWire {
215    pub status: u16,
216    pub body_base64: Option<String>,
217    pub category: UpstreamErrorCategory,
218}
219
220impl UpstreamErrorWire {
221    pub fn dry_run_sample() -> Self {
222        Self {
223            status: 401,
224            body_base64: None,
225            category: UpstreamErrorCategory::dry_run_sample(),
226        }
227    }
228}
229
230#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
231#[serde(rename_all = "snake_case", tag = "kind")]
232#[serde(deny_unknown_fields)]
233pub enum ObserveEventWire {
234    RequestStarted {
235        request_id: String,
236        downstream_user_agent: Option<String>,
237    },
238    AuthnComplete {
239        principal_id: String,
240        principal_kind: String,
241    },
242    UpstreamChosen {
243        upstream: UpstreamWire,
244    },
245    Chunk {
246        batch_index: u64,
247        event_count: usize,
248        total_bytes: usize,
249    },
250    RequestFinished {
251        status: u16,
252        input_tokens: Option<u64>,
253        output_tokens: Option<u64>,
254        cache_creation_input_tokens: Option<u64>,
255        cache_read_input_tokens: Option<u64>,
256        duration_ms: u64,
257    },
258    Error {
259        code: String,
260        message: String,
261        source: String,
262    },
263}
264
265impl ObserveEventWire {
266    pub fn dry_run_sample() -> Self {
267        Self::RequestStarted {
268            request_id: String::from("dry-run-request"),
269            downstream_user_agent: None,
270        }
271    }
272}
273
274// Cache-related wire types (v2 only)
275
276#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
277#[serde(deny_unknown_fields, rename_all = "snake_case")]
278pub enum TtlClassWire {
279    #[default]
280    Ephemeral5m,
281    Ephemeral1h,
282}
283
284#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
285#[serde(deny_unknown_fields, rename_all = "snake_case")]
286pub enum BreakpointOriginWire {
287    Explicit,
288    AutoCacheInferred,
289}
290
291#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
292#[serde(deny_unknown_fields, rename_all = "snake_case")]
293pub enum CacheBreakpointSourceWire {
294    Tools,
295    System,
296    Message,
297}
298
299#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
300#[serde(deny_unknown_fields)]
301pub struct CacheBreakpointWire {
302    pub block_index: u32,
303    pub source: CacheBreakpointSourceWire,
304    pub path: String,
305    #[serde(default, skip_serializing_if = "Option::is_none")]
306    pub message_index: Option<u32>,
307    pub prefix_hash: String,
308    pub prefix_token_count: u64,
309    pub requested_ttl: TtlClassWire,
310    pub origin: BreakpointOriginWire,
311}
312
313#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
314#[serde(deny_unknown_fields)]
315pub struct WarmCacheEntryWire {
316    pub prefix_hash: String,
317    pub expires_at_unix_secs: u64,
318    pub ttl_class: TtlClassWire,
319    pub last_observed_at_unix_secs: u64,
320}
321
322#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
323#[serde(deny_unknown_fields)]
324pub struct CacheScoreWire {
325    pub predicted_cache_read_tokens: u32,
326    pub predicted_cache_creation_tokens_5m: u32,
327    pub predicted_cache_creation_tokens_1h: u32,
328    pub predicted_uncached_input_tokens: u32,
329    #[serde(default, skip_serializing_if = "Option::is_none")]
330    pub predicted_expires_at_unix_secs: Option<u64>,
331    #[serde(default, skip_serializing_if = "Option::is_none")]
332    pub matched_breakpoint_index: Option<u32>,
333    pub confidence: f32,
334    #[serde(default, skip_serializing_if = "Option::is_none")]
335    pub ambiguity_reason: Option<String>,
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn cache_score_wire_roundtrip() {
344        let score = CacheScoreWire {
345            predicted_cache_read_tokens: 1000,
346            predicted_cache_creation_tokens_5m: 200,
347            predicted_cache_creation_tokens_1h: 0,
348            predicted_uncached_input_tokens: 50,
349            predicted_expires_at_unix_secs: Some(1700000000),
350            matched_breakpoint_index: Some(0),
351            confidence: 0.85,
352            ambiguity_reason: None,
353        };
354        let json = serde_json::to_string(&score).unwrap();
355        let parsed: CacheScoreWire = serde_json::from_str(&json).unwrap();
356        assert_eq!(score, parsed);
357    }
358
359    #[test]
360    fn candidate_wire_with_cache_score_roundtrip() {
361        let candidate = CandidateWire {
362            upstream_id: String::from("00000000-0000-0000-0000-000000000000"),
363            name: String::from("test-upstream"),
364            kind: String::from("anthropic_api_key"),
365            observed_rate_limits: Vec::new(),
366            subscription_quotas: Vec::new(),
367            observed_at_unix_secs: 1_717_171_717,
368            cache_score: Some(CacheScoreWire {
369                predicted_cache_read_tokens: 500,
370                predicted_cache_creation_tokens_5m: 100,
371                predicted_cache_creation_tokens_1h: 50,
372                predicted_uncached_input_tokens: 25,
373                predicted_expires_at_unix_secs: None,
374                matched_breakpoint_index: Some(1),
375                confidence: 0.95,
376                ambiguity_reason: None,
377            }),
378        };
379        let json = serde_json::to_string(&candidate).unwrap();
380        let parsed: CandidateWire = serde_json::from_str(&json).unwrap();
381        assert_eq!(candidate, parsed);
382    }
383
384    #[test]
385    fn candidate_wire_without_cache_score_deserializes() {
386        let json = r#"{
387            "upstream_id": "11111111-1111-1111-1111-111111111111",
388            "name": "legacy-upstream",
389            "kind": "anthropic_api_key",
390            "observed_rate_limits": [],
391            "subscription_quotas": [],
392            "observed_at_unix_secs": 1000000000
393        }"#;
394        let parsed: CandidateWire = serde_json::from_str(json).unwrap();
395        assert_eq!(parsed.cache_score, None);
396    }
397
398    #[test]
399    fn cache_breakpoint_wire_roundtrip() {
400        let bp = CacheBreakpointWire {
401            block_index: 0,
402            source: CacheBreakpointSourceWire::Tools,
403            path: String::from("/v1/messages/create"),
404            message_index: Some(5),
405            prefix_hash: String::from("abc123def456"),
406            prefix_token_count: 2048,
407            requested_ttl: TtlClassWire::Ephemeral5m,
408            origin: BreakpointOriginWire::Explicit,
409        };
410        let json = serde_json::to_string(&bp).unwrap();
411        let parsed: CacheBreakpointWire = serde_json::from_str(&json).unwrap();
412        assert_eq!(bp, parsed);
413    }
414}