1use std::convert::TryFrom;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use chrono::{DateTime, Duration, Utc};
7use hmac::{Hmac, Mac};
8use rand::distributions::Alphanumeric;
9use rand::{thread_rng, Rng};
10use serde::{Deserialize, Serialize};
11use sha2::Sha256;
12
13use crate::error::{Error, ErrorCode};
14use crate::rest::RestInner;
15use crate::{http, rest, Result};
16
17const MAX_TOKEN_LENGTH: usize = 128 * 1024;
20
21mod duration {
22 use std::fmt;
23
24 use super::*;
25 use serde::{de, Deserializer, Serializer};
26
27 #[derive(Debug)]
28 pub struct MilliSecondsTimestampVisitor;
29
30 impl<'de> de::Visitor<'de> for MilliSecondsTimestampVisitor {
31 type Value = Duration;
32
33 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
34 formatter.write_str("a duration in milliseconds")
35 }
36
37 fn visit_i64<E>(self, value: i64) -> std::result::Result<Self::Value, E>
39 where
40 E: de::Error,
41 {
42 Ok(Duration::milliseconds(value))
43 }
44 }
45
46 pub fn deserialize<'de, D>(d: D) -> std::result::Result<Duration, D::Error>
47 where
48 D: Deserializer<'de>,
49 {
50 d.deserialize_u64(MilliSecondsTimestampVisitor)
51 }
52
53 pub fn serialize<S>(d: &Duration, serializer: S) -> std::result::Result<S::Ok, S::Error>
54 where
55 S: Serializer,
56 {
57 let n = d.num_milliseconds();
58 serializer.serialize_i64(n)
59 }
60}
61
62#[derive(Clone)]
63pub enum Credential {
64 TokenDetails(TokenDetails),
65 TokenRequest(TokenRequest),
66 Callback(Arc<dyn AuthCallback>),
67 Key(Key),
68 Url(reqwest::Url),
69}
70
71impl std::fmt::Debug for Credential {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 Self::TokenDetails(arg0) => f.debug_tuple("TokenDetails").field(arg0).finish(),
75 Self::TokenRequest(arg0) => f.debug_tuple("TokenRequest").field(arg0).finish(),
76 Self::Key(arg0) => f.debug_tuple("Key").field(arg0).finish(),
77 Self::Callback(_) => f.debug_tuple("Callback").field(&"Fn").finish(),
78 Self::Url(arg0) => f.debug_tuple("Url").field(arg0).finish(),
79 }
80 }
81}
82
83#[derive(Debug, Clone, Default)]
84pub struct AuthOptions {
85 pub token: Option<Credential>,
86 pub headers: Option<http::HeaderMap>,
87 pub method: http::Method,
88 pub params: Option<http::UrlQuery>,
89}
90
91#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
93pub struct Key {
94 #[serde(rename(deserialize = "keyName"))]
95 pub name: String,
96 pub value: String,
97}
98
99impl Key {
100 pub fn new(s: &str) -> Result<Self> {
101 if let [name, value] = s.splitn(2, ':').collect::<Vec<&str>>()[..] {
102 Ok(Key {
103 name: name.to_string(),
104 value: value.to_string(),
105 })
106 } else {
107 Err(Error::new(ErrorCode::BadRequest, "Invalid key"))
108 }
109 }
110}
111
112impl TryFrom<&str> for Key {
113 type Error = Error;
114
115 fn try_from(s: &str) -> Result<Self> {
130 Self::new(s)
131 }
132}
133
134impl Key {
135 pub fn sign(&self, params: &TokenParams) -> Result<TokenRequest> {
155 params.sign(self)
156 }
157}
158
159#[derive(Clone, Debug)]
161pub struct Auth<'a> {
162 pub(crate) rest: &'a rest::Rest,
163}
164
165impl<'a> Auth<'a> {
166 pub fn new(rest: &'a rest::Rest) -> Self {
167 Self { rest }
168 }
169
170 fn inner(&self) -> &RestInner {
171 &self.rest.inner
172 }
173
174 pub fn create_token_request(
176 &self,
177 params: &TokenParams,
178 options: &AuthOptions,
179 ) -> Result<TokenRequest> {
180 let key = match &options.token {
181 Some(Credential::Key(k)) => k,
182 _ => {
183 return Err(Error::new(
184 ErrorCode::UnableToObtainCredentialsFromGivenParameters,
185 "API key is required to create signed token requests",
186 ))
187 }
188 };
189 params.sign(key)
190 }
191
192 pub(crate) fn exchange(
201 &self,
202 req: &TokenRequest,
203 ) -> Pin<Box<dyn Future<Output = Result<TokenDetails>> + Send + 'a>> {
204 let req = self
205 .rest
206 .request(
207 http::Method::POST,
208 &format!("/keys/{}/requestToken", req.key_name),
209 )
210 .authenticate(false)
211 .body(req);
212
213 Box::pin(async move { req.send().await?.body().await.map_err(Into::into) })
214 }
215
216 fn request_url<'b>(
218 &'b self,
219 url: &'b reqwest::Url,
220 ) -> Pin<Box<dyn Future<Output = Result<TokenDetails>> + Send + 'b>> {
221 let fut = async move {
222 let res = self
223 .rest
224 .request_url(Default::default(), url.clone())
225 .authenticate(false)
226 .send()
227 .await?;
228
229 let content_type = res.content_type().ok_or_else(|| {
231 Error::new(
232 ErrorCode::ErrorFromClientTokenCallback,
233 "authUrl response is missing a content-type header",
234 )
235 })?;
236 match content_type.essence_str() {
237 "application/json" => {
238 let token: RequestOrDetails = res.json().await?;
242 match token {
243 RequestOrDetails::Request(r) => self.exchange(&r).await,
244 RequestOrDetails::Details(d) => Ok(d),
245 }
246 },
247
248 "text/plain" | "application/jwt" => {
249 let token = res.text().await?;
251 Ok(TokenDetails::from(token))
252 },
253
254 _ => Err(Error::new(ErrorCode::ErrorFromClientTokenCallback, format!("authUrl responded with unacceptable content-type {}, should be either text/plain, application/jwt or application/json", content_type))),
256 }
257 };
258
259 Box::pin(fut)
260 }
261
262 pub async fn request_token(
263 &self,
264 params: &TokenParams,
265 options: &AuthOptions,
266 ) -> Result<TokenDetails> {
267 let token = options.token.as_ref().ok_or_else(|| {
268 Error::new(
269 ErrorCode::NoWayToRenewAuthToken,
270 "no means provided to renew auth token",
271 )
272 })?;
273
274 let mut details = match token {
275 Credential::TokenDetails(token) => Ok(token.clone()),
276 Credential::TokenRequest(r) => self.exchange(r).await,
277 Credential::Callback(f) => match f.token(params).await {
278 Ok(token) => token.into_details(self).await,
279 Err(e) => Err(e),
280 },
281 Credential::Key(k) => self.exchange(¶ms.sign(k)?).await,
282 Credential::Url(url) => self.request_url(url).await,
283 };
284
285 if matches!(token, Credential::Callback(_) | Credential::Url(_)) {
286 if let Err(ref mut err) = details {
287 if err.code == ErrorCode::BadRequest {
289 err.code = ErrorCode::ErrorFromClientTokenCallback;
290 err.status_code = Some(401);
291 }
292 };
293 }
294
295 let details = details?;
296
297 if details.token.len() > MAX_TOKEN_LENGTH {
299 return Err(Error::with_status(
300 ErrorCode::ErrorFromClientTokenCallback,
301 401,
302 format!(
303 "Token string exceeded max permitted length (was {} bytes)",
304 details.token.len()
305 ),
306 ));
307 }
308
309 Ok(details)
310 }
311
312 pub(crate) async fn with_auth_headers(&self, req: &mut reqwest::Request) -> Result<()> {
314 if let Credential::Key(k) = &self.inner().opts.credential {
315 return Self::set_basic_auth(req, k);
316 }
317
318 let options = AuthOptions {
319 token: Some(self.inner().opts.credential.clone()),
320 ..Default::default()
321 };
322
323 let res = self.request_token(&Default::default(), &options).await?;
325 Self::set_bearer_auth(req, &res.token)
326 }
327
328 fn set_bearer_auth(req: &mut reqwest::Request, token: &str) -> Result<()> {
329 Self::set_header(
330 req,
331 reqwest::header::AUTHORIZATION,
332 format!("Bearer {}", token),
333 )
334 }
335
336 fn set_basic_auth(req: &mut reqwest::Request, key: &Key) -> Result<()> {
337 let encoded = base64::encode(format!("{}:{}", key.name, key.value));
338 Self::set_header(
339 req,
340 reqwest::header::AUTHORIZATION,
341 format!("Basic {}", encoded),
342 )
343 }
344
345 fn set_header(req: &mut reqwest::Request, key: http::HeaderName, value: String) -> Result<()> {
346 req.headers_mut().append(key, value.parse()?);
347 Ok(())
348 }
349
350 fn generate_nonce() -> String {
352 thread_rng()
353 .sample_iter(&Alphanumeric)
354 .take(16)
355 .map(char::from)
356 .collect()
357 }
358
359 fn compute_mac(
366 key: &Key,
367 ttl: Duration,
368 capability: &str,
369 client_id: Option<&str>,
370 timestamp: DateTime<Utc>,
371 nonce: &str,
372 ) -> Result<String> {
373 let mut mac = Hmac::<Sha256>::new_from_slice(key.value.as_bytes())?;
374
375 mac.update(key.name.as_bytes());
376 mac.update(b"\n");
377
378 mac.update(ttl.num_milliseconds().to_string().as_bytes());
379 mac.update(b"\n");
380
381 mac.update(capability.as_bytes());
382 mac.update(b"\n");
383
384 mac.update(client_id.map(|c| c.as_bytes()).unwrap_or_default());
385 mac.update(b"\n");
386
387 mac.update(timestamp.timestamp_millis().to_string().as_bytes());
388 mac.update(b"\n");
389
390 mac.update(nonce.as_bytes());
391 mac.update(b"\n");
392
393 Ok(base64::encode(mac.finalize().into_bytes()))
394 }
395}
396
397#[derive(Clone, Debug)]
401pub struct TokenParams {
402 pub capability: String,
403 pub client_id: Option<String>,
404 pub nonce: Option<String>,
405 pub timestamp: Option<DateTime<Utc>>,
406 pub ttl: Duration,
407}
408
409impl Default for TokenParams {
410 fn default() -> Self {
411 Self {
412 capability: "{\"*\":[\"*\"]}".to_string(),
413 client_id: Default::default(),
414 nonce: Default::default(),
415 timestamp: Default::default(),
416 ttl: Duration::minutes(60),
417 }
418 }
419}
420
421impl TokenParams {
422 pub fn new() -> Self {
423 Default::default()
424 }
425
426 pub fn capability(mut self, capability: &str) -> Self {
428 self.capability = capability.to_string();
429 self
430 }
431
432 pub fn client_id(mut self, client_id: &str) -> Self {
434 self.client_id = Some(client_id.to_string());
435 self
436 }
437
438 pub fn ttl(mut self, ttl: Duration) -> Self {
440 self.ttl = ttl;
441 self
442 }
443
444 pub fn timestamp(mut self, timestamp: DateTime<Utc>) -> Self {
446 self.timestamp = Some(timestamp);
447 self
448 }
449
450 fn sign(&self, key: &Key) -> Result<TokenRequest> {
455 if let Some(ref client_id) = self.client_id {
457 if client_id.is_empty() {
458 return Err(Error::new(
459 ErrorCode::InvalidClientID,
460 "client_id can’t be an empty string",
461 ));
462 }
463 }
464
465 let nonce = self.nonce.clone().unwrap_or_else(Auth::generate_nonce);
466 let timestamp = self.timestamp.unwrap_or_else(Utc::now);
467 let key_name = key.name.clone();
468
469 let req = TokenRequest {
470 mac: Auth::compute_mac(
471 key,
472 self.ttl,
473 &self.capability,
474 self.client_id.as_deref(),
475 timestamp,
476 &nonce,
477 )?,
478 key_name,
479 timestamp,
480 capability: self.capability.clone(),
481 client_id: self.client_id.clone(),
482 nonce,
483 ttl: self.ttl,
484 };
485
486 Ok(req)
487 }
488}
489
490#[derive(Clone, Debug, Deserialize, Serialize)]
494#[serde(rename_all = "camelCase")]
495pub struct TokenRequest {
496 pub key_name: String,
497 #[serde(with = "chrono::serde::ts_milliseconds")]
498 pub timestamp: DateTime<Utc>,
499 pub capability: String,
500 #[serde(skip_serializing_if = "Option::is_none")]
501 pub client_id: Option<String>,
502 pub mac: String,
503 pub nonce: String,
504 #[serde(with = "duration")]
505 pub ttl: Duration,
506}
507
508#[derive(Clone, Debug, Deserialize)]
513#[serde(rename_all = "camelCase")]
514pub struct TokenDetails {
515 pub token: String,
516 #[serde(flatten)]
517 pub metadata: Option<TokenMetadata>,
518}
519
520impl TokenDetails {
521 pub fn token(s: String) -> Self {
522 Self {
523 token: s,
524 metadata: None,
525 }
526 }
527}
528
529impl From<String> for TokenDetails {
530 fn from(token: String) -> Self {
531 TokenDetails {
532 token,
533 metadata: None,
534 }
535 }
536}
537
538#[derive(Clone, Debug, Deserialize)]
539#[serde(rename_all = "camelCase")]
540pub struct TokenMetadata {
541 #[serde(with = "chrono::serde::ts_milliseconds")]
542 pub expires: DateTime<Utc>,
543 #[serde(with = "chrono::serde::ts_milliseconds")]
544 pub issued: DateTime<Utc>,
545 pub capability: String,
546 #[serde(skip_serializing_if = "Option::is_none")]
547 pub client_id: Option<String>,
548}
549
550#[derive(Clone, Debug, Deserialize)]
551#[serde(untagged)]
552pub enum RequestOrDetails {
553 Request(TokenRequest),
554 Details(TokenDetails),
555}
556
557impl RequestOrDetails {
558 async fn into_details(self, auth: &Auth<'_>) -> Result<TokenDetails> {
559 match self {
560 RequestOrDetails::Request(r) => auth.exchange(&r).await,
561 RequestOrDetails::Details(d) => Ok(d),
562 }
563 }
564}
565
566pub trait AuthCallback: Send + Sync {
567 fn token<'a>(
568 &'a self,
569 params: &'a TokenParams,
570 ) -> Pin<Box<dyn Send + Future<Output = Result<RequestOrDetails>> + 'a>>;
571}