http_signature_normalization_reqwest/
lib.rs

1use http_signature_normalization::create::Signed;
2use httpdate::HttpDate;
3use reqwest::{
4    header::{InvalidHeaderValue, ToStrError},
5    Request, RequestBuilder,
6};
7use std::{
8    convert::TryInto,
9    fmt::Display,
10    time::{Duration, SystemTime},
11};
12
13pub use http_signature_normalization::RequiredError;
14
15#[cfg(feature = "digest")]
16pub mod digest;
17
18pub mod prelude {
19    pub use crate::{Config, Sign, SignError};
20
21    #[cfg(feature = "default-spawner")]
22    pub use crate::default_spawner::DefaultSpawner;
23
24    #[cfg(feature = "digest")]
25    pub use crate::digest::{DigestCreate, SignExt};
26}
27
28#[cfg(feature = "default-spawner")]
29pub use default_spawner::DefaultSpawner;
30
31#[cfg(feature = "default-spawner")]
32#[derive(Clone, Debug, Default)]
33/// Configuration for signing and verifying signatures
34///
35/// By default, the config is set up to create and verify signatures that expire after 10 seconds,
36/// and use the `(created)` and `(expires)` fields that were introduced in draft 11
37pub struct Config<Spawner = DefaultSpawner> {
38    /// The inner config type
39    config: http_signature_normalization::Config,
40
41    /// Whether to set the Host header
42    set_host: bool,
43
44    /// Whether to set the Date header
45    set_date: bool,
46
47    /// How to spawn blocking tasks
48    spawner: Spawner,
49}
50
51#[cfg(not(feature = "default-spawner"))]
52#[derive(Clone, Debug, Default)]
53/// Configuration for signing and verifying signatures
54///
55/// By default, the config is set up to create and verify signatures that expire after 10 seconds,
56/// and use the `(created)` and `(expires)` fields that were introduced in draft 11
57pub struct Config<Spawner> {
58    /// The inner config type
59    config: http_signature_normalization::Config,
60
61    /// Whether to set the Host header
62    set_host: bool,
63
64    /// Whether to set the Date header
65    set_date: bool,
66
67    /// How to spawn blocking tasks
68    spawner: Spawner,
69}
70
71#[cfg(feature = "default-spawner")]
72mod default_spawner {
73    use super::{Canceled, Config, Spawn};
74
75    impl Config<DefaultSpawner> {
76        /// Create a new config with the default spawner
77        pub fn new() -> Self {
78            Default::default()
79        }
80    }
81
82    /// A default implementation of Spawner for spawning blocking operations
83    #[derive(Clone, Copy, Debug, Default)]
84    pub struct DefaultSpawner;
85
86    /// The future returned by DefaultSpawner when spawning blocking operations on the tokio
87    /// blocking threadpool
88    pub struct DefaultSpawnerFuture<Out> {
89        inner: tokio::task::JoinHandle<Out>,
90    }
91
92    impl Spawn for DefaultSpawner {
93        type Future<T> = DefaultSpawnerFuture<T> where T: Send;
94
95        fn spawn_blocking<Func, Out>(&self, func: Func) -> Self::Future<Out>
96        where
97            Func: FnOnce() -> Out + Send + 'static,
98            Out: Send + 'static,
99        {
100            DefaultSpawnerFuture {
101                inner: tokio::task::spawn_blocking(func),
102            }
103        }
104    }
105
106    impl<Out> std::future::Future for DefaultSpawnerFuture<Out> {
107        type Output = Result<Out, Canceled>;
108
109        fn poll(
110            mut self: std::pin::Pin<&mut Self>,
111            cx: &mut std::task::Context<'_>,
112        ) -> std::task::Poll<Self::Output> {
113            let res = std::task::ready!(std::pin::Pin::new(&mut self.inner).poll(cx));
114
115            std::task::Poll::Ready(res.map_err(|_| Canceled))
116        }
117    }
118}
119
120/// An error that indicates a blocking operation panicked and cannot return a response
121#[derive(Debug)]
122pub struct Canceled;
123
124impl std::fmt::Display for Canceled {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(f, "Operation was canceled")
127    }
128}
129
130impl std::error::Error for Canceled {}
131
132/// A trait dictating how to spawn a future onto a blocking threadpool. By default,
133/// http-signature-normalization-actix will use tokio's built-in blocking threadpool, but this
134/// can be customized
135pub trait Spawn {
136    /// The future type returned by spawn_blocking
137    type Future<T>: std::future::Future<Output = Result<T, Canceled>> + Send
138    where
139        T: Send;
140
141    /// Spawn the blocking function onto the threadpool
142    fn spawn_blocking<Func, Out>(&self, func: Func) -> Self::Future<Out>
143    where
144        Func: FnOnce() -> Out + Send + 'static,
145        Out: Send + 'static;
146}
147
148/// A trait implemented by the reqwest RequestBuilder type to add an HTTP Signature to the request
149#[async_trait::async_trait]
150pub trait Sign {
151    /// Add an Authorization Signature to the request
152    async fn authorization_signature<F, E, K, S>(
153        self,
154        config: &Config<S>,
155        key_id: K,
156        f: F,
157    ) -> Result<Request, E>
158    where
159        Self: Sized,
160        F: FnOnce(&str) -> Result<String, E> + Send + 'static,
161        E: From<SignError> + From<reqwest::Error> + Send + 'static,
162        K: Display + Send,
163        S: Spawn + Send + Sync;
164
165    /// Add a Signature to the request
166    async fn signature<F, E, K, S>(self, config: &Config<S>, key_id: K, f: F) -> Result<Request, E>
167    where
168        Self: Sized,
169        F: FnOnce(&str) -> Result<String, E> + Send + 'static,
170        E: From<SignError> + From<reqwest::Error> + Send + 'static,
171        K: Display + Send,
172        S: Spawn + Send + Sync;
173}
174
175#[derive(Debug, thiserror::Error)]
176pub enum SignError {
177    #[error("Failed to read header, {0}")]
178    /// An error occurred when reading the request's headers
179    Header(#[from] ToStrError),
180
181    #[error("Failed to write header, {0}")]
182    /// An error occured when adding a new header
183    NewHeader(#[from] InvalidHeaderValue),
184
185    #[error("{0}")]
186    /// Some headers were marked as required, but are missing
187    RequiredError(#[from] RequiredError),
188
189    #[error("No host provided for URL, {0}")]
190    /// Missing host
191    Host(String),
192
193    #[error("Cannot sign request with body already present")]
194    BodyPresent,
195
196    #[error("Panic in spawn blocking")]
197    Canceled,
198}
199
200impl<Spawner> Config<Spawner> {
201    /// Create a new config with the provided spawner
202    pub fn new_with_spawner(spawner: Spawner) -> Self {
203        Config {
204            config: Default::default(),
205            set_host: Default::default(),
206            set_date: Default::default(),
207            spawner,
208        }
209    }
210
211    /// This method can be used to include the Host header in the HTTP Signature without
212    /// interfering with Reqwest's built-in Host mechanisms
213    pub fn set_host_header(self) -> Self {
214        Config {
215            config: self.config,
216            set_host: true,
217            set_date: self.set_date,
218            spawner: self.spawner,
219        }
220    }
221
222    /// Enable mastodon compatibility
223    ///
224    /// This is the same as disabling the use of `(created)` and `(expires)` signature fields,
225    /// requiring the Date header, and requiring the Host header
226    pub fn mastodon_compat(self) -> Self {
227        Config {
228            config: self.config.mastodon_compat(),
229            set_host: true,
230            set_date: true,
231            spawner: self.spawner,
232        }
233    }
234
235    /// Require the Digest header be set
236    ///
237    /// This is useful for POST, PUT, and PATCH requests, but doesn't make sense for GET or DELETE.
238    pub fn require_digest(self) -> Self {
239        Config {
240            config: self.config.require_digest(),
241            set_host: self.set_host,
242            set_date: self.set_date,
243            spawner: self.spawner,
244        }
245    }
246
247    /// Opt out of using the (created) and (expires) fields introduced in draft 11
248    ///
249    /// Note that by enabling this, the Date header becomes required on requests. This is to
250    /// prevent replay attacks
251    pub fn dont_use_created_field(self) -> Self {
252        Config {
253            config: self.config.dont_use_created_field(),
254            set_host: self.set_host,
255            set_date: self.set_date,
256            spawner: self.spawner,
257        }
258    }
259
260    /// Set the expiration to a custom duration
261    pub fn set_expiration(self, expiries_after: Duration) -> Self {
262        Config {
263            config: self.config.set_expiration(expiries_after),
264            set_host: self.set_host,
265            set_date: self.set_date,
266            spawner: self.spawner,
267        }
268    }
269
270    /// Require a header on signed requests
271    pub fn require_header(self, header: &str) -> Self {
272        Config {
273            config: self.config.require_header(header),
274            set_host: self.set_host,
275            set_date: self.set_date,
276            spawner: self.spawner,
277        }
278    }
279
280    pub fn set_spawner<NewSpawner: Spawn>(self, spawner: NewSpawner) -> Config<NewSpawner> {
281        Config {
282            config: self.config,
283            set_host: self.set_host,
284            set_date: self.set_date,
285            spawner,
286        }
287    }
288}
289
290#[async_trait::async_trait]
291impl Sign for RequestBuilder {
292    async fn authorization_signature<F, E, K, S>(
293        self,
294        config: &Config<S>,
295        key_id: K,
296        f: F,
297    ) -> Result<Request, E>
298    where
299        F: FnOnce(&str) -> Result<String, E> + Send + 'static,
300        E: From<SignError> + From<reqwest::Error> + Send + 'static,
301        K: Display + Send,
302        S: Spawn + Send + Sync,
303    {
304        let mut request = self.build()?;
305        let signed = prepare(&mut request, config, key_id, f).await?;
306
307        let auth_header = signed.authorization_header();
308        request.headers_mut().insert(
309            "Authorization",
310            auth_header.parse().map_err(SignError::NewHeader)?,
311        );
312
313        Ok(request)
314    }
315
316    async fn signature<F, E, K, S>(self, config: &Config<S>, key_id: K, f: F) -> Result<Request, E>
317    where
318        F: FnOnce(&str) -> Result<String, E> + Send + 'static,
319        E: From<SignError> + From<reqwest::Error> + Send + 'static,
320        K: Display + Send,
321        S: Spawn + Send + Sync,
322    {
323        let mut request = self.build()?;
324        let signed = prepare(&mut request, config, key_id, f).await?;
325
326        let sig_header = signed.signature_header();
327
328        request.headers_mut().insert(
329            "Signature",
330            sig_header.parse().map_err(SignError::NewHeader)?,
331        );
332
333        Ok(request)
334    }
335}
336
337async fn prepare<F, E, K, S>(
338    req: &mut Request,
339    config: &Config<S>,
340    key_id: K,
341    f: F,
342) -> Result<Signed, E>
343where
344    F: FnOnce(&str) -> Result<String, E> + Send + 'static,
345    E: From<SignError> + Send + 'static,
346    K: Display + Send,
347    S: Spawn,
348{
349    if config.set_date && !req.headers().contains_key("date") {
350        req.headers_mut().insert(
351            "date",
352            HttpDate::from(SystemTime::now())
353                .to_string()
354                .try_into()
355                .map_err(SignError::from)?,
356        );
357    }
358    let mut bt = std::collections::BTreeMap::new();
359    for (k, v) in req.headers().iter() {
360        bt.insert(
361            k.as_str().to_owned(),
362            v.to_str().map_err(SignError::from)?.to_owned(),
363        );
364    }
365    if config.set_host {
366        let header_string = req
367            .url()
368            .host()
369            .ok_or_else(|| SignError::Host(req.url().to_string()))?
370            .to_string();
371
372        let header_string = match req.url().port() {
373            None | Some(443) | Some(80) => header_string,
374            Some(port) => format!("{}:{}", header_string, port),
375        };
376
377        bt.insert("Host".to_string(), header_string);
378    }
379    let path_and_query = if let Some(query) = req.url().query() {
380        format!("{}?{}", req.url().path(), query)
381    } else {
382        req.url().path().to_string()
383    };
384    let unsigned = config
385        .config
386        .begin_sign(req.method().as_str(), &path_and_query, bt)
387        .map_err(SignError::from)?;
388
389    let key_string = key_id.to_string();
390    let signed = config
391        .spawner
392        .spawn_blocking(move || unsigned.sign(key_string, f))
393        .await
394        .map_err(|_| SignError::Canceled)??;
395    Ok(signed)
396}
397
398#[cfg(feature = "middleware")]
399mod middleware {
400    use super::{prepare, Config, Sign, SignError, Spawn};
401    use reqwest::Request;
402    use reqwest_middleware::RequestBuilder;
403    use std::fmt::Display;
404
405    #[async_trait::async_trait]
406    impl Sign for RequestBuilder {
407        async fn authorization_signature<F, E, K, S>(
408            self,
409            config: &Config<S>,
410            key_id: K,
411            f: F,
412        ) -> Result<Request, E>
413        where
414            F: FnOnce(&str) -> Result<String, E> + Send + 'static,
415            E: From<SignError> + From<reqwest::Error> + Send + 'static,
416            K: Display + Send,
417            S: Spawn + Send + Sync,
418        {
419            let mut request = self.build()?;
420            let signed = prepare(&mut request, config, key_id, f).await?;
421
422            let auth_header = signed.authorization_header();
423            request.headers_mut().insert(
424                "Authorization",
425                auth_header.parse().map_err(SignError::NewHeader)?,
426            );
427
428            Ok(request)
429        }
430
431        async fn signature<F, E, K, S>(
432            self,
433            config: &Config<S>,
434            key_id: K,
435            f: F,
436        ) -> Result<Request, E>
437        where
438            F: FnOnce(&str) -> Result<String, E> + Send + 'static,
439            E: From<SignError> + From<reqwest::Error> + Send + 'static,
440            K: Display + Send,
441            S: Spawn + Send + Sync,
442        {
443            let mut request = self.build()?;
444            let signed = prepare(&mut request, config, key_id, f).await?;
445
446            let sig_header = signed.signature_header();
447
448            request.headers_mut().insert(
449                "Signature",
450                sig_header.parse().map_err(SignError::NewHeader)?,
451            );
452
453            Ok(request)
454        }
455    }
456}