httpclient/middleware/
recorder.rs

1use std::sync::OnceLock;
2use async_trait::async_trait;
3
4use http::header::CONTENT_TYPE;
5use tracing::info;
6
7use crate::error::ProtocolResult;
8use crate::middleware::Next;
9use crate::middleware::ProtocolError;
10use crate::recorder::{HashableRequest, RequestRecorder};
11use crate::{Body, InMemoryRequest, InMemoryResponse, Middleware, Response};
12
13#[derive(PartialEq, Eq, Clone, Copy, Default, Debug)]
14pub enum RecorderMode {
15    /// Default. Will check for recordings, but will make the request if no recording is found.
16    #[default]
17    RecordOrRequest,
18    /// Always make the request.
19    IgnoreRecordings,
20    /// Always use recordings. Fail if no recording is found.
21    ForceNoRequests,
22}
23
24impl RecorderMode {
25    #[must_use]
26    pub fn should_lookup(self) -> bool {
27        match self {
28            RecorderMode::IgnoreRecordings => false,
29            RecorderMode::ForceNoRequests | RecorderMode::RecordOrRequest => true,
30        }
31    }
32
33    #[must_use]
34    pub fn should_request(self) -> bool {
35        match self {
36            RecorderMode::IgnoreRecordings | RecorderMode::RecordOrRequest => true,
37            RecorderMode::ForceNoRequests => false,
38        }
39    }
40}
41
42static SHARED_RECORDER: OnceLock<RequestRecorder> = OnceLock::new();
43
44pub fn shared_recorder() -> &'static RequestRecorder {
45    SHARED_RECORDER.get_or_init(RequestRecorder::new)
46}
47
48#[derive(Default, Copy, Clone, Debug)]
49/// This middleware caches requests to the local filesystem. Subsequent requests will return results
50/// from the filesystem, and not touch the remote server.
51///
52/// The recordings are sanitized to hide secrets.
53///
54/// Use `.mode()` to configure the behavior:
55/// - `RecorderMode::RecordOrRequest` (default): Will check for recordings, but will make the request if no recording is found.
56/// - `RecorderMode::IgnoreRecordings`: Always make the request. (Use to force refresh recordings.)
57/// - `RecorderMode::ForceNoRequests`: Fail if no recording is found. (Use to run tests without hitting the remote server.)
58pub struct Recorder {
59    pub mode: RecorderMode,
60}
61
62impl Recorder {
63    #[must_use]
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    #[must_use]
69    pub fn mode(mut self, mode: RecorderMode) -> Self {
70        self.mode = mode;
71        self
72    }
73
74    fn should_lookup(self) -> bool {
75        self.mode.should_lookup()
76    }
77
78    fn should_request(self) -> bool {
79        self.mode.should_request()
80    }
81}
82
83#[async_trait]
84impl Middleware for Recorder {
85    #[allow(clippy::similar_names)]
86    async fn handle(&self, request: InMemoryRequest, next: Next<'_>) -> ProtocolResult<Response> {
87        let recorder = shared_recorder();
88
89        let request = HashableRequest(request);
90        if self.should_lookup() {
91            let recorded = recorder.get_response(&request);
92
93            if let Some(recorded) = recorded {
94                info!(url = request.uri().to_string(), "Using recorded response");
95
96                let (parts, body) = recorded.into_parts();
97                return Ok(Response::from_parts(parts, Body::InMemory(body)));
98            }
99        }
100
101        if !self.should_request() {
102            return Err(ProtocolError::IoError(std::io::Error::new(std::io::ErrorKind::NotFound, "No recording found")));
103        }
104
105        let response = next.run(request.clone()).await?;
106        let (parts, body) = response.into_parts();
107        let content_type = parts.headers.get(CONTENT_TYPE);
108        let body = body.into_content_type(content_type).await?;
109        let response = InMemoryResponse::from_parts(parts, body);
110
111        recorder.record_response(request.0, response.clone())?;
112
113        let (parts, body) = response.into_parts();
114        Ok(Response::from_parts(parts, Body::InMemory(body)))
115    }
116}