httpclient/middleware/
recorder.rs1use std::sync::OnceLock;
2use async_trait::async_trait;
3
4use http::header::CONTENT_TYPE;
5use tracing::info;
6
7use crate::error::ProtocolResult;
8use crate::middleware::Next;
9use crate::middleware::ProtocolError;
10use crate::recorder::{HashableRequest, RequestRecorder};
11use crate::{Body, InMemoryRequest, InMemoryResponse, Middleware, Response};
12
13#[derive(PartialEq, Eq, Clone, Copy, Default, Debug)]
14pub enum RecorderMode {
15 #[default]
17 RecordOrRequest,
18 IgnoreRecordings,
20 ForceNoRequests,
22}
23
24impl RecorderMode {
25 #[must_use]
26 pub fn should_lookup(self) -> bool {
27 match self {
28 RecorderMode::IgnoreRecordings => false,
29 RecorderMode::ForceNoRequests | RecorderMode::RecordOrRequest => true,
30 }
31 }
32
33 #[must_use]
34 pub fn should_request(self) -> bool {
35 match self {
36 RecorderMode::IgnoreRecordings | RecorderMode::RecordOrRequest => true,
37 RecorderMode::ForceNoRequests => false,
38 }
39 }
40}
41
42static SHARED_RECORDER: OnceLock<RequestRecorder> = OnceLock::new();
43
44pub fn shared_recorder() -> &'static RequestRecorder {
45 SHARED_RECORDER.get_or_init(RequestRecorder::new)
46}
47
48#[derive(Default, Copy, Clone, Debug)]
49pub struct Recorder {
59 pub mode: RecorderMode,
60}
61
62impl Recorder {
63 #[must_use]
64 pub fn new() -> Self {
65 Self::default()
66 }
67
68 #[must_use]
69 pub fn mode(mut self, mode: RecorderMode) -> Self {
70 self.mode = mode;
71 self
72 }
73
74 fn should_lookup(self) -> bool {
75 self.mode.should_lookup()
76 }
77
78 fn should_request(self) -> bool {
79 self.mode.should_request()
80 }
81}
82
83#[async_trait]
84impl Middleware for Recorder {
85 #[allow(clippy::similar_names)]
86 async fn handle(&self, request: InMemoryRequest, next: Next<'_>) -> ProtocolResult<Response> {
87 let recorder = shared_recorder();
88
89 let request = HashableRequest(request);
90 if self.should_lookup() {
91 let recorded = recorder.get_response(&request);
92
93 if let Some(recorded) = recorded {
94 info!(url = request.uri().to_string(), "Using recorded response");
95
96 let (parts, body) = recorded.into_parts();
97 return Ok(Response::from_parts(parts, Body::InMemory(body)));
98 }
99 }
100
101 if !self.should_request() {
102 return Err(ProtocolError::IoError(std::io::Error::new(std::io::ErrorKind::NotFound, "No recording found")));
103 }
104
105 let response = next.run(request.clone()).await?;
106 let (parts, body) = response.into_parts();
107 let content_type = parts.headers.get(CONTENT_TYPE);
108 let body = body.into_content_type(content_type).await?;
109 let response = InMemoryResponse::from_parts(parts, body);
110
111 recorder.record_response(request.0, response.clone())?;
112
113 let (parts, body) = response.into_parts();
114 Ok(Response::from_parts(parts, Body::InMemory(body)))
115 }
116}