multiplex_tonic_hyper/
lib.rs

1#![warn(missing_docs)]
2
3//! Crate to route requests between a tonic gRPC service, and some other service
4//!
5//! The [Multiplexer] struct implements Service<Request<Body>>, and routes
6//! requests based on the Content-Type header.
7
8use std::{future::Future, task::Poll};
9
10use hyper::{body::HttpBody, Body, Request, Response};
11use pin_project::pin_project;
12use tower::Service;
13
14pub use make::MakeMultiplexer;
15mod make;
16
17/// Service that routes to a gRPC service and other service
18///
19/// This service checks the Content-Type header, and send all requests
20/// with `application/grpc` to the grpc service, and all other requests
21/// to the web service.
22///
23/// # Examples:
24///
25/// Routing to the web service:
26/// ```
27/// # async fn run() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
28/// # use std::convert::Infallible;
29///	# use multiplex_tonic_hyper::Multiplexer;
30/// use hyper::{header::CONTENT_TYPE, service::service_fn, Body, Request, Response};
31///	use tower::{Service, ServiceExt};
32///	async fn str_to_res(str: &'static str) -> Result<Response<Body>, Infallible> {
33///		Ok(Response::new(Body::from(str)))
34///	}
35///
36/// //Services that answer every request with a word
37///	let grpc = service_fn(|_| str_to_res("gRPC"));
38///	let web = service_fn(|_| str_to_res("web"));
39///
40///	let mut multiplex = Multiplexer::new(grpc, web);
41/// # /// We must check if service is ready before call. See [tower::Service]
42///	# multiplex.ready().await?;
43///	//Request web without content-type header
44///	let response = multiplex.call(Request::new(Body::empty())).await?;
45///	let content = hyper::body::to_bytes(response.into_body()).await?;
46/// assert_eq!(content, "web");
47/// # Ok(())
48/// # }
49/// # tokio_test::block_on(run()).unwrap();
50/// ```
51///
52/// Routing to the gRPC service:
53/// ```
54/// # async fn run() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
55/// # use std::convert::Infallible;
56/// # use hyper::{header::CONTENT_TYPE, service::service_fn, Body, Request, Response};
57///	# use multiplex_tonic_hyper::Multiplexer;
58///	# use tower::{Service, ServiceExt};
59///	# async fn str_to_res(str: &'static str) -> Result<Response<Body>, Infallible> {
60///	# 	Ok(Response::new(Body::from(str)))
61///	# }
62///	//...
63///	let grpc = service_fn(|_| str_to_res("gRPC"));
64///	let web = service_fn(|_| str_to_res("web"));
65///
66///	let mut multiplex = Multiplexer::new(grpc, web);
67/// # // We must check if service is ready before call. See [tower::Service]
68///	# multiplex.ready().await?;
69///	//Request grpc using content-type header
70///	let request = Request::builder()
71///		.header(CONTENT_TYPE, "application/grpc")
72///		.body(Body::empty())?;
73///	let response = multiplex.call(request).await?;
74///	let content = hyper::body::to_bytes(response.into_body()).await?;
75/// assert_eq!(content, "gRPC");
76/// # Ok(())
77/// # }
78/// # tokio_test::block_on(run()).unwrap();
79/// ```
80pub struct Multiplexer<Grpc, Web> {
81	grpc: Grpc,
82	web: Web,
83}
84impl<Grpc, Web> Multiplexer<Grpc, Web>
85where
86	Grpc: Service<Request<Body>>,
87	Web: Service<Request<Body>>,
88{
89	///This function consumes two Services, and returns a Multiplexer
90	pub fn new(grpc: Grpc, web: Web) -> Self {
91		Multiplexer { grpc, web }
92	}
93}
94type BoxedError = Box<dyn std::error::Error + Send + Sync + 'static>;
95fn to_boxed<T: Into<BoxedError>>(e: T) -> BoxedError {
96	e.into()
97}
98impl<Grpc, Web, GrpcBody, WebBody> Service<Request<Body>> for Multiplexer<Grpc, Web>
99where
100	//Each type is a Service<> with its own Body type
101	Grpc: Service<Request<Body>, Response = Response<GrpcBody>>,
102	Web: Service<Request<Body>, Response = Response<WebBody>>,
103	GrpcBody: HttpBody,
104	WebBody: HttpBody,
105	//Inner errors can be converted to our error type
106	Grpc::Error: Into<BoxedError>,
107	Web::Error: Into<BoxedError>,
108{
109	type Response = Response<EncapsulatedBody<GrpcBody, WebBody>>;
110	///Generic error that can be moved between threads
111	type Error = BoxedError;
112	type Future = EncapsulatedFuture<Grpc::Future, Web::Future>;
113
114	///Call inner services poll_ready, and propagate errors.
115	/// Only is ready if both are ready.
116	fn poll_ready(
117		&mut self,
118		cx: &mut std::task::Context<'_>,
119	) -> std::task::Poll<Result<(), Self::Error>> {
120		//There is no problem in calling poll_ready if is Ready, and the docs don't have any limitation on pending
121		let grpc = self.grpc.poll_ready(cx).map_err(to_boxed)?;
122		let web = self.web.poll_ready(cx).map_err(to_boxed)?;
123		match (grpc, web) {
124			(Poll::Ready(_), Poll::Ready(_)) => Poll::Ready(Ok(())),
125			_ => Poll::Pending,
126		}
127	}
128
129	fn call(&mut self, req: Request<Body>) -> Self::Future {
130		let is_grpc = req
131			.headers()
132			.get("content-type")
133			.map(|x| x.as_bytes().starts_with(b"application/grpc"))
134			.unwrap_or_default();
135		if is_grpc {
136			EncapsulatedFuture::Grpc(self.grpc.call(req))
137		} else {
138			EncapsulatedFuture::Web(self.web.call(req))
139		}
140	}
141}
142
143/// Type to encapsulate both inner services Futures
144///
145///Because [poll(cx)][Future::poll] uses a pinned mutable reference,
146/// this enum needs to project the pin.
147/// That way is possible to call poll on the inner future
148#[pin_project(project = EncapsulatedProj)]
149pub enum EncapsulatedFuture<GrpcFuture, WebFuture> {
150	///Encapsulates a future from Grpc service
151	Grpc(#[pin] GrpcFuture),
152	///Encapsulates a future from Web service
153	Web(#[pin] WebFuture),
154}
155/// This implementation should map the response and the error from the inner futures
156///
157/// The response has its body mapped to another enum, the enum should implement `HttpBody`
158///
159impl<GrpcFuture, WebFuture, GrpcResponseBody, WebResponseBody, GrpcError, WebError> Future
160	for EncapsulatedFuture<GrpcFuture, WebFuture>
161where
162	GrpcFuture: Future<Output = Result<Response<GrpcResponseBody>, GrpcError>>,
163	WebFuture: Future<Output = Result<Response<WebResponseBody>, WebError>>,
164	GrpcError: Into<BoxedError>,
165	WebError: Into<BoxedError>,
166{
167	/// We should output `Result<Response<impl HttpBody>, Multiplexer::Error>`
168	type Output = Result<Response<EncapsulatedBody<GrpcResponseBody, WebResponseBody>>, BoxedError>;
169
170	fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
171		match self.project() {
172			EncapsulatedProj::Grpc(future) => future
173				.poll(cx)
174				.map_ok(EncapsulatedBody::map_grpc)
175				.map_err(to_boxed),
176			EncapsulatedProj::Web(future) => future
177				.poll(cx)
178				.map_ok(EncapsulatedBody::map_web)
179				.map_err(to_boxed),
180		}
181	}
182}
183
184/// Type to encapsulate both inner services HttpBody types
185///
186/// Because [poll_data(cx)][HttpBody::poll_data] and [poll_trailers(cx)][HttpBody::poll_trailers] uses pinned reference, this enum needs to project the pin.
187///
188/// This enum is used as the body type in [EncapsulatedFuture].
189#[pin_project(project = BodyProj)]
190pub enum EncapsulatedBody<GrpcBody, WebBody> {
191	///Encapsulates the body from Grpc service
192	Grpc(#[pin] GrpcBody),
193	///Encapsulates the body from Web service
194	Web(#[pin] WebBody),
195}
196impl<GrpcBody, WebBody> EncapsulatedBody<GrpcBody, WebBody> {
197	fn map_grpc(response: Response<GrpcBody>) -> Response<Self> {
198		response.map(EncapsulatedBody::Grpc)
199	}
200	fn map_web(response: Response<WebBody>) -> Response<Self> {
201		response.map(EncapsulatedBody::Web)
202	}
203}
204fn into_data<T: Into<hyper::body::Bytes>>(data: T) -> hyper::body::Bytes {
205	data.into()
206}
207impl<GrpcBody, WebBody, GrpcError, WebError> HttpBody for EncapsulatedBody<GrpcBody, WebBody>
208where
209	GrpcBody: HttpBody<Error = GrpcError>,
210	WebBody: HttpBody<Error = WebError>,
211	GrpcBody::Error: Into<BoxedError>,
212	WebBody::Error: Into<BoxedError>,
213	GrpcBody::Data: Into<hyper::body::Bytes>,
214	WebBody::Data: Into<hyper::body::Bytes>,
215{
216	type Data = hyper::body::Bytes;
217
218	type Error = BoxedError;
219
220	fn poll_data(
221		self: std::pin::Pin<&mut Self>,
222		cx: &mut std::task::Context<'_>,
223	) -> Poll<Option<Result<Self::Data, Self::Error>>> {
224		match self.project() {
225			BodyProj::Grpc(body) => body.poll_data(cx).map_ok(into_data).map_err(to_boxed),
226			BodyProj::Web(body) => body.poll_data(cx).map_ok(into_data).map_err(to_boxed),
227		}
228	}
229
230	fn poll_trailers(
231		self: std::pin::Pin<&mut Self>,
232		cx: &mut std::task::Context<'_>,
233	) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
234		match self.project() {
235			BodyProj::Grpc(body) => body.poll_trailers(cx).map_err(to_boxed),
236			BodyProj::Web(body) => body.poll_trailers(cx).map_err(to_boxed),
237		}
238	}
239}
240
241#[cfg(test)]
242mod tests {
243	use std::{convert::Infallible, future::ready};
244
245	use crate::{EncapsulatedBody, Multiplexer};
246	use hyper::{
247		body::HttpBody, header::CONTENT_TYPE, service::service_fn, Body, HeaderMap, Request,
248		Response,
249	};
250	use tower::{Service, ServiceExt}; // ServiceExt provides ready()
251
252	//This test only checks if this compiles
253	#[test]
254	fn new_multiplex_receives_two_services() {
255		let generate_service = |string: &'static str| {
256			service_fn(|_req: Request<Body>| {
257				ready(Ok::<Response<Body>, Infallible>(Response::new(Body::from(
258					string.to_owned(),
259				))))
260			})
261		};
262		let service_1 = generate_service("Service 1");
263		let service_2 = generate_service("Service 2");
264
265		let _multiplex = Multiplexer::new(service_1, service_2);
266	}
267
268	#[tokio::test]
269	async fn new_multiplex_is_ready() {
270		let generate_service = |string: &'static str| {
271			service_fn(|_req: Request<Body>| {
272				ready(Ok::<Response<Body>, Infallible>(Response::new(Body::from(
273					string.to_owned(),
274				))))
275			})
276		};
277		let grpc = generate_service("gRPC service");
278		let web = generate_service("web service");
279
280		let mut multiplex = Multiplexer::new(grpc, web);
281
282		multiplex.ready().await.unwrap();
283	}
284
285	// #[ignore = "While EncapsulatedBody is not implemented"]
286	#[tokio::test]
287	async fn multiplexer_request_to_web() {
288		let generate_service = |string: &'static str| {
289			service_fn(|_req: Request<Body>| {
290				ready(Ok::<Response<Body>, Infallible>(Response::new(Body::from(
291					string.to_owned(),
292				))))
293			})
294		};
295		let grpc = generate_service("gRPC service");
296		let web = generate_service("web service");
297		let mut multiplex = Multiplexer::new(grpc, web);
298		multiplex.ready().await.unwrap();
299		{
300			//Request web
301			let request = Request::new(Body::empty());
302			let response = multiplex.call(request).await.unwrap();
303			let content = hyper::body::to_bytes(response.into_body()).await.unwrap();
304
305			assert_ne!(content.len(), 0);
306			assert_eq!(content, "web service");
307		}
308		multiplex.ready().await.unwrap();
309		{
310			//Request grpc
311			let request = Request::builder()
312				.header(CONTENT_TYPE, "application/grpc")
313				.body(Body::empty())
314				.unwrap();
315			let response = multiplex.call(request).await.unwrap();
316			let content = hyper::body::to_bytes(response.into_body()).await.unwrap();
317
318			assert_ne!(content.len(), 0);
319			assert_eq!(content, "gRPC service");
320		}
321	}
322
323	#[tokio::test]
324	async fn encapsulated_body_poll_data_grpc() {
325		let string = "body grpc";
326		let body = EncapsulatedBody::<Body, Body>::Grpc(Body::from(string));
327
328		let data = hyper::body::to_bytes(body).await.unwrap();
329		assert_eq!(data, string);
330	}
331
332	#[tokio::test]
333	async fn encapsulated_body_poll_data_web() {
334		let string = "body web";
335		let body = EncapsulatedBody::<Body, Body>::Grpc(Body::from(string));
336
337		let data = hyper::body::to_bytes(body).await.unwrap();
338		assert_eq!(data, string);
339	}
340
341	#[tokio::test]
342	async fn encapsulated_body_poll_trailers_grpc() {
343		let (mut sender, body) = Body::channel();
344		let mut header_map = HeaderMap::new();
345		header_map.insert("From", "grpc sender".parse().unwrap());
346		let header_map = header_map;
347		sender.send_trailers(header_map.clone()).await.unwrap();
348
349		let mut body = EncapsulatedBody::<Body, Body>::Grpc(body);
350
351		let headers = body
352			.trailers()
353			.await
354			.unwrap()
355			.expect("Should return trailers!");
356		assert_eq!(headers, header_map);
357	}
358
359	#[tokio::test]
360	async fn encapsulated_body_poll_trailers_web() {
361		let (mut sender, body) = Body::channel();
362		let mut header_map = HeaderMap::new();
363		header_map.insert("From", "web sender".parse().unwrap());
364		let header_map = header_map;
365		sender.send_trailers(header_map.clone()).await.unwrap();
366
367		let mut body = EncapsulatedBody::<Body, Body>::Web(body);
368
369		let headers = body
370			.trailers()
371			.await
372			.unwrap()
373			.expect("Should return trailers!");
374		assert_eq!(headers, header_map);
375	}
376}