1#![warn(missing_docs)]
2
3use std::{future::Future, task::Poll};
9
10use hyper::{body::HttpBody, Body, Request, Response};
11use pin_project::pin_project;
12use tower::Service;
13
14pub use make::MakeMultiplexer;
15mod make;
16
17pub struct Multiplexer<Grpc, Web> {
81 grpc: Grpc,
82 web: Web,
83}
84impl<Grpc, Web> Multiplexer<Grpc, Web>
85where
86 Grpc: Service<Request<Body>>,
87 Web: Service<Request<Body>>,
88{
89 pub fn new(grpc: Grpc, web: Web) -> Self {
91 Multiplexer { grpc, web }
92 }
93}
94type BoxedError = Box<dyn std::error::Error + Send + Sync + 'static>;
95fn to_boxed<T: Into<BoxedError>>(e: T) -> BoxedError {
96 e.into()
97}
98impl<Grpc, Web, GrpcBody, WebBody> Service<Request<Body>> for Multiplexer<Grpc, Web>
99where
100 Grpc: Service<Request<Body>, Response = Response<GrpcBody>>,
102 Web: Service<Request<Body>, Response = Response<WebBody>>,
103 GrpcBody: HttpBody,
104 WebBody: HttpBody,
105 Grpc::Error: Into<BoxedError>,
107 Web::Error: Into<BoxedError>,
108{
109 type Response = Response<EncapsulatedBody<GrpcBody, WebBody>>;
110 type Error = BoxedError;
112 type Future = EncapsulatedFuture<Grpc::Future, Web::Future>;
113
114 fn poll_ready(
117 &mut self,
118 cx: &mut std::task::Context<'_>,
119 ) -> std::task::Poll<Result<(), Self::Error>> {
120 let grpc = self.grpc.poll_ready(cx).map_err(to_boxed)?;
122 let web = self.web.poll_ready(cx).map_err(to_boxed)?;
123 match (grpc, web) {
124 (Poll::Ready(_), Poll::Ready(_)) => Poll::Ready(Ok(())),
125 _ => Poll::Pending,
126 }
127 }
128
129 fn call(&mut self, req: Request<Body>) -> Self::Future {
130 let is_grpc = req
131 .headers()
132 .get("content-type")
133 .map(|x| x.as_bytes().starts_with(b"application/grpc"))
134 .unwrap_or_default();
135 if is_grpc {
136 EncapsulatedFuture::Grpc(self.grpc.call(req))
137 } else {
138 EncapsulatedFuture::Web(self.web.call(req))
139 }
140 }
141}
142
143#[pin_project(project = EncapsulatedProj)]
149pub enum EncapsulatedFuture<GrpcFuture, WebFuture> {
150 Grpc(#[pin] GrpcFuture),
152 Web(#[pin] WebFuture),
154}
155impl<GrpcFuture, WebFuture, GrpcResponseBody, WebResponseBody, GrpcError, WebError> Future
160 for EncapsulatedFuture<GrpcFuture, WebFuture>
161where
162 GrpcFuture: Future<Output = Result<Response<GrpcResponseBody>, GrpcError>>,
163 WebFuture: Future<Output = Result<Response<WebResponseBody>, WebError>>,
164 GrpcError: Into<BoxedError>,
165 WebError: Into<BoxedError>,
166{
167 type Output = Result<Response<EncapsulatedBody<GrpcResponseBody, WebResponseBody>>, BoxedError>;
169
170 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
171 match self.project() {
172 EncapsulatedProj::Grpc(future) => future
173 .poll(cx)
174 .map_ok(EncapsulatedBody::map_grpc)
175 .map_err(to_boxed),
176 EncapsulatedProj::Web(future) => future
177 .poll(cx)
178 .map_ok(EncapsulatedBody::map_web)
179 .map_err(to_boxed),
180 }
181 }
182}
183
184#[pin_project(project = BodyProj)]
190pub enum EncapsulatedBody<GrpcBody, WebBody> {
191 Grpc(#[pin] GrpcBody),
193 Web(#[pin] WebBody),
195}
196impl<GrpcBody, WebBody> EncapsulatedBody<GrpcBody, WebBody> {
197 fn map_grpc(response: Response<GrpcBody>) -> Response<Self> {
198 response.map(EncapsulatedBody::Grpc)
199 }
200 fn map_web(response: Response<WebBody>) -> Response<Self> {
201 response.map(EncapsulatedBody::Web)
202 }
203}
204fn into_data<T: Into<hyper::body::Bytes>>(data: T) -> hyper::body::Bytes {
205 data.into()
206}
207impl<GrpcBody, WebBody, GrpcError, WebError> HttpBody for EncapsulatedBody<GrpcBody, WebBody>
208where
209 GrpcBody: HttpBody<Error = GrpcError>,
210 WebBody: HttpBody<Error = WebError>,
211 GrpcBody::Error: Into<BoxedError>,
212 WebBody::Error: Into<BoxedError>,
213 GrpcBody::Data: Into<hyper::body::Bytes>,
214 WebBody::Data: Into<hyper::body::Bytes>,
215{
216 type Data = hyper::body::Bytes;
217
218 type Error = BoxedError;
219
220 fn poll_data(
221 self: std::pin::Pin<&mut Self>,
222 cx: &mut std::task::Context<'_>,
223 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
224 match self.project() {
225 BodyProj::Grpc(body) => body.poll_data(cx).map_ok(into_data).map_err(to_boxed),
226 BodyProj::Web(body) => body.poll_data(cx).map_ok(into_data).map_err(to_boxed),
227 }
228 }
229
230 fn poll_trailers(
231 self: std::pin::Pin<&mut Self>,
232 cx: &mut std::task::Context<'_>,
233 ) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
234 match self.project() {
235 BodyProj::Grpc(body) => body.poll_trailers(cx).map_err(to_boxed),
236 BodyProj::Web(body) => body.poll_trailers(cx).map_err(to_boxed),
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use std::{convert::Infallible, future::ready};
244
245 use crate::{EncapsulatedBody, Multiplexer};
246 use hyper::{
247 body::HttpBody, header::CONTENT_TYPE, service::service_fn, Body, HeaderMap, Request,
248 Response,
249 };
250 use tower::{Service, ServiceExt}; #[test]
254 fn new_multiplex_receives_two_services() {
255 let generate_service = |string: &'static str| {
256 service_fn(|_req: Request<Body>| {
257 ready(Ok::<Response<Body>, Infallible>(Response::new(Body::from(
258 string.to_owned(),
259 ))))
260 })
261 };
262 let service_1 = generate_service("Service 1");
263 let service_2 = generate_service("Service 2");
264
265 let _multiplex = Multiplexer::new(service_1, service_2);
266 }
267
268 #[tokio::test]
269 async fn new_multiplex_is_ready() {
270 let generate_service = |string: &'static str| {
271 service_fn(|_req: Request<Body>| {
272 ready(Ok::<Response<Body>, Infallible>(Response::new(Body::from(
273 string.to_owned(),
274 ))))
275 })
276 };
277 let grpc = generate_service("gRPC service");
278 let web = generate_service("web service");
279
280 let mut multiplex = Multiplexer::new(grpc, web);
281
282 multiplex.ready().await.unwrap();
283 }
284
285 #[tokio::test]
287 async fn multiplexer_request_to_web() {
288 let generate_service = |string: &'static str| {
289 service_fn(|_req: Request<Body>| {
290 ready(Ok::<Response<Body>, Infallible>(Response::new(Body::from(
291 string.to_owned(),
292 ))))
293 })
294 };
295 let grpc = generate_service("gRPC service");
296 let web = generate_service("web service");
297 let mut multiplex = Multiplexer::new(grpc, web);
298 multiplex.ready().await.unwrap();
299 {
300 let request = Request::new(Body::empty());
302 let response = multiplex.call(request).await.unwrap();
303 let content = hyper::body::to_bytes(response.into_body()).await.unwrap();
304
305 assert_ne!(content.len(), 0);
306 assert_eq!(content, "web service");
307 }
308 multiplex.ready().await.unwrap();
309 {
310 let request = Request::builder()
312 .header(CONTENT_TYPE, "application/grpc")
313 .body(Body::empty())
314 .unwrap();
315 let response = multiplex.call(request).await.unwrap();
316 let content = hyper::body::to_bytes(response.into_body()).await.unwrap();
317
318 assert_ne!(content.len(), 0);
319 assert_eq!(content, "gRPC service");
320 }
321 }
322
323 #[tokio::test]
324 async fn encapsulated_body_poll_data_grpc() {
325 let string = "body grpc";
326 let body = EncapsulatedBody::<Body, Body>::Grpc(Body::from(string));
327
328 let data = hyper::body::to_bytes(body).await.unwrap();
329 assert_eq!(data, string);
330 }
331
332 #[tokio::test]
333 async fn encapsulated_body_poll_data_web() {
334 let string = "body web";
335 let body = EncapsulatedBody::<Body, Body>::Grpc(Body::from(string));
336
337 let data = hyper::body::to_bytes(body).await.unwrap();
338 assert_eq!(data, string);
339 }
340
341 #[tokio::test]
342 async fn encapsulated_body_poll_trailers_grpc() {
343 let (mut sender, body) = Body::channel();
344 let mut header_map = HeaderMap::new();
345 header_map.insert("From", "grpc sender".parse().unwrap());
346 let header_map = header_map;
347 sender.send_trailers(header_map.clone()).await.unwrap();
348
349 let mut body = EncapsulatedBody::<Body, Body>::Grpc(body);
350
351 let headers = body
352 .trailers()
353 .await
354 .unwrap()
355 .expect("Should return trailers!");
356 assert_eq!(headers, header_map);
357 }
358
359 #[tokio::test]
360 async fn encapsulated_body_poll_trailers_web() {
361 let (mut sender, body) = Body::channel();
362 let mut header_map = HeaderMap::new();
363 header_map.insert("From", "web sender".parse().unwrap());
364 let header_map = header_map;
365 sender.send_trailers(header_map.clone()).await.unwrap();
366
367 let mut body = EncapsulatedBody::<Body, Body>::Web(body);
368
369 let headers = body
370 .trailers()
371 .await
372 .unwrap()
373 .expect("Should return trailers!");
374 assert_eq!(headers, header_map);
375 }
376}