1use super::WriterMode;
2use crate::dep::http_body;
3use crate::dep::http_body_util::BodyExt;
4use crate::io::write_http_request;
5use crate::{Body, Request, Response};
6use bytes::Bytes;
7use rama_core::error::{BoxError, ErrorExt, OpaqueError};
8use rama_core::rt::Executor;
9use rama_core::{Context, Layer, Service};
10use rama_utils::macros::define_inner_service_accessors;
11use std::fmt::Debug;
12use std::future::Future;
13use tokio::io::{stderr, stdout, AsyncWrite, AsyncWriteExt};
14use tokio::sync::mpsc::{channel, unbounded_channel, Sender, UnboundedSender};
15
16pub struct RequestWriterLayer<W> {
18 writer: W,
19}
20
21impl<W> Debug for RequestWriterLayer<W> {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("RequestWriterLayer")
24 .field("writer", &format_args!("{}", std::any::type_name::<W>()))
25 .finish()
26 }
27}
28
29impl<W: Clone> Clone for RequestWriterLayer<W> {
30 fn clone(&self) -> Self {
31 Self {
32 writer: self.writer.clone(),
33 }
34 }
35}
36
37impl<W> RequestWriterLayer<W> {
38 pub const fn new(writer: W) -> Self {
40 Self { writer }
41 }
42}
43
44pub trait RequestWriter: Send + Sync + 'static {
46 fn write_request(&self, req: Request) -> impl Future<Output = ()> + Send + '_;
48}
49
50#[derive(Debug, Clone, Default)]
52#[non_exhaustive]
53pub struct DoNotWriteRequest;
54
55impl DoNotWriteRequest {
56 pub const fn new() -> Self {
58 Self
59 }
60}
61
62impl RequestWriterLayer<UnboundedSender<Request>> {
63 pub fn writer_unbounded<W>(executor: &Executor, mut writer: W, mode: Option<WriterMode>) -> Self
66 where
67 W: AsyncWrite + Unpin + Send + Sync + 'static,
68 {
69 let (tx, mut rx) = unbounded_channel();
70 let (write_headers, write_body) = match mode {
71 Some(WriterMode::All) => (true, true),
72 Some(WriterMode::Headers) => (true, false),
73 Some(WriterMode::Body) => (false, true),
74 None => (false, false),
75 };
76 executor.spawn_task(async move {
77 while let Some(req) = rx.recv().await {
78 if let Err(err) =
79 write_http_request(&mut writer, req, write_headers, write_body).await
80 {
81 tracing::error!(err = %err, "failed to write http request to writer")
82 }
83 if let Err(err) = writer.write_all(b"\r\n").await {
84 tracing::error!(err = %err, "failed to write separator to writer")
85 }
86 }
87 });
88 Self { writer: tx }
89 }
90
91 pub fn stdout_unbounded(executor: &Executor, mode: Option<WriterMode>) -> Self {
94 Self::writer_unbounded(executor, stdout(), mode)
95 }
96
97 pub fn stderr_unbounded(executor: &Executor, mode: Option<WriterMode>) -> Self {
100 Self::writer_unbounded(executor, stderr(), mode)
101 }
102}
103
104impl RequestWriterLayer<Sender<Request>> {
105 pub fn writer<W>(
108 executor: &Executor,
109 mut writer: W,
110 buffer_size: usize,
111 mode: Option<WriterMode>,
112 ) -> Self
113 where
114 W: AsyncWrite + Unpin + Send + Sync + 'static,
115 {
116 let (tx, mut rx) = channel(buffer_size);
117 let (write_headers, write_body) = match mode {
118 Some(WriterMode::All) => (true, true),
119 Some(WriterMode::Headers) => (true, false),
120 Some(WriterMode::Body) => (false, true),
121 None => (false, false),
122 };
123 executor.spawn_task(async move {
124 while let Some(req) = rx.recv().await {
125 if let Err(err) =
126 write_http_request(&mut writer, req, write_headers, write_body).await
127 {
128 tracing::error!(err = %err, "failed to write http request to writer")
129 }
130 if let Err(err) = writer.write_all(b"\r\n").await {
131 tracing::error!(err = %err, "failed to write separator to writer")
132 }
133 }
134 });
135 Self { writer: tx }
136 }
137
138 pub fn stdout(executor: &Executor, buffer_size: usize, mode: Option<WriterMode>) -> Self {
141 Self::writer(executor, stdout(), buffer_size, mode)
142 }
143
144 pub fn stderr(executor: &Executor, buffer_size: usize, mode: Option<WriterMode>) -> Self {
147 Self::writer(executor, stderr(), buffer_size, mode)
148 }
149}
150
151impl<S, W: Clone> Layer<S> for RequestWriterLayer<W> {
152 type Service = RequestWriterService<S, W>;
153
154 fn layer(&self, inner: S) -> Self::Service {
155 RequestWriterService {
156 inner,
157 writer: self.writer.clone(),
158 }
159 }
160}
161
162pub struct RequestWriterService<S, W> {
166 inner: S,
167 writer: W,
168}
169
170impl<S, W> RequestWriterService<S, W> {
171 pub const fn new(writer: W, inner: S) -> Self {
173 Self { inner, writer }
174 }
175
176 define_inner_service_accessors!();
177}
178
179impl<S: Debug, W> Debug for RequestWriterService<S, W> {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 f.debug_struct("RequestWriterService")
182 .field("inner", &self.inner)
183 .field("writer", &format_args!("{}", std::any::type_name::<W>()))
184 .finish()
185 }
186}
187
188impl<S: Clone, W: Clone> Clone for RequestWriterService<S, W> {
189 fn clone(&self) -> Self {
190 Self {
191 inner: self.inner.clone(),
192 writer: self.writer.clone(),
193 }
194 }
195}
196
197impl<S> RequestWriterService<S, UnboundedSender<Request>> {
198 pub fn writer_unbounded<W>(
201 executor: &Executor,
202 writer: W,
203 mode: Option<WriterMode>,
204 inner: S,
205 ) -> Self
206 where
207 W: AsyncWrite + Unpin + Send + Sync + 'static,
208 {
209 let layer = RequestWriterLayer::writer_unbounded(executor, writer, mode);
210 layer.layer(inner)
211 }
212
213 pub fn stdout_unbounded(executor: &Executor, mode: Option<WriterMode>, inner: S) -> Self {
216 Self::writer_unbounded(executor, stdout(), mode, inner)
217 }
218
219 pub fn stderr_unbounded(executor: &Executor, mode: Option<WriterMode>, inner: S) -> Self {
222 Self::writer_unbounded(executor, stderr(), mode, inner)
223 }
224}
225
226impl<S> RequestWriterService<S, Sender<Request>> {
227 pub fn writer<W>(
230 executor: &Executor,
231 writer: W,
232 buffer_size: usize,
233 mode: Option<WriterMode>,
234 inner: S,
235 ) -> Self
236 where
237 W: AsyncWrite + Unpin + Send + Sync + 'static,
238 {
239 let layer = RequestWriterLayer::writer(executor, writer, buffer_size, mode);
240 layer.layer(inner)
241 }
242
243 pub fn stdout(
246 executor: &Executor,
247 buffer_size: usize,
248 mode: Option<WriterMode>,
249 inner: S,
250 ) -> Self {
251 Self::writer(executor, stdout(), buffer_size, mode, inner)
252 }
253
254 pub fn stderr(
257 executor: &Executor,
258 buffer_size: usize,
259 mode: Option<WriterMode>,
260 inner: S,
261 ) -> Self {
262 Self::writer(executor, stderr(), buffer_size, mode, inner)
263 }
264}
265
266impl<S, W> RequestWriterService<S, W> {}
267
268impl<State, S, W, ReqBody, ResBody> Service<State, Request<ReqBody>> for RequestWriterService<S, W>
269where
270 State: Clone + Send + Sync + 'static,
271 S: Service<State, Request, Response = Response<ResBody>, Error: Into<BoxError>>,
272 W: RequestWriter,
273 ReqBody: http_body::Body<Data = Bytes, Error: Into<BoxError>> + Send + Sync + 'static,
274 ResBody: Send + 'static,
275{
276 type Response = Response<ResBody>;
277 type Error = BoxError;
278
279 async fn serve(
280 &self,
281 ctx: Context<State>,
282 req: Request<ReqBody>,
283 ) -> Result<Self::Response, Self::Error> {
284 let req = match ctx.get::<DoNotWriteRequest>() {
285 Some(_) => req.map(Body::new),
286 None => {
287 let (parts, body) = req.into_parts();
288 let body_bytes = body
289 .collect()
290 .await
291 .map_err(|err| {
292 OpaqueError::from_boxed(err.into())
293 .context("printer prepare: collect request body")
294 })?
295 .to_bytes();
296 let req = Request::from_parts(parts.clone(), Body::from(body_bytes.clone()));
297 self.writer.write_request(req).await;
298 Request::from_parts(parts, Body::from(body_bytes))
299 }
300 };
301 self.inner.serve(ctx, req).await.map_err(Into::into)
302 }
303}
304
305impl RequestWriter for Sender<Request> {
306 async fn write_request(&self, req: Request) {
307 if let Err(err) = self.send(req).await {
308 tracing::error!(err = %err, "failed to send request to channel")
309 }
310 }
311}
312
313impl RequestWriter for UnboundedSender<Request> {
314 async fn write_request(&self, req: Request) {
315 if let Err(err) = self.send(req) {
316 tracing::error!(err = %err, "failed to send request to unbounded channel")
317 }
318 }
319}
320
321impl<F, Fut> RequestWriter for F
322where
323 F: Fn(Request) -> Fut + Send + Sync + 'static,
324 Fut: Future<Output = ()> + Send + 'static,
325{
326 async fn write_request(&self, req: Request) {
327 self(req).await
328 }
329}