Skip to main content

reqsign_core/
request.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::borrow::Cow;
19use std::mem;
20use std::time::Duration;
21
22use crate::{Error, Result};
23use http::HeaderMap;
24use http::HeaderValue;
25use http::Method;
26use http::Uri;
27use http::header::HeaderName;
28use http::uri::Authority;
29use http::uri::PathAndQuery;
30use http::uri::Scheme;
31use std::str::FromStr;
32
33/// Signing context for request.
34#[derive(Debug)]
35pub struct SigningRequest {
36    /// HTTP method.
37    pub method: Method,
38    /// HTTP scheme.
39    pub scheme: Scheme,
40    /// HTTP authority.
41    pub authority: Authority,
42    /// HTTP path.
43    pub path: String,
44    /// HTTP query parameters.
45    pub query: Vec<(String, String)>,
46    /// HTTP headers.
47    pub headers: HeaderMap,
48}
49
50impl SigningRequest {
51    /// Build a signing context from http::request::Parts.
52    pub fn build(parts: &mut http::request::Parts) -> Result<Self> {
53        let uri = mem::take(&mut parts.uri).into_parts();
54        let paq = uri
55            .path_and_query
56            .unwrap_or_else(|| PathAndQuery::from_static("/"));
57
58        Ok(SigningRequest {
59            method: parts.method.clone(),
60            scheme: uri.scheme.unwrap_or(Scheme::HTTP),
61            authority: uri.authority.ok_or_else(|| {
62                Error::request_invalid("request without authority is invalid for signing")
63            })?,
64            path: paq.path().to_string(),
65            query: paq
66                .query()
67                .map(|v| {
68                    form_urlencoded::parse(v.as_bytes())
69                        .map(|(k, v)| (k.into_owned(), v.into_owned()))
70                        .collect()
71                })
72                .unwrap_or_default(),
73
74            // Take the headers out of the request to avoid copy.
75            // We will return it back when apply the context.
76            headers: mem::take(&mut parts.headers),
77        })
78    }
79
80    /// Apply the signing context back to http::request::Parts.
81    pub fn apply(mut self, parts: &mut http::request::Parts) -> Result<()> {
82        let query_size = self.query_size();
83
84        // Return headers back.
85        mem::swap(&mut parts.headers, &mut self.headers);
86        parts.method = self.method;
87        parts.uri = {
88            let mut uri_parts = mem::take(&mut parts.uri).into_parts();
89            // Return scheme bakc.
90            uri_parts.scheme = Some(self.scheme);
91            // Return authority back.
92            uri_parts.authority = Some(self.authority);
93            // Build path and query.
94            uri_parts.path_and_query =
95                {
96                    let paq = if query_size == 0 {
97                        self.path
98                    } else {
99                        let mut s = self.path;
100                        s.reserve(query_size + 1);
101
102                        s.push('?');
103                        for (i, (k, v)) in self.query.iter().enumerate() {
104                            if i > 0 {
105                                s.push('&');
106                            }
107
108                            s.push_str(k);
109                            if !v.is_empty() {
110                                s.push('=');
111                                s.push_str(v);
112                            }
113                        }
114
115                        s
116                    };
117
118                    Some(PathAndQuery::from_str(&paq).map_err(|e| {
119                        Error::request_invalid("invalid path and query").with_source(e)
120                    })?)
121                };
122            Uri::from_parts(uri_parts)
123                .map_err(|e| Error::request_invalid("failed to build URI").with_source(e))?
124        };
125
126        Ok(())
127    }
128
129    /// Get the path percent decoded.
130    pub fn path_percent_decoded(&self) -> Cow<'_, str> {
131        percent_encoding::percent_decode_str(&self.path).decode_utf8_lossy()
132    }
133
134    /// Get query size.
135    #[inline]
136    pub fn query_size(&self) -> usize {
137        self.query
138            .iter()
139            .map(|(k, v)| k.len() + v.len())
140            .sum::<usize>()
141    }
142
143    /// Push a new query pair into query list.
144    #[inline]
145    pub fn query_push(&mut self, key: impl Into<String>, value: impl Into<String>) {
146        self.query.push((key.into(), value.into()));
147    }
148
149    /// Push a query string into query list.
150    #[inline]
151    pub fn query_append(&mut self, query: &str) {
152        self.query.push((query.to_string(), "".to_string()));
153    }
154
155    /// Get query value by filter.
156    pub fn query_to_vec_with_filter(&self, filter: impl Fn(&str) -> bool) -> Vec<(String, String)> {
157        self.query
158            .iter()
159            // Filter all queries
160            .filter(|(k, _)| filter(k))
161            // Clone all queries
162            .map(|(k, v)| (k.to_string(), v.to_string()))
163            .collect()
164    }
165
166    /// Convert sorted query to string.
167    ///
168    /// ```shell
169    /// [(a, b), (c, d)] => "a:b\nc:d"
170    /// ```
171    pub fn query_to_string(mut query: Vec<(String, String)>, sep: &str, join: &str) -> String {
172        let mut s = String::with_capacity(16);
173
174        // Sort via header name.
175        query.sort();
176
177        for (idx, (k, v)) in query.into_iter().enumerate() {
178            if idx != 0 {
179                s.push_str(join);
180            }
181
182            s.push_str(&k);
183            if !v.is_empty() {
184                s.push_str(sep);
185                s.push_str(&v);
186            }
187        }
188
189        s
190    }
191
192    /// Convert sorted query to percent decoded string.
193    ///
194    /// ```shell
195    /// [(a, b), (c, d)] => "a:b\nc:d"
196    /// ```
197    pub fn query_to_percent_decoded_string(
198        mut query: Vec<(String, String)>,
199        sep: &str,
200        join: &str,
201    ) -> String {
202        let mut s = String::with_capacity(16);
203
204        // Sort via header name.
205        query.sort();
206
207        for (idx, (k, v)) in query.into_iter().enumerate() {
208            if idx != 0 {
209                s.push_str(join);
210            }
211
212            s.push_str(&k);
213            if !v.is_empty() {
214                s.push_str(sep);
215                s.push_str(&percent_encoding::percent_decode_str(&v).decode_utf8_lossy());
216            }
217        }
218
219        s
220    }
221
222    /// Get header value by name.
223    ///
224    /// Returns empty string if header not found.
225    #[inline]
226    pub fn header_get_or_default(&self, key: &HeaderName) -> Result<&str> {
227        match self.headers.get(key) {
228            Some(v) => v
229                .to_str()
230                .map_err(|e| Error::request_invalid("invalid header value").with_source(e)),
231            None => Ok(""),
232        }
233    }
234
235    /// Normalize header value.
236    pub fn header_value_normalize(v: &mut HeaderValue) {
237        let bs = v.as_bytes();
238
239        let starting_index = bs.iter().position(|b| *b != b' ').unwrap_or(0);
240        let ending_offset = bs.iter().rev().position(|b| *b != b' ').unwrap_or(0);
241        let ending_index = bs.len() - ending_offset;
242
243        // This can't fail because we started with a valid HeaderValue and then only trimmed spaces
244        *v = HeaderValue::from_bytes(&bs[starting_index..ending_index])
245            .expect("invalid header value")
246    }
247
248    /// Get header names as sorted vector.
249    pub fn header_name_to_vec_sorted(&self) -> Vec<&str> {
250        let mut h = self
251            .headers
252            .keys()
253            .map(|k| k.as_str())
254            .collect::<Vec<&str>>();
255        h.sort_unstable();
256
257        h
258    }
259
260    /// Get header names with given prefix.
261    pub fn header_to_vec_with_prefix(&self, prefix: &str) -> Vec<(String, String)> {
262        self.headers
263            .iter()
264            // Filter all header that starts with prefix
265            .filter(|(k, _)| k.as_str().starts_with(prefix))
266            // Convert all header name to lowercase
267            .map(|(k, v)| {
268                (
269                    k.as_str().to_lowercase(),
270                    v.to_str().expect("must be valid header").to_string(),
271                )
272            })
273            .collect()
274    }
275
276    /// Convert sorted headers to string.
277    ///
278    /// ```shell
279    /// [(a, b), (c, d)] => "a:b\nc:d"
280    /// ```
281    pub fn header_to_string(mut headers: Vec<(String, String)>, sep: &str, join: &str) -> String {
282        let mut s = String::with_capacity(16);
283
284        // Sort via header name.
285        headers.sort();
286
287        for (idx, (k, v)) in headers.into_iter().enumerate() {
288            if idx != 0 {
289                s.push_str(join);
290            }
291
292            s.push_str(&k);
293            s.push_str(sep);
294            s.push_str(&v);
295        }
296
297        s
298    }
299}
300
301/// SigningMethod is the method that used in signing.
302#[derive(Copy, Clone, PartialEq, Eq)]
303pub enum SigningMethod {
304    /// Signing with header.
305    Header,
306    /// Signing with query.
307    Query(Duration),
308}