rama_http/layer/traffic_writer/
request.rs

1use super::WriterMode;
2use crate::dep::http_body;
3use crate::dep::http_body_util::BodyExt;
4use crate::io::write_http_request;
5use crate::{Body, Request, Response};
6use bytes::Bytes;
7use rama_core::error::{BoxError, ErrorExt, OpaqueError};
8use rama_core::rt::Executor;
9use rama_core::{Context, Layer, Service};
10use rama_utils::macros::define_inner_service_accessors;
11use std::fmt::Debug;
12use std::future::Future;
13use tokio::io::{stderr, stdout, AsyncWrite, AsyncWriteExt};
14use tokio::sync::mpsc::{channel, unbounded_channel, Sender, UnboundedSender};
15
16/// Layer that applies [`RequestWriterService`] which prints the http request in std format.
17pub struct RequestWriterLayer<W> {
18    writer: W,
19}
20
21impl<W> Debug for RequestWriterLayer<W> {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("RequestWriterLayer")
24            .field("writer", &format_args!("{}", std::any::type_name::<W>()))
25            .finish()
26    }
27}
28
29impl<W: Clone> Clone for RequestWriterLayer<W> {
30    fn clone(&self) -> Self {
31        Self {
32            writer: self.writer.clone(),
33        }
34    }
35}
36
37impl<W> RequestWriterLayer<W> {
38    /// Create a new [`RequestWriterLayer`] with a custom [`RequestWriter`].
39    pub const fn new(writer: W) -> Self {
40        Self { writer }
41    }
42}
43
44/// A trait for writing http requests.
45pub trait RequestWriter: Send + Sync + 'static {
46    /// Write the http request.
47    fn write_request(&self, req: Request) -> impl Future<Output = ()> + Send + '_;
48}
49
50/// Marker struct to indicate that the request should not be printed.
51#[derive(Debug, Clone, Default)]
52#[non_exhaustive]
53pub struct DoNotWriteRequest;
54
55impl DoNotWriteRequest {
56    /// Create a new [`DoNotWriteRequest`] marker.
57    pub const fn new() -> Self {
58        Self
59    }
60}
61
62impl RequestWriterLayer<UnboundedSender<Request>> {
63    /// Create a new [`RequestWriterLayer`] that prints requests to an [`AsyncWrite`]r
64    /// over an unbounded channel
65    pub fn writer_unbounded<W>(executor: &Executor, mut writer: W, mode: Option<WriterMode>) -> Self
66    where
67        W: AsyncWrite + Unpin + Send + Sync + 'static,
68    {
69        let (tx, mut rx) = unbounded_channel();
70        let (write_headers, write_body) = match mode {
71            Some(WriterMode::All) => (true, true),
72            Some(WriterMode::Headers) => (true, false),
73            Some(WriterMode::Body) => (false, true),
74            None => (false, false),
75        };
76        executor.spawn_task(async move {
77            while let Some(req) = rx.recv().await {
78                if let Err(err) =
79                    write_http_request(&mut writer, req, write_headers, write_body).await
80                {
81                    tracing::error!(err = %err, "failed to write http request to writer")
82                }
83                if let Err(err) = writer.write_all(b"\r\n").await {
84                    tracing::error!(err = %err, "failed to write separator to writer")
85                }
86            }
87        });
88        Self { writer: tx }
89    }
90
91    /// Create a new [`RequestWriterLayer`] that prints requests to stdout
92    /// over an unbounded channel.
93    pub fn stdout_unbounded(executor: &Executor, mode: Option<WriterMode>) -> Self {
94        Self::writer_unbounded(executor, stdout(), mode)
95    }
96
97    /// Create a new [`RequestWriterLayer`] that prints requests to stderr
98    /// over an unbounded channel.
99    pub fn stderr_unbounded(executor: &Executor, mode: Option<WriterMode>) -> Self {
100        Self::writer_unbounded(executor, stderr(), mode)
101    }
102}
103
104impl RequestWriterLayer<Sender<Request>> {
105    /// Create a new [`RequestWriterLayer`] that prints requests to an [`AsyncWrite`]r
106    /// over a bounded channel with a fixed buffer size.
107    pub fn writer<W>(
108        executor: &Executor,
109        mut writer: W,
110        buffer_size: usize,
111        mode: Option<WriterMode>,
112    ) -> Self
113    where
114        W: AsyncWrite + Unpin + Send + Sync + 'static,
115    {
116        let (tx, mut rx) = channel(buffer_size);
117        let (write_headers, write_body) = match mode {
118            Some(WriterMode::All) => (true, true),
119            Some(WriterMode::Headers) => (true, false),
120            Some(WriterMode::Body) => (false, true),
121            None => (false, false),
122        };
123        executor.spawn_task(async move {
124            while let Some(req) = rx.recv().await {
125                if let Err(err) =
126                    write_http_request(&mut writer, req, write_headers, write_body).await
127                {
128                    tracing::error!(err = %err, "failed to write http request to writer")
129                }
130                if let Err(err) = writer.write_all(b"\r\n").await {
131                    tracing::error!(err = %err, "failed to write separator to writer")
132                }
133            }
134        });
135        Self { writer: tx }
136    }
137
138    /// Create a new [`RequestWriterLayer`] that prints requests to stdout
139    /// over a bounded channel with a fixed buffer size.
140    pub fn stdout(executor: &Executor, buffer_size: usize, mode: Option<WriterMode>) -> Self {
141        Self::writer(executor, stdout(), buffer_size, mode)
142    }
143
144    /// Create a new [`RequestWriterLayer`] that prints requests to stderr
145    /// over a bounded channel with a fixed buffer size.
146    pub fn stderr(executor: &Executor, buffer_size: usize, mode: Option<WriterMode>) -> Self {
147        Self::writer(executor, stderr(), buffer_size, mode)
148    }
149}
150
151impl<S, W: Clone> Layer<S> for RequestWriterLayer<W> {
152    type Service = RequestWriterService<S, W>;
153
154    fn layer(&self, inner: S) -> Self::Service {
155        RequestWriterService {
156            inner,
157            writer: self.writer.clone(),
158        }
159    }
160}
161
162/// Middleware to print Http request in std format.
163///
164/// See the [module docs](super) for more details.
165pub struct RequestWriterService<S, W> {
166    inner: S,
167    writer: W,
168}
169
170impl<S, W> RequestWriterService<S, W> {
171    /// Create a new [`RequestWriterService`] with a custom [`RequestWriter`].
172    pub const fn new(writer: W, inner: S) -> Self {
173        Self { inner, writer }
174    }
175
176    define_inner_service_accessors!();
177}
178
179impl<S: Debug, W> Debug for RequestWriterService<S, W> {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        f.debug_struct("RequestWriterService")
182            .field("inner", &self.inner)
183            .field("writer", &format_args!("{}", std::any::type_name::<W>()))
184            .finish()
185    }
186}
187
188impl<S: Clone, W: Clone> Clone for RequestWriterService<S, W> {
189    fn clone(&self) -> Self {
190        Self {
191            inner: self.inner.clone(),
192            writer: self.writer.clone(),
193        }
194    }
195}
196
197impl<S> RequestWriterService<S, UnboundedSender<Request>> {
198    /// Create a new [`RequestWriterService`] that prints requests to an [`AsyncWrite`]r
199    /// over an unbounded channel
200    pub fn writer_unbounded<W>(
201        executor: &Executor,
202        writer: W,
203        mode: Option<WriterMode>,
204        inner: S,
205    ) -> Self
206    where
207        W: AsyncWrite + Unpin + Send + Sync + 'static,
208    {
209        let layer = RequestWriterLayer::writer_unbounded(executor, writer, mode);
210        layer.layer(inner)
211    }
212
213    /// Create a new [`RequestWriterService`] that prints requests to stdout
214    /// over an unbounded channel.
215    pub fn stdout_unbounded(executor: &Executor, mode: Option<WriterMode>, inner: S) -> Self {
216        Self::writer_unbounded(executor, stdout(), mode, inner)
217    }
218
219    /// Create a new [`RequestWriterService`] that prints requests to stderr
220    /// over an unbounded channel.
221    pub fn stderr_unbounded(executor: &Executor, mode: Option<WriterMode>, inner: S) -> Self {
222        Self::writer_unbounded(executor, stderr(), mode, inner)
223    }
224}
225
226impl<S> RequestWriterService<S, Sender<Request>> {
227    /// Create a new [`RequestWriterService`] that prints requests to an [`AsyncWrite`]r
228    /// over a bounded channel with a fixed buffer size.
229    pub fn writer<W>(
230        executor: &Executor,
231        writer: W,
232        buffer_size: usize,
233        mode: Option<WriterMode>,
234        inner: S,
235    ) -> Self
236    where
237        W: AsyncWrite + Unpin + Send + Sync + 'static,
238    {
239        let layer = RequestWriterLayer::writer(executor, writer, buffer_size, mode);
240        layer.layer(inner)
241    }
242
243    /// Create a new [`RequestWriterService`] that prints requests to stdout
244    /// over a bounded channel with a fixed buffer size.
245    pub fn stdout(
246        executor: &Executor,
247        buffer_size: usize,
248        mode: Option<WriterMode>,
249        inner: S,
250    ) -> Self {
251        Self::writer(executor, stdout(), buffer_size, mode, inner)
252    }
253
254    /// Create a new [`RequestWriterService`] that prints requests to stderr
255    /// over a bounded channel with a fixed buffer size.
256    pub fn stderr(
257        executor: &Executor,
258        buffer_size: usize,
259        mode: Option<WriterMode>,
260        inner: S,
261    ) -> Self {
262        Self::writer(executor, stderr(), buffer_size, mode, inner)
263    }
264}
265
266impl<S, W> RequestWriterService<S, W> {}
267
268impl<State, S, W, ReqBody, ResBody> Service<State, Request<ReqBody>> for RequestWriterService<S, W>
269where
270    State: Clone + Send + Sync + 'static,
271    S: Service<State, Request, Response = Response<ResBody>, Error: Into<BoxError>>,
272    W: RequestWriter,
273    ReqBody: http_body::Body<Data = Bytes, Error: Into<BoxError>> + Send + Sync + 'static,
274    ResBody: Send + 'static,
275{
276    type Response = Response<ResBody>;
277    type Error = BoxError;
278
279    async fn serve(
280        &self,
281        ctx: Context<State>,
282        req: Request<ReqBody>,
283    ) -> Result<Self::Response, Self::Error> {
284        let req = match ctx.get::<DoNotWriteRequest>() {
285            Some(_) => req.map(Body::new),
286            None => {
287                let (parts, body) = req.into_parts();
288                let body_bytes = body
289                    .collect()
290                    .await
291                    .map_err(|err| {
292                        OpaqueError::from_boxed(err.into())
293                            .context("printer prepare: collect request body")
294                    })?
295                    .to_bytes();
296                let req = Request::from_parts(parts.clone(), Body::from(body_bytes.clone()));
297                self.writer.write_request(req).await;
298                Request::from_parts(parts, Body::from(body_bytes))
299            }
300        };
301        self.inner.serve(ctx, req).await.map_err(Into::into)
302    }
303}
304
305impl RequestWriter for Sender<Request> {
306    async fn write_request(&self, req: Request) {
307        if let Err(err) = self.send(req).await {
308            tracing::error!(err = %err, "failed to send request to channel")
309        }
310    }
311}
312
313impl RequestWriter for UnboundedSender<Request> {
314    async fn write_request(&self, req: Request) {
315        if let Err(err) = self.send(req) {
316            tracing::error!(err = %err, "failed to send request to unbounded channel")
317        }
318    }
319}
320
321impl<F, Fut> RequestWriter for F
322where
323    F: Fn(Request) -> Fut + Send + Sync + 'static,
324    Fut: Future<Output = ()> + Send + 'static,
325{
326    async fn write_request(&self, req: Request) {
327        self(req).await
328    }
329}