rama_http/layer/traffic_writer/
request.rs1use super::WriterMode;
2use crate::dep::http_body;
3use crate::dep::http_body_util::BodyExt;
4use crate::io::write_http_request;
5use crate::{Body, Request};
6use bytes::Bytes;
7use rama_core::error::{BoxError, ErrorExt, OpaqueError};
8use rama_core::rt::Executor;
9use rama_core::{Context, Service};
10use std::fmt::Debug;
11use tokio::io::{AsyncWrite, AsyncWriteExt, stderr, stdout};
12use tokio::sync::mpsc::{Sender, UnboundedSender, channel, unbounded_channel};
13
14pub trait RequestWriter: Send + Sync + 'static {
16 fn write_request(&self, req: Request) -> impl Future<Output = ()> + Send + '_;
18}
19
20#[derive(Debug, Clone, Default)]
22#[non_exhaustive]
23pub struct DoNotWriteRequest;
24
25impl DoNotWriteRequest {
26 pub const fn new() -> Self {
28 Self
29 }
30}
31
32pub struct RequestWriterInspector<W> {
36 writer: W,
37}
38
39impl<W> RequestWriterInspector<W> {
40 pub const fn new(writer: W) -> Self {
42 Self { writer }
43 }
44}
45
46impl<W> Debug for RequestWriterInspector<W> {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("RequestWriterInspector")
49 .field("writer", &format_args!("{}", std::any::type_name::<W>()))
50 .finish()
51 }
52}
53
54impl<W: Clone> Clone for RequestWriterInspector<W> {
55 fn clone(&self) -> Self {
56 Self {
57 writer: self.writer.clone(),
58 }
59 }
60}
61
62impl RequestWriterInspector<UnboundedSender<Request>> {
63 pub fn writer_unbounded<W>(executor: &Executor, mut writer: W, mode: Option<WriterMode>) -> Self
66 where
67 W: AsyncWrite + Unpin + Send + Sync + 'static,
68 {
69 let (tx, mut rx) = unbounded_channel();
70 let (write_headers, write_body) = match mode {
71 Some(WriterMode::All) => (true, true),
72 Some(WriterMode::Headers) => (true, false),
73 Some(WriterMode::Body) => (false, true),
74 None => (false, false),
75 };
76 executor.spawn_task(async move {
77 while let Some(req) = rx.recv().await {
78 if let Err(err) =
79 write_http_request(&mut writer, req, write_headers, write_body).await
80 {
81 tracing::error!(err = %err, "failed to write http request to writer")
82 }
83 if let Err(err) = writer.write_all(b"\r\n").await {
84 tracing::error!(err = %err, "failed to write separator to writer")
85 }
86 }
87 });
88 Self { writer: tx }
89 }
90
91 pub fn stdout_unbounded(executor: &Executor, mode: Option<WriterMode>) -> Self {
94 Self::writer_unbounded(executor, stdout(), mode)
95 }
96
97 pub fn stderr_unbounded(executor: &Executor, mode: Option<WriterMode>) -> Self {
100 Self::writer_unbounded(executor, stderr(), mode)
101 }
102}
103
104impl RequestWriterInspector<Sender<Request>> {
105 pub fn writer<W>(
108 executor: &Executor,
109 mut writer: W,
110 buffer_size: usize,
111 mode: Option<WriterMode>,
112 ) -> Self
113 where
114 W: AsyncWrite + Unpin + Send + Sync + 'static,
115 {
116 let (tx, mut rx) = channel(buffer_size);
117 let (write_headers, write_body) = match mode {
118 Some(WriterMode::All) => (true, true),
119 Some(WriterMode::Headers) => (true, false),
120 Some(WriterMode::Body) => (false, true),
121 None => (false, false),
122 };
123 executor.spawn_task(async move {
124 while let Some(req) = rx.recv().await {
125 if let Err(err) =
126 write_http_request(&mut writer, req, write_headers, write_body).await
127 {
128 tracing::error!(err = %err, "failed to write http request to writer")
129 }
130 if let Err(err) = writer.write_all(b"\r\n").await {
131 tracing::error!(err = %err, "failed to write separator to writer")
132 }
133 }
134 });
135 Self { writer: tx }
136 }
137
138 pub fn stdout(executor: &Executor, buffer_size: usize, mode: Option<WriterMode>) -> Self {
141 Self::writer(executor, stdout(), buffer_size, mode)
142 }
143
144 pub fn stderr(executor: &Executor, buffer_size: usize, mode: Option<WriterMode>) -> Self {
147 Self::writer(executor, stderr(), buffer_size, mode)
148 }
149}
150
151impl<State, W, ReqBody> Service<State, Request<ReqBody>> for RequestWriterInspector<W>
152where
153 State: Clone + Send + Sync + 'static,
154 W: RequestWriter,
155 ReqBody: http_body::Body<Data = Bytes, Error: Into<BoxError>> + Send + Sync + 'static,
156{
157 type Error = BoxError;
158 type Response = (Context<State>, Request);
159
160 async fn serve(
161 &self,
162 ctx: Context<State>,
163 req: Request<ReqBody>,
164 ) -> Result<(Context<State>, Request), Self::Error> {
165 let req = match ctx.get::<DoNotWriteRequest>() {
166 Some(_) => req.map(Body::new),
167 None => {
168 let (parts, body) = req.into_parts();
169 let body_bytes = body
170 .collect()
171 .await
172 .map_err(|err| {
173 OpaqueError::from_boxed(err.into())
174 .context("printer prepare: collect request body")
175 })?
176 .to_bytes();
177 let req = Request::from_parts(parts.clone(), Body::from(body_bytes.clone()));
178 self.writer.write_request(req).await;
179 Request::from_parts(parts, Body::from(body_bytes))
180 }
181 };
182 Ok((ctx, req))
183 }
184}
185
186impl RequestWriter for Sender<Request> {
187 async fn write_request(&self, req: Request) {
188 if let Err(err) = self.send(req).await {
189 tracing::error!(err = %err, "failed to send request to channel")
190 }
191 }
192}
193
194impl RequestWriter for UnboundedSender<Request> {
195 async fn write_request(&self, req: Request) {
196 if let Err(err) = self.send(req) {
197 tracing::error!(err = %err, "failed to send request to unbounded channel")
198 }
199 }
200}
201
202impl<F, Fut> RequestWriter for F
203where
204 F: Fn(Request) -> Fut + Send + Sync + 'static,
205 Fut: Future<Output = ()> + Send + 'static,
206{
207 async fn write_request(&self, req: Request) {
208 self(req).await
209 }
210}