1use std::{borrow::Cow, collections::HashMap, time::Duration};
2
3use base64::prelude::{Engine, BASE64_URL_SAFE_NO_PAD};
4use http::{
5 header,
6 uri::{Authority, Scheme},
7 HeaderMap, Method, Uri,
8};
9
10use crate::{
11 common::{
12 streaming_message_codec, unary_message_codec, CONNECT_ACCEPT_ENCODING,
13 CONNECT_CONTENT_ENCODING, CONNECT_PROTOCOL_VERSION, CONNECT_TIMEOUT_MS, PROTOCOL_VERSION_1,
14 STREAMING_CONTENT_TYPE_PREFIX,
15 },
16 metadata::Metadata,
17 Error,
18};
19
20pub mod builder;
21
22pub trait ConnectRequest {
24 fn connect_protocol_version(&self) -> Option<&str>;
26
27 fn scheme(&self) -> Option<&Scheme>;
29
30 fn authority(&self) -> Option<&Authority>;
32
33 fn path(&self) -> &str;
35
36 fn protobuf_rpc_parts(&self) -> Option<(&str, &str, &str)> {
41 let (prefix, method) = self.path().rsplit_once('/')?;
42 let (routing_prefix, service) = prefix.rsplit_once('/')?;
43 Some((routing_prefix, service, method))
44 }
45
46 fn message_codec(&self) -> Result<&str, Error>;
48
49 fn timeout(&self) -> Option<Duration>;
51
52 fn content_encoding(&self) -> Option<&str>;
54
55 fn accept_encoding(&self) -> impl Iterator<Item = &str>;
57
58 fn metadata(&self) -> &impl Metadata;
60
61 fn validate(&self) -> Result<(), Error>;
63}
64
65pub enum ConnectRequestType<T> {
67 Unary(UnaryRequest<T>),
68 Streaming(StreamingRequest<T>),
69 UnaryGet(UnaryGetRequest),
70}
71
72impl<T> ConnectRequestType<T> {
73 pub fn from_http(req: http::Request<T>) -> Self {
74 if req.method() == Method::GET {
75 Self::UnaryGet(req.map(|_| ()).into())
76 } else if req.headers().get(header::CONTENT_TYPE).is_some_and(|ct| {
77 ct.to_str()
78 .unwrap_or_default()
79 .starts_with(STREAMING_CONTENT_TYPE_PREFIX)
80 }) {
81 Self::Streaming(req.into())
82 } else {
83 Self::Unary(req.into())
84 }
85 }
86}
87
88trait HttpConnectRequest {
90 fn http_uri(&self) -> &Uri;
91
92 fn http_headers(&self) -> &HeaderMap;
93
94 fn http_message_codec(&self) -> Result<&str, Error>;
95
96 fn http_connect_protocol_version(&self) -> Option<&str> {
97 self.http_headers()
98 .get(CONNECT_PROTOCOL_VERSION)?
99 .to_str()
100 .ok()
101 }
102
103 fn http_content_encoding(&self) -> Option<&str>;
104
105 fn http_accept_encoding(&self) -> impl Iterator<Item = &str> {
106 self.http_headers()
107 .get_all(header::ACCEPT_ENCODING)
108 .into_iter()
109 .filter_map(|val| val.to_str().ok())
110 }
111
112 fn http_validate(&self) -> Result<(), Error>
113 where
114 Self: Sized,
115 {
116 validate_request(self)
117 }
118}
119
120fn validate_request(req: &impl HttpConnectRequest) -> Result<(), Error> {
121 match req.http_connect_protocol_version() {
122 None => (),
123 Some(ver) if ver == PROTOCOL_VERSION_1 => (),
124 Some(ver) => {
125 return Err(Error::InvalidRequest(format!(
126 "unknown connect-protocol-version {ver:?}"
127 )));
128 }
129 }
130 let _ = req.http_message_codec()?;
131 Ok(())
132}
133
134impl<T: HttpConnectRequest> ConnectRequest for T {
135 fn connect_protocol_version(&self) -> Option<&str> {
136 HttpConnectRequest::http_connect_protocol_version(self)
137 }
138
139 fn scheme(&self) -> Option<&Scheme> {
140 self.http_uri().scheme()
141 }
142
143 fn authority(&self) -> Option<&Authority> {
144 self.http_uri().authority()
145 }
146
147 fn path(&self) -> &str {
148 self.http_uri().path()
149 }
150
151 fn message_codec(&self) -> Result<&str, Error> {
152 self.http_message_codec()
153 }
154
155 fn timeout(&self) -> Option<Duration> {
156 let timeout_ms: u64 = self
157 .http_headers()
158 .get(CONNECT_TIMEOUT_MS)?
159 .to_str()
160 .ok()?
161 .parse()
162 .ok()?;
163 Some(Duration::from_millis(timeout_ms))
164 }
165
166 fn content_encoding(&self) -> Option<&str> {
167 self.http_content_encoding()
168 }
169
170 fn accept_encoding(&self) -> impl Iterator<Item = &str> {
171 self.http_accept_encoding()
172 }
173
174 fn metadata(&self) -> &impl Metadata {
175 self.http_headers()
176 }
177
178 fn validate(&self) -> Result<(), Error> {
179 self.http_validate()
180 }
181}
182
183pub struct UnaryRequest<T>(http::Request<T>);
185
186impl<T> HttpConnectRequest for UnaryRequest<T> {
187 fn http_uri(&self) -> &Uri {
188 self.0.uri()
189 }
190
191 fn http_headers(&self) -> &HeaderMap {
192 self.0.headers()
193 }
194
195 fn http_message_codec(&self) -> Result<&str, Error> {
196 unary_message_codec(self.http_headers())
197 }
198
199 fn http_content_encoding(&self) -> Option<&str> {
200 self.http_headers()
201 .get(header::CONTENT_ENCODING)?
202 .to_str()
203 .ok()
204 }
205}
206
207impl<T> From<http::Request<T>> for UnaryRequest<T> {
208 fn from(req: http::Request<T>) -> Self {
209 Self(req)
210 }
211}
212
213impl<T> From<UnaryRequest<T>> for http::Request<T> {
214 fn from(req: UnaryRequest<T>) -> Self {
215 req.0
216 }
217}
218
219pub struct StreamingRequest<T>(http::Request<T>);
221
222impl<T> HttpConnectRequest for StreamingRequest<T> {
223 fn http_uri(&self) -> &Uri {
224 self.0.uri()
225 }
226
227 fn http_headers(&self) -> &HeaderMap {
228 self.0.headers()
229 }
230
231 fn http_message_codec(&self) -> Result<&str, Error> {
232 streaming_message_codec(self.http_headers())
233 }
234
235 fn http_content_encoding(&self) -> Option<&str> {
236 self.http_headers()
237 .get(CONNECT_CONTENT_ENCODING)?
238 .to_str()
239 .ok()
240 }
241
242 fn http_accept_encoding(&self) -> impl Iterator<Item = &str> {
243 self.http_headers()
244 .get_all(CONNECT_ACCEPT_ENCODING)
245 .into_iter()
246 .filter_map(|val| val.to_str().ok())
247 }
248}
249
250impl<T> From<http::Request<T>> for StreamingRequest<T> {
251 fn from(req: http::Request<T>) -> Self {
252 Self(req)
253 }
254}
255
256impl<T> From<StreamingRequest<T>> for http::Request<T> {
257 fn from(req: StreamingRequest<T>) -> Self {
258 req.0
259 }
260}
261
262pub struct UnaryGetRequest {
264 inner: http::Request<()>,
265 query: HashMap<String, String>,
266}
267
268impl UnaryGetRequest {
269 pub fn message(&self) -> Result<Cow<[u8]>, Error> {
270 let message = self
271 .query
272 .get("message")
273 .ok_or(Error::invalid_request("missing message"))?;
274 let is_b64 = self.query.get("base64").map(|s| s.as_str()) == Some("1");
275 if is_b64 {
276 Ok(BASE64_URL_SAFE_NO_PAD.decode(message)?.into())
277 } else {
278 Ok(
279 match percent_encoding::percent_decode_str(message)
280 .decode_utf8()
281 .map_err(|_| Error::invalid_request("message not valid utf8"))?
282 {
283 Cow::Borrowed(s) => s.as_bytes().into(),
284 Cow::Owned(s) => s.into_bytes().into(),
285 },
286 )
287 }
288 }
289}
290
291impl HttpConnectRequest for UnaryGetRequest {
292 fn http_uri(&self) -> &Uri {
293 self.inner.uri()
294 }
295
296 fn http_headers(&self) -> &HeaderMap {
297 self.inner.headers()
298 }
299
300 fn http_message_codec(&self) -> Result<&str, Error> {
301 self.query
302 .get("encoding")
303 .map(|s| s.as_str())
304 .ok_or(Error::invalid_request("missing 'encoding' param"))
305 }
306
307 fn http_connect_protocol_version(&self) -> Option<&str> {
308 self.query.get("connect")?.strip_prefix("v")
309 }
310
311 fn http_content_encoding(&self) -> Option<&str> {
312 self.query.get("encoding").map(|s| s.as_str())
313 }
314
315 fn http_validate(&self) -> Result<(), Error>
316 where
317 Self: Sized,
318 {
319 validate_request(self)?;
320 if !self.query.contains_key("message") {
321 return Err(Error::invalid_request("missing 'message' param"));
322 }
323 Ok(())
324 }
325}
326
327impl From<http::Request<()>> for UnaryGetRequest {
328 fn from(req: http::Request<()>) -> Self {
329 let query: HashMap<_, _> =
330 form_urlencoded::parse(req.uri().query().unwrap_or_default().as_bytes())
331 .map(|(k, v)| (k.to_string(), v.to_string()))
332 .collect();
333 Self { inner: req, query }
334 }
335}
336
337impl From<UnaryGetRequest> for http::Request<()> {
338 fn from(req: UnaryGetRequest) -> Self {
339 req.inner
340 }
341}