rama_http/layer/traffic_writer/
request.rs

1use super::WriterMode;
2use crate::dep::http_body;
3use crate::dep::http_body_util::BodyExt;
4use crate::io::write_http_request;
5use crate::{Body, Request};
6use bytes::Bytes;
7use rama_core::error::{BoxError, ErrorExt, OpaqueError};
8use rama_core::rt::Executor;
9use rama_core::{Context, Service};
10use std::fmt::Debug;
11use tokio::io::{AsyncWrite, AsyncWriteExt, stderr, stdout};
12use tokio::sync::mpsc::{Sender, UnboundedSender, channel, unbounded_channel};
13
14/// A trait for writing http requests.
15pub trait RequestWriter: Send + Sync + 'static {
16    /// Write the http request.
17    fn write_request(&self, req: Request) -> impl Future<Output = ()> + Send + '_;
18}
19
20/// Marker struct to indicate that the request should not be printed.
21#[derive(Debug, Clone, Default)]
22#[non_exhaustive]
23pub struct DoNotWriteRequest;
24
25impl DoNotWriteRequest {
26    /// Create a new [`DoNotWriteRequest`] marker.
27    pub const fn new() -> Self {
28        Self
29    }
30}
31
32/// Middleware to print Http request in std format.
33///
34/// See the [module docs](super) for more details.
35pub struct RequestWriterInspector<W> {
36    writer: W,
37}
38
39impl<W> RequestWriterInspector<W> {
40    /// Create a new [`RequestWriterInspector`] with a custom [`RequestWriter`].
41    pub const fn new(writer: W) -> Self {
42        Self { writer }
43    }
44}
45
46impl<W> Debug for RequestWriterInspector<W> {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("RequestWriterInspector")
49            .field("writer", &format_args!("{}", std::any::type_name::<W>()))
50            .finish()
51    }
52}
53
54impl<W: Clone> Clone for RequestWriterInspector<W> {
55    fn clone(&self) -> Self {
56        Self {
57            writer: self.writer.clone(),
58        }
59    }
60}
61
62impl RequestWriterInspector<UnboundedSender<Request>> {
63    /// Create a new [`RequestWriterInspector`] that prints requests to an [`AsyncWrite`]r
64    /// over an unbounded channel
65    pub fn writer_unbounded<W>(executor: &Executor, mut writer: W, mode: Option<WriterMode>) -> Self
66    where
67        W: AsyncWrite + Unpin + Send + Sync + 'static,
68    {
69        let (tx, mut rx) = unbounded_channel();
70        let (write_headers, write_body) = match mode {
71            Some(WriterMode::All) => (true, true),
72            Some(WriterMode::Headers) => (true, false),
73            Some(WriterMode::Body) => (false, true),
74            None => (false, false),
75        };
76        executor.spawn_task(async move {
77            while let Some(req) = rx.recv().await {
78                if let Err(err) =
79                    write_http_request(&mut writer, req, write_headers, write_body).await
80                {
81                    tracing::error!(err = %err, "failed to write http request to writer")
82                }
83                if let Err(err) = writer.write_all(b"\r\n").await {
84                    tracing::error!(err = %err, "failed to write separator to writer")
85                }
86            }
87        });
88        Self { writer: tx }
89    }
90
91    /// Create a new [`RequestWriterInspector`] that prints requests to stdout
92    /// over an unbounded channel.
93    pub fn stdout_unbounded(executor: &Executor, mode: Option<WriterMode>) -> Self {
94        Self::writer_unbounded(executor, stdout(), mode)
95    }
96
97    /// Create a new [`RequestWriterInspector`] that prints requests to stderr
98    /// over an unbounded channel.
99    pub fn stderr_unbounded(executor: &Executor, mode: Option<WriterMode>) -> Self {
100        Self::writer_unbounded(executor, stderr(), mode)
101    }
102}
103
104impl RequestWriterInspector<Sender<Request>> {
105    /// Create a new [`RequestWriterInspector`] that prints requests to an [`AsyncWrite`]r
106    /// over a bounded channel with a fixed buffer size.
107    pub fn writer<W>(
108        executor: &Executor,
109        mut writer: W,
110        buffer_size: usize,
111        mode: Option<WriterMode>,
112    ) -> Self
113    where
114        W: AsyncWrite + Unpin + Send + Sync + 'static,
115    {
116        let (tx, mut rx) = channel(buffer_size);
117        let (write_headers, write_body) = match mode {
118            Some(WriterMode::All) => (true, true),
119            Some(WriterMode::Headers) => (true, false),
120            Some(WriterMode::Body) => (false, true),
121            None => (false, false),
122        };
123        executor.spawn_task(async move {
124            while let Some(req) = rx.recv().await {
125                if let Err(err) =
126                    write_http_request(&mut writer, req, write_headers, write_body).await
127                {
128                    tracing::error!(err = %err, "failed to write http request to writer")
129                }
130                if let Err(err) = writer.write_all(b"\r\n").await {
131                    tracing::error!(err = %err, "failed to write separator to writer")
132                }
133            }
134        });
135        Self { writer: tx }
136    }
137
138    /// Create a new [`RequestWriterInspector`] that prints requests to stdout
139    /// over a bounded channel with a fixed buffer size.
140    pub fn stdout(executor: &Executor, buffer_size: usize, mode: Option<WriterMode>) -> Self {
141        Self::writer(executor, stdout(), buffer_size, mode)
142    }
143
144    /// Create a new [`RequestWriterInspector`] that prints requests to stderr
145    /// over a bounded channel with a fixed buffer size.
146    pub fn stderr(executor: &Executor, buffer_size: usize, mode: Option<WriterMode>) -> Self {
147        Self::writer(executor, stderr(), buffer_size, mode)
148    }
149}
150
151impl<State, W, ReqBody> Service<State, Request<ReqBody>> for RequestWriterInspector<W>
152where
153    State: Clone + Send + Sync + 'static,
154    W: RequestWriter,
155    ReqBody: http_body::Body<Data = Bytes, Error: Into<BoxError>> + Send + Sync + 'static,
156{
157    type Error = BoxError;
158    type Response = (Context<State>, Request);
159
160    async fn serve(
161        &self,
162        ctx: Context<State>,
163        req: Request<ReqBody>,
164    ) -> Result<(Context<State>, Request), Self::Error> {
165        let req = match ctx.get::<DoNotWriteRequest>() {
166            Some(_) => req.map(Body::new),
167            None => {
168                let (parts, body) = req.into_parts();
169                let body_bytes = body
170                    .collect()
171                    .await
172                    .map_err(|err| {
173                        OpaqueError::from_boxed(err.into())
174                            .context("printer prepare: collect request body")
175                    })?
176                    .to_bytes();
177                let req = Request::from_parts(parts.clone(), Body::from(body_bytes.clone()));
178                self.writer.write_request(req).await;
179                Request::from_parts(parts, Body::from(body_bytes))
180            }
181        };
182        Ok((ctx, req))
183    }
184}
185
186impl RequestWriter for Sender<Request> {
187    async fn write_request(&self, req: Request) {
188        if let Err(err) = self.send(req).await {
189            tracing::error!(err = %err, "failed to send request to channel")
190        }
191    }
192}
193
194impl RequestWriter for UnboundedSender<Request> {
195    async fn write_request(&self, req: Request) {
196        if let Err(err) = self.send(req) {
197            tracing::error!(err = %err, "failed to send request to unbounded channel")
198        }
199    }
200}
201
202impl<F, Fut> RequestWriter for F
203where
204    F: Fn(Request) -> Fut + Send + Sync + 'static,
205    Fut: Future<Output = ()> + Send + 'static,
206{
207    async fn write_request(&self, req: Request) {
208        self(req).await
209    }
210}