1use futures_util::{future, stream, StreamExt, TryStreamExt};
19use http::HeaderValue;
20use http_body::Body;
21use prost::Message;
22use serde::{Deserialize, Serialize};
23use std::marker::PhantomData;
24
25use crate::{
26 invocation::Request,
27 status::Status,
28 triple::{
29 client::triple::get_codec,
30 codec::{Decoder, Encoder},
31 compression::{CompressionEncoding, COMPRESSIONS},
32 decode::Decoding,
33 encode::encode_server,
34 server::service::{ClientStreamingSvc, ServerStreamingSvc, StreamingSvc, UnarySvc},
35 },
36 BoxBody,
37};
38
39pub const GRPC_ACCEPT_ENCODING: &str = "grpc-accept-encoding";
40pub const GRPC_ENCODING: &str = "grpc-encoding";
41
42pub struct TripleServer<M1, M2> {
43 _pd: PhantomData<(M1, M2)>,
44 compression: Option<CompressionEncoding>,
45}
46
47impl<M1, M2> TripleServer<M1, M2> {
48 pub fn new() -> Self {
49 Self {
50 _pd: PhantomData,
51 compression: Some(CompressionEncoding::Gzip),
52 }
53 }
54}
55
56impl<M1, M2> TripleServer<M1, M2>
57where
58 M1: Message + for<'a> Deserialize<'a> + Default + 'static,
59 M2: Message + Serialize + Default + 'static,
60{
61 pub async fn client_streaming<S, B>(
62 &mut self,
63 mut service: S,
64 req: http::Request<B>,
65 ) -> http::Response<BoxBody>
66 where
67 S: ClientStreamingSvc<M1, Response = M2>,
68 B: Body + Send + 'static,
69 B::Error: Into<crate::Error> + Send,
70 {
71 let content_type = req
72 .headers()
73 .get("content-type")
74 .cloned()
75 .unwrap_or(HeaderValue::from_str("application/grpc+proto").unwrap());
76 let content_type_str = content_type.to_str().unwrap();
77 let (decoder, encoder): (
78 Box<dyn Decoder<Item = M1, Error = Status> + Send + 'static>,
79 Box<dyn Encoder<Error = Status, Item = M2> + Send + 'static>,
80 ) = get_codec(content_type_str);
81 let mut accept_encoding = CompressionEncoding::from_accept_encoding(req.headers());
82 if self.compression.is_none() || accept_encoding.is_none() {
83 accept_encoding = None;
84 }
85
86 let compression = match self.get_encoding_from_req(req.headers()) {
88 Ok(val) => val,
89 Err(status) => return status.to_http(),
90 };
91
92 let req_stream = req.map(|body| Decoding::new(body, decoder, compression, true));
93
94 let resp = service.call(Request::from_http(req_stream)).await;
95
96 let (mut parts, resp_body) = match resp {
97 Ok(v) => v.into_http().into_parts(),
98 Err(err) => return err.to_http(),
99 };
100
101 let resp_body = encode_server(
102 encoder,
103 stream::once(future::ready(resp_body)).map(Ok).into_stream(),
104 accept_encoding,
105 true,
106 );
107
108 parts
109 .headers
110 .insert(http::header::CONTENT_TYPE, content_type);
111 if let Some(encoding) = accept_encoding {
112 parts
113 .headers
114 .insert(GRPC_ENCODING, encoding.into_header_value());
115 }
116 parts.status = http::StatusCode::OK;
117 http::Response::from_parts(parts, BoxBody::new(resp_body))
118 }
119
120 pub async fn bidi_streaming<S, B>(
121 &mut self,
122 mut service: S,
123 req: http::Request<B>,
124 ) -> http::Response<BoxBody>
125 where
126 S: StreamingSvc<M1, Response = M2>,
127 S::ResponseStream: Send + 'static,
128 B: Body + Send + 'static,
129 B::Error: Into<crate::Error> + Send,
130 {
131 let content_type = req
132 .headers()
133 .get("content-type")
134 .cloned()
135 .unwrap_or(HeaderValue::from_str("application/grpc+proto").unwrap());
136 let content_type_str = content_type.to_str().unwrap();
137 let (decoder, encoder): (
138 Box<dyn Decoder<Item = M1, Error = Status> + Send + 'static>,
139 Box<dyn Encoder<Error = Status, Item = M2> + Send + 'static>,
140 ) = get_codec(content_type_str);
141 let mut accept_encoding = CompressionEncoding::from_accept_encoding(req.headers());
144 if self.compression.is_none() || accept_encoding.is_none() {
145 accept_encoding = None;
146 }
147
148 let compression = match self.get_encoding_from_req(req.headers()) {
150 Ok(val) => val,
151 Err(status) => return status.to_http(),
152 };
153
154 let req_stream = req.map(|body| Decoding::new(body, decoder, compression, true));
155
156 let resp = service.call(Request::from_http(req_stream)).await;
157
158 let (mut parts, resp_body) = match resp {
159 Ok(v) => v.into_http().into_parts(),
160 Err(err) => return err.to_http(),
161 };
162 let resp_body = encode_server(encoder, resp_body, compression, true);
163
164 parts
165 .headers
166 .insert(http::header::CONTENT_TYPE, content_type);
167 if let Some(encoding) = accept_encoding {
168 parts
169 .headers
170 .insert(GRPC_ENCODING, encoding.into_header_value());
171 }
172 parts.status = http::StatusCode::OK;
173 http::Response::from_parts(parts, BoxBody::new(resp_body))
174 }
175
176 pub async fn server_streaming<S, B>(
177 &mut self,
178 mut service: S,
179 req: http::Request<B>,
180 ) -> http::Response<BoxBody>
181 where
182 S: ServerStreamingSvc<M1, Response = M2>,
183 S::ResponseStream: Send + 'static,
184 B: Body + Send + 'static,
185 B::Error: Into<crate::Error> + Send,
186 {
187 let content_type = req
188 .headers()
189 .get("content-type")
190 .cloned()
191 .unwrap_or(HeaderValue::from_str("application/grpc+proto").unwrap());
192 let content_type_str = content_type.to_str().unwrap();
193 let (decoder, encoder): (
194 Box<dyn Decoder<Item = M1, Error = Status> + Send + 'static>,
195 Box<dyn Encoder<Error = Status, Item = M2> + Send + 'static>,
196 ) = get_codec(content_type_str);
197 let mut accept_encoding = CompressionEncoding::from_accept_encoding(req.headers());
200 if self.compression.is_none() || accept_encoding.is_none() {
201 accept_encoding = None;
202 }
203
204 let compression = match self.get_encoding_from_req(req.headers()) {
206 Ok(val) => val,
207 Err(status) => return status.to_http(),
208 };
209 let req_stream = req.map(|body| Decoding::new(body, decoder, compression, true));
210 let (parts, mut body) = Request::from_http(req_stream).into_parts();
211 let msg = body.try_next().await.unwrap().ok_or_else(|| {
212 crate::status::Status::new(crate::status::Code::Unknown, "request wrong".to_string())
213 });
214 let msg = match msg {
215 Ok(v) => v,
216 Err(err) => return err.to_http(),
217 };
218
219 let resp = service.call(Request::from_parts(parts, msg)).await;
220
221 let (mut parts, resp_body) = match resp {
222 Ok(v) => v.into_http().into_parts(),
223 Err(err) => return err.to_http(),
224 };
225 let resp_body = encode_server(encoder, resp_body, compression, true);
226
227 parts
228 .headers
229 .insert(http::header::CONTENT_TYPE, content_type);
230 if let Some(encoding) = accept_encoding {
231 parts
232 .headers
233 .insert(GRPC_ENCODING, encoding.into_header_value());
234 }
235 parts.status = http::StatusCode::OK;
236 http::Response::from_parts(parts, BoxBody::new(resp_body))
237 }
238
239 pub async fn unary<S, B>(
240 &mut self,
241 mut service: S,
242 req: http::Request<B>,
243 ) -> http::Response<BoxBody>
244 where
245 S: UnarySvc<M1, Response = M2>,
246 B: Body + Send + 'static,
247 B::Error: Into<crate::Error> + Send,
248 {
249 let mut accept_encoding = CompressionEncoding::from_accept_encoding(req.headers());
250 if self.compression.is_none() || accept_encoding.is_none() {
251 accept_encoding = None;
252 }
253 let compression = match self.get_encoding_from_req(req.headers()) {
254 Ok(val) => val,
255 Err(status) => return status.to_http(),
256 };
257 let content_type = req
258 .headers()
259 .get("content-type")
260 .cloned()
261 .unwrap_or(HeaderValue::from_str("application/grpc+proto").unwrap());
262 let content_type_str = content_type.to_str().unwrap();
263 let handle_request_as_grpc = content_type_str.contains("grpc");
265 let (decoder, encoder): (
266 Box<dyn Decoder<Item = M1, Error = Status> + Send + 'static>,
267 Box<dyn Encoder<Error = Status, Item = M2> + Send + 'static>,
268 ) = get_codec(content_type_str);
269 let req_stream =
270 req.map(|body| Decoding::new(body, decoder, compression, handle_request_as_grpc));
271 let (parts, mut body) = Request::from_http(req_stream).into_parts();
272 let msg = body.try_next().await.unwrap().ok_or_else(|| {
273 crate::status::Status::new(crate::status::Code::Unknown, "request wrong".to_string())
274 });
275 let msg = match msg {
276 Ok(v) => v,
277 Err(err) => return err.to_http(),
278 };
279
280 let resp = service.call(Request::from_parts(parts, msg)).await;
281
282 let (mut parts, resp_body) = match resp {
283 Ok(v) => v.into_http().into_parts(),
284 Err(err) => return err.to_http(),
285 };
286 let resp_body = encode_server(
287 encoder,
288 stream::once(future::ready(resp_body)).map(Ok).into_stream(),
289 accept_encoding,
290 handle_request_as_grpc,
291 );
292
293 parts
294 .headers
295 .insert(http::header::CONTENT_TYPE, content_type);
296 if let Some(encoding) = accept_encoding {
297 parts
298 .headers
299 .insert(GRPC_ENCODING, encoding.into_header_value());
300 }
301 parts.status = http::StatusCode::OK;
302 http::Response::from_parts(parts, BoxBody::new(resp_body))
303 }
304
305 fn get_encoding_from_req(
306 &self,
307 header: &http::HeaderMap,
308 ) -> Result<Option<CompressionEncoding>, crate::status::Status> {
309 let encoding = match header.get(GRPC_ENCODING) {
310 Some(val) => val.to_str().unwrap(),
311 None => return Ok(None),
312 };
313
314 let compression = match COMPRESSIONS.get(encoding) {
315 Some(val) => val.to_owned(),
316 None => {
317 let status = crate::status::Status::new(
318 crate::status::Code::Unimplemented,
319 format!("grpc-accept-encoding: {} not support!", encoding),
320 );
321
322 return Err(status);
323 }
324 };
325 Ok(compression)
326 }
327}