1use http_signature_normalization::create::Signed;
2use httpdate::HttpDate;
3use reqwest::{
4 header::{InvalidHeaderValue, ToStrError},
5 Request, RequestBuilder,
6};
7use std::{
8 convert::TryInto,
9 fmt::Display,
10 time::{Duration, SystemTime},
11};
12
13pub use http_signature_normalization::RequiredError;
14
15#[cfg(feature = "digest")]
16pub mod digest;
17
18pub mod prelude {
19 pub use crate::{Config, Sign, SignError};
20
21 #[cfg(feature = "default-spawner")]
22 pub use crate::default_spawner::DefaultSpawner;
23
24 #[cfg(feature = "digest")]
25 pub use crate::digest::{DigestCreate, SignExt};
26}
27
28#[cfg(feature = "default-spawner")]
29pub use default_spawner::DefaultSpawner;
30
31#[cfg(feature = "default-spawner")]
32#[derive(Clone, Debug, Default)]
33pub struct Config<Spawner = DefaultSpawner> {
38 config: http_signature_normalization::Config,
40
41 set_host: bool,
43
44 set_date: bool,
46
47 spawner: Spawner,
49}
50
51#[cfg(not(feature = "default-spawner"))]
52#[derive(Clone, Debug, Default)]
53pub struct Config<Spawner> {
58 config: http_signature_normalization::Config,
60
61 set_host: bool,
63
64 set_date: bool,
66
67 spawner: Spawner,
69}
70
71#[cfg(feature = "default-spawner")]
72mod default_spawner {
73 use super::{Canceled, Config, Spawn};
74
75 impl Config<DefaultSpawner> {
76 pub fn new() -> Self {
78 Default::default()
79 }
80 }
81
82 #[derive(Clone, Copy, Debug, Default)]
84 pub struct DefaultSpawner;
85
86 pub struct DefaultSpawnerFuture<Out> {
89 inner: tokio::task::JoinHandle<Out>,
90 }
91
92 impl Spawn for DefaultSpawner {
93 type Future<T> = DefaultSpawnerFuture<T> where T: Send;
94
95 fn spawn_blocking<Func, Out>(&self, func: Func) -> Self::Future<Out>
96 where
97 Func: FnOnce() -> Out + Send + 'static,
98 Out: Send + 'static,
99 {
100 DefaultSpawnerFuture {
101 inner: tokio::task::spawn_blocking(func),
102 }
103 }
104 }
105
106 impl<Out> std::future::Future for DefaultSpawnerFuture<Out> {
107 type Output = Result<Out, Canceled>;
108
109 fn poll(
110 mut self: std::pin::Pin<&mut Self>,
111 cx: &mut std::task::Context<'_>,
112 ) -> std::task::Poll<Self::Output> {
113 let res = std::task::ready!(std::pin::Pin::new(&mut self.inner).poll(cx));
114
115 std::task::Poll::Ready(res.map_err(|_| Canceled))
116 }
117 }
118}
119
120#[derive(Debug)]
122pub struct Canceled;
123
124impl std::fmt::Display for Canceled {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 write!(f, "Operation was canceled")
127 }
128}
129
130impl std::error::Error for Canceled {}
131
132pub trait Spawn {
136 type Future<T>: std::future::Future<Output = Result<T, Canceled>> + Send
138 where
139 T: Send;
140
141 fn spawn_blocking<Func, Out>(&self, func: Func) -> Self::Future<Out>
143 where
144 Func: FnOnce() -> Out + Send + 'static,
145 Out: Send + 'static;
146}
147
148#[async_trait::async_trait]
150pub trait Sign {
151 async fn authorization_signature<F, E, K, S>(
153 self,
154 config: &Config<S>,
155 key_id: K,
156 f: F,
157 ) -> Result<Request, E>
158 where
159 Self: Sized,
160 F: FnOnce(&str) -> Result<String, E> + Send + 'static,
161 E: From<SignError> + From<reqwest::Error> + Send + 'static,
162 K: Display + Send,
163 S: Spawn + Send + Sync;
164
165 async fn signature<F, E, K, S>(self, config: &Config<S>, key_id: K, f: F) -> Result<Request, E>
167 where
168 Self: Sized,
169 F: FnOnce(&str) -> Result<String, E> + Send + 'static,
170 E: From<SignError> + From<reqwest::Error> + Send + 'static,
171 K: Display + Send,
172 S: Spawn + Send + Sync;
173}
174
175#[derive(Debug, thiserror::Error)]
176pub enum SignError {
177 #[error("Failed to read header, {0}")]
178 Header(#[from] ToStrError),
180
181 #[error("Failed to write header, {0}")]
182 NewHeader(#[from] InvalidHeaderValue),
184
185 #[error("{0}")]
186 RequiredError(#[from] RequiredError),
188
189 #[error("No host provided for URL, {0}")]
190 Host(String),
192
193 #[error("Cannot sign request with body already present")]
194 BodyPresent,
195
196 #[error("Panic in spawn blocking")]
197 Canceled,
198}
199
200impl<Spawner> Config<Spawner> {
201 pub fn new_with_spawner(spawner: Spawner) -> Self {
203 Config {
204 config: Default::default(),
205 set_host: Default::default(),
206 set_date: Default::default(),
207 spawner,
208 }
209 }
210
211 pub fn set_host_header(self) -> Self {
214 Config {
215 config: self.config,
216 set_host: true,
217 set_date: self.set_date,
218 spawner: self.spawner,
219 }
220 }
221
222 pub fn mastodon_compat(self) -> Self {
227 Config {
228 config: self.config.mastodon_compat(),
229 set_host: true,
230 set_date: true,
231 spawner: self.spawner,
232 }
233 }
234
235 pub fn require_digest(self) -> Self {
239 Config {
240 config: self.config.require_digest(),
241 set_host: self.set_host,
242 set_date: self.set_date,
243 spawner: self.spawner,
244 }
245 }
246
247 pub fn dont_use_created_field(self) -> Self {
252 Config {
253 config: self.config.dont_use_created_field(),
254 set_host: self.set_host,
255 set_date: self.set_date,
256 spawner: self.spawner,
257 }
258 }
259
260 pub fn set_expiration(self, expiries_after: Duration) -> Self {
262 Config {
263 config: self.config.set_expiration(expiries_after),
264 set_host: self.set_host,
265 set_date: self.set_date,
266 spawner: self.spawner,
267 }
268 }
269
270 pub fn require_header(self, header: &str) -> Self {
272 Config {
273 config: self.config.require_header(header),
274 set_host: self.set_host,
275 set_date: self.set_date,
276 spawner: self.spawner,
277 }
278 }
279
280 pub fn set_spawner<NewSpawner: Spawn>(self, spawner: NewSpawner) -> Config<NewSpawner> {
281 Config {
282 config: self.config,
283 set_host: self.set_host,
284 set_date: self.set_date,
285 spawner,
286 }
287 }
288}
289
290#[async_trait::async_trait]
291impl Sign for RequestBuilder {
292 async fn authorization_signature<F, E, K, S>(
293 self,
294 config: &Config<S>,
295 key_id: K,
296 f: F,
297 ) -> Result<Request, E>
298 where
299 F: FnOnce(&str) -> Result<String, E> + Send + 'static,
300 E: From<SignError> + From<reqwest::Error> + Send + 'static,
301 K: Display + Send,
302 S: Spawn + Send + Sync,
303 {
304 let mut request = self.build()?;
305 let signed = prepare(&mut request, config, key_id, f).await?;
306
307 let auth_header = signed.authorization_header();
308 request.headers_mut().insert(
309 "Authorization",
310 auth_header.parse().map_err(SignError::NewHeader)?,
311 );
312
313 Ok(request)
314 }
315
316 async fn signature<F, E, K, S>(self, config: &Config<S>, key_id: K, f: F) -> Result<Request, E>
317 where
318 F: FnOnce(&str) -> Result<String, E> + Send + 'static,
319 E: From<SignError> + From<reqwest::Error> + Send + 'static,
320 K: Display + Send,
321 S: Spawn + Send + Sync,
322 {
323 let mut request = self.build()?;
324 let signed = prepare(&mut request, config, key_id, f).await?;
325
326 let sig_header = signed.signature_header();
327
328 request.headers_mut().insert(
329 "Signature",
330 sig_header.parse().map_err(SignError::NewHeader)?,
331 );
332
333 Ok(request)
334 }
335}
336
337async fn prepare<F, E, K, S>(
338 req: &mut Request,
339 config: &Config<S>,
340 key_id: K,
341 f: F,
342) -> Result<Signed, E>
343where
344 F: FnOnce(&str) -> Result<String, E> + Send + 'static,
345 E: From<SignError> + Send + 'static,
346 K: Display + Send,
347 S: Spawn,
348{
349 if config.set_date && !req.headers().contains_key("date") {
350 req.headers_mut().insert(
351 "date",
352 HttpDate::from(SystemTime::now())
353 .to_string()
354 .try_into()
355 .map_err(SignError::from)?,
356 );
357 }
358 let mut bt = std::collections::BTreeMap::new();
359 for (k, v) in req.headers().iter() {
360 bt.insert(
361 k.as_str().to_owned(),
362 v.to_str().map_err(SignError::from)?.to_owned(),
363 );
364 }
365 if config.set_host {
366 let header_string = req
367 .url()
368 .host()
369 .ok_or_else(|| SignError::Host(req.url().to_string()))?
370 .to_string();
371
372 let header_string = match req.url().port() {
373 None | Some(443) | Some(80) => header_string,
374 Some(port) => format!("{}:{}", header_string, port),
375 };
376
377 bt.insert("Host".to_string(), header_string);
378 }
379 let path_and_query = if let Some(query) = req.url().query() {
380 format!("{}?{}", req.url().path(), query)
381 } else {
382 req.url().path().to_string()
383 };
384 let unsigned = config
385 .config
386 .begin_sign(req.method().as_str(), &path_and_query, bt)
387 .map_err(SignError::from)?;
388
389 let key_string = key_id.to_string();
390 let signed = config
391 .spawner
392 .spawn_blocking(move || unsigned.sign(key_string, f))
393 .await
394 .map_err(|_| SignError::Canceled)??;
395 Ok(signed)
396}
397
398#[cfg(feature = "middleware")]
399mod middleware {
400 use super::{prepare, Config, Sign, SignError, Spawn};
401 use reqwest::Request;
402 use reqwest_middleware::RequestBuilder;
403 use std::fmt::Display;
404
405 #[async_trait::async_trait]
406 impl Sign for RequestBuilder {
407 async fn authorization_signature<F, E, K, S>(
408 self,
409 config: &Config<S>,
410 key_id: K,
411 f: F,
412 ) -> Result<Request, E>
413 where
414 F: FnOnce(&str) -> Result<String, E> + Send + 'static,
415 E: From<SignError> + From<reqwest::Error> + Send + 'static,
416 K: Display + Send,
417 S: Spawn + Send + Sync,
418 {
419 let mut request = self.build()?;
420 let signed = prepare(&mut request, config, key_id, f).await?;
421
422 let auth_header = signed.authorization_header();
423 request.headers_mut().insert(
424 "Authorization",
425 auth_header.parse().map_err(SignError::NewHeader)?,
426 );
427
428 Ok(request)
429 }
430
431 async fn signature<F, E, K, S>(
432 self,
433 config: &Config<S>,
434 key_id: K,
435 f: F,
436 ) -> Result<Request, E>
437 where
438 F: FnOnce(&str) -> Result<String, E> + Send + 'static,
439 E: From<SignError> + From<reqwest::Error> + Send + 'static,
440 K: Display + Send,
441 S: Spawn + Send + Sync,
442 {
443 let mut request = self.build()?;
444 let signed = prepare(&mut request, config, key_id, f).await?;
445
446 let sig_header = signed.signature_header();
447
448 request.headers_mut().insert(
449 "Signature",
450 sig_header.parse().map_err(SignError::NewHeader)?,
451 );
452
453 Ok(request)
454 }
455 }
456}