livekit_api/services/
twirp_client.rs

1// Copyright 2023 LiveKit, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use http::{
18    header::{HeaderMap, HeaderValue, CONTENT_TYPE},
19    StatusCode,
20};
21use serde::Deserialize;
22use thiserror::Error;
23
24use crate::http_client;
25
26pub const DEFAULT_PREFIX: &str = "/twirp";
27
28#[derive(Debug, Error)]
29pub enum TwirpError {
30    #[cfg(feature = "services-tokio")]
31    #[error("failed to execute the request: {0}")]
32    Request(#[from] reqwest::Error),
33    #[cfg(feature = "services-async")]
34    #[error("failed to execute the request: {0}")]
35    Request(#[from] std::io::Error),
36    #[error("twirp error: {0}")]
37    Twirp(TwirpErrorCode),
38    #[error("url error: {0}")]
39    Url(#[from] url::ParseError),
40    #[error("prost error: {0}")]
41    Prost(#[from] prost::DecodeError),
42}
43
44#[derive(Debug, Deserialize)]
45pub struct TwirpErrorCode {
46    pub code: String,
47    pub msg: String,
48}
49
50impl TwirpErrorCode {
51    pub const CANCELED: &'static str = "canceled";
52    pub const UNKNOWN: &'static str = "unknown";
53    pub const INVALID_ARGUMENT: &'static str = "invalid_argument";
54    pub const MALFORMED: &'static str = "malformed";
55    pub const DEADLINE_EXCEEDED: &'static str = "deadline_exceeded";
56    pub const NOT_FOUND: &'static str = "not_found";
57    pub const BAD_ROUTE: &'static str = "bad_route";
58    pub const ALREADY_EXISTS: &'static str = "already_exists";
59    pub const PERMISSION_DENIED: &'static str = "permission_denied";
60    pub const UNAUTHENTICATED: &'static str = "unauthenticated";
61    pub const RESOURCE_EXHAUSTED: &'static str = "resource_exhausted";
62    pub const FAILED_PRECONDITION: &'static str = "failed_precondition";
63    pub const ABORTED: &'static str = "aborted";
64    pub const OUT_OF_RANGE: &'static str = "out_of_range";
65    pub const UNIMPLEMENTED: &'static str = "unimplemented";
66    pub const INTERNAL: &'static str = "internal";
67    pub const UNAVAILABLE: &'static str = "unavailable";
68    pub const DATA_LOSS: &'static str = "dataloss";
69}
70
71impl Display for TwirpErrorCode {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        write!(f, "{}: {}", self.code, self.msg)
74    }
75}
76
77pub type TwirpResult<T> = Result<T, TwirpError>;
78
79#[derive(Debug)]
80pub struct TwirpClient {
81    host: String,
82    pkg: String,
83    prefix: String,
84    client: http_client::Client,
85}
86
87impl TwirpClient {
88    pub fn new(host: &str, pkg: &str, prefix: Option<&str>) -> Self {
89        Self {
90            host: host.to_owned(),
91            pkg: pkg.to_owned(),
92            prefix: prefix.unwrap_or(DEFAULT_PREFIX).to_owned(),
93            client: http_client::Client::new(),
94        }
95    }
96
97    pub async fn request<D: prost::Message, R: prost::Message + Default>(
98        &self,
99        service: &str,
100        method: &str,
101        data: D,
102        mut headers: HeaderMap,
103    ) -> TwirpResult<R> {
104        let mut url = url::Url::parse(&self.host)?;
105
106        if let Ok(mut segs) = url.path_segments_mut() {
107            segs.push(&format!("{}/{}.{}/{}", self.prefix, self.pkg, service, method));
108        }
109
110        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/protobuf"));
111
112        let resp = self.client.post(url).headers(headers).body(data.encode_to_vec()).send().await?;
113
114        if resp.status() == StatusCode::OK {
115            Ok(R::decode(resp.bytes().await?)?)
116        } else {
117            let err: TwirpErrorCode = resp.json().await?;
118            Err(TwirpError::Twirp(err))
119        }
120    }
121}