1use anyhow::anyhow;
2use aws_lc_rs::{
3 constant_time,
4 digest::{digest, SHA256, SHA256_OUTPUT_LEN},
5};
6use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
7use educe::Educe;
8use http::{header::AUTHORIZATION, HeaderValue};
9use rand::{distributions::Standard, prelude::Distribution};
10use regex::Regex;
11use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
12use std::{
13 str::{self, FromStr},
14 sync::OnceLock,
15};
16
17pub const DAP_AUTH_HEADER: &str = "DAP-Auth-Token";
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
23#[serde(tag = "type", content = "token")]
24#[non_exhaustive]
25pub enum AuthenticationToken {
26 Bearer(BearerToken),
33
34 DapAuth(DapAuthToken),
43}
44
45impl AuthenticationToken {
46 pub fn new_bearer_token_from_bytes<T: AsRef<[u8]>>(bytes: T) -> Result<Self, anyhow::Error> {
48 BearerToken::try_from(bytes.as_ref().to_vec()).map(AuthenticationToken::Bearer)
49 }
50
51 pub fn new_bearer_token_from_string<T: Into<String>>(string: T) -> Result<Self, anyhow::Error> {
53 BearerToken::try_from(string.into()).map(AuthenticationToken::Bearer)
54 }
55
56 pub fn new_dap_auth_token_from_bytes<T: AsRef<[u8]>>(bytes: T) -> Result<Self, anyhow::Error> {
58 DapAuthToken::try_from(bytes.as_ref().to_vec()).map(AuthenticationToken::DapAuth)
59 }
60
61 pub fn new_dap_auth_token_from_string<T: Into<String>>(
63 string: T,
64 ) -> Result<Self, anyhow::Error> {
65 DapAuthToken::try_from(string.into()).map(AuthenticationToken::DapAuth)
66 }
67
68 pub fn request_authentication(&self) -> (&'static str, String) {
71 match self {
72 Self::Bearer(token) => (AUTHORIZATION.as_str(), format!("Bearer {}", token.as_str())),
73 Self::DapAuth(token) => (DAP_AUTH_HEADER, token.as_str().to_string()),
75 }
76 }
77
78 pub fn as_str(&self) -> &str {
80 match self {
81 Self::DapAuth(token) => token.as_str(),
82 Self::Bearer(token) => token.as_str(),
83 }
84 }
85}
86
87impl AsRef<[u8]> for AuthenticationToken {
88 fn as_ref(&self) -> &[u8] {
89 match self {
90 Self::DapAuth(token) => token.as_ref(),
91 Self::Bearer(token) => token.as_ref(),
92 }
93 }
94}
95
96impl FromStr for AuthenticationToken {
97 type Err = anyhow::Error;
98
99 fn from_str(s: &str) -> Result<Self, Self::Err> {
103 if let Some(s) = s.strip_prefix("bearer:") {
104 return Ok(Self::Bearer(BearerToken::from_str(s)?));
105 }
106 if let Some(s) = s.strip_prefix("dap:") {
107 return Ok(Self::DapAuth(DapAuthToken::from_str(s)?));
108 }
109 Err(anyhow!(
110 "bad or missing prefix on authentication token flag value"
111 ))
112 }
113}
114
115impl Distribution<AuthenticationToken> for Standard {
116 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> AuthenticationToken {
117 AuthenticationToken::Bearer(Standard::sample(self, rng))
118 }
119}
120
121#[derive(Clone, Educe, Serialize)]
132#[educe(Debug)]
133#[serde(transparent)]
134pub struct DapAuthToken(#[educe(Debug(ignore))] String);
135
136impl DapAuthToken {
137 pub fn as_str(&self) -> &str {
139 &self.0
140 }
141
142 fn validate(value: &str) -> Result<(), anyhow::Error> {
144 HeaderValue::try_from(value)?;
145 Ok(())
146 }
147}
148
149impl AsRef<str> for DapAuthToken {
150 fn as_ref(&self) -> &str {
151 &self.0
152 }
153}
154
155impl AsRef<[u8]> for DapAuthToken {
156 fn as_ref(&self) -> &[u8] {
157 self.0.as_bytes()
158 }
159}
160
161impl From<DapAuthToken> for AuthenticationToken {
162 fn from(value: DapAuthToken) -> Self {
163 Self::DapAuth(value)
164 }
165}
166
167impl TryFrom<String> for DapAuthToken {
168 type Error = anyhow::Error;
169
170 fn try_from(value: String) -> Result<Self, Self::Error> {
171 Self::validate(&value)?;
172 Ok(Self(value))
173 }
174}
175
176impl FromStr for DapAuthToken {
177 type Err = anyhow::Error;
178
179 fn from_str(s: &str) -> Result<Self, Self::Err> {
180 Self::try_from(s.to_string())
181 }
182}
183
184impl TryFrom<Vec<u8>> for DapAuthToken {
185 type Error = anyhow::Error;
186
187 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
188 Self::try_from(String::from_utf8(value)?)
189 }
190}
191
192impl<'de> Deserialize<'de> for DapAuthToken {
193 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
194 where
195 D: Deserializer<'de>,
196 {
197 String::deserialize(deserializer)
198 .and_then(|string| Self::try_from(string).map_err(D::Error::custom))
199 }
200}
201
202impl PartialEq for DapAuthToken {
203 fn eq(&self, other: &Self) -> bool {
204 constant_time::verify_slices_are_equal(self.0.as_ref(), other.0.as_ref()).is_ok()
209 }
210}
211
212impl Eq for DapAuthToken {}
213
214impl Distribution<DapAuthToken> for Standard {
215 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> DapAuthToken {
216 DapAuthToken(URL_SAFE_NO_PAD.encode(rng.gen::<[u8; 16]>()))
217 }
218}
219
220#[derive(Clone, Educe, Serialize)]
231#[educe(Debug)]
232#[serde(transparent)]
233pub struct BearerToken(#[educe(Debug(ignore))] String);
234
235impl BearerToken {
236 pub fn as_str(&self) -> &str {
238 &self.0
239 }
240
241 fn validate(value: &str) -> Result<(), anyhow::Error> {
245 static REGEX: OnceLock<Regex> = OnceLock::new();
246
247 let regex = REGEX.get_or_init(|| Regex::new("^[-A-Za-z0-9._~+/]+=*$").unwrap());
248
249 if regex.is_match(value) {
250 Ok(())
251 } else {
252 Err(anyhow::anyhow!("bearer token has invalid format"))
253 }
254 }
255}
256
257impl AsRef<str> for BearerToken {
258 fn as_ref(&self) -> &str {
259 &self.0
260 }
261}
262
263impl AsRef<[u8]> for BearerToken {
264 fn as_ref(&self) -> &[u8] {
265 self.0.as_bytes()
266 }
267}
268
269impl From<BearerToken> for AuthenticationToken {
270 fn from(value: BearerToken) -> Self {
271 Self::Bearer(value)
272 }
273}
274
275impl TryFrom<String> for BearerToken {
276 type Error = anyhow::Error;
277
278 fn try_from(value: String) -> Result<Self, Self::Error> {
279 Self::validate(&value)?;
280 Ok(Self(value))
281 }
282}
283
284impl FromStr for BearerToken {
285 type Err = anyhow::Error;
286
287 fn from_str(s: &str) -> Result<Self, Self::Err> {
288 Self::try_from(s.to_string())
289 }
290}
291
292impl TryFrom<Vec<u8>> for BearerToken {
293 type Error = anyhow::Error;
294
295 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
296 Self::try_from(String::from_utf8(value)?)
297 }
298}
299
300impl<'de> Deserialize<'de> for BearerToken {
301 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
302 where
303 D: Deserializer<'de>,
304 {
305 String::deserialize(deserializer)
306 .and_then(|string| Self::try_from(string).map_err(D::Error::custom))
307 }
308}
309
310impl PartialEq for BearerToken {
311 fn eq(&self, other: &Self) -> bool {
312 constant_time::verify_slices_are_equal(self.0.as_bytes(), other.0.as_bytes()).is_ok()
317 }
318}
319
320impl Eq for BearerToken {}
321
322impl Distribution<BearerToken> for Standard {
323 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> BearerToken {
324 BearerToken(URL_SAFE_NO_PAD.encode(rng.gen::<[u8; 16]>()))
325 }
326}
327
328#[derive(Clone, Educe, Deserialize, Serialize, Eq)]
331#[educe(Debug)]
332#[serde(tag = "type", content = "hash")]
333#[non_exhaustive]
334pub enum AuthenticationTokenHash {
335 Bearer(
342 #[educe(Debug(ignore))]
343 #[serde(
344 serialize_with = "AuthenticationTokenHash::serialize_contents",
345 deserialize_with = "AuthenticationTokenHash::deserialize_contents"
346 )]
347 [u8; SHA256_OUTPUT_LEN],
348 ),
349
350 DapAuth(
359 #[educe(Debug(ignore))]
360 #[serde(
361 serialize_with = "AuthenticationTokenHash::serialize_contents",
362 deserialize_with = "AuthenticationTokenHash::deserialize_contents"
363 )]
364 [u8; SHA256_OUTPUT_LEN],
365 ),
366}
367
368impl AuthenticationTokenHash {
369 pub fn validate(&self, incoming_token: &AuthenticationToken) -> bool {
371 &Self::from(incoming_token) == self
372 }
373
374 fn serialize_contents<S: Serializer>(
375 value: &[u8; SHA256_OUTPUT_LEN],
376 serializer: S,
377 ) -> Result<S::Ok, S::Error> {
378 serializer.serialize_str(&URL_SAFE_NO_PAD.encode(value))
379 }
380
381 fn deserialize_contents<'de, D>(deserializer: D) -> Result<[u8; SHA256_OUTPUT_LEN], D::Error>
382 where
383 D: Deserializer<'de>,
384 {
385 let b64_digest: String = Deserialize::deserialize(deserializer)?;
386 let decoded = URL_SAFE_NO_PAD
387 .decode(b64_digest)
388 .map_err(D::Error::custom)?;
389
390 decoded
391 .try_into()
392 .map_err(|_| D::Error::custom("digest has wrong length"))
393 }
394}
395
396impl From<&AuthenticationToken> for AuthenticationTokenHash {
397 fn from(value: &AuthenticationToken) -> Self {
398 let digest = digest(&SHA256, value.as_ref()).as_ref().try_into().unwrap();
401
402 match value {
403 AuthenticationToken::Bearer(_) => Self::Bearer(digest),
404 AuthenticationToken::DapAuth(_) => Self::DapAuth(digest),
405 }
406 }
407}
408
409impl PartialEq for AuthenticationTokenHash {
410 fn eq(&self, other: &Self) -> bool {
411 let (self_digest, other_digest) = match (self, other) {
412 (Self::Bearer(self_digest), Self::Bearer(other_digest)) => (self_digest, other_digest),
413 (Self::DapAuth(self_digest), Self::DapAuth(other_digest)) => {
414 (self_digest, other_digest)
415 }
416 _ => return false,
417 };
418
419 constant_time::verify_slices_are_equal(self_digest.as_ref(), other_digest.as_ref()).is_ok()
421 }
422}
423
424impl AsRef<[u8]> for AuthenticationTokenHash {
425 fn as_ref(&self) -> &[u8] {
426 match self {
427 Self::Bearer(inner) => inner.as_slice(),
428 Self::DapAuth(inner) => inner.as_slice(),
429 }
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use crate::auth_tokens::{AuthenticationToken, AuthenticationTokenHash};
436 use rand::random;
437 use std::str::FromStr as _;
438
439 #[test]
440 fn valid_dap_auth_token() {
441 serde_yaml::from_str::<AuthenticationToken>(
442 "{type: \"DapAuth\", token: \"correct-horse-battery-staple-!@#$\"}",
443 )
444 .unwrap();
445 }
446
447 #[test]
448 fn valid_bearer_token() {
449 serde_yaml::from_str::<AuthenticationToken>(
450 "{type: \"Bearer\", token: \"AAAAAAA~-_/A===\"}",
451 )
452 .unwrap();
453 }
454
455 #[test]
456 fn reject_invalid_auth_token_dap_auth() {
457 serde_yaml::from_str::<AuthenticationToken>("{type: \"DapAuth\", token: \"\\x0b\"}")
458 .unwrap_err();
459 serde_yaml::from_str::<AuthenticationToken>("{type: \"DapAuth\", token: \"\\x00\"}")
460 .unwrap_err();
461 }
462
463 #[test]
464 fn reject_invalid_auth_token_bearer() {
465 serde_yaml::from_str::<AuthenticationToken>("{type: \"Bearer\", token: \"é\"}")
466 .unwrap_err();
467 serde_yaml::from_str::<AuthenticationToken>("{type: \"Bearer\", token: \"^\"}")
468 .unwrap_err();
469 serde_yaml::from_str::<AuthenticationToken>("{type: \"Bearer\", token: \"=\"}")
470 .unwrap_err();
471 serde_yaml::from_str::<AuthenticationToken>("{type: \"Bearer\", token: \"AAAA==AAA\"}")
472 .unwrap_err();
473 }
474
475 #[test]
476 fn authentication_token_from_str() {
477 for (value, expected_result) in [
478 (
479 "bearer:foo",
480 Some(AuthenticationToken::new_bearer_token_from_string("foo").unwrap()),
481 ),
482 (
483 "dap:foo",
484 Some(AuthenticationToken::new_dap_auth_token_from_string("foo").unwrap()),
485 ),
486 ("badtype:foo", None),
487 ("notype", None),
488 ] {
489 let rslt = AuthenticationToken::from_str(value);
490 match expected_result {
491 Some(expected_result) => assert_eq!(rslt.unwrap(), expected_result),
492 None => assert!(rslt.is_err()),
493 }
494 }
495 }
496
497 #[rstest::rstest]
498 #[case::bearer(r#"{ type: "Bearer", hash: "MJOoBO_ysLEuG_lv2C37eEOf1Ngetsr-Ers0ZYj4vdQ" }"#)]
499 #[case::dap_auth(r#"{ type: "DapAuth", hash: "MJOoBO_ysLEuG_lv2C37eEOf1Ngetsr-Ers0ZYj4vdQ" }"#)]
500 #[test]
501 fn serde_aggregator_token_hash_valid(#[case] yaml: &str) {
502 serde_yaml::from_str::<AuthenticationTokenHash>(yaml).unwrap();
503 }
504
505 #[rstest::rstest]
506 #[case::bearer_token_invalid_encoding(r#"{ type: "Bearer", hash: "+" }"#)]
507 #[case::bearer_token_wrong_length(
508 r#"{ type: "Bearer", hash: "MJOoBO_ysLEuG_lv2C37eEOf1Ngetsr-Ers0ZYj4" }"#
509 )]
510 #[case::dap_auth_token_invalid_encoding(r#"{ type: "DapAuth", hash: "+" }"#)]
511 #[case::dap_auth_token_wrong_length(
512 r#"{ type: "DapAuth", hash: "MJOoBO_ysLEuG_lv2C37eEOf1Ngetsr-Ers0ZYj4" }"#
513 )]
514 #[test]
515 fn serde_aggregator_token_hash_invalid(#[case] yaml: &str) {
516 serde_yaml::from_str::<AuthenticationTokenHash>(yaml).unwrap_err();
517 }
518
519 #[test]
520 fn validate_token() {
521 let dap_auth_token_1 = AuthenticationToken::DapAuth(random());
522 let dap_auth_token_2 = AuthenticationToken::DapAuth(random());
523 let bearer_token_1 = AuthenticationToken::Bearer(random());
524 let bearer_token_2 = AuthenticationToken::Bearer(random());
525
526 assert_eq!(dap_auth_token_1, dap_auth_token_1);
527 assert_ne!(dap_auth_token_1, dap_auth_token_2);
528 assert_eq!(bearer_token_1, bearer_token_1);
529 assert_ne!(bearer_token_1, bearer_token_2);
530 assert_ne!(dap_auth_token_1, bearer_token_1);
531
532 assert!(AuthenticationTokenHash::from(&dap_auth_token_1).validate(&dap_auth_token_1));
533 assert!(!AuthenticationTokenHash::from(&dap_auth_token_1).validate(&dap_auth_token_2));
534 assert!(AuthenticationTokenHash::from(&bearer_token_1).validate(&bearer_token_1));
535 assert!(!AuthenticationTokenHash::from(&bearer_token_1).validate(&bearer_token_2));
536 assert!(!AuthenticationTokenHash::from(&dap_auth_token_1).validate(&bearer_token_1));
537 }
538}