xitca_web/middleware/
sync.rs1use core::mem;
4
5use std::sync::mpsc::Receiver;
6
7use tokio::sync::mpsc::UnboundedSender;
8
9use crate::{
10 context::WebContext,
11 http::{Request, RequestExt, Response},
12 service::Service,
13};
14
15pub struct SyncMiddleware<F>(F);
17
18impl<F> SyncMiddleware<F> {
19 pub fn new<C, E>(func: F) -> Self
26 where
27 F: Fn(&mut Next<E>, WebContext<'_, C>) -> Result<Response<()>, E> + Send + Sync + 'static,
28 C: Clone + Send + 'static,
29 E: Send + 'static,
30 {
31 Self(func)
32 }
33}
34
35pub struct Next<E> {
38 tx: UnboundedSender<Request<RequestExt<()>>>,
39 rx: Receiver<Result<Response<()>, E>>,
40}
41
42impl<E> Next<E> {
43 pub fn call<C>(&mut self, mut ctx: WebContext<'_, C>) -> Result<Response<()>, E> {
45 let req = mem::take(ctx.req_mut());
46 self.tx.send(req).unwrap();
47 self.rx.recv().unwrap()
48 }
49}
50
51impl<F, S, E> Service<Result<S, E>> for SyncMiddleware<F>
52where
53 F: Clone,
54{
55 type Response = service::SyncService<F, S>;
56 type Error = E;
57
58 async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
59 res.map(|service| service::SyncService {
60 func: self.0.clone(),
61 service,
62 })
63 }
64}
65
66mod service {
67 use core::cell::RefCell;
68
69 use std::sync::mpsc::sync_channel;
70
71 use tokio::sync::mpsc::unbounded_channel;
72
73 use crate::{body::RequestBody, http::WebResponse, service::ready::ReadyService};
74
75 use super::*;
76
77 pub struct SyncService<F, S> {
78 pub(super) func: F,
79 pub(super) service: S,
80 }
81
82 impl<'r, F, C, S, B, ResB, Err> Service<WebContext<'r, C, B>> for SyncService<F, S>
83 where
84 F: Fn(&mut Next<Err>, WebContext<'_, C>) -> Result<Response<()>, Err> + Send + Clone + 'static,
85 C: Clone + Send + 'static,
86 S: for<'r2> Service<WebContext<'r, C, B>, Response = WebResponse<ResB>, Error = Err>,
87 Err: Send + 'static,
88 {
89 type Response = WebResponse<ResB>;
90 type Error = Err;
91
92 async fn call(&self, mut ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
93 let func = self.func.clone();
94 let state = ctx.state().clone();
95 let mut req = mem::take(ctx.req_mut());
96
97 let (tx, mut rx) = unbounded_channel();
98 let (tx2, rx2) = sync_channel(1);
99
100 let mut next = Next { tx, rx: rx2 };
101 let handle = tokio::task::spawn_blocking(move || {
102 let mut body = RefCell::new(RequestBody::None);
103 let ctx = WebContext::new(&mut req, &mut body, &state);
104 func(&mut next, ctx)
105 });
106
107 *ctx.req_mut() = match rx.recv().await {
108 Some(req) => req,
109 None => {
110 match handle.await.unwrap() {
112 Ok(_) => todo!("there is no support for body type yet"),
113 Err(e) => return Err(e),
114 }
115 }
116 };
117
118 match self.service.call(ctx).await {
119 Ok(res) => {
120 let (parts, body) = res.into_parts();
121 let _ = tx2.send(Ok(Response::from_parts(parts, ())));
122 let res = handle.await.unwrap()?;
123 Ok(res.map(|_| body))
124 }
125 Err(e) => {
126 let _ = tx2.send(Err(e));
127 let res = handle.await.unwrap()?;
128 Ok(res.map(|_| todo!("there is no support for body type yet")))
129 }
130 }
131 }
132 }
133
134 impl<F, S> ReadyService for SyncService<F, S>
135 where
136 S: ReadyService,
137 {
138 type Ready = S::Ready;
139
140 #[inline]
141 async fn ready(&self) -> Self::Ready {
142 self.service.ready().await
143 }
144 }
145}
146
147#[cfg(test)]
148mod test {
149 use core::convert::Infallible;
150
151 use crate::{
152 App,
153 body::ResponseBody,
154 http::{StatusCode, WebResponse},
155 service::fn_service,
156 };
157
158 use super::*;
159
160 async fn handler(req: WebContext<'_, &'static str>) -> Result<WebResponse, Infallible> {
161 assert_eq!(*req.state(), "996");
162 Ok(req.into_response(ResponseBody::empty()))
163 }
164
165 fn middleware<E>(next: &mut Next<E>, ctx: WebContext<'_, &'static str>) -> Result<Response<()>, E> {
166 next.call(ctx)
167 }
168
169 #[tokio::test]
170 async fn sync_middleware() {
171 let res = App::new()
172 .with_state("996")
173 .at("/", fn_service(handler))
174 .enclosed(SyncMiddleware::new(middleware))
175 .finish()
176 .call(())
177 .await
178 .unwrap()
179 .call(Request::default())
180 .await
181 .unwrap();
182
183 assert_eq!(res.status(), StatusCode::OK);
184 }
185}