cc-lb-plugin-conformance 0.1.3

cc-lb plugin conformance suite — in-process protocol verification helpers for external plugin authors.
Documentation
use std::collections::{BTreeMap, BTreeSet};

use cc_lb_plugin_wire::augmented_metadata::AugmentedMetadata;
use cc_lb_plugin_wire::identity::{CC_LB_PLUGIN_MAGIC, PluginIdentity};
use cc_lb_plugin_wire::v2::common::{
    CandidateWire, HeaderWire, ObserveEventWire, Principal, UpstreamWire,
};
use cc_lb_plugin_wire::v2::observe::ObserveRequest;
use cc_lb_plugin_wire::v2::shape::ShapeRequest;
use cc_lb_plugin_wire::v3::filter::FilterRequest;
use cc_lb_plugin_wire::wire_function::WireFunction;

pub fn synth_principal() -> Principal {
    Principal::dry_run_sample()
}

pub fn synth_metadata() -> AugmentedMetadata {
    let mut negotiated_functions = BTreeMap::new();
    negotiated_functions.insert(String::from("shape"), 1);
    negotiated_functions.insert(String::from("filter"), 1);
    negotiated_functions.insert(String::from("observe"), 1);

    AugmentedMetadata {
        identity: PluginIdentity {
            magic: CC_LB_PLUGIN_MAGIC,
            abi_envelope: 1,
            plugin_name: String::from("dry-run-plugin"),
            plugin_version: String::from("0.0.0"),
        },
        negotiated_functions,
        negotiated_capabilities: BTreeSet::new(),
        handshake_completed_at: 0,
        self_check_passed: true,
        self_check_completed_at: 0,
        expires_at: 0,
    }
}

pub fn shape_request_builder() -> ShapeRequestBuilder {
    ShapeRequestBuilder {
        request: ShapeRequest::dry_run_sample(),
    }
}

#[non_exhaustive]
pub struct ShapeRequestBuilder {
    request: ShapeRequest,
}

impl ShapeRequestBuilder {
    pub fn method(mut self, method: impl Into<String>) -> Self {
        self.request.request.method = method.into();
        self
    }

    pub fn path(mut self, path: impl Into<String>) -> Self {
        self.request.request.path = path.into();
        self
    }

    pub fn query(mut self, query: impl Into<String>) -> Self {
        self.request.request.query = Some(query.into());
        self
    }

    pub fn body_base64(mut self, body_base64: impl Into<String>) -> Self {
        self.request.request.body_base64 = body_base64.into();
        self
    }

    pub fn headers(mut self, headers: Vec<HeaderWire>) -> Self {
        self.request.request.headers = headers;
        self
    }

    pub fn principal(mut self, principal: Principal) -> Self {
        self.request.principal = principal;
        self
    }

    pub fn upstream(mut self, upstream: UpstreamWire) -> Self {
        self.request.upstream = upstream;
        self
    }

    pub fn build(self) -> ShapeRequest {
        self.request
    }
}

pub fn filter_request_builder() -> FilterRequestBuilder {
    FilterRequestBuilder {
        request: FilterRequest::dry_run_sample(),
    }
}

#[non_exhaustive]
pub struct FilterRequestBuilder {
    request: FilterRequest,
}

impl FilterRequestBuilder {
    pub fn request_id(mut self, request_id: impl Into<String>) -> Self {
        self.request.request_id = request_id.into();
        self
    }

    pub fn headers(mut self, headers: Vec<HeaderWire>) -> Self {
        self.request.headers = headers;
        self
    }

    pub fn method(mut self, method: impl Into<String>) -> Self {
        self.request.method = method.into();
        self
    }

    pub fn path(mut self, path: impl Into<String>) -> Self {
        self.request.path = path.into();
        self
    }

    pub fn query(mut self, query: impl Into<String>) -> Self {
        self.request.query = Some(query.into());
        self
    }

    pub fn body_base64(mut self, body_base64: impl Into<String>) -> Self {
        self.request.body_base64 = body_base64.into();
        self
    }

    pub fn principal(mut self, principal: Principal) -> Self {
        self.request.principal = principal;
        self
    }

    pub fn candidates(mut self, candidates: Vec<CandidateWire>) -> Self {
        self.request.candidates = candidates;
        self
    }

    pub fn build(self) -> FilterRequest {
        self.request
    }
}

pub fn observe_request_builder() -> ObserveRequestBuilder {
    ObserveRequestBuilder {
        request: ObserveRequest::dry_run_sample(),
    }
}

#[non_exhaustive]
pub struct ObserveRequestBuilder {
    request: ObserveRequest,
}

impl ObserveRequestBuilder {
    pub fn add_event(mut self, event: ObserveEventWire) -> Self {
        self.request.events.push(event);
        self
    }

    pub fn build(self) -> ObserveRequest {
        self.request
    }
}

pub fn dry_run_sample<F: WireFunction>() -> F::Request {
    F::dry_run_request()
}

#[cfg(test)]
mod tests {
    use super::*;
    use cc_lb_plugin_wire::v2::common::ObserveEventWire;
    use cc_lb_plugin_wire::v2::shape::ShapeFn;

    #[test]
    fn shape_builder_overrides_request_fields() {
        let request = shape_request_builder()
            .method("GET")
            .path("/v1/models")
            .query("limit=1")
            .body_base64("e30=")
            .headers(vec![HeaderWire {
                name: String::from("content-type"),
                value_base64: String::from("YXBwbGljYXRpb24vanNvbg=="),
            }])
            .principal(synth_principal())
            .upstream(UpstreamWire::dry_run_sample())
            .build();

        assert_eq!(request.request.method, "GET");
        assert_eq!(request.request.path, "/v1/models");
        assert_eq!(request.request.query.as_deref(), Some("limit=1"));
        assert_eq!(request.request.body_base64, "e30=");
        assert_eq!(request.request.headers.len(), 1);
    }

    #[test]
    fn filter_builder_overrides_filter_fields() {
        let request = filter_request_builder()
            .request_id("req-1")
            .headers(Vec::new())
            .method("POST")
            .path("/v1/messages")
            .query("beta=true")
            .body_base64("")
            .principal(synth_principal())
            .candidates(vec![CandidateWire::dry_run_sample()])
            .build();

        assert_eq!(request.request_id, "req-1");
        assert_eq!(request.query.as_deref(), Some("beta=true"));
        assert_eq!(request.candidates.len(), 1);
    }

    #[test]
    fn observe_builder_appends_events() {
        let request = observe_request_builder()
            .add_event(ObserveEventWire::dry_run_sample())
            .build();

        assert_eq!(request.events.len(), 1);
    }

    #[test]
    fn dry_run_sample_forwards_to_wire_function() {
        let request = dry_run_sample::<ShapeFn>();

        assert_eq!(request, ShapeRequest::dry_run_sample());
    }

    #[test]
    fn synth_metadata_has_protocol_defaults_and_zero_timestamps() {
        let metadata = synth_metadata();

        assert_eq!(metadata.identity.plugin_name, "dry-run-plugin");
        assert_eq!(metadata.negotiated_functions.get("shape"), Some(&1));
        assert_eq!(metadata.negotiated_functions.get("filter"), Some(&1));
        assert_eq!(metadata.negotiated_functions.get("observe"), Some(&1));
        assert_eq!(metadata.handshake_completed_at, 0);
        assert_eq!(metadata.self_check_completed_at, 0);
        assert_eq!(metadata.expires_at, 0);
    }
}