rama_http/layer/traffic_writer/
response.rs

1use super::WriterMode;
2use crate::dep::http_body;
3use crate::dep::http_body_util::BodyExt;
4use crate::io::write_http_response;
5use crate::{Body, Request, Response};
6use bytes::Bytes;
7use rama_core::error::{BoxError, ErrorContext, OpaqueError};
8use rama_core::rt::Executor;
9use rama_core::{Context, Layer, Service};
10use rama_utils::macros::define_inner_service_accessors;
11use std::fmt::Debug;
12use tokio::io::{AsyncWrite, stderr, stdout};
13use tokio::sync::mpsc::{Sender, UnboundedSender, channel, unbounded_channel};
14
15/// Layer that applies [`ResponseWriterService`] which prints the http response in std format.
16pub struct ResponseWriterLayer<W> {
17    writer: W,
18}
19
20impl<W> Debug for ResponseWriterLayer<W> {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        f.debug_struct("ResponseWriterLayer")
23            .field("writer", &format_args!("{}", std::any::type_name::<W>()))
24            .finish()
25    }
26}
27
28impl<W: Clone> Clone for ResponseWriterLayer<W> {
29    fn clone(&self) -> Self {
30        Self {
31            writer: self.writer.clone(),
32        }
33    }
34}
35
36impl<W> ResponseWriterLayer<W> {
37    /// Create a new [`ResponseWriterLayer`] with a custom [`ResponseWriter`].
38    pub const fn new(writer: W) -> Self {
39        Self { writer }
40    }
41}
42
43/// A trait for writing http responses.
44pub trait ResponseWriter: Send + Sync + 'static {
45    /// Write the http response.
46    fn write_response(&self, res: Response) -> impl Future<Output = ()> + Send + '_;
47}
48
49/// Marker struct to indicate that the response should not be printed.
50#[derive(Debug, Clone, Default)]
51#[non_exhaustive]
52pub struct DoNotWriteResponse;
53
54impl DoNotWriteResponse {
55    /// Create a new [`DoNotWriteResponse`] marker.
56    pub const fn new() -> Self {
57        Self
58    }
59}
60
61impl ResponseWriterLayer<UnboundedSender<Response>> {
62    /// Create a new [`ResponseWriterLayer`] that prints responses to an [`AsyncWrite`]r
63    /// over an unbounded channel
64    pub fn writer_unbounded<W>(executor: &Executor, mut writer: W, mode: Option<WriterMode>) -> Self
65    where
66        W: AsyncWrite + Unpin + Send + Sync + 'static,
67    {
68        let (tx, mut rx) = unbounded_channel();
69        let (write_headers, write_body) = match mode {
70            Some(WriterMode::All) => (true, true),
71            Some(WriterMode::Headers) => (true, false),
72            Some(WriterMode::Body) => (false, true),
73            None => (false, false),
74        };
75        executor.spawn_task(async move {
76            while let Some(res) = rx.recv().await {
77                if let Err(err) =
78                    write_http_response(&mut writer, res, write_headers, write_body).await
79                {
80                    tracing::error!(err = %err, "failed to write http response to writer")
81                }
82            }
83        });
84        Self { writer: tx }
85    }
86
87    /// Create a new [`ResponseWriterLayer`] that prints responses to stdout
88    /// over an unbounded channel.
89    pub fn stdout_unbounded(executor: &Executor, mode: Option<WriterMode>) -> Self {
90        Self::writer_unbounded(executor, stdout(), mode)
91    }
92
93    /// Create a new [`ResponseWriterLayer`] that prints responses to stderr
94    /// over an unbounded channel.
95    pub fn stderr_unbounded(executor: &Executor, mode: Option<WriterMode>) -> Self {
96        Self::writer_unbounded(executor, stderr(), mode)
97    }
98}
99
100impl ResponseWriterLayer<Sender<Response>> {
101    /// Create a new [`ResponseWriterLayer`] that prints responses to an [`AsyncWrite`]r
102    /// over a bounded channel with a fixed buffer size.
103    pub fn writer<W>(
104        executor: &Executor,
105        mut writer: W,
106        buffer_size: usize,
107        mode: Option<WriterMode>,
108    ) -> Self
109    where
110        W: AsyncWrite + Unpin + Send + Sync + 'static,
111    {
112        let (tx, mut rx) = channel(buffer_size);
113        let (write_headers, write_body) = match mode {
114            Some(WriterMode::All) => (true, true),
115            Some(WriterMode::Headers) => (true, false),
116            Some(WriterMode::Body) => (false, true),
117            None => (false, false),
118        };
119        executor.spawn_task(async move {
120            while let Some(res) = rx.recv().await {
121                if let Err(err) =
122                    write_http_response(&mut writer, res, write_headers, write_body).await
123                {
124                    tracing::error!(err = %err, "failed to write http response to writer")
125                }
126            }
127        });
128        Self { writer: tx }
129    }
130
131    /// Create a new [`ResponseWriterLayer`] that prints responses to stdout
132    /// over a bounded channel with a fixed buffer size.
133    pub fn stdout(executor: &Executor, buffer_size: usize, mode: Option<WriterMode>) -> Self {
134        Self::writer(executor, stdout(), buffer_size, mode)
135    }
136
137    /// Create a new [`ResponseWriterLayer`] that prints responses to stderr
138    /// over a bounded channel with a fixed buffer size.
139    pub fn stderr(executor: &Executor, buffer_size: usize, mode: Option<WriterMode>) -> Self {
140        Self::writer(executor, stderr(), buffer_size, mode)
141    }
142}
143
144impl<S, W: Clone> Layer<S> for ResponseWriterLayer<W> {
145    type Service = ResponseWriterService<S, W>;
146
147    fn layer(&self, inner: S) -> Self::Service {
148        ResponseWriterService {
149            inner,
150            writer: self.writer.clone(),
151        }
152    }
153
154    fn into_layer(self, inner: S) -> Self::Service {
155        ResponseWriterService {
156            inner,
157            writer: self.writer,
158        }
159    }
160}
161
162/// Middleware to print Http request in std format.
163///
164/// See the [module docs](super) for more details.
165pub struct ResponseWriterService<S, W> {
166    inner: S,
167    writer: W,
168}
169
170impl<S, W> ResponseWriterService<S, W> {
171    /// Create a new [`ResponseWriterService`] with a custom [`ResponseWriter`].
172    pub const fn new(writer: W, inner: S) -> Self {
173        Self { inner, writer }
174    }
175
176    define_inner_service_accessors!();
177}
178
179impl<S: Debug, W> Debug for ResponseWriterService<S, W> {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        f.debug_struct("ResponseWriterService")
182            .field("inner", &self.inner)
183            .field("writer", &format_args!("{}", std::any::type_name::<W>()))
184            .finish()
185    }
186}
187
188impl<S: Clone, W: Clone> Clone for ResponseWriterService<S, W> {
189    fn clone(&self) -> Self {
190        Self {
191            inner: self.inner.clone(),
192            writer: self.writer.clone(),
193        }
194    }
195}
196
197impl<S> ResponseWriterService<S, UnboundedSender<Response>> {
198    /// Create a new [`ResponseWriterService`] that prints responses to an [`AsyncWrite`]r
199    /// over an unbounded channel
200    pub fn writer_unbounded<W>(
201        executor: &Executor,
202        writer: W,
203        mode: Option<WriterMode>,
204        inner: S,
205    ) -> Self
206    where
207        W: AsyncWrite + Unpin + Send + Sync + 'static,
208    {
209        let layer = ResponseWriterLayer::writer_unbounded(executor, writer, mode);
210        layer.into_layer(inner)
211    }
212
213    /// Create a new [`ResponseWriterService`] that prints responses to stdout
214    /// over an unbounded channel.
215    pub fn stdout_unbounded(executor: &Executor, mode: Option<WriterMode>, inner: S) -> Self {
216        Self::writer_unbounded(executor, stdout(), mode, inner)
217    }
218
219    /// Create a new [`ResponseWriterService`] that prints responses to stderr
220    /// over an unbounded channel.
221    pub fn stderr_unbounded(executor: &Executor, mode: Option<WriterMode>, inner: S) -> Self {
222        Self::writer_unbounded(executor, stderr(), mode, inner)
223    }
224}
225
226impl<S> ResponseWriterService<S, Sender<Response>> {
227    /// Create a new [`ResponseWriterService`] that prints responses to an [`AsyncWrite`]r
228    /// over a bounded channel with a fixed buffer size.
229    pub fn writer<W>(
230        executor: &Executor,
231        writer: W,
232        buffer_size: usize,
233        mode: Option<WriterMode>,
234        inner: S,
235    ) -> Self
236    where
237        W: AsyncWrite + Unpin + Send + Sync + 'static,
238    {
239        let layer = ResponseWriterLayer::writer(executor, writer, buffer_size, mode);
240        layer.into_layer(inner)
241    }
242
243    /// Create a new [`ResponseWriterService`] that prints responses to stdout
244    /// over a bounded channel with a fixed buffer size.
245    pub fn stdout(
246        executor: &Executor,
247        buffer_size: usize,
248        mode: Option<WriterMode>,
249        inner: S,
250    ) -> Self {
251        Self::writer(executor, stdout(), buffer_size, mode, inner)
252    }
253
254    /// Create a new [`ResponseWriterService`] that prints responses to stderr
255    /// over a bounded channel with a fixed buffer size.
256    pub fn stderr(
257        executor: &Executor,
258        buffer_size: usize,
259        mode: Option<WriterMode>,
260        inner: S,
261    ) -> Self {
262        Self::writer(executor, stderr(), buffer_size, mode, inner)
263    }
264}
265
266impl<S, W> ResponseWriterService<S, W> {}
267
268impl<State, S, W, ReqBody, ResBody> Service<State, Request<ReqBody>> for ResponseWriterService<S, W>
269where
270    State: Clone + Send + Sync + 'static,
271    S: Service<State, Request<ReqBody>, Response = Response<ResBody>, Error: Into<BoxError>>,
272    W: ResponseWriter,
273    ReqBody: Send + 'static,
274    ResBody: http_body::Body<Data = Bytes, Error: Into<BoxError>> + Send + Sync + 'static,
275{
276    type Response = Response;
277    type Error = BoxError;
278
279    async fn serve(
280        &self,
281        ctx: Context<State>,
282        req: Request<ReqBody>,
283    ) -> Result<Self::Response, Self::Error> {
284        let do_not_print_response: Option<DoNotWriteResponse> = ctx.get().cloned();
285        let resp = self.inner.serve(ctx, req).await.map_err(Into::into)?;
286        let resp = match do_not_print_response {
287            Some(_) => resp.map(Body::new),
288            None => {
289                let (parts, body) = resp.into_parts();
290                let body_bytes = body
291                    .collect()
292                    .await
293                    .map_err(|err| OpaqueError::from_boxed(err.into()))
294                    .context("printer prepare: collect response body")?
295                    .to_bytes();
296                let resp: rama_http_types::Response<Body> =
297                    Response::from_parts(parts.clone(), Body::from(body_bytes.clone()));
298                self.writer.write_response(resp).await;
299                Response::from_parts(parts, Body::from(body_bytes))
300            }
301        };
302        Ok(resp)
303    }
304}
305
306impl ResponseWriter for Sender<Response> {
307    async fn write_response(&self, res: Response) {
308        if let Err(err) = self.send(res).await {
309            tracing::error!(err = %err, "failed to send response to channel")
310        }
311    }
312}
313
314impl ResponseWriter for UnboundedSender<Response> {
315    async fn write_response(&self, res: Response) {
316        if let Err(err) = self.send(res) {
317            tracing::error!(err = %err, "failed to send response to unbounded channel")
318        }
319    }
320}
321
322impl<F, Fut> ResponseWriter for F
323where
324    F: Fn(Response) -> Fut + Send + Sync + 'static,
325    Fut: Future<Output = ()> + Send + 'static,
326{
327    async fn write_response(&self, res: Response) {
328        self(res).await
329    }
330}