Skip to main content

cc_lb_plugin_conformance/
fixtures.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use cc_lb_plugin_wire::augmented_metadata::AugmentedMetadata;
4use cc_lb_plugin_wire::identity::{CC_LB_PLUGIN_MAGIC, PluginIdentity};
5use cc_lb_plugin_wire::v2::common::{
6    CandidateWire, HeaderWire, ObserveEventWire, Principal, UpstreamWire,
7};
8use cc_lb_plugin_wire::v2::observe::ObserveRequest;
9use cc_lb_plugin_wire::v2::shape::ShapeRequest;
10use cc_lb_plugin_wire::v3::filter::FilterRequest;
11use cc_lb_plugin_wire::wire_function::WireFunction;
12
13pub fn synth_principal() -> Principal {
14    Principal::dry_run_sample()
15}
16
17pub fn synth_metadata() -> AugmentedMetadata {
18    let mut negotiated_functions = BTreeMap::new();
19    negotiated_functions.insert(String::from("shape"), 1);
20    negotiated_functions.insert(String::from("filter"), 1);
21    negotiated_functions.insert(String::from("observe"), 1);
22
23    AugmentedMetadata {
24        identity: PluginIdentity {
25            magic: CC_LB_PLUGIN_MAGIC,
26            abi_envelope: 1,
27            plugin_name: String::from("dry-run-plugin"),
28            plugin_version: String::from("0.0.0"),
29        },
30        negotiated_functions,
31        negotiated_capabilities: BTreeSet::new(),
32        handshake_completed_at: 0,
33        self_check_passed: true,
34        self_check_completed_at: 0,
35        expires_at: 0,
36    }
37}
38
39pub fn shape_request_builder() -> ShapeRequestBuilder {
40    ShapeRequestBuilder {
41        request: ShapeRequest::dry_run_sample(),
42    }
43}
44
45#[non_exhaustive]
46pub struct ShapeRequestBuilder {
47    request: ShapeRequest,
48}
49
50impl ShapeRequestBuilder {
51    pub fn method(mut self, method: impl Into<String>) -> Self {
52        self.request.request.method = method.into();
53        self
54    }
55
56    pub fn path(mut self, path: impl Into<String>) -> Self {
57        self.request.request.path = path.into();
58        self
59    }
60
61    pub fn query(mut self, query: impl Into<String>) -> Self {
62        self.request.request.query = Some(query.into());
63        self
64    }
65
66    pub fn body_base64(mut self, body_base64: impl Into<String>) -> Self {
67        self.request.request.body_base64 = body_base64.into();
68        self
69    }
70
71    pub fn headers(mut self, headers: Vec<HeaderWire>) -> Self {
72        self.request.request.headers = headers;
73        self
74    }
75
76    pub fn principal(mut self, principal: Principal) -> Self {
77        self.request.principal = principal;
78        self
79    }
80
81    pub fn upstream(mut self, upstream: UpstreamWire) -> Self {
82        self.request.upstream = upstream;
83        self
84    }
85
86    pub fn build(self) -> ShapeRequest {
87        self.request
88    }
89}
90
91pub fn filter_request_builder() -> FilterRequestBuilder {
92    FilterRequestBuilder {
93        request: FilterRequest::dry_run_sample(),
94    }
95}
96
97#[non_exhaustive]
98pub struct FilterRequestBuilder {
99    request: FilterRequest,
100}
101
102impl FilterRequestBuilder {
103    pub fn request_id(mut self, request_id: impl Into<String>) -> Self {
104        self.request.request_id = request_id.into();
105        self
106    }
107
108    pub fn headers(mut self, headers: Vec<HeaderWire>) -> Self {
109        self.request.headers = headers;
110        self
111    }
112
113    pub fn method(mut self, method: impl Into<String>) -> Self {
114        self.request.method = method.into();
115        self
116    }
117
118    pub fn path(mut self, path: impl Into<String>) -> Self {
119        self.request.path = path.into();
120        self
121    }
122
123    pub fn query(mut self, query: impl Into<String>) -> Self {
124        self.request.query = Some(query.into());
125        self
126    }
127
128    pub fn body_base64(mut self, body_base64: impl Into<String>) -> Self {
129        self.request.body_base64 = body_base64.into();
130        self
131    }
132
133    pub fn principal(mut self, principal: Principal) -> Self {
134        self.request.principal = principal;
135        self
136    }
137
138    pub fn candidates(mut self, candidates: Vec<CandidateWire>) -> Self {
139        self.request.candidates = candidates;
140        self
141    }
142
143    pub fn build(self) -> FilterRequest {
144        self.request
145    }
146}
147
148pub fn observe_request_builder() -> ObserveRequestBuilder {
149    ObserveRequestBuilder {
150        request: ObserveRequest::dry_run_sample(),
151    }
152}
153
154#[non_exhaustive]
155pub struct ObserveRequestBuilder {
156    request: ObserveRequest,
157}
158
159impl ObserveRequestBuilder {
160    pub fn add_event(mut self, event: ObserveEventWire) -> Self {
161        self.request.events.push(event);
162        self
163    }
164
165    pub fn build(self) -> ObserveRequest {
166        self.request
167    }
168}
169
170pub fn dry_run_sample<F: WireFunction>() -> F::Request {
171    F::dry_run_request()
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use cc_lb_plugin_wire::v2::common::ObserveEventWire;
178    use cc_lb_plugin_wire::v2::shape::ShapeFn;
179
180    #[test]
181    fn shape_builder_overrides_request_fields() {
182        let request = shape_request_builder()
183            .method("GET")
184            .path("/v1/models")
185            .query("limit=1")
186            .body_base64("e30=")
187            .headers(vec![HeaderWire {
188                name: String::from("content-type"),
189                value_base64: String::from("YXBwbGljYXRpb24vanNvbg=="),
190            }])
191            .principal(synth_principal())
192            .upstream(UpstreamWire::dry_run_sample())
193            .build();
194
195        assert_eq!(request.request.method, "GET");
196        assert_eq!(request.request.path, "/v1/models");
197        assert_eq!(request.request.query.as_deref(), Some("limit=1"));
198        assert_eq!(request.request.body_base64, "e30=");
199        assert_eq!(request.request.headers.len(), 1);
200    }
201
202    #[test]
203    fn filter_builder_overrides_filter_fields() {
204        let request = filter_request_builder()
205            .request_id("req-1")
206            .headers(Vec::new())
207            .method("POST")
208            .path("/v1/messages")
209            .query("beta=true")
210            .body_base64("")
211            .principal(synth_principal())
212            .candidates(vec![CandidateWire::dry_run_sample()])
213            .build();
214
215        assert_eq!(request.request_id, "req-1");
216        assert_eq!(request.query.as_deref(), Some("beta=true"));
217        assert_eq!(request.candidates.len(), 1);
218    }
219
220    #[test]
221    fn observe_builder_appends_events() {
222        let request = observe_request_builder()
223            .add_event(ObserveEventWire::dry_run_sample())
224            .build();
225
226        assert_eq!(request.events.len(), 1);
227    }
228
229    #[test]
230    fn dry_run_sample_forwards_to_wire_function() {
231        let request = dry_run_sample::<ShapeFn>();
232
233        assert_eq!(request, ShapeRequest::dry_run_sample());
234    }
235
236    #[test]
237    fn synth_metadata_has_protocol_defaults_and_zero_timestamps() {
238        let metadata = synth_metadata();
239
240        assert_eq!(metadata.identity.plugin_name, "dry-run-plugin");
241        assert_eq!(metadata.negotiated_functions.get("shape"), Some(&1));
242        assert_eq!(metadata.negotiated_functions.get("filter"), Some(&1));
243        assert_eq!(metadata.negotiated_functions.get("observe"), Some(&1));
244        assert_eq!(metadata.handshake_completed_at, 0);
245        assert_eq!(metadata.self_check_completed_at, 0);
246        assert_eq!(metadata.expires_at, 0);
247    }
248}