http_signature_normalization/
lib.rs1#![deny(missing_docs)]
2
3use std::{
45 collections::{BTreeMap, HashSet},
46 num::ParseIntError,
47 time::{Duration, SystemTime, UNIX_EPOCH},
48};
49
50pub mod create;
51pub mod verify;
52
53use self::{
54 create::Unsigned,
55 verify::{ParseSignatureError, ParsedHeader, Unverified, ValidateError},
56};
57
58const REQUEST_TARGET: &str = "(request-target)";
59const CREATED: &str = "(created)";
60const EXPIRES: &str = "(expires)";
61
62const KEY_ID_FIELD: &str = "keyId";
63const ALGORITHM_FIELD: &str = "algorithm";
64const ALGORITHM_VALUE: &str = "hs2019";
65const CREATED_FIELD: &str = "created";
66const EXPIRES_FIELD: &str = "expires";
67const HEADERS_FIELD: &str = "headers";
68const SIGNATURE_FIELD: &str = "signature";
69
70#[derive(Clone, Debug)]
71pub struct Config {
76 expires_after: Duration,
77 use_created_field: bool,
78 required_headers: Vec<String>,
79}
80
81#[derive(Debug)]
82pub enum PrepareVerifyError {
86 Validate(ValidateError),
88
89 Parse(ParseSignatureError),
91
92 Required(RequiredError),
94}
95
96impl std::fmt::Display for PrepareVerifyError {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 match self {
99 Self::Validate(ref e) => std::fmt::Display::fmt(e, f),
100 Self::Parse(ref e) => std::fmt::Display::fmt(e, f),
101 Self::Required(ref e) => std::fmt::Display::fmt(e, f),
102 }
103 }
104}
105
106impl std::error::Error for PrepareVerifyError {
107 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
108 match self {
109 Self::Validate(ref e) => Some(e),
110 Self::Parse(ref e) => Some(e),
111 Self::Required(ref e) => Some(e),
112 }
113 }
114}
115
116impl From<ValidateError> for PrepareVerifyError {
117 fn from(e: ValidateError) -> Self {
118 Self::Validate(e)
119 }
120}
121impl From<ParseSignatureError> for PrepareVerifyError {
122 fn from(e: ParseSignatureError) -> Self {
123 Self::Parse(e)
124 }
125}
126impl From<RequiredError> for PrepareVerifyError {
127 fn from(e: RequiredError) -> Self {
128 Self::Required(e)
129 }
130}
131
132#[derive(Debug)]
133pub struct RequiredError(HashSet<String>);
135
136impl std::fmt::Display for RequiredError {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 write!(f, "Missing required headers {:?}", self.0)
139 }
140}
141
142impl std::error::Error for RequiredError {}
143
144impl RequiredError {
145 pub fn headers(&self) -> &HashSet<String> {
147 &self.0
148 }
149
150 pub fn take_headers(&mut self) -> HashSet<String> {
152 std::mem::take(&mut self.0)
153 }
154}
155
156impl Config {
157 pub const fn new() -> Self {
159 Config {
160 expires_after: Duration::from_secs(10),
161 use_created_field: true,
162 required_headers: Vec::new(),
163 }
164 }
165
166 pub fn mastodon_compat(self) -> Self {
171 self.dont_use_created_field().require_header("host")
172 }
173
174 pub fn require_digest(self) -> Self {
178 self.require_header("Digest")
179 }
180
181 pub fn dont_use_created_field(mut self) -> Self {
186 self.use_created_field = false;
187 self.require_header("date")
188 }
189
190 pub const fn set_expiration(mut self, expires_after: Duration) -> Self {
192 self.expires_after = expires_after;
193 self
194 }
195
196 pub fn require_header(mut self, header: &str) -> Self {
198 self.required_headers.push(header.to_lowercase());
199 self
200 }
201
202 pub fn begin_sign(
205 &self,
206 method: &str,
207 path_and_query: &str,
208 headers: BTreeMap<String, String>,
209 ) -> Result<Unsigned, RequiredError> {
210 let mut headers = headers
211 .into_iter()
212 .map(|(k, v)| (k.to_lowercase(), v))
213 .collect();
214
215 let sig_headers = self.build_headers_list(&headers);
216
217 let (created, expires) = if self.use_created_field {
218 let created = SystemTime::now();
219 let expires = created + self.expires_after;
220
221 (Some(created), Some(expires))
222 } else {
223 (None, None)
224 };
225
226 let signing_string = build_signing_string(
227 method,
228 path_and_query,
229 created,
230 expires,
231 &sig_headers,
232 &mut headers,
233 self.required_headers.iter().cloned().collect(),
234 )?;
235
236 Ok(Unsigned {
237 signing_string,
238 sig_headers,
239 created,
240 expires,
241 })
242 }
243
244 pub fn begin_verify(
247 &self,
248 method: &str,
249 path_and_query: &str,
250 headers: BTreeMap<String, String>,
251 ) -> Result<Unverified, PrepareVerifyError> {
252 let mut headers: BTreeMap<String, String> = headers
253 .into_iter()
254 .map(|(k, v)| (k.to_lowercase(), v))
255 .collect();
256
257 let header = headers
258 .remove("authorization")
259 .or_else(|| headers.remove("signature"))
260 .ok_or(ValidateError::Missing)?;
261
262 let parsed_header: ParsedHeader = header.parse()?;
263 let unvalidated = parsed_header.into_unvalidated(
264 method,
265 path_and_query,
266 &mut headers,
267 self.required_headers.iter().cloned().collect(),
268 )?;
269
270 Ok(unvalidated.validate(self.expires_after)?)
271 }
272
273 fn build_headers_list(&self, btm: &BTreeMap<String, String>) -> Vec<String> {
274 let http_header_keys: Vec<String> = btm.keys().cloned().collect();
275
276 let mut sig_headers = if self.use_created_field {
277 vec![
278 REQUEST_TARGET.to_owned(),
279 CREATED.to_owned(),
280 EXPIRES.to_owned(),
281 ]
282 } else {
283 vec![REQUEST_TARGET.to_owned()]
284 };
285
286 sig_headers.extend(http_header_keys);
287
288 sig_headers
289 }
290}
291
292fn build_signing_string(
293 method: &str,
294 path_and_query: &str,
295 created: Option<SystemTime>,
296 expires: Option<SystemTime>,
297 sig_headers: &[String],
298 btm: &mut BTreeMap<String, String>,
299 mut required_headers: HashSet<String>,
300) -> Result<String, RequiredError> {
301 let request_target = format!("{} {}", method.to_string().to_lowercase(), path_and_query);
302
303 btm.insert(REQUEST_TARGET.to_owned(), request_target);
304 if let Some(created) = created {
305 btm.insert(CREATED.to_owned(), unix_timestamp(created).to_string());
306 }
307 if let Some(expires) = expires {
308 btm.insert(EXPIRES.to_owned(), unix_timestamp(expires).to_string());
309 }
310
311 let signing_string = sig_headers
312 .iter()
313 .filter_map(|h| {
314 let opt = btm.remove(h).map(|v| format!("{}: {}", h, v));
315 if opt.is_some() {
316 required_headers.remove(h);
317 }
318 opt
319 })
320 .collect::<Vec<_>>()
321 .join("\n");
322
323 if !required_headers.is_empty() {
324 return Err(RequiredError(required_headers));
325 }
326
327 Ok(signing_string)
328}
329
330impl Default for Config {
331 fn default() -> Self {
332 Self::new()
333 }
334}
335
336fn unix_timestamp(time: SystemTime) -> u64 {
337 time.duration_since(UNIX_EPOCH)
338 .expect("UNIX_EPOCH is never in the future")
339 .as_secs()
340}
341
342fn parse_unix_timestamp(s: &str) -> Result<SystemTime, ParseIntError> {
343 let u: u64 = s.parse()?;
344 Ok(UNIX_EPOCH + Duration::from_secs(u))
345}
346
347#[cfg(test)]
348mod tests {
349 use super::Config;
350 use std::collections::BTreeMap;
351
352 fn prepare_headers() -> BTreeMap<String, String> {
353 let mut headers = BTreeMap::new();
354 headers.insert(
355 "Content-Type".to_owned(),
356 "application/activity+json".to_owned(),
357 );
358 headers
359 }
360
361 #[test]
362 fn required_header() {
363 let headers = prepare_headers();
364 let config = Config::default().require_header("date");
365
366 let res = config.begin_sign("GET", "/foo?bar=baz", headers);
367
368 assert!(res.is_err())
369 }
370
371 #[test]
372 fn round_trip_authorization() {
373 let headers = prepare_headers();
374 let config = Config::default().require_header("content-type");
375
376 let authorization_header = config
377 .begin_sign("GET", "/foo?bar=baz", headers)
378 .unwrap()
379 .sign("hi".to_owned(), |s| {
380 Ok(s.to_owned()) as Result<_, std::io::Error>
381 })
382 .unwrap()
383 .authorization_header();
384
385 let mut headers = prepare_headers();
386 headers.insert("Authorization".to_owned(), authorization_header);
387
388 let verified = config
389 .begin_verify("GET", "/foo?bar=baz", headers)
390 .unwrap()
391 .verify(|sig, signing_string| sig == signing_string);
392
393 assert!(verified);
394 }
395
396 #[test]
397 fn round_trip_signature() {
398 let headers = prepare_headers();
399 let config = Config::default();
400
401 let signature_header = config
402 .begin_sign("GET", "/foo?bar=baz", headers)
403 .unwrap()
404 .sign("hi".to_owned(), |s| {
405 Ok(s.to_owned()) as Result<_, std::io::Error>
406 })
407 .unwrap()
408 .signature_header();
409
410 let mut headers = prepare_headers();
411 headers.insert("Signature".to_owned(), signature_header);
412
413 let verified = config
414 .begin_verify("GET", "/foo?bar=baz", headers)
415 .unwrap()
416 .verify(|sig, signing_string| sig == signing_string);
417
418 assert!(verified);
419 }
420}