1use std::error::Error;
2use std::fmt::{self, Display, Formatter};
3use std::future::ready;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures::{future, Future, TryFutureExt};
9use hyper::client::HttpConnector;
10use hyper::header::{HeaderMap, ALLOW, CONTENT_LENGTH, CONTENT_TYPE};
11use hyper::http::{self, HeaderValue};
12use hyper::service::Service;
13use hyper::{Body, Client, Method, Request, Response, StatusCode, Uri, Version};
14use prost::{DecodeError, EncodeError, Message};
15
16pub type PTRes<O> =
18 Pin<Box<dyn Future<Output = Result<ServiceResponse<O>, ProstTwirpError>> + Send + 'static>>;
19
20static JSON_CONTENT_TYPE: &str = "application/json";
21static PROTOBUF_CONTENT_TYPE: &str = "application/protobuf";
22
23#[derive(Debug)]
25pub struct ServiceRequest<T: Message> {
26 pub uri: Uri,
30 pub method: Method,
32 pub version: Version,
34 pub headers: HeaderMap,
38 pub input: T,
40}
41
42impl<T: Message> ServiceRequest<T> {
43 pub fn new(input: T) -> ServiceRequest<T> {
47 let mut headers = HeaderMap::new();
48 headers.insert(
49 CONTENT_TYPE,
50 HeaderValue::from_static(PROTOBUF_CONTENT_TYPE),
51 );
52 ServiceRequest {
53 uri: Default::default(),
54 method: Method::POST,
55 version: Version::default(),
56 headers,
57 input,
58 }
59 }
60
61 pub fn clone_with_input(&self, input: T) -> ServiceRequest<T> {
63 ServiceRequest {
64 uri: self.uri.clone(),
65 method: self.method.clone(),
66 version: self.version,
67 headers: self.headers.clone(),
68 input,
69 }
70 }
71}
72
73impl<T: Message + Default + 'static> From<T> for ServiceRequest<T> {
74 fn from(v: T) -> ServiceRequest<T> {
75 ServiceRequest::new(v)
76 }
77}
78
79impl<T: Message + Default + 'static> ServiceRequest<T> {
80 pub fn to_hyper_request(&self) -> Result<Request<Body>, ProstTwirpError> {
82 let mut body = Vec::new();
83 self.input
84 .encode(&mut body)
85 .map_err(ProstTwirpError::ProstEncodeError)?;
86 let mut builder = Request::post(self.uri.clone());
87 builder.headers_mut().unwrap().clone_from(&self.headers);
88 builder
89 .header(CONTENT_LENGTH, body.len() as u64)
90 .body(Body::from(body))
91 .map_err(ProstTwirpError::from)
92 }
93
94 pub async fn from_hyper_request(
95 req: Request<Body>,
96 ) -> Result<ServiceRequest<T>, ProstTwirpError> {
97 if req.method() != Method::POST {
98 return Err(ProstTwirpError::InvalidMethod);
99 } else if req
100 .headers()
101 .get(CONTENT_TYPE)
102 .map_or(true, |v| v != PROTOBUF_CONTENT_TYPE)
103 {
104 return Err(ProstTwirpError::InvalidContentType);
105 }
106 let uri = req.uri().clone();
107 let method = req.method().clone();
108 let version = req.version();
109 let headers = req.headers().clone();
110 let body_bytes = hyper::body::to_bytes(req.into_body()).await?;
111 match T::decode(body_bytes.clone()) {
112 Ok(input) => Ok(ServiceRequest {
113 uri,
114 method,
115 version,
116 headers,
117 input,
118 }),
119 Err(err) => Err(ProstTwirpError::AfterBodyError {
120 status: None,
121 method: Some(method),
122 version,
123 headers,
124 err: Box::new(ProstTwirpError::ProstDecodeError(err)),
125 body: body_bytes.to_vec(),
126 }),
127 }
128 }
129}
130
131#[derive(Debug)]
133pub struct ServiceResponse<M: Message> {
134 pub version: Version,
136 pub headers: HeaderMap,
140 pub status: StatusCode,
142 pub output: M,
144}
145
146impl<M: Message> ServiceResponse<M> {
147 pub fn new(output: M) -> ServiceResponse<M> {
151 let mut headers = HeaderMap::new();
152 headers.insert(
153 CONTENT_TYPE,
154 HeaderValue::from_static(PROTOBUF_CONTENT_TYPE),
155 );
156 ServiceResponse {
157 version: Version::default(),
158 headers,
159 status: StatusCode::OK,
160 output,
161 }
162 }
163
164 pub fn clone_with_output(&self, output: M) -> ServiceResponse<M> {
166 ServiceResponse {
167 version: self.version,
168 headers: self.headers.clone(),
169 status: self.status,
170 output,
171 }
172 }
173}
174
175impl<M: Message + Default + 'static> From<M> for ServiceResponse<M> {
176 fn from(v: M) -> ServiceResponse<M> {
177 ServiceResponse::new(v)
178 }
179}
180
181impl<M: Message + Default> ServiceResponse<M> {
182 pub async fn from_hyper_response(resp: Response<Body>) -> Result<Self, ProstTwirpError> {
184 let version = resp.version();
185 let headers = resp.headers().clone();
186 let status = resp.status();
187 let body_bytes = hyper::body::to_bytes(resp.into_body()).await?;
188 let err = if status.is_success() {
189 match M::decode(&*body_bytes) {
190 Ok(output) => {
191 return Ok(ServiceResponse {
192 version,
193 headers,
194 status,
195 output,
196 })
197 }
198 Err(err) => ProstTwirpError::ProstDecodeError(err),
199 }
200 } else {
201 match TwirpError::from_json_bytes(status, &body_bytes) {
202 Ok(err) => ProstTwirpError::TwirpError(err),
203 Err(err) => ProstTwirpError::JsonDecodeError(err),
204 }
205 };
206 Err(ProstTwirpError::AfterBodyError {
207 body: body_bytes.to_vec(),
208 method: None,
209 version,
210 headers,
211 status: Some(status),
212 err: Box::new(err),
213 })
214 }
215
216 pub fn to_hyper_response(&self) -> Result<Response<Body>, ProstTwirpError> {
218 let body_bytes = self.output.encode_to_vec();
219 let mut builder = Response::builder().status(self.status);
220 builder.headers_mut().unwrap().clone_from(&self.headers);
221 builder
222 .header(CONTENT_LENGTH, body_bytes.len() as u64)
223 .body(body_bytes.into())
224 .map_err(ProstTwirpError::from)
225 }
226}
227
228#[derive(Debug, Clone)]
230pub struct TwirpError {
231 pub status: StatusCode,
232 pub error_type: String,
233 pub msg: String,
234 pub meta: Option<serde_json::Value>,
235}
236
237impl TwirpError {
238 pub fn new(status: StatusCode, error_type: &str, msg: &str) -> TwirpError {
240 TwirpError::new_meta(status, error_type, msg, None)
241 }
242
243 pub fn new_meta(
245 status: StatusCode,
246 error_type: &str,
247 msg: &str,
248 meta: Option<serde_json::Value>,
249 ) -> TwirpError {
250 TwirpError {
251 status,
252 error_type: error_type.to_string(),
253 msg: msg.to_string(),
254 meta,
255 }
256 }
257
258 pub fn to_hyper_response(&self) -> Response<Body> {
260 let body_bytes = self
261 .to_json_bytes()
262 .unwrap_or_else(|_| "{}".as_bytes().to_vec());
263 let body_len = body_bytes.len() as u64;
264 Response::builder()
265 .status(self.status)
266 .header(CONTENT_TYPE, JSON_CONTENT_TYPE)
267 .header(CONTENT_LENGTH, HeaderValue::from(body_len))
268 .header(ALLOW, HeaderValue::from_static("POST"))
269 .body(Body::from(body_bytes))
270 .expect("failed to serialize twirp error")
271 }
275
276 pub fn from_json(status: StatusCode, json: serde_json::Value) -> TwirpError {
278 let error_type = json["error_type"].as_str();
279 TwirpError {
280 status,
281 error_type: error_type.unwrap_or("<no code>").to_string(),
282 msg: json["msg"].as_str().unwrap_or("<no message>").to_string(),
283 meta: if error_type.is_some() {
285 json.get("meta").cloned()
286 } else {
287 Some(json.clone())
288 },
289 }
290 }
291
292 pub fn from_json_bytes(status: StatusCode, json: &[u8]) -> serde_json::Result<TwirpError> {
294 serde_json::from_slice(json).map(|v| TwirpError::from_json(status, v))
295 }
296
297 pub fn to_json(&self) -> serde_json::Value {
299 let mut props = serde_json::map::Map::new();
300 props.insert(
301 "error_type".to_string(),
302 serde_json::Value::String(self.error_type.clone()),
303 );
304 props.insert(
305 "msg".to_string(),
306 serde_json::Value::String(self.msg.clone()),
307 );
308 if let Some(ref meta) = self.meta {
309 props.insert("meta".to_string(), meta.clone());
310 }
311 serde_json::Value::Object(props)
312 }
313
314 pub fn to_json_bytes(&self) -> serde_json::Result<Vec<u8>> {
316 serde_json::to_vec(&self.to_json())
317 }
318}
319
320impl Error for TwirpError {}
321
322impl Display for TwirpError {
323 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
324 write!(f, "{:?} {}: {}", self.status, self.error_type, self.msg)
325 }
326}
327
328impl From<TwirpError> for ProstTwirpError {
329 fn from(v: TwirpError) -> ProstTwirpError {
330 ProstTwirpError::TwirpError(v)
331 }
332}
333
334#[derive(Debug)]
336#[non_exhaustive]
337pub enum ProstTwirpError {
338 TwirpError(TwirpError),
340 JsonDecodeError(serde_json::Error),
342 ProstEncodeError(EncodeError),
344 ProstDecodeError(DecodeError),
346 HyperError(hyper::Error),
348 HttpError(http::Error),
350 InvalidUri(http::uri::InvalidUri),
352 InvalidMethod,
354 InvalidContentType,
356 NotFound,
358 AfterBodyError {
360 body: Vec<u8>,
362 method: Option<Method>,
364 version: Version,
366 headers: HeaderMap,
368 status: Option<StatusCode>,
370 err: Box<ProstTwirpError>,
372 },
373}
374
375impl ProstTwirpError {
376 pub fn root_err(self) -> ProstTwirpError {
378 match self {
379 ProstTwirpError::AfterBodyError { err, .. } => err.root_err(),
380 _ => self,
381 }
382 }
383
384 pub fn into_hyper_response(self) -> Result<Response<Body>, hyper::Error> {
385 let external_err = match self {
386 ProstTwirpError::TwirpError(err) => err,
387 ProstTwirpError::HyperError(err) => return Err(err),
389 ProstTwirpError::InvalidMethod => TwirpError::new(
390 StatusCode::METHOD_NOT_ALLOWED,
391 "bad_method",
392 "Method must be POST",
393 ),
394 ProstTwirpError::ProstDecodeError(_) => TwirpError::new(
395 StatusCode::BAD_REQUEST,
396 "protobuf_decode_err",
397 "Invalid protobuf body",
398 ),
399 ProstTwirpError::InvalidContentType => TwirpError::new(
400 StatusCode::UNSUPPORTED_MEDIA_TYPE,
401 "bad_content_type",
402 "Content type must be application/protobuf",
403 ),
404 ProstTwirpError::NotFound => TwirpError::new(
405 StatusCode::NOT_FOUND,
406 "not_found",
407 "The requested method was not found",
408 ),
409 _ => TwirpError::new(
410 StatusCode::INTERNAL_SERVER_ERROR,
411 "internal_err",
412 "Internal error",
413 ),
414 };
415 Ok(external_err.to_hyper_response())
416 }
417}
418
419impl From<hyper::Error> for ProstTwirpError {
420 fn from(v: hyper::Error) -> ProstTwirpError {
421 ProstTwirpError::HyperError(v)
422 }
423}
424
425impl From<http::Error> for ProstTwirpError {
426 fn from(v: http::Error) -> ProstTwirpError {
427 ProstTwirpError::HttpError(v)
428 }
429}
430
431impl Display for ProstTwirpError {
432 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
433 write!(f, "{:?}", self)
434 }
435}
436
437impl Error for ProstTwirpError {
438 fn source(&self) -> Option<&(dyn Error + 'static)> {
439 match self {
440 ProstTwirpError::TwirpError(err) => Some(err),
441 ProstTwirpError::JsonDecodeError(err) => Some(err),
442 ProstTwirpError::ProstEncodeError(err) => Some(err),
443 ProstTwirpError::ProstDecodeError(err) => Some(err),
444 ProstTwirpError::HyperError(err) => Some(err),
445 ProstTwirpError::HttpError(err) => Some(err),
446 ProstTwirpError::InvalidUri(err) => Some(err),
447 ProstTwirpError::InvalidMethod => None,
448 ProstTwirpError::InvalidContentType => None,
449 ProstTwirpError::NotFound => None,
450 ProstTwirpError::AfterBodyError { err, .. } => Some(err),
451 }
452 }
453}
454
455#[derive(Debug)]
457pub struct HyperClient {
458 pub client: Client<HttpConnector>,
460 pub root_url: String,
462}
463
464impl HyperClient {
465 pub fn new(client: Client<HttpConnector>, root_url: &str) -> HyperClient {
467 HyperClient {
468 client,
469 root_url: root_url.trim_end_matches('/').to_string(),
470 }
471 }
472
473 pub fn go<I, O>(&self, path: &str, req: ServiceRequest<I>) -> PTRes<O>
475 where
476 I: Message + Default + 'static,
477 O: Message + Default + 'static,
478 {
479 let uri = match format!("{}/{}", self.root_url, path.trim_start_matches('/')).parse() {
481 Err(err) => return Box::pin(ready(Err(ProstTwirpError::InvalidUri(err)))),
482 Ok(v) => v,
483 };
484 let mut hyper_req = match req.to_hyper_request() {
486 Err(err) => return Box::pin(ready(Err(err))),
487 Ok(v) => v,
488 };
489 *hyper_req.uri_mut() = uri;
490 Box::pin(
492 self.client
493 .request(hyper_req)
494 .map_err(ProstTwirpError::HyperError)
495 .and_then(ServiceResponse::from_hyper_response),
496 )
497 }
498}
499
500pub trait HyperService {
515 fn handle(
517 &self,
518 req: Request<Body>,
519 ) -> Pin<Box<dyn Future<Output = Result<Response<Body>, ProstTwirpError>> + Send>>;
520}
521
522pub struct HyperServer<T: HyperService + Send + Sync + 'static> {
533 pub service: Arc<T>,
537}
538
539impl<T: HyperService + Send + Sync + 'static> HyperServer<T> {
540 pub fn new(service: T) -> HyperServer<T> {
542 HyperServer {
543 service: Arc::new(service),
544 }
545 }
546}
547
548impl<T: 'static + HyperService + Send + Sync> Service<Request<Body>> for HyperServer<T> {
549 type Response = Response<Body>;
550 type Error = hyper::Error;
551 type Future = Pin<Box<dyn (Future<Output = Result<Self::Response, Self::Error>>) + Send>>;
552
553 fn poll_ready(&mut self, _context: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
554 Poll::Ready(Ok(()))
555 }
556
557 fn call(&mut self, req: Request<Body>) -> Self::Future {
558 let service = self.service.clone();
560 Box::pin(
561 service
562 .handle(req)
563 .or_else(|err| future::ready(err.into_hyper_response())),
564 )
565 }
566}