dubbo/triple/server/
triple.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18use futures_util::{future, stream, StreamExt, TryStreamExt};
19use http::HeaderValue;
20use http_body::Body;
21use prost::Message;
22use serde::{Deserialize, Serialize};
23use std::marker::PhantomData;
24
25use crate::{
26    invocation::Request,
27    status::Status,
28    triple::{
29        client::triple::get_codec,
30        codec::{Decoder, Encoder},
31        compression::{CompressionEncoding, COMPRESSIONS},
32        decode::Decoding,
33        encode::encode_server,
34        server::service::{ClientStreamingSvc, ServerStreamingSvc, StreamingSvc, UnarySvc},
35    },
36    BoxBody,
37};
38
39pub const GRPC_ACCEPT_ENCODING: &str = "grpc-accept-encoding";
40pub const GRPC_ENCODING: &str = "grpc-encoding";
41
42pub struct TripleServer<M1, M2> {
43    _pd: PhantomData<(M1, M2)>,
44    compression: Option<CompressionEncoding>,
45}
46
47impl<M1, M2> TripleServer<M1, M2> {
48    pub fn new() -> Self {
49        Self {
50            _pd: PhantomData,
51            compression: Some(CompressionEncoding::Gzip),
52        }
53    }
54}
55
56impl<M1, M2> TripleServer<M1, M2>
57where
58    M1: Message + for<'a> Deserialize<'a> + Default + 'static,
59    M2: Message + Serialize + Default + 'static,
60{
61    pub async fn client_streaming<S, B>(
62        &mut self,
63        mut service: S,
64        req: http::Request<B>,
65    ) -> http::Response<BoxBody>
66    where
67        S: ClientStreamingSvc<M1, Response = M2>,
68        B: Body + Send + 'static,
69        B::Error: Into<crate::Error> + Send,
70    {
71        let content_type = req
72            .headers()
73            .get("content-type")
74            .cloned()
75            .unwrap_or(HeaderValue::from_str("application/grpc+proto").unwrap());
76        let content_type_str = content_type.to_str().unwrap();
77        let (decoder, encoder): (
78            Box<dyn Decoder<Item = M1, Error = Status> + Send + 'static>,
79            Box<dyn Encoder<Error = Status, Item = M2> + Send + 'static>,
80        ) = get_codec(content_type_str);
81        let mut accept_encoding = CompressionEncoding::from_accept_encoding(req.headers());
82        if self.compression.is_none() || accept_encoding.is_none() {
83            accept_encoding = None;
84        }
85
86        // Get grpc_encoding from http_header, decompress message.
87        let compression = match self.get_encoding_from_req(req.headers()) {
88            Ok(val) => val,
89            Err(status) => return status.to_http(),
90        };
91
92        let req_stream = req.map(|body| Decoding::new(body, decoder, compression, true));
93
94        let resp = service.call(Request::from_http(req_stream)).await;
95
96        let (mut parts, resp_body) = match resp {
97            Ok(v) => v.into_http().into_parts(),
98            Err(err) => return err.to_http(),
99        };
100
101        let resp_body = encode_server(
102            encoder,
103            stream::once(future::ready(resp_body)).map(Ok).into_stream(),
104            accept_encoding,
105            true,
106        );
107
108        parts
109            .headers
110            .insert(http::header::CONTENT_TYPE, content_type);
111        if let Some(encoding) = accept_encoding {
112            parts
113                .headers
114                .insert(GRPC_ENCODING, encoding.into_header_value());
115        }
116        parts.status = http::StatusCode::OK;
117        http::Response::from_parts(parts, BoxBody::new(resp_body))
118    }
119
120    pub async fn bidi_streaming<S, B>(
121        &mut self,
122        mut service: S,
123        req: http::Request<B>,
124    ) -> http::Response<BoxBody>
125    where
126        S: StreamingSvc<M1, Response = M2>,
127        S::ResponseStream: Send + 'static,
128        B: Body + Send + 'static,
129        B::Error: Into<crate::Error> + Send,
130    {
131        let content_type = req
132            .headers()
133            .get("content-type")
134            .cloned()
135            .unwrap_or(HeaderValue::from_str("application/grpc+proto").unwrap());
136        let content_type_str = content_type.to_str().unwrap();
137        let (decoder, encoder): (
138            Box<dyn Decoder<Item = M1, Error = Status> + Send + 'static>,
139            Box<dyn Encoder<Error = Status, Item = M2> + Send + 'static>,
140        ) = get_codec(content_type_str);
141        // Firstly, get grpc_accept_encoding from http_header, get compression
142        // Secondly, if server enable compression and compression is valid, this method should compress response
143        let mut accept_encoding = CompressionEncoding::from_accept_encoding(req.headers());
144        if self.compression.is_none() || accept_encoding.is_none() {
145            accept_encoding = None;
146        }
147
148        // Get grpc_encoding from http_header, decompress message.
149        let compression = match self.get_encoding_from_req(req.headers()) {
150            Ok(val) => val,
151            Err(status) => return status.to_http(),
152        };
153
154        let req_stream = req.map(|body| Decoding::new(body, decoder, compression, true));
155
156        let resp = service.call(Request::from_http(req_stream)).await;
157
158        let (mut parts, resp_body) = match resp {
159            Ok(v) => v.into_http().into_parts(),
160            Err(err) => return err.to_http(),
161        };
162        let resp_body = encode_server(encoder, resp_body, compression, true);
163
164        parts
165            .headers
166            .insert(http::header::CONTENT_TYPE, content_type);
167        if let Some(encoding) = accept_encoding {
168            parts
169                .headers
170                .insert(GRPC_ENCODING, encoding.into_header_value());
171        }
172        parts.status = http::StatusCode::OK;
173        http::Response::from_parts(parts, BoxBody::new(resp_body))
174    }
175
176    pub async fn server_streaming<S, B>(
177        &mut self,
178        mut service: S,
179        req: http::Request<B>,
180    ) -> http::Response<BoxBody>
181    where
182        S: ServerStreamingSvc<M1, Response = M2>,
183        S::ResponseStream: Send + 'static,
184        B: Body + Send + 'static,
185        B::Error: Into<crate::Error> + Send,
186    {
187        let content_type = req
188            .headers()
189            .get("content-type")
190            .cloned()
191            .unwrap_or(HeaderValue::from_str("application/grpc+proto").unwrap());
192        let content_type_str = content_type.to_str().unwrap();
193        let (decoder, encoder): (
194            Box<dyn Decoder<Item = M1, Error = Status> + Send + 'static>,
195            Box<dyn Encoder<Error = Status, Item = M2> + Send + 'static>,
196        ) = get_codec(content_type_str);
197        // Firstly, get grpc_accept_encoding from http_header, get compression
198        // Secondly, if server enable compression and compression is valid, this method should compress response
199        let mut accept_encoding = CompressionEncoding::from_accept_encoding(req.headers());
200        if self.compression.is_none() || accept_encoding.is_none() {
201            accept_encoding = None;
202        }
203
204        // Get grpc_encoding from http_header, decompress message.
205        let compression = match self.get_encoding_from_req(req.headers()) {
206            Ok(val) => val,
207            Err(status) => return status.to_http(),
208        };
209        let req_stream = req.map(|body| Decoding::new(body, decoder, compression, true));
210        let (parts, mut body) = Request::from_http(req_stream).into_parts();
211        let msg = body.try_next().await.unwrap().ok_or_else(|| {
212            crate::status::Status::new(crate::status::Code::Unknown, "request wrong".to_string())
213        });
214        let msg = match msg {
215            Ok(v) => v,
216            Err(err) => return err.to_http(),
217        };
218
219        let resp = service.call(Request::from_parts(parts, msg)).await;
220
221        let (mut parts, resp_body) = match resp {
222            Ok(v) => v.into_http().into_parts(),
223            Err(err) => return err.to_http(),
224        };
225        let resp_body = encode_server(encoder, resp_body, compression, true);
226
227        parts
228            .headers
229            .insert(http::header::CONTENT_TYPE, content_type);
230        if let Some(encoding) = accept_encoding {
231            parts
232                .headers
233                .insert(GRPC_ENCODING, encoding.into_header_value());
234        }
235        parts.status = http::StatusCode::OK;
236        http::Response::from_parts(parts, BoxBody::new(resp_body))
237    }
238
239    pub async fn unary<S, B>(
240        &mut self,
241        mut service: S,
242        req: http::Request<B>,
243    ) -> http::Response<BoxBody>
244    where
245        S: UnarySvc<M1, Response = M2>,
246        B: Body + Send + 'static,
247        B::Error: Into<crate::Error> + Send,
248    {
249        let mut accept_encoding = CompressionEncoding::from_accept_encoding(req.headers());
250        if self.compression.is_none() || accept_encoding.is_none() {
251            accept_encoding = None;
252        }
253        let compression = match self.get_encoding_from_req(req.headers()) {
254            Ok(val) => val,
255            Err(status) => return status.to_http(),
256        };
257        let content_type = req
258            .headers()
259            .get("content-type")
260            .cloned()
261            .unwrap_or(HeaderValue::from_str("application/grpc+proto").unwrap());
262        let content_type_str = content_type.to_str().unwrap();
263        //Determine whether to use the gRPC mode to handle request data
264        let handle_request_as_grpc = content_type_str.contains("grpc");
265        let (decoder, encoder): (
266            Box<dyn Decoder<Item = M1, Error = Status> + Send + 'static>,
267            Box<dyn Encoder<Error = Status, Item = M2> + Send + 'static>,
268        ) = get_codec(content_type_str);
269        let req_stream =
270            req.map(|body| Decoding::new(body, decoder, compression, handle_request_as_grpc));
271        let (parts, mut body) = Request::from_http(req_stream).into_parts();
272        let msg = body.try_next().await.unwrap().ok_or_else(|| {
273            crate::status::Status::new(crate::status::Code::Unknown, "request wrong".to_string())
274        });
275        let msg = match msg {
276            Ok(v) => v,
277            Err(err) => return err.to_http(),
278        };
279
280        let resp = service.call(Request::from_parts(parts, msg)).await;
281
282        let (mut parts, resp_body) = match resp {
283            Ok(v) => v.into_http().into_parts(),
284            Err(err) => return err.to_http(),
285        };
286        let resp_body = encode_server(
287            encoder,
288            stream::once(future::ready(resp_body)).map(Ok).into_stream(),
289            accept_encoding,
290            handle_request_as_grpc,
291        );
292
293        parts
294            .headers
295            .insert(http::header::CONTENT_TYPE, content_type);
296        if let Some(encoding) = accept_encoding {
297            parts
298                .headers
299                .insert(GRPC_ENCODING, encoding.into_header_value());
300        }
301        parts.status = http::StatusCode::OK;
302        http::Response::from_parts(parts, BoxBody::new(resp_body))
303    }
304
305    fn get_encoding_from_req(
306        &self,
307        header: &http::HeaderMap,
308    ) -> Result<Option<CompressionEncoding>, crate::status::Status> {
309        let encoding = match header.get(GRPC_ENCODING) {
310            Some(val) => val.to_str().unwrap(),
311            None => return Ok(None),
312        };
313
314        let compression = match COMPRESSIONS.get(encoding) {
315            Some(val) => val.to_owned(),
316            None => {
317                let status = crate::status::Status::new(
318                    crate::status::Code::Unimplemented,
319                    format!("grpc-accept-encoding: {} not support!", encoding),
320                );
321
322                return Err(status);
323            }
324        };
325        Ok(compression)
326    }
327}