1use http::{
2 header,
3 uri::{Authority, Parts, PathAndQuery, Scheme},
4 HeaderMap, HeaderName, HeaderValue, Method, Request, Uri,
5};
6
7use base64::{engine::general_purpose::URL_SAFE_NO_PAD as BASE64_URL_SAFE, Engine};
8
9use crate::{
10 common::{
11 is_valid_http_token, CONNECT_ACCEPT_ENCODING, CONNECT_CONTENT_ENCODING,
12 CONNECT_PROTOCOL_VERSION, CONNECT_TIMEOUT_MS, CONTENT_TYPE_PREFIX, PROTOCOL_VERSION_1,
13 },
14 metadata::Metadata,
15 Error,
16};
17
18use super::{StreamingRequest, UnaryGetRequest, UnaryRequest};
19
20#[derive(Debug, Default)]
21pub struct RequestBuilder {
22 scheme: Option<Scheme>,
23 authority: Option<Authority>,
24 path: Option<String>,
25 metadata: HeaderMap,
26 message_codec: Option<String>,
27 timeout_ms: Option<HeaderValue>,
28 content_encoding: Option<String>,
29 accept_encoding: Vec<HeaderValue>,
30}
31
32impl RequestBuilder {
33 pub fn scheme(
37 mut self,
38 scheme: impl TryInto<Scheme, Error: Into<Error>>,
39 ) -> Result<Self, Error> {
40 self.scheme = Some(scheme.try_into().map_err(Into::into)?);
41 Ok(self)
42 }
43
44 pub fn authority(
46 mut self,
47 authority: impl TryInto<Authority, Error: Into<Error>>,
48 ) -> Result<Self, Error> {
49 self.authority = Some(authority.try_into().map_err(Into::into)?);
50 Ok(self)
51 }
52
53 pub fn path(mut self, path: impl Into<String>) -> Result<Self, Error> {
59 let mut path = path.into();
60 if path.contains('?') {
61 return Err(Error::invalid_request(
62 "path may not contain query params ('?')",
63 ));
64 }
65 if !path.starts_with('/') {
66 path = format!("/{path}");
67 }
68 self.path = Some(path);
69 Ok(self)
70 }
71
72 pub fn protobuf_rpc(
76 self,
77 full_service_name: impl AsRef<str>,
78 method_name: impl AsRef<str>,
79 ) -> Result<Self, Error> {
80 self.path(format!(
81 "/{}/{}",
82 full_service_name.as_ref(),
83 method_name.as_ref()
84 ))
85 }
86
87 pub fn protobuf_rpc_with_routing_prefix(
90 self,
91 routing_prefix: impl Into<String>,
92 full_service_name: impl AsRef<str>,
93 method_name: impl AsRef<str>,
94 ) -> Result<Self, Error> {
95 let mut routing_prefix = routing_prefix.into();
96 if !routing_prefix.ends_with('/') {
97 routing_prefix = format!("{routing_prefix}/");
98 }
99 self.path(format!(
100 "{routing_prefix}{}/{}",
101 full_service_name.as_ref(),
102 method_name.as_ref()
103 ))
104 }
105
106 pub fn uri(mut self, uri: impl TryInto<Uri, Error: Into<Error>>) -> Result<Self, Error> {
110 let uri: Uri = uri.try_into().map_err(Into::into)?;
111 let Parts {
112 scheme,
113 authority,
114 path_and_query,
115 ..
116 } = uri.into_parts();
117 self.scheme = scheme;
118 self.authority = authority;
119 self.path = path_and_query.map(|paq| paq.path().to_string());
120 Ok(self)
121 }
122
123 pub fn ascii_metadata(
125 mut self,
126 key: impl TryInto<HeaderName, Error: Into<Error>>,
127 val: impl Into<String>,
128 ) -> Result<Self, Error> {
129 self.metadata.append_ascii(key, val)?;
130 Ok(self)
131 }
132
133 pub fn binary_metadata(
135 mut self,
136 key: impl TryInto<HeaderName, Error: Into<Error>>,
137 val: impl AsRef<[u8]>,
138 ) -> Result<Self, Error> {
139 self.metadata.append_binary(key, val)?;
140 Ok(self)
141 }
142
143 pub fn message_codec(mut self, message_codec: impl Into<String>) -> Result<Self, Error> {
151 let mut message_codec: String = message_codec.into();
152 message_codec.make_ascii_lowercase();
153 if !is_valid_http_token(&message_codec) {
154 return Err(Error::invalid_request("invalid message codec"));
155 }
156 self.message_codec = Some(message_codec);
157 Ok(self)
158 }
159
160 pub fn timeout_ms(mut self, timeout_ms: u64) -> Result<Self, Error> {
162 let timeout = timeout_ms.to_string();
164 if timeout.len() > 10 {
165 return Err(Error::invalid_request("timeout too large"));
166 }
167 self.timeout_ms = Some(timeout.try_into().unwrap());
168 Ok(self)
169 }
170
171 pub fn clear_timeout(mut self) -> Self {
173 self.timeout_ms = None;
174 self
175 }
176
177 pub fn content_encoding(mut self, content_encoding: impl Into<String>) -> Result<Self, Error> {
179 let content_encoding = content_encoding.into();
180 if !is_valid_http_token(&content_encoding) {
181 return Err(Error::invalid_request("invalid content encoding"));
182 }
183 self.content_encoding = Some(content_encoding);
184 Ok(self)
185 }
186
187 pub fn accept_encoding<T: TryInto<HeaderValue, Error: Into<Error>>>(
189 mut self,
190 accept_encodings: impl IntoIterator<Item = T>,
191 ) -> Result<Self, Error> {
192 self.accept_encoding = accept_encodings
193 .into_iter()
194 .map(|v| v.try_into().map_err(Into::into))
195 .collect::<Result<_, _>>()?;
196 Ok(self)
197 }
198
199 fn common_request<T>(&mut self, method: Method, body: T) -> Result<http::Request<T>, Error> {
201 let mut req = Request::new(body);
202 *req.method_mut() = method;
203 let mut headers: HeaderMap = std::mem::take(&mut self.metadata);
204 headers.insert(CONNECT_PROTOCOL_VERSION, PROTOCOL_VERSION_1);
206 if let Some(timeout) = self.timeout_ms.take() {
208 headers.insert(CONNECT_TIMEOUT_MS, timeout);
209 }
210 *req.headers_mut() = headers;
211 Ok(req)
212 }
213
214 pub fn unary<T>(mut self, body: T) -> Result<UnaryRequest<T>, Error> {
218 let mut req = self.common_request(Method::POST, body)?;
219 *req.uri_mut() = build_uri(self.scheme, self.authority, self.path)?;
220
221 if let Some(message_codec) = &self.message_codec {
223 req.headers_mut().insert(
224 header::CONTENT_TYPE,
225 (format!("{CONTENT_TYPE_PREFIX}{message_codec}")).try_into()?,
226 );
227 }
228 if let Some(content_encoding) = self.content_encoding.take() {
230 req.headers_mut()
231 .insert(header::CONTENT_ENCODING, content_encoding.try_into()?);
232 }
233 for value in std::mem::take(&mut self.accept_encoding) {
235 req.headers_mut().append(header::ACCEPT_ENCODING, value);
236 }
237 Ok(req.into())
238 }
239
240 pub fn streaming<T>(mut self, body: T) -> Result<StreamingRequest<T>, Error> {
244 let mut req = self.common_request(Method::POST, body)?;
245 *req.uri_mut() = build_uri(self.scheme, self.authority, self.path)?;
246
247 if let Some(message_codec) = &self.message_codec {
249 req.headers_mut().insert(
250 header::CONTENT_TYPE,
251 (format!("{CONTENT_TYPE_PREFIX}{message_codec}")).try_into()?,
252 );
253 }
254 if let Some(content_encoding) = self.content_encoding.take() {
256 req.headers_mut()
257 .insert(CONNECT_CONTENT_ENCODING, content_encoding.try_into()?);
258 }
259 for value in std::mem::take(&mut self.accept_encoding) {
261 req.headers_mut().append(CONNECT_ACCEPT_ENCODING, value);
262 }
263 Ok(req.into())
264 }
265
266 pub fn unary_get(mut self, message: impl AsRef<[u8]>) -> Result<UnaryGetRequest, Error> {
270 let mut req = self.common_request(Method::GET, ())?;
271 *req.method_mut() = Method::GET;
272
273 let path_and_query = {
274 let path = self.path.ok_or(Error::invalid_request("path required"))?;
275 let query = {
276 let mut query = form_urlencoded::Serializer::new("?".to_string());
277 query
278 .append_pair("message", &BASE64_URL_SAFE.encode(message))
280 .append_pair("base64", "1")
282 .append_pair("connect", "v1");
284 if let Some(message_codec) = &self.message_codec {
285 query.append_pair("encoding", message_codec);
287 } else {
288 return Err(Error::invalid_request("message codec required"));
289 }
290 if let Some(content_encoding) = &self.content_encoding {
291 query.append_pair("compression", content_encoding);
293 }
294 query.finish()
295 };
296 Some(format!("{path}?{query}"))
297 };
298 *req.uri_mut() = build_uri(self.scheme, self.authority, path_and_query)?;
299
300 for value in std::mem::take(&mut self.accept_encoding) {
302 req.headers_mut().append(header::ACCEPT_ENCODING, value);
303 }
304 Ok(req.into())
305 }
306}
307
308fn build_uri(
309 scheme: Option<Scheme>,
310 authority: Option<Authority>,
311 path_and_query: Option<impl TryInto<PathAndQuery, Error: Into<Error>>>,
312) -> Result<Uri, Error> {
313 Ok(Uri::from_parts({
314 let mut parts = Parts::default();
315 parts.scheme = scheme;
316 parts.authority = authority;
317 parts.path_and_query = path_and_query
318 .map(TryInto::try_into)
319 .transpose()
320 .map_err(Into::into)?;
321 parts
322 })?)
323}