1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
mod hyper_content_digest;
mod hyper_http;

// hyper's http specific extension to generate and verify http signature

/// content-digest header name
const CONTENT_DIGEST_HEADER: &str = "content-digest";

/// content-digest header type
pub enum ContentDigestType {
  Sha256,
  Sha512,
}

impl std::fmt::Display for ContentDigestType {
  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    match self {
      ContentDigestType::Sha256 => write!(f, "sha-256"),
      ContentDigestType::Sha512 => write!(f, "sha-512"),
    }
  }
}

pub use httpsig::prelude;
pub use hyper_content_digest::{ContentDigest, RequestContentDigest};
pub use hyper_http::RequestMessageSignature;

#[cfg(test)]
mod tests {
  use super::{
    prelude::{message_component::*, *},
    *,
  };
  use http::Request;
  use http_body_util::Full;
  use httpsig::prelude::{PublicKey, SecretKey};

  const EDDSA_SECRET_KEY: &str = r##"-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIDSHAE++q1BP7T8tk+mJtS+hLf81B0o6CFyWgucDFN/C
-----END PRIVATE KEY-----
"##;
  const EDDSA_PUBLIC_KEY: &str = r##"-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEA1ixMQcxO46PLlgQfYS46ivFd+n0CcDHSKUnuhm3i1O0=
-----END PUBLIC KEY-----
"##;
  // const EDDSA_KEY_ID: &str = "gjrE7ACMxgzYfFHgabgf4kLTg1eKIdsJ94AiFTFj1is";

  const COVERED_COMPONENTS: &[&str] = &["@method", "date", "content-type", "content-digest"];

  async fn build_request() -> anyhow::Result<Request<Full<bytes::Bytes>>> {
    let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
    let req = Request::builder()
      .method("GET")
      .uri("https://example.com/parameters?var=this%20is%20a%20big%0Amultiline%20value&bar=with+plus+whitespace&fa%C3%A7ade%22%3A%20=something")
      .header("date", "Sun, 09 May 2021 18:30:00 GMT")
      .header("content-type", "application/json")
      .header("content-type", "application/json-patch+json")
      .body(body)
      .unwrap();
    req.set_content_digest(&ContentDigestType::Sha256).await
  }

  #[test]
  fn test_content_digest_type() {
    assert_eq!(ContentDigestType::Sha256.to_string(), "sha-256");
    assert_eq!(ContentDigestType::Sha512.to_string(), "sha-512");
  }

  #[tokio::test]
  async fn test_set_verify() {
    // show usage of set_message_signature and verify_message_signature

    let mut req = build_request().await.unwrap();

    let secret_key = SecretKey::from_pem(EDDSA_SECRET_KEY).unwrap();

    let covered_components = COVERED_COMPONENTS
      .iter()
      .map(|v| HttpMessageComponentId::try_from(*v))
      .collect::<Result<Vec<_>, _>>()
      .unwrap();
    let mut signature_params = HttpSignatureParams::try_new(&covered_components).unwrap();

    // set key information, alg and keyid
    signature_params.set_key_info(&secret_key);

    // set custom signature name
    req
      .set_message_signature(&signature_params, &secret_key, Some("custom_sig_name"))
      .await
      .unwrap();
    let signature_input = req.headers().get("signature-input").unwrap().to_str().unwrap();
    let signature = req.headers().get("signature").unwrap().to_str().unwrap();
    assert!(signature_input.starts_with(r##"custom_sig_name=("##));
    assert!(signature.starts_with(r##"custom_sig_name=:"##));

    // verify without checking key_id
    let public_key = PublicKey::from_pem(EDDSA_PUBLIC_KEY).unwrap();
    let verification_res = req.verify_message_signature(&public_key, None).await.unwrap();
    assert!(verification_res);

    // verify with checking key_id
    let key_id = public_key.key_id();
    let verification_res = req.verify_message_signature(&public_key, Some(&key_id)).await.unwrap();
    assert!(verification_res);

    let verification_res = req.verify_message_signature(&public_key, Some("NotFoundKeyId")).await;
    assert!(verification_res.is_err());
  }
}