ably/
auth.rs

1use std::convert::TryFrom;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use chrono::{DateTime, Duration, Utc};
7use hmac::{Hmac, Mac};
8use rand::distributions::Alphanumeric;
9use rand::{thread_rng, Rng};
10use serde::{Deserialize, Serialize};
11use sha2::Sha256;
12
13use crate::error::{Error, ErrorCode};
14use crate::rest::RestInner;
15use crate::{http, rest, Result};
16
17/// The maximum length of a valid token. Tokens with a length longer than this
18/// are rejected with a ErrorCode::ErrorFromClientTokenCallback error code.
19const MAX_TOKEN_LENGTH: usize = 128 * 1024;
20
21mod duration {
22    use std::fmt;
23
24    use super::*;
25    use serde::{de, Deserializer, Serializer};
26
27    #[derive(Debug)]
28    pub struct MilliSecondsTimestampVisitor;
29
30    impl<'de> de::Visitor<'de> for MilliSecondsTimestampVisitor {
31        type Value = Duration;
32
33        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
34            formatter.write_str("a duration in milliseconds")
35        }
36
37        /// Deserialize a timestamp in milliseconds since the epoch
38        fn visit_i64<E>(self, value: i64) -> std::result::Result<Self::Value, E>
39        where
40            E: de::Error,
41        {
42            Ok(Duration::milliseconds(value))
43        }
44    }
45
46    pub fn deserialize<'de, D>(d: D) -> std::result::Result<Duration, D::Error>
47    where
48        D: Deserializer<'de>,
49    {
50        d.deserialize_u64(MilliSecondsTimestampVisitor)
51    }
52
53    pub fn serialize<S>(d: &Duration, serializer: S) -> std::result::Result<S::Ok, S::Error>
54    where
55        S: Serializer,
56    {
57        let n = d.num_milliseconds();
58        serializer.serialize_i64(n)
59    }
60}
61
62#[derive(Clone)]
63pub enum Credential {
64    TokenDetails(TokenDetails),
65    TokenRequest(TokenRequest),
66    Callback(Arc<dyn AuthCallback>),
67    Key(Key),
68    Url(reqwest::Url),
69}
70
71impl std::fmt::Debug for Credential {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        match self {
74            Self::TokenDetails(arg0) => f.debug_tuple("TokenDetails").field(arg0).finish(),
75            Self::TokenRequest(arg0) => f.debug_tuple("TokenRequest").field(arg0).finish(),
76            Self::Key(arg0) => f.debug_tuple("Key").field(arg0).finish(),
77            Self::Callback(_) => f.debug_tuple("Callback").field(&"Fn").finish(),
78            Self::Url(arg0) => f.debug_tuple("Url").field(arg0).finish(),
79        }
80    }
81}
82
83#[derive(Debug, Clone, Default)]
84pub struct AuthOptions {
85    pub token: Option<Credential>,
86    pub headers: Option<http::HeaderMap>,
87    pub method: http::Method,
88    pub params: Option<http::UrlQuery>,
89}
90
91/// An API Key used to authenticate with the REST API using HTTP Basic Auth.
92#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
93pub struct Key {
94    #[serde(rename(deserialize = "keyName"))]
95    pub name: String,
96    pub value: String,
97}
98
99impl Key {
100    pub fn new(s: &str) -> Result<Self> {
101        if let [name, value] = s.splitn(2, ':').collect::<Vec<&str>>()[..] {
102            Ok(Key {
103                name: name.to_string(),
104                value: value.to_string(),
105            })
106        } else {
107            Err(Error::new(ErrorCode::BadRequest, "Invalid key"))
108        }
109    }
110}
111
112impl TryFrom<&str> for Key {
113    type Error = Error;
114
115    /// Parse an API Key from a string of the form '<keyName>:<keySecret>'.
116    ///
117    /// # Example
118    ///
119    /// ```
120    /// use std::convert::TryFrom;
121    /// use ably::auth;
122    ///
123    /// let res = auth::Key::try_from("ABC123.DEF456:XXXXXXXXXXXX");
124    /// assert!(res.is_ok());
125    ///
126    /// let res = auth::Key::try_from("not-a-valid-key");
127    /// assert!(res.is_err());
128    /// ```
129    fn try_from(s: &str) -> Result<Self> {
130        Self::new(s)
131    }
132}
133
134impl Key {
135    /// Use the API key to sign the given TokenParams, returning a signed
136    /// TokenRequest which can be exchanged for a token.
137    ///
138    /// # Example
139    ///
140    /// ```
141    /// # async fn run() -> ably::Result<()> {
142    /// use std::convert::TryFrom;
143    /// use ably::auth;
144    ///
145    /// let key = auth::Key::try_from("ABC123.DEF456:XXXXXXXXXXXX").unwrap();
146    ///
147    /// let mut params = auth::TokenParams::default();
148    /// params.client_id = Some("test@example.com".to_string());
149    ///
150    /// let req = key.sign(&params).unwrap();
151    /// # Ok(())
152    /// # }
153    /// ```
154    pub fn sign(&self, params: &TokenParams) -> Result<TokenRequest> {
155        params.sign(self)
156    }
157}
158
159/// Provides functions relating to Ably API authentication.
160#[derive(Clone, Debug)]
161pub struct Auth<'a> {
162    pub(crate) rest: &'a rest::Rest,
163}
164
165impl<'a> Auth<'a> {
166    pub fn new(rest: &'a rest::Rest) -> Self {
167        Self { rest }
168    }
169
170    fn inner(&self) -> &RestInner {
171        &self.rest.inner
172    }
173
174    /// Start building a TokenRequest to be signed by a local API key.
175    pub fn create_token_request(
176        &self,
177        params: &TokenParams,
178        options: &AuthOptions,
179    ) -> Result<TokenRequest> {
180        let key = match &options.token {
181            Some(Credential::Key(k)) => k,
182            _ => {
183                return Err(Error::new(
184                    ErrorCode::UnableToObtainCredentialsFromGivenParameters,
185                    "API key is required to create signed token requests",
186                ))
187            }
188        };
189        params.sign(key)
190    }
191
192    /// Exchange a TokenRequest for a token by making a HTTP request to the
193    /// [requestToken endpoint] in the Ably REST API.
194    ///
195    /// Returns a boxed future rather than using async since this is both
196    /// called from and calls out to RequestBuilder.send, and recursive
197    /// async functions are not supported.
198    ///
199    /// [requestToken endpoint]: https://docs.ably.io/rest-api/#request-token
200    pub(crate) fn exchange(
201        &self,
202        req: &TokenRequest,
203    ) -> Pin<Box<dyn Future<Output = Result<TokenDetails>> + Send + 'a>> {
204        let req = self
205            .rest
206            .request(
207                http::Method::POST,
208                &format!("/keys/{}/requestToken", req.key_name),
209            )
210            .authenticate(false)
211            .body(req);
212
213        Box::pin(async move { req.send().await?.body().await.map_err(Into::into) })
214    }
215
216    /// Request a token from the URL.
217    fn request_url<'b>(
218        &'b self,
219        url: &'b reqwest::Url,
220    ) -> Pin<Box<dyn Future<Output = Result<TokenDetails>> + Send + 'b>> {
221        let fut = async move {
222            let res = self
223                .rest
224                .request_url(Default::default(), url.clone())
225                .authenticate(false)
226                .send()
227                .await?;
228
229            // Parse the token response based on the Content-Type header.
230            let content_type = res.content_type().ok_or_else(|| {
231                Error::new(
232                    ErrorCode::ErrorFromClientTokenCallback,
233                    "authUrl response is missing a content-type header",
234                )
235            })?;
236            match content_type.essence_str() {
237            "application/json" => {
238                // Expect a JSON encoded TokenRequest or TokenDetails, and just
239                // let serde figure out which Token variant to decode the JSON
240                // response into.
241                let token: RequestOrDetails = res.json().await?;
242                match token {
243                    RequestOrDetails::Request(r) => self.exchange(&r).await,
244                    RequestOrDetails::Details(d) => Ok(d),
245                }
246            },
247
248            "text/plain" | "application/jwt" => {
249                // Expect a literal token string.
250                let token = res.text().await?;
251                Ok(TokenDetails::from(token))
252            },
253
254            // Anything else is an error.
255            _ => Err(Error::new(ErrorCode::ErrorFromClientTokenCallback, format!("authUrl responded with unacceptable content-type {}, should be either text/plain, application/jwt or application/json", content_type))),
256        }
257        };
258
259        Box::pin(fut)
260    }
261
262    pub async fn request_token(
263        &self,
264        params: &TokenParams,
265        options: &AuthOptions,
266    ) -> Result<TokenDetails> {
267        let token = options.token.as_ref().ok_or_else(|| {
268            Error::new(
269                ErrorCode::NoWayToRenewAuthToken,
270                "no means provided to renew auth token",
271            )
272        })?;
273
274        let mut details = match token {
275            Credential::TokenDetails(token) => Ok(token.clone()),
276            Credential::TokenRequest(r) => self.exchange(r).await,
277            Credential::Callback(f) => match f.token(params).await {
278                Ok(token) => token.into_details(self).await,
279                Err(e) => Err(e),
280            },
281            Credential::Key(k) => self.exchange(&params.sign(k)?).await,
282            Credential::Url(url) => self.request_url(url).await,
283        };
284
285        if matches!(token, Credential::Callback(_) | Credential::Url(_)) {
286            if let Err(ref mut err) = details {
287                // Normalise auth error according to RSA4e.
288                if err.code == ErrorCode::BadRequest {
289                    err.code = ErrorCode::ErrorFromClientTokenCallback;
290                    err.status_code = Some(401);
291                }
292            };
293        }
294
295        let details = details?;
296
297        // Reject tokens with size greater than 128KiB (RSA4f).
298        if details.token.len() > MAX_TOKEN_LENGTH {
299            return Err(Error::with_status(
300                ErrorCode::ErrorFromClientTokenCallback,
301                401,
302                format!(
303                    "Token string exceeded max permitted length (was {} bytes)",
304                    details.token.len()
305                ),
306            ));
307        }
308
309        Ok(details)
310    }
311
312    /// Set the Authorization header in the given request.
313    pub(crate) async fn with_auth_headers(&self, req: &mut reqwest::Request) -> Result<()> {
314        if let Credential::Key(k) = &self.inner().opts.credential {
315            return Self::set_basic_auth(req, k);
316        }
317
318        let options = AuthOptions {
319            token: Some(self.inner().opts.credential.clone()),
320            ..Default::default()
321        };
322
323        // TODO defaults
324        let res = self.request_token(&Default::default(), &options).await?;
325        Self::set_bearer_auth(req, &res.token)
326    }
327
328    fn set_bearer_auth(req: &mut reqwest::Request, token: &str) -> Result<()> {
329        Self::set_header(
330            req,
331            reqwest::header::AUTHORIZATION,
332            format!("Bearer {}", token),
333        )
334    }
335
336    fn set_basic_auth(req: &mut reqwest::Request, key: &Key) -> Result<()> {
337        let encoded = base64::encode(format!("{}:{}", key.name, key.value));
338        Self::set_header(
339            req,
340            reqwest::header::AUTHORIZATION,
341            format!("Basic {}", encoded),
342        )
343    }
344
345    fn set_header(req: &mut reqwest::Request, key: http::HeaderName, value: String) -> Result<()> {
346        req.headers_mut().append(key, value.parse()?);
347        Ok(())
348    }
349
350    /// Generate a random 16 character nonce to use in a TokenRequest.
351    fn generate_nonce() -> String {
352        thread_rng()
353            .sample_iter(&Alphanumeric)
354            .take(16)
355            .map(char::from)
356            .collect()
357    }
358
359    /// Use the given API key to compute the HMAC of the canonicalised
360    /// representation of the given TokenRequest.
361    ///
362    /// See the [REST API Token Request Spec] for further details.
363    ///
364    /// [REST API Token Request Spec]: https://docs.ably.io/rest-api/token-request-spec/
365    fn compute_mac(
366        key: &Key,
367        ttl: Duration,
368        capability: &str,
369        client_id: Option<&str>,
370        timestamp: DateTime<Utc>,
371        nonce: &str,
372    ) -> Result<String> {
373        let mut mac = Hmac::<Sha256>::new_from_slice(key.value.as_bytes())?;
374
375        mac.update(key.name.as_bytes());
376        mac.update(b"\n");
377
378        mac.update(ttl.num_milliseconds().to_string().as_bytes());
379        mac.update(b"\n");
380
381        mac.update(capability.as_bytes());
382        mac.update(b"\n");
383
384        mac.update(client_id.map(|c| c.as_bytes()).unwrap_or_default());
385        mac.update(b"\n");
386
387        mac.update(timestamp.timestamp_millis().to_string().as_bytes());
388        mac.update(b"\n");
389
390        mac.update(nonce.as_bytes());
391        mac.update(b"\n");
392
393        Ok(base64::encode(mac.finalize().into_bytes()))
394    }
395}
396
397/// An Ably [TokenParams] object.
398///
399/// [TokenParams]: https://docs.ably.io/realtime/types/#token-params
400#[derive(Clone, Debug)]
401pub struct TokenParams {
402    pub capability: String,
403    pub client_id: Option<String>,
404    pub nonce: Option<String>,
405    pub timestamp: Option<DateTime<Utc>>,
406    pub ttl: Duration,
407}
408
409impl Default for TokenParams {
410    fn default() -> Self {
411        Self {
412            capability: "{\"*\":[\"*\"]}".to_string(),
413            client_id: Default::default(),
414            nonce: Default::default(),
415            timestamp: Default::default(),
416            ttl: Duration::minutes(60),
417        }
418    }
419}
420
421impl TokenParams {
422    pub fn new() -> Self {
423        Default::default()
424    }
425
426    /// Set the desired capability.
427    pub fn capability(mut self, capability: &str) -> Self {
428        self.capability = capability.to_string();
429        self
430    }
431
432    /// Set the desired client_id.
433    pub fn client_id(mut self, client_id: &str) -> Self {
434        self.client_id = Some(client_id.to_string());
435        self
436    }
437
438    /// Set the desired TTL.
439    pub fn ttl(mut self, ttl: Duration) -> Self {
440        self.ttl = ttl;
441        self
442    }
443
444    /// Set the timestamp.
445    pub fn timestamp(mut self, timestamp: DateTime<Utc>) -> Self {
446        self.timestamp = Some(timestamp);
447        self
448    }
449
450    /// Generate a signed TokenRequest for these TokenParams using the steps
451    /// described in the [REST API Token Request Spec].
452    ///
453    /// [REST API Token Request Spec]: https://ably.com/documentation/rest-api/token-request-spec
454    fn sign(&self, key: &Key) -> Result<TokenRequest> {
455        // if client_id is set, it must be a non-empty string
456        if let Some(ref client_id) = self.client_id {
457            if client_id.is_empty() {
458                return Err(Error::new(
459                    ErrorCode::InvalidClientID,
460                    "client_id can’t be an empty string",
461                ));
462            }
463        }
464
465        let nonce = self.nonce.clone().unwrap_or_else(Auth::generate_nonce);
466        let timestamp = self.timestamp.unwrap_or_else(Utc::now);
467        let key_name = key.name.clone();
468
469        let req = TokenRequest {
470            mac: Auth::compute_mac(
471                key,
472                self.ttl,
473                &self.capability,
474                self.client_id.as_deref(),
475                timestamp,
476                &nonce,
477            )?,
478            key_name,
479            timestamp,
480            capability: self.capability.clone(),
481            client_id: self.client_id.clone(),
482            nonce,
483            ttl: self.ttl,
484        };
485
486        Ok(req)
487    }
488}
489
490/// An Ably [TokenRequest] object.
491///
492/// [TokenRequest]: https://docs.ably.io/realtime/types/#token-request
493#[derive(Clone, Debug, Deserialize, Serialize)]
494#[serde(rename_all = "camelCase")]
495pub struct TokenRequest {
496    pub key_name: String,
497    #[serde(with = "chrono::serde::ts_milliseconds")]
498    pub timestamp: DateTime<Utc>,
499    pub capability: String,
500    #[serde(skip_serializing_if = "Option::is_none")]
501    pub client_id: Option<String>,
502    pub mac: String,
503    pub nonce: String,
504    #[serde(with = "duration")]
505    pub ttl: Duration,
506}
507
508/// The token details returned in a successful response from the [REST
509/// requestToken endpoint].
510///
511/// [REST requestToken endpoint]: https://docs.ably.io/rest-api/#request-token
512#[derive(Clone, Debug, Deserialize)]
513#[serde(rename_all = "camelCase")]
514pub struct TokenDetails {
515    pub token: String,
516    #[serde(flatten)]
517    pub metadata: Option<TokenMetadata>,
518}
519
520impl TokenDetails {
521    pub fn token(s: String) -> Self {
522        Self {
523            token: s,
524            metadata: None,
525        }
526    }
527}
528
529impl From<String> for TokenDetails {
530    fn from(token: String) -> Self {
531        TokenDetails {
532            token,
533            metadata: None,
534        }
535    }
536}
537
538#[derive(Clone, Debug, Deserialize)]
539#[serde(rename_all = "camelCase")]
540pub struct TokenMetadata {
541    #[serde(with = "chrono::serde::ts_milliseconds")]
542    pub expires: DateTime<Utc>,
543    #[serde(with = "chrono::serde::ts_milliseconds")]
544    pub issued: DateTime<Utc>,
545    pub capability: String,
546    #[serde(skip_serializing_if = "Option::is_none")]
547    pub client_id: Option<String>,
548}
549
550#[derive(Clone, Debug, Deserialize)]
551#[serde(untagged)]
552pub enum RequestOrDetails {
553    Request(TokenRequest),
554    Details(TokenDetails),
555}
556
557impl RequestOrDetails {
558    async fn into_details(self, auth: &Auth<'_>) -> Result<TokenDetails> {
559        match self {
560            RequestOrDetails::Request(r) => auth.exchange(&r).await,
561            RequestOrDetails::Details(d) => Ok(d),
562        }
563    }
564}
565
566pub trait AuthCallback: Send + Sync {
567    fn token<'a>(
568        &'a self,
569        params: &'a TokenParams,
570    ) -> Pin<Box<dyn Send + Future<Output = Result<RequestOrDetails>> + 'a>>;
571}