openid-client 0.2.7

OpenID client for Rust
Documentation
use std::{
    cell::RefCell,
    collections::{HashMap, VecDeque},
};

use josekit::Value;
use url::Url;

use crate::{
    helpers::{decode_jwt, form_url_encoded_to_string_map},
    types::{
        http_client::ClientCertificate, HttpMethod, HttpRequest, HttpResponse, OidcHttpClient,
    },
};

pub struct TestHttpReqRes {
    pub url: Url,
    pub method: HttpMethod,
    pub headers: HashMap<String, Vec<String>>,
    pub dpop: bool,
    pub dpop_ath: bool,
    pub dpop_nonce_not_present: bool,
    pub dpop_nonce: Option<String>,
    pub body: Option<String>,
    pub mtls: bool,

    pub response_body: Option<String>,
    pub response_status_code: u16,
    pub response_content_type: Option<String>,
    pub response_www_authenticate: Option<String>,
    pub response_dpop_nonce: Option<String>,
}

impl TestHttpReqRes {
    pub fn new(url: impl Into<String>) -> Self {
        TestHttpReqRes {
            url: Url::parse(&url.into()).unwrap(),
            method: HttpMethod::GET,
            headers: HashMap::new(),
            dpop: false,
            dpop_ath: false,
            dpop_nonce_not_present: false,
            dpop_nonce: None,
            body: None,
            mtls: false,
            response_body: None,
            response_status_code: 200,
            response_content_type: None,
            response_www_authenticate: None,
            response_dpop_nonce: None,
        }
    }

    pub fn assert_request_method(mut self, method: HttpMethod) -> Self {
        self.method = method;
        self
    }

    pub fn assert_request_header(mut self, key: impl Into<String>, value: Vec<String>) -> Self {
        self.headers.insert(key.into(), value);
        self
    }

    pub fn assert_dpop(mut self) -> Self {
        self.dpop = true;
        self
    }

    pub fn assert_dpop_ath(mut self) -> Self {
        self.dpop_ath = true;
        self
    }

    pub fn assert_dpop_nonce_not_present(mut self) -> Self {
        self.dpop_nonce_not_present = true;
        self
    }

    pub fn assert_dpop_nonce_value(mut self, val: impl Into<String>) -> Self {
        self.dpop_nonce = Some(val.into());
        self
    }

    pub fn assert_request_body(mut self, body: impl Into<String>) -> Self {
        self.body = Some(body.into());
        self
    }

    pub fn assert_request_mtls(mut self, mtls: bool) -> Self {
        self.mtls = mtls;
        self
    }

    pub fn set_response_body(mut self, response_body: impl Into<String>) -> Self {
        self.response_body = Some(response_body.into());
        self
    }

    pub fn set_response_status_code(mut self, response_status_code: u16) -> Self {
        self.response_status_code = response_status_code;
        self
    }

    pub fn set_response_content_type_header(mut self, ct: impl Into<String>) -> Self {
        self.response_content_type = Some(ct.into());
        self
    }

    pub fn set_response_www_authenticate_header(mut self, www: impl Into<String>) -> Self {
        self.response_www_authenticate = Some(www.into());
        self
    }

    pub fn set_response_dpop_nonce_header(mut self, nonce: impl Into<String>) -> Self {
        self.response_dpop_nonce = Some(nonce.into());
        self
    }

    pub fn build(self) -> TestHttpClient {
        let http_client = TestHttpClient::new();

        http_client.add(self)
    }
}

pub struct TestHttpClient {
    req_res: RefCell<VecDeque<TestHttpReqRes>>,
    pub return_client_cert: bool,
}

impl TestHttpClient {
    pub fn new() -> Self {
        Self {
            req_res: RefCell::new(VecDeque::with_capacity(5)),
            return_client_cert: false,
        }
    }

    pub fn add(mut self, req_res: TestHttpReqRes) -> Self {
        self.req_res.get_mut().push_back(req_res);

        self
    }

    pub fn return_client_cert(&mut self, v: bool) {
        self.return_client_cert = v;
    }

    pub fn assert(&self) {
        assert!(
            self.req_res.borrow().is_empty(),
            "All requests not fullfilled"
        );
    }
}

unsafe impl Sync for TestHttpClient {}

impl OidcHttpClient for TestHttpClient {
    async fn get_client_certificate(&self, _req: &HttpRequest) -> Option<ClientCertificate> {
        if self.return_client_cert {
            return Some(ClientCertificate {
                cert: "".to_string(),
                key: "".to_string(),
            });
        }
        None
    }

    async fn request(&self, mut req: HttpRequest) -> Result<HttpResponse, String> {
        let mut req_res_list = self.req_res.borrow_mut();

        let req_res = req_res_list.pop_front().unwrap();

        if req_res.dpop
            || req_res.dpop_ath
            || req_res.dpop_nonce.is_some()
            || req_res.dpop_nonce_not_present
        {
            if req_res.dpop && req.headers.get("DPoP").is_none() {
                assert!(false, "Expected DPoP header")
            }

            if let Some(Some(dpop)) = req
                .headers
                .get("DPoP")
                .map(|d| d.first().map(|v| v.to_owned()))
            {
                req.headers.remove("DPoP");

                let decoded = decode_jwt(&dpop).unwrap();

                if req_res.dpop_ath {
                    assert!(decoded.payload.claim("ath").is_some());
                }

                if req_res.dpop_nonce_not_present {
                    assert!(decoded.payload.claim("nonce").is_none())
                }

                if let Some(nonce) = req_res.dpop_nonce {
                    if nonce.is_empty() {
                        assert!(decoded.payload.claim("nonce").is_some());
                    } else {
                        assert_eq!(
                            decoded.payload.claim("nonce").unwrap().as_str().unwrap(),
                            nonce
                        );
                    }
                }
            }
        }

        assert_eq!(req.url, req_res.url);
        assert_eq!(req.method, req_res.method);
        assert_eq!(req.headers, req_res.headers);

        if req_res
            .headers
            .get("content-type")
            .is_some_and(|ct| ct.contains(&"application/json".to_string()))
        {
            assert_eq!(
                req.body.map(|b| serde_json::from_str::<Value>(&b).unwrap()),
                req_res
                    .body
                    .map(|b| serde_json::from_str::<Value>(&b).unwrap()),
            )
        } else if req_res
            .headers
            .get("content-type")
            .is_some_and(|ct| ct.contains(&"application/x-www-form-urlencoded".to_string()))
        {
            assert_eq!(
                req.body.map(|b| form_url_encoded_to_string_map(&b)),
                req_res.body.map(|b| form_url_encoded_to_string_map(&b)),
            )
        } else {
            assert_eq!(req.body, req_res.body);
        }

        assert_eq!(req.mtls, req_res.mtls);

        Ok(HttpResponse {
            body: req_res.response_body,
            status_code: req_res.response_status_code,
            content_type: req_res.response_content_type,
            dpop_nonce: req_res.response_dpop_nonce,
            www_authenticate: req_res.response_www_authenticate,
        })
    }
}