1use super::WriterMode;
2use crate::dep::http_body;
3use crate::dep::http_body_util::BodyExt;
4use crate::io::write_http_response;
5use crate::{Body, Request, Response};
6use bytes::Bytes;
7use rama_core::error::{BoxError, ErrorContext, OpaqueError};
8use rama_core::rt::Executor;
9use rama_core::{Context, Layer, Service};
10use rama_utils::macros::define_inner_service_accessors;
11use std::fmt::Debug;
12use tokio::io::{AsyncWrite, stderr, stdout};
13use tokio::sync::mpsc::{Sender, UnboundedSender, channel, unbounded_channel};
14
15pub struct ResponseWriterLayer<W> {
17 writer: W,
18}
19
20impl<W> Debug for ResponseWriterLayer<W> {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 f.debug_struct("ResponseWriterLayer")
23 .field("writer", &format_args!("{}", std::any::type_name::<W>()))
24 .finish()
25 }
26}
27
28impl<W: Clone> Clone for ResponseWriterLayer<W> {
29 fn clone(&self) -> Self {
30 Self {
31 writer: self.writer.clone(),
32 }
33 }
34}
35
36impl<W> ResponseWriterLayer<W> {
37 pub const fn new(writer: W) -> Self {
39 Self { writer }
40 }
41}
42
43pub trait ResponseWriter: Send + Sync + 'static {
45 fn write_response(&self, res: Response) -> impl Future<Output = ()> + Send + '_;
47}
48
49#[derive(Debug, Clone, Default)]
51#[non_exhaustive]
52pub struct DoNotWriteResponse;
53
54impl DoNotWriteResponse {
55 pub const fn new() -> Self {
57 Self
58 }
59}
60
61impl ResponseWriterLayer<UnboundedSender<Response>> {
62 pub fn writer_unbounded<W>(executor: &Executor, mut writer: W, mode: Option<WriterMode>) -> Self
65 where
66 W: AsyncWrite + Unpin + Send + Sync + 'static,
67 {
68 let (tx, mut rx) = unbounded_channel();
69 let (write_headers, write_body) = match mode {
70 Some(WriterMode::All) => (true, true),
71 Some(WriterMode::Headers) => (true, false),
72 Some(WriterMode::Body) => (false, true),
73 None => (false, false),
74 };
75 executor.spawn_task(async move {
76 while let Some(res) = rx.recv().await {
77 if let Err(err) =
78 write_http_response(&mut writer, res, write_headers, write_body).await
79 {
80 tracing::error!(err = %err, "failed to write http response to writer")
81 }
82 }
83 });
84 Self { writer: tx }
85 }
86
87 pub fn stdout_unbounded(executor: &Executor, mode: Option<WriterMode>) -> Self {
90 Self::writer_unbounded(executor, stdout(), mode)
91 }
92
93 pub fn stderr_unbounded(executor: &Executor, mode: Option<WriterMode>) -> Self {
96 Self::writer_unbounded(executor, stderr(), mode)
97 }
98}
99
100impl ResponseWriterLayer<Sender<Response>> {
101 pub fn writer<W>(
104 executor: &Executor,
105 mut writer: W,
106 buffer_size: usize,
107 mode: Option<WriterMode>,
108 ) -> Self
109 where
110 W: AsyncWrite + Unpin + Send + Sync + 'static,
111 {
112 let (tx, mut rx) = channel(buffer_size);
113 let (write_headers, write_body) = match mode {
114 Some(WriterMode::All) => (true, true),
115 Some(WriterMode::Headers) => (true, false),
116 Some(WriterMode::Body) => (false, true),
117 None => (false, false),
118 };
119 executor.spawn_task(async move {
120 while let Some(res) = rx.recv().await {
121 if let Err(err) =
122 write_http_response(&mut writer, res, write_headers, write_body).await
123 {
124 tracing::error!(err = %err, "failed to write http response to writer")
125 }
126 }
127 });
128 Self { writer: tx }
129 }
130
131 pub fn stdout(executor: &Executor, buffer_size: usize, mode: Option<WriterMode>) -> Self {
134 Self::writer(executor, stdout(), buffer_size, mode)
135 }
136
137 pub fn stderr(executor: &Executor, buffer_size: usize, mode: Option<WriterMode>) -> Self {
140 Self::writer(executor, stderr(), buffer_size, mode)
141 }
142}
143
144impl<S, W: Clone> Layer<S> for ResponseWriterLayer<W> {
145 type Service = ResponseWriterService<S, W>;
146
147 fn layer(&self, inner: S) -> Self::Service {
148 ResponseWriterService {
149 inner,
150 writer: self.writer.clone(),
151 }
152 }
153
154 fn into_layer(self, inner: S) -> Self::Service {
155 ResponseWriterService {
156 inner,
157 writer: self.writer,
158 }
159 }
160}
161
162pub struct ResponseWriterService<S, W> {
166 inner: S,
167 writer: W,
168}
169
170impl<S, W> ResponseWriterService<S, W> {
171 pub const fn new(writer: W, inner: S) -> Self {
173 Self { inner, writer }
174 }
175
176 define_inner_service_accessors!();
177}
178
179impl<S: Debug, W> Debug for ResponseWriterService<S, W> {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 f.debug_struct("ResponseWriterService")
182 .field("inner", &self.inner)
183 .field("writer", &format_args!("{}", std::any::type_name::<W>()))
184 .finish()
185 }
186}
187
188impl<S: Clone, W: Clone> Clone for ResponseWriterService<S, W> {
189 fn clone(&self) -> Self {
190 Self {
191 inner: self.inner.clone(),
192 writer: self.writer.clone(),
193 }
194 }
195}
196
197impl<S> ResponseWriterService<S, UnboundedSender<Response>> {
198 pub fn writer_unbounded<W>(
201 executor: &Executor,
202 writer: W,
203 mode: Option<WriterMode>,
204 inner: S,
205 ) -> Self
206 where
207 W: AsyncWrite + Unpin + Send + Sync + 'static,
208 {
209 let layer = ResponseWriterLayer::writer_unbounded(executor, writer, mode);
210 layer.into_layer(inner)
211 }
212
213 pub fn stdout_unbounded(executor: &Executor, mode: Option<WriterMode>, inner: S) -> Self {
216 Self::writer_unbounded(executor, stdout(), mode, inner)
217 }
218
219 pub fn stderr_unbounded(executor: &Executor, mode: Option<WriterMode>, inner: S) -> Self {
222 Self::writer_unbounded(executor, stderr(), mode, inner)
223 }
224}
225
226impl<S> ResponseWriterService<S, Sender<Response>> {
227 pub fn writer<W>(
230 executor: &Executor,
231 writer: W,
232 buffer_size: usize,
233 mode: Option<WriterMode>,
234 inner: S,
235 ) -> Self
236 where
237 W: AsyncWrite + Unpin + Send + Sync + 'static,
238 {
239 let layer = ResponseWriterLayer::writer(executor, writer, buffer_size, mode);
240 layer.into_layer(inner)
241 }
242
243 pub fn stdout(
246 executor: &Executor,
247 buffer_size: usize,
248 mode: Option<WriterMode>,
249 inner: S,
250 ) -> Self {
251 Self::writer(executor, stdout(), buffer_size, mode, inner)
252 }
253
254 pub fn stderr(
257 executor: &Executor,
258 buffer_size: usize,
259 mode: Option<WriterMode>,
260 inner: S,
261 ) -> Self {
262 Self::writer(executor, stderr(), buffer_size, mode, inner)
263 }
264}
265
266impl<S, W> ResponseWriterService<S, W> {}
267
268impl<State, S, W, ReqBody, ResBody> Service<State, Request<ReqBody>> for ResponseWriterService<S, W>
269where
270 State: Clone + Send + Sync + 'static,
271 S: Service<State, Request<ReqBody>, Response = Response<ResBody>, Error: Into<BoxError>>,
272 W: ResponseWriter,
273 ReqBody: Send + 'static,
274 ResBody: http_body::Body<Data = Bytes, Error: Into<BoxError>> + Send + Sync + 'static,
275{
276 type Response = Response;
277 type Error = BoxError;
278
279 async fn serve(
280 &self,
281 ctx: Context<State>,
282 req: Request<ReqBody>,
283 ) -> Result<Self::Response, Self::Error> {
284 let do_not_print_response: Option<DoNotWriteResponse> = ctx.get().cloned();
285 let resp = self.inner.serve(ctx, req).await.map_err(Into::into)?;
286 let resp = match do_not_print_response {
287 Some(_) => resp.map(Body::new),
288 None => {
289 let (parts, body) = resp.into_parts();
290 let body_bytes = body
291 .collect()
292 .await
293 .map_err(|err| OpaqueError::from_boxed(err.into()))
294 .context("printer prepare: collect response body")?
295 .to_bytes();
296 let resp: rama_http_types::Response<Body> =
297 Response::from_parts(parts.clone(), Body::from(body_bytes.clone()));
298 self.writer.write_response(resp).await;
299 Response::from_parts(parts, Body::from(body_bytes))
300 }
301 };
302 Ok(resp)
303 }
304}
305
306impl ResponseWriter for Sender<Response> {
307 async fn write_response(&self, res: Response) {
308 if let Err(err) = self.send(res).await {
309 tracing::error!(err = %err, "failed to send response to channel")
310 }
311 }
312}
313
314impl ResponseWriter for UnboundedSender<Response> {
315 async fn write_response(&self, res: Response) {
316 if let Err(err) = self.send(res) {
317 tracing::error!(err = %err, "failed to send response to unbounded channel")
318 }
319 }
320}
321
322impl<F, Fut> ResponseWriter for F
323where
324 F: Fn(Response) -> Fut + Send + Sync + 'static,
325 Fut: Future<Output = ()> + Send + 'static,
326{
327 async fn write_response(&self, res: Response) {
328 self(res).await
329 }
330}