Skip to main content

prost_twirp/
service_run.rs

1use std::error::Error;
2use std::fmt::{self, Display, Formatter};
3use std::future::ready;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures::{future, Future, TryFutureExt};
9use hyper::client::HttpConnector;
10use hyper::header::{HeaderMap, ALLOW, CONTENT_LENGTH, CONTENT_TYPE};
11use hyper::http::{self, HeaderValue};
12use hyper::service::Service;
13use hyper::{Body, Client, Method, Request, Response, StatusCode, Uri, Version};
14use prost::{DecodeError, EncodeError, Message};
15
16/// The type of every service response
17pub type PTRes<O> =
18    Pin<Box<dyn Future<Output = Result<ServiceResponse<O>, ProstTwirpError>> + Send + 'static>>;
19
20static JSON_CONTENT_TYPE: &str = "application/json";
21static PROTOBUF_CONTENT_TYPE: &str = "application/protobuf";
22
23/// A request with HTTP info and a proto request payload object.
24#[derive(Debug)]
25pub struct ServiceRequest<T: Message> {
26    /// The URI of the original request
27    ///
28    /// When using a client, this will be overridden with the proper URI. It is only valuable for servers.
29    pub uri: Uri,
30    /// The request method; should always be `POST`.
31    pub method: Method,
32    /// The HTTP version, rarely changed from the default.
33    pub version: Version,
34    /// The set of headers
35    ///
36    /// Should always at least have `Content-Type`. Clients will override `Content-Length` on serialization.
37    pub headers: HeaderMap,
38    /// The request body as a proto `Message`, representing the arguments of the proto rpc.
39    pub input: T,
40}
41
42impl<T: Message> ServiceRequest<T> {
43    /// Create new service request with the given input object
44    ///
45    /// This automatically sets the `Content-Type` header as `application/protobuf`.
46    pub fn new(input: T) -> ServiceRequest<T> {
47        let mut headers = HeaderMap::new();
48        headers.insert(
49            CONTENT_TYPE,
50            HeaderValue::from_static(PROTOBUF_CONTENT_TYPE),
51        );
52        ServiceRequest {
53            uri: Default::default(),
54            method: Method::POST,
55            version: Version::default(),
56            headers,
57            input,
58        }
59    }
60
61    /// Copy this request with a different input value
62    pub fn clone_with_input(&self, input: T) -> ServiceRequest<T> {
63        ServiceRequest {
64            uri: self.uri.clone(),
65            method: self.method.clone(),
66            version: self.version,
67            headers: self.headers.clone(),
68            input,
69        }
70    }
71}
72
73impl<T: Message + Default + 'static> From<T> for ServiceRequest<T> {
74    fn from(v: T) -> ServiceRequest<T> {
75        ServiceRequest::new(v)
76    }
77}
78
79impl<T: Message + Default + 'static> ServiceRequest<T> {
80    /// Serialize into a hyper request.
81    pub fn to_hyper_request(&self) -> Result<Request<Body>, ProstTwirpError> {
82        let mut body = Vec::new();
83        self.input
84            .encode(&mut body)
85            .map_err(ProstTwirpError::ProstEncodeError)?;
86        let mut builder = Request::post(self.uri.clone());
87        builder.headers_mut().unwrap().clone_from(&self.headers);
88        builder
89            .header(CONTENT_LENGTH, body.len() as u64)
90            .body(Body::from(body))
91            .map_err(ProstTwirpError::from)
92    }
93
94    pub async fn from_hyper_request(
95        req: Request<Body>,
96    ) -> Result<ServiceRequest<T>, ProstTwirpError> {
97        if req.method() != Method::POST {
98            return Err(ProstTwirpError::InvalidMethod);
99        } else if req
100            .headers()
101            .get(CONTENT_TYPE)
102            .map_or(true, |v| v != PROTOBUF_CONTENT_TYPE)
103        {
104            return Err(ProstTwirpError::InvalidContentType);
105        }
106        let uri = req.uri().clone();
107        let method = req.method().clone();
108        let version = req.version();
109        let headers = req.headers().clone();
110        let body_bytes = hyper::body::to_bytes(req.into_body()).await?;
111        match T::decode(body_bytes.clone()) {
112            Ok(input) => Ok(ServiceRequest {
113                uri,
114                method,
115                version,
116                headers,
117                input,
118            }),
119            Err(err) => Err(ProstTwirpError::AfterBodyError {
120                status: None,
121                method: Some(method),
122                version,
123                headers,
124                err: Box::new(ProstTwirpError::ProstDecodeError(err)),
125                body: body_bytes.to_vec(),
126            }),
127        }
128    }
129}
130
131/// A response with HTTP info and the output object as a protobuf [Message].
132#[derive(Debug)]
133pub struct ServiceResponse<M: Message> {
134    /// The HTTP version
135    pub version: Version,
136    /// The set of headers
137    ///
138    /// Should always at least have `Content-Type`. Servers will override `Content-Length` on serialization.
139    pub headers: HeaderMap,
140    /// The status code
141    pub status: StatusCode,
142    /// The output object
143    pub output: M,
144}
145
146impl<M: Message> ServiceResponse<M> {
147    /// Create new service request with the given input object
148    ///
149    /// This automatically sets the `Content-Type` header as `application/protobuf`.
150    pub fn new(output: M) -> ServiceResponse<M> {
151        let mut headers = HeaderMap::new();
152        headers.insert(
153            CONTENT_TYPE,
154            HeaderValue::from_static(PROTOBUF_CONTENT_TYPE),
155        );
156        ServiceResponse {
157            version: Version::default(),
158            headers,
159            status: StatusCode::OK,
160            output,
161        }
162    }
163
164    /// Copy this response with a different output value
165    pub fn clone_with_output(&self, output: M) -> ServiceResponse<M> {
166        ServiceResponse {
167            version: self.version,
168            headers: self.headers.clone(),
169            status: self.status,
170            output,
171        }
172    }
173}
174
175impl<M: Message + Default + 'static> From<M> for ServiceResponse<M> {
176    fn from(v: M) -> ServiceResponse<M> {
177        ServiceResponse::new(v)
178    }
179}
180
181impl<M: Message + Default> ServiceResponse<M> {
182    /// Deserialze an object response from a hyper response.
183    pub async fn from_hyper_response(resp: Response<Body>) -> Result<Self, ProstTwirpError> {
184        let version = resp.version();
185        let headers = resp.headers().clone();
186        let status = resp.status();
187        let body_bytes = hyper::body::to_bytes(resp.into_body()).await?;
188        let err = if status.is_success() {
189            match M::decode(&*body_bytes) {
190                Ok(output) => {
191                    return Ok(ServiceResponse {
192                        version,
193                        headers,
194                        status,
195                        output,
196                    })
197                }
198                Err(err) => ProstTwirpError::ProstDecodeError(err),
199            }
200        } else {
201            match TwirpError::from_json_bytes(status, &body_bytes) {
202                Ok(err) => ProstTwirpError::TwirpError(err),
203                Err(err) => ProstTwirpError::JsonDecodeError(err),
204            }
205        };
206        Err(ProstTwirpError::AfterBodyError {
207            body: body_bytes.to_vec(),
208            method: None,
209            version,
210            headers,
211            status: Some(status),
212            err: Box::new(err),
213        })
214    }
215
216    /// Serialize an object response into a hyper response.
217    pub fn to_hyper_response(&self) -> Result<Response<Body>, ProstTwirpError> {
218        let body_bytes = self.output.encode_to_vec();
219        let mut builder = Response::builder().status(self.status);
220        builder.headers_mut().unwrap().clone_from(&self.headers);
221        builder
222            .header(CONTENT_LENGTH, body_bytes.len() as u64)
223            .body(body_bytes.into())
224            .map_err(ProstTwirpError::from)
225    }
226}
227
228/// A JSON-serializable Twirp error
229#[derive(Debug, Clone)]
230pub struct TwirpError {
231    pub status: StatusCode,
232    pub error_type: String,
233    pub msg: String,
234    pub meta: Option<serde_json::Value>,
235}
236
237impl TwirpError {
238    /// Create a Twirp error with no meta
239    pub fn new(status: StatusCode, error_type: &str, msg: &str) -> TwirpError {
240        TwirpError::new_meta(status, error_type, msg, None)
241    }
242
243    /// Create a Twirp error with optional meta
244    pub fn new_meta(
245        status: StatusCode,
246        error_type: &str,
247        msg: &str,
248        meta: Option<serde_json::Value>,
249    ) -> TwirpError {
250        TwirpError {
251            status,
252            error_type: error_type.to_string(),
253            msg: msg.to_string(),
254            meta,
255        }
256    }
257
258    /// Create a hyper response for this error and the given status code
259    pub fn to_hyper_response(&self) -> Response<Body> {
260        let body_bytes = self
261            .to_json_bytes()
262            .unwrap_or_else(|_| "{}".as_bytes().to_vec());
263        let body_len = body_bytes.len() as u64;
264        Response::builder()
265            .status(self.status)
266            .header(CONTENT_TYPE, JSON_CONTENT_TYPE)
267            .header(CONTENT_LENGTH, HeaderValue::from(body_len))
268            .header(ALLOW, HeaderValue::from_static("POST"))
269            .body(Body::from(body_bytes))
270            .expect("failed to serialize twirp error")
271        // The potential panic here is not desirable but it seems highly
272        // unlikely that we fail to serialize a body from a simple string
273        // like this.
274    }
275
276    /// Create error from Serde JSON value
277    pub fn from_json(status: StatusCode, json: serde_json::Value) -> TwirpError {
278        let error_type = json["error_type"].as_str();
279        TwirpError {
280            status,
281            error_type: error_type.unwrap_or("<no code>").to_string(),
282            msg: json["msg"].as_str().unwrap_or("<no message>").to_string(),
283            // Put the whole thing as meta if there was no type
284            meta: if error_type.is_some() {
285                json.get("meta").cloned()
286            } else {
287                Some(json.clone())
288            },
289        }
290    }
291
292    /// Create error from byte array
293    pub fn from_json_bytes(status: StatusCode, json: &[u8]) -> serde_json::Result<TwirpError> {
294        serde_json::from_slice(json).map(|v| TwirpError::from_json(status, v))
295    }
296
297    /// Create Serde JSON value from error
298    pub fn to_json(&self) -> serde_json::Value {
299        let mut props = serde_json::map::Map::new();
300        props.insert(
301            "error_type".to_string(),
302            serde_json::Value::String(self.error_type.clone()),
303        );
304        props.insert(
305            "msg".to_string(),
306            serde_json::Value::String(self.msg.clone()),
307        );
308        if let Some(ref meta) = self.meta {
309            props.insert("meta".to_string(), meta.clone());
310        }
311        serde_json::Value::Object(props)
312    }
313
314    /// Create byte array from error
315    pub fn to_json_bytes(&self) -> serde_json::Result<Vec<u8>> {
316        serde_json::to_vec(&self.to_json())
317    }
318}
319
320impl Error for TwirpError {}
321
322impl Display for TwirpError {
323    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
324        write!(f, "{:?} {}: {}", self.status, self.error_type, self.msg)
325    }
326}
327
328impl From<TwirpError> for ProstTwirpError {
329    fn from(v: TwirpError) -> ProstTwirpError {
330        ProstTwirpError::TwirpError(v)
331    }
332}
333
334/// An error that can occur during a call to a Twirp service
335#[derive(Debug)]
336#[non_exhaustive]
337pub enum ProstTwirpError {
338    /// A standard Twirp error with a type, message, and some metadata
339    TwirpError(TwirpError),
340    /// An error when trying to decode JSON into an error or object
341    JsonDecodeError(serde_json::Error),
342    /// An error when trying to encode a protobuf object
343    ProstEncodeError(EncodeError),
344    /// An error when trying to decode a protobuf object
345    ProstDecodeError(DecodeError),
346    /// A generic hyper error
347    HyperError(hyper::Error),
348    /// A HTTP protocol error
349    HttpError(http::Error),
350    /// An invalid URI.
351    InvalidUri(http::uri::InvalidUri),
352    /// The HTTP Method was not `POST`.
353    InvalidMethod,
354    /// The request content type was not `application/protobuf`.
355    InvalidContentType,
356    /// No matching method was found for the request.
357    NotFound,
358    /// A wrapper for any of the other `ProstTwirpError`s that also includes request/response info
359    AfterBodyError {
360        /// The request or response's raw body before the error happened
361        body: Vec<u8>,
362        /// The request method, only present for server errors
363        method: Option<Method>,
364        /// The request or response's HTTP version
365        version: Version,
366        /// The request or response's headers
367        headers: HeaderMap,
368        /// The response status, only present for client errors
369        status: Option<StatusCode>,
370        /// The underlying error
371        err: Box<ProstTwirpError>,
372    },
373}
374
375impl ProstTwirpError {
376    /// This same error, or the underlying error if it is an `AfterBodyError`
377    pub fn root_err(self) -> ProstTwirpError {
378        match self {
379            ProstTwirpError::AfterBodyError { err, .. } => err.root_err(),
380            _ => self,
381        }
382    }
383
384    pub fn into_hyper_response(self) -> Result<Response<Body>, hyper::Error> {
385        let external_err = match self {
386            ProstTwirpError::TwirpError(err) => err,
387            // Just propagate hyper errors
388            ProstTwirpError::HyperError(err) => return Err(err),
389            ProstTwirpError::InvalidMethod => TwirpError::new(
390                StatusCode::METHOD_NOT_ALLOWED,
391                "bad_method",
392                "Method must be POST",
393            ),
394            ProstTwirpError::ProstDecodeError(_) => TwirpError::new(
395                StatusCode::BAD_REQUEST,
396                "protobuf_decode_err",
397                "Invalid protobuf body",
398            ),
399            ProstTwirpError::InvalidContentType => TwirpError::new(
400                StatusCode::UNSUPPORTED_MEDIA_TYPE,
401                "bad_content_type",
402                "Content type must be application/protobuf",
403            ),
404            ProstTwirpError::NotFound => TwirpError::new(
405                StatusCode::NOT_FOUND,
406                "not_found",
407                "The requested method was not found",
408            ),
409            _ => TwirpError::new(
410                StatusCode::INTERNAL_SERVER_ERROR,
411                "internal_err",
412                "Internal error",
413            ),
414        };
415        Ok(external_err.to_hyper_response())
416    }
417}
418
419impl From<hyper::Error> for ProstTwirpError {
420    fn from(v: hyper::Error) -> ProstTwirpError {
421        ProstTwirpError::HyperError(v)
422    }
423}
424
425impl From<http::Error> for ProstTwirpError {
426    fn from(v: http::Error) -> ProstTwirpError {
427        ProstTwirpError::HttpError(v)
428    }
429}
430
431impl Display for ProstTwirpError {
432    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
433        write!(f, "{:?}", self)
434    }
435}
436
437impl Error for ProstTwirpError {
438    fn source(&self) -> Option<&(dyn Error + 'static)> {
439        match self {
440            ProstTwirpError::TwirpError(err) => Some(err),
441            ProstTwirpError::JsonDecodeError(err) => Some(err),
442            ProstTwirpError::ProstEncodeError(err) => Some(err),
443            ProstTwirpError::ProstDecodeError(err) => Some(err),
444            ProstTwirpError::HyperError(err) => Some(err),
445            ProstTwirpError::HttpError(err) => Some(err),
446            ProstTwirpError::InvalidUri(err) => Some(err),
447            ProstTwirpError::InvalidMethod => None,
448            ProstTwirpError::InvalidContentType => None,
449            ProstTwirpError::NotFound => None,
450            ProstTwirpError::AfterBodyError { err, .. } => Some(err),
451        }
452    }
453}
454
455/// A wrapper for a hyper client
456#[derive(Debug)]
457pub struct HyperClient {
458    /// The hyper client
459    pub client: Client<HttpConnector>,
460    /// The root URL without any path attached
461    pub root_url: String,
462}
463
464impl HyperClient {
465    /// Create a new client wrapper for the given client and root using protobuf
466    pub fn new(client: Client<HttpConnector>, root_url: &str) -> HyperClient {
467        HyperClient {
468            client,
469            root_url: root_url.trim_end_matches('/').to_string(),
470        }
471    }
472
473    /// Invoke the given request for the given path and return a boxed future result
474    pub fn go<I, O>(&self, path: &str, req: ServiceRequest<I>) -> PTRes<O>
475    where
476        I: Message + Default + 'static,
477        O: Message + Default + 'static,
478    {
479        // Build the URI
480        let uri = match format!("{}/{}", self.root_url, path.trim_start_matches('/')).parse() {
481            Err(err) => return Box::pin(ready(Err(ProstTwirpError::InvalidUri(err)))),
482            Ok(v) => v,
483        };
484        // Build the request
485        let mut hyper_req = match req.to_hyper_request() {
486            Err(err) => return Box::pin(ready(Err(err))),
487            Ok(v) => v,
488        };
489        *hyper_req.uri_mut() = uri;
490        // Run the request and map the response
491        Box::pin(
492            self.client
493                .request(hyper_req)
494                .map_err(ProstTwirpError::HyperError)
495                .and_then(ServiceResponse::from_hyper_response),
496        )
497    }
498}
499
500/// A trait for the heart of a Twirp service: responding to every service method.
501///
502/// Implementations are responsible for:
503///
504/// 1. Matching the URL to a service method (or returning a 404).
505/// 2. Decoding the request body into a protobuf message, typically
506///    using [ServiceRequest::from_hyper_request] for the appropriate
507///    message.
508/// 3. Calling the application logic to handle the request.
509/// 4. Encoding the response into a protobuf message, typically using
510///    [ServiceResponse::to_hyper_response].
511///
512/// An implementation of this trait is generated by the `service-gen`
513/// integration, or it can be implemented manually.
514pub trait HyperService {
515    /// Accept a raw service request and return a boxed future of a raw service response
516    fn handle(
517        &self,
518        req: Request<Body>,
519    ) -> Pin<Box<dyn Future<Output = Result<Response<Body>, ProstTwirpError>> + Send>>;
520}
521
522/// A wrapper for a [HyperService] trait that keeps a [Arc] version of the
523/// service.
524///
525/// This layer checkcs preconditions of the request (the method and content
526/// type) and translates any errors into the Twirp json format.
527///
528/// TODO: Perhaps a clearer name indicating this is a layer?
529///
530/// TODO: Perhaps change to a Tower `Layer`, although that would require
531/// another dependency on `tower_layer`.
532pub struct HyperServer<T: HyperService + Send + Sync + 'static> {
533    /// The `Arc` version of the service
534    ///
535    /// Needed because of [hyper Service lifetimes](https://github.com/tokio-rs/tokio-service/issues/9)
536    pub service: Arc<T>,
537}
538
539impl<T: HyperService + Send + Sync + 'static> HyperServer<T> {
540    /// Create a new service wrapper for the given impl
541    pub fn new(service: T) -> HyperServer<T> {
542        HyperServer {
543            service: Arc::new(service),
544        }
545    }
546}
547
548impl<T: 'static + HyperService + Send + Sync> Service<Request<Body>> for HyperServer<T> {
549    type Response = Response<Body>;
550    type Error = hyper::Error;
551    type Future = Pin<Box<dyn (Future<Output = Result<Self::Response, Self::Error>>) + Send>>;
552
553    fn poll_ready(&mut self, _context: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
554        Poll::Ready(Ok(()))
555    }
556
557    fn call(&mut self, req: Request<Body>) -> Self::Future {
558        // Ug: https://github.com/tokio-rs/tokio-service/issues/9
559        let service = self.service.clone();
560        Box::pin(
561            service
562                .handle(req)
563                .or_else(|err| future::ready(err.into_hyper_response())),
564        )
565    }
566}