connect_rpc/
request.rs

1use std::{borrow::Cow, collections::HashMap, time::Duration};
2
3use base64::prelude::{Engine, BASE64_URL_SAFE_NO_PAD};
4use http::{
5    header,
6    uri::{Authority, Scheme},
7    HeaderMap, Method, Uri,
8};
9
10use crate::{
11    common::{
12        streaming_message_codec, unary_message_codec, CONNECT_ACCEPT_ENCODING,
13        CONNECT_CONTENT_ENCODING, CONNECT_PROTOCOL_VERSION, CONNECT_TIMEOUT_MS, PROTOCOL_VERSION_1,
14        STREAMING_CONTENT_TYPE_PREFIX,
15    },
16    metadata::Metadata,
17    Error,
18};
19
20pub mod builder;
21
22/// A Connect request.
23pub trait ConnectRequest {
24    /// Returns the connect protocol version.
25    fn connect_protocol_version(&self) -> Option<&str>;
26
27    /// Returns the URI scheme.
28    fn scheme(&self) -> Option<&Scheme>;
29
30    /// Returns the URI authority.
31    fn authority(&self) -> Option<&Authority>;
32
33    /// Returns the URI path.
34    fn path(&self) -> &str;
35
36    /// Splits a protobuf RPC request path into routing prefix, service name,
37    /// and method name.
38    ///
39    /// Returns `None` if the request path does not contain a `/`.
40    fn protobuf_rpc_parts(&self) -> Option<(&str, &str, &str)> {
41        let (prefix, method) = self.path().rsplit_once('/')?;
42        let (routing_prefix, service) = prefix.rsplit_once('/')?;
43        Some((routing_prefix, service, method))
44    }
45
46    /// Returns the message codec.
47    fn message_codec(&self) -> Result<&str, Error>;
48
49    /// Returns the timeout.
50    fn timeout(&self) -> Option<Duration>;
51
52    /// Returns the content encoding (e.g. compression).
53    fn content_encoding(&self) -> Option<&str>;
54
55    /// Returns the accept encoding(s).
56    fn accept_encoding(&self) -> impl Iterator<Item = &str>;
57
58    /// Returns the metadata.
59    fn metadata(&self) -> &impl Metadata;
60
61    /// Validates the request.
62    fn validate(&self) -> Result<(), Error>;
63}
64
65/// Connect request types.
66pub enum ConnectRequestType<T> {
67    Unary(UnaryRequest<T>),
68    Streaming(StreamingRequest<T>),
69    UnaryGet(UnaryGetRequest),
70}
71
72impl<T> ConnectRequestType<T> {
73    pub fn from_http(req: http::Request<T>) -> Self {
74        if req.method() == Method::GET {
75            Self::UnaryGet(req.map(|_| ()).into())
76        } else if req.headers().get(header::CONTENT_TYPE).is_some_and(|ct| {
77            ct.to_str()
78                .unwrap_or_default()
79                .starts_with(STREAMING_CONTENT_TYPE_PREFIX)
80        }) {
81            Self::Streaming(req.into())
82        } else {
83            Self::Unary(req.into())
84        }
85    }
86}
87
88/// A [`ConnectRequest`] backed by an [`http::Request`]
89trait HttpConnectRequest {
90    fn http_uri(&self) -> &Uri;
91
92    fn http_headers(&self) -> &HeaderMap;
93
94    fn http_message_codec(&self) -> Result<&str, Error>;
95
96    fn http_connect_protocol_version(&self) -> Option<&str> {
97        self.http_headers()
98            .get(CONNECT_PROTOCOL_VERSION)?
99            .to_str()
100            .ok()
101    }
102
103    fn http_content_encoding(&self) -> Option<&str>;
104
105    fn http_accept_encoding(&self) -> impl Iterator<Item = &str> {
106        self.http_headers()
107            .get_all(header::ACCEPT_ENCODING)
108            .into_iter()
109            .filter_map(|val| val.to_str().ok())
110    }
111
112    fn http_validate(&self) -> Result<(), Error>
113    where
114        Self: Sized,
115    {
116        validate_request(self)
117    }
118}
119
120fn validate_request(req: &impl HttpConnectRequest) -> Result<(), Error> {
121    match req.http_connect_protocol_version() {
122        None => (),
123        Some(ver) if ver == PROTOCOL_VERSION_1 => (),
124        Some(ver) => {
125            return Err(Error::InvalidRequest(format!(
126                "unknown connect-protocol-version {ver:?}"
127            )));
128        }
129    }
130    let _ = req.http_message_codec()?;
131    Ok(())
132}
133
134impl<T: HttpConnectRequest> ConnectRequest for T {
135    fn connect_protocol_version(&self) -> Option<&str> {
136        HttpConnectRequest::http_connect_protocol_version(self)
137    }
138
139    fn scheme(&self) -> Option<&Scheme> {
140        self.http_uri().scheme()
141    }
142
143    fn authority(&self) -> Option<&Authority> {
144        self.http_uri().authority()
145    }
146
147    fn path(&self) -> &str {
148        self.http_uri().path()
149    }
150
151    fn message_codec(&self) -> Result<&str, Error> {
152        self.http_message_codec()
153    }
154
155    fn timeout(&self) -> Option<Duration> {
156        let timeout_ms: u64 = self
157            .http_headers()
158            .get(CONNECT_TIMEOUT_MS)?
159            .to_str()
160            .ok()?
161            .parse()
162            .ok()?;
163        Some(Duration::from_millis(timeout_ms))
164    }
165
166    fn content_encoding(&self) -> Option<&str> {
167        self.http_content_encoding()
168    }
169
170    fn accept_encoding(&self) -> impl Iterator<Item = &str> {
171        self.http_accept_encoding()
172    }
173
174    fn metadata(&self) -> &impl Metadata {
175        self.http_headers()
176    }
177
178    fn validate(&self) -> Result<(), Error> {
179        self.http_validate()
180    }
181}
182
183/// A Connect unary request.
184pub struct UnaryRequest<T>(http::Request<T>);
185
186impl<T> HttpConnectRequest for UnaryRequest<T> {
187    fn http_uri(&self) -> &Uri {
188        self.0.uri()
189    }
190
191    fn http_headers(&self) -> &HeaderMap {
192        self.0.headers()
193    }
194
195    fn http_message_codec(&self) -> Result<&str, Error> {
196        unary_message_codec(self.http_headers())
197    }
198
199    fn http_content_encoding(&self) -> Option<&str> {
200        self.http_headers()
201            .get(header::CONTENT_ENCODING)?
202            .to_str()
203            .ok()
204    }
205}
206
207impl<T> From<http::Request<T>> for UnaryRequest<T> {
208    fn from(req: http::Request<T>) -> Self {
209        Self(req)
210    }
211}
212
213impl<T> From<UnaryRequest<T>> for http::Request<T> {
214    fn from(req: UnaryRequest<T>) -> Self {
215        req.0
216    }
217}
218
219/// A Connect streaming request.
220pub struct StreamingRequest<T>(http::Request<T>);
221
222impl<T> HttpConnectRequest for StreamingRequest<T> {
223    fn http_uri(&self) -> &Uri {
224        self.0.uri()
225    }
226
227    fn http_headers(&self) -> &HeaderMap {
228        self.0.headers()
229    }
230
231    fn http_message_codec(&self) -> Result<&str, Error> {
232        streaming_message_codec(self.http_headers())
233    }
234
235    fn http_content_encoding(&self) -> Option<&str> {
236        self.http_headers()
237            .get(CONNECT_CONTENT_ENCODING)?
238            .to_str()
239            .ok()
240    }
241
242    fn http_accept_encoding(&self) -> impl Iterator<Item = &str> {
243        self.http_headers()
244            .get_all(CONNECT_ACCEPT_ENCODING)
245            .into_iter()
246            .filter_map(|val| val.to_str().ok())
247    }
248}
249
250impl<T> From<http::Request<T>> for StreamingRequest<T> {
251    fn from(req: http::Request<T>) -> Self {
252        Self(req)
253    }
254}
255
256impl<T> From<StreamingRequest<T>> for http::Request<T> {
257    fn from(req: StreamingRequest<T>) -> Self {
258        req.0
259    }
260}
261
262/// A Connect unary GET request.
263pub struct UnaryGetRequest {
264    inner: http::Request<()>,
265    query: HashMap<String, String>,
266}
267
268impl UnaryGetRequest {
269    pub fn message(&self) -> Result<Cow<[u8]>, Error> {
270        let message = self
271            .query
272            .get("message")
273            .ok_or(Error::invalid_request("missing message"))?;
274        let is_b64 = self.query.get("base64").map(|s| s.as_str()) == Some("1");
275        if is_b64 {
276            Ok(BASE64_URL_SAFE_NO_PAD.decode(message)?.into())
277        } else {
278            Ok(
279                match percent_encoding::percent_decode_str(message)
280                    .decode_utf8()
281                    .map_err(|_| Error::invalid_request("message not valid utf8"))?
282                {
283                    Cow::Borrowed(s) => s.as_bytes().into(),
284                    Cow::Owned(s) => s.into_bytes().into(),
285                },
286            )
287        }
288    }
289}
290
291impl HttpConnectRequest for UnaryGetRequest {
292    fn http_uri(&self) -> &Uri {
293        self.inner.uri()
294    }
295
296    fn http_headers(&self) -> &HeaderMap {
297        self.inner.headers()
298    }
299
300    fn http_message_codec(&self) -> Result<&str, Error> {
301        self.query
302            .get("encoding")
303            .map(|s| s.as_str())
304            .ok_or(Error::invalid_request("missing 'encoding' param"))
305    }
306
307    fn http_connect_protocol_version(&self) -> Option<&str> {
308        self.query.get("connect")?.strip_prefix("v")
309    }
310
311    fn http_content_encoding(&self) -> Option<&str> {
312        self.query.get("encoding").map(|s| s.as_str())
313    }
314
315    fn http_validate(&self) -> Result<(), Error>
316    where
317        Self: Sized,
318    {
319        validate_request(self)?;
320        if !self.query.contains_key("message") {
321            return Err(Error::invalid_request("missing 'message' param"));
322        }
323        Ok(())
324    }
325}
326
327impl From<http::Request<()>> for UnaryGetRequest {
328    fn from(req: http::Request<()>) -> Self {
329        let query: HashMap<_, _> =
330            form_urlencoded::parse(req.uri().query().unwrap_or_default().as_bytes())
331                .map(|(k, v)| (k.to_string(), v.to_string()))
332                .collect();
333        Self { inner: req, query }
334    }
335}
336
337impl From<UnaryGetRequest> for http::Request<()> {
338    fn from(req: UnaryGetRequest) -> Self {
339        req.inner
340    }
341}