cognite/api/
authenticator.rs

1use crate::dto::utils::MaybeStringU64;
2use async_trait::async_trait;
3use futures_locks::RwLock;
4use reqwest::{
5    header::{HeaderMap, HeaderValue},
6    StatusCode,
7};
8use reqwest_middleware::ClientWithMiddleware;
9use serde::{Deserialize, Serialize};
10use std::time::{Duration, Instant};
11use std::{fmt::Display, sync::Arc};
12use thiserror::Error;
13
14/// Type of closure for a synchronous auth callback.
15type CustomAuthCallback =
16    dyn Fn(&mut HeaderMap, &ClientWithMiddleware) -> Result<(), AuthenticatorError> + Send + Sync;
17
18#[async_trait]
19/// Trait for a custom authenticator. This should set the necessary headers in `headers` before each
20/// request. Note that this may be called from multiple places in parallel.
21pub trait CustomAuthenticator {
22    /// Set the required headers for authentication. This may use the provided
23    /// `client` to perform a request, if necessary. This will be called frequently, so
24    /// make sure it only makes external requests when needed.
25    ///
26    /// # Arguments
27    ///
28    /// * `headers` - Header map to modify.
29    /// * `client` - Client used to perform any external authentication requests.
30    async fn set_headers(
31        &self,
32        headers: &mut HeaderMap,
33        client: &ClientWithMiddleware,
34    ) -> Result<(), AuthenticatorError>;
35}
36
37/// Enumeration of the possible authentication methods available.
38#[derive(Clone)]
39pub enum AuthHeaderManager {
40    /// Authenticator that makes OIDC requests to obtain tokens.
41    OIDCToken(Arc<Authenticator>),
42    /// A fixed OIDC token
43    FixedToken(String),
44    /// An internal auth ticket.
45    AuthTicket(String),
46    /// A synchronous authentication method.
47    Custom(Arc<CustomAuthCallback>),
48    /// An async authentication method.
49    CustomAsync(Arc<dyn CustomAuthenticator + Send + Sync>),
50}
51
52impl AuthHeaderManager {
53    /// Set necessary headers in `headers`. This will sometimes request tokens from
54    /// the identity provider.
55    ///
56    /// # Arguments
57    ///
58    /// * `headers` - Request header collection.
59    /// * `client` - Reqwest client used to send authentication requests, if necessary.
60    pub async fn set_headers(
61        &self,
62        headers: &mut HeaderMap,
63        client: &ClientWithMiddleware,
64    ) -> Result<(), AuthenticatorError> {
65        match self {
66            AuthHeaderManager::OIDCToken(a) => {
67                let token = a.get_token(client).await?;
68                let auth_header_value =
69                    HeaderValue::from_str(&format!("Bearer {token}")).map_err(|e| {
70                        AuthenticatorError::internal_error(
71                            "Failed to set authorization bearer token".to_string(),
72                            Some(e.to_string()),
73                        )
74                    })?;
75                headers.insert("Authorization", auth_header_value);
76            }
77            AuthHeaderManager::FixedToken(token) => {
78                let auth_header_value =
79                    HeaderValue::from_str(&format!("Bearer {token}")).map_err(|e| {
80                        AuthenticatorError::internal_error(
81                            "Failed to set authorization bearer token".to_string(),
82                            Some(e.to_string()),
83                        )
84                    })?;
85                headers.insert("Authorization", auth_header_value);
86            }
87            AuthHeaderManager::AuthTicket(t) => {
88                let auth_ticket_header_value = HeaderValue::from_str(t).map_err(|e| {
89                    AuthenticatorError::internal_error(
90                        "Failed to set auth ticket".to_string(),
91                        Some(e.to_string()),
92                    )
93                })?;
94                headers.insert("auth-ticket", auth_ticket_header_value);
95            }
96            AuthHeaderManager::Custom(c) => c(headers, client)?,
97            AuthHeaderManager::CustomAsync(c) => c.set_headers(headers, client).await?,
98        }
99        Ok(())
100    }
101}
102
103/// Configuration for authentication using the OIDC authenticator
104pub struct AuthenticatorConfig {
105    /// Service principal client ID.
106    pub client_id: String,
107    /// IdP token URL.
108    pub token_url: String,
109    /// Service principal client secret.
110    pub secret: String,
111    /// Optional resource.
112    pub resource: Option<String>,
113    /// Optional audience.
114    pub audience: Option<String>,
115    /// Optional space separate list of scopes.
116    pub scopes: Option<String>,
117    /// Optional default token expiry time, in seconds.
118    /// If this is set, the authenticator will fall back on this if
119    /// the identity provider returns a token response without `expires_in`.
120    /// If this is not set, and `expires_in` is missing, the authenticator will return an error.
121    pub default_expires_in: Option<u64>,
122}
123
124#[derive(Serialize, Deserialize, Debug)]
125struct AuthenticatorRequest {
126    client_id: String,
127    client_secret: String,
128    resource: Option<String>,
129    audience: Option<String>,
130    scope: Option<String>,
131    grant_type: String,
132}
133
134impl AuthenticatorRequest {
135    fn new(config: AuthenticatorConfig) -> AuthenticatorRequest {
136        AuthenticatorRequest {
137            client_id: config.client_id,
138            client_secret: config.secret,
139            grant_type: "client_credentials".to_string(),
140            resource: config.resource,
141            audience: config.audience,
142            scope: config.scopes,
143        }
144    }
145}
146
147#[derive(Serialize, Deserialize, Debug)]
148struct AuthenticatorResponse {
149    access_token: String,
150    expires_in: Option<MaybeStringU64>,
151}
152
153#[derive(Serialize, Deserialize, Debug, Error)]
154/// Error from an authenticator request.
155pub struct AuthenticatorError {
156    /// Error message
157    pub error: String,
158    /// Detailed error description.
159    pub error_description: Option<String>,
160    /// Error URI.
161    pub error_uri: Option<String>,
162}
163
164impl AuthenticatorError {
165    /// Create an authenticator error from message and description.
166    ///
167    /// # Arguments
168    ///
169    /// * `error` - Short error message
170    /// * `error_description` - Detailed error description.
171    pub fn internal_error(error: String, error_description: Option<String>) -> Self {
172        Self {
173            error,
174            error_description,
175            error_uri: None,
176        }
177    }
178}
179
180impl Display for AuthenticatorError {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        write!(f, "{}", self.error,)?;
183        if let Some(error_description) = &self.error_description {
184            write!(f, ": {error_description}")?;
185        }
186        if let Some(error_uri) = &self.error_uri {
187            write!(f, " ({error_uri})")?;
188        }
189        Ok(())
190    }
191}
192
193struct AuthenticatorState {
194    last_token: Option<String>,
195    current_token_expiry: Instant,
196}
197
198/// Result from getting a token, including expiry time.
199pub struct AuthenticatorResult {
200    /// The token string.
201    token: String,
202    /// The time when the token will expire.
203    expiry: Instant,
204}
205
206/// Simple OIDC authenticator.
207pub struct Authenticator {
208    req: AuthenticatorRequest,
209    state: RwLock<AuthenticatorState>,
210    token_url: String,
211    default_expires_in: Option<Duration>,
212}
213
214impl AuthenticatorResult {
215    /// Get the token string.
216    pub fn token(&self) -> &str {
217        &self.token
218    }
219
220    /// Consume self and get the token string.
221    pub fn into_token(self) -> String {
222        self.token
223    }
224
225    /// Get the expiry time.
226    pub fn expiry(&self) -> Instant {
227        self.expiry
228    }
229}
230
231impl Authenticator {
232    /// Create a new authenticator with given config.
233    ///
234    /// # Arguments
235    ///
236    /// * `config` - Authenticator configuration.
237    pub fn new(config: AuthenticatorConfig) -> Authenticator {
238        Authenticator {
239            token_url: config.token_url.clone(),
240            default_expires_in: config.default_expires_in.map(Duration::from_secs),
241            req: AuthenticatorRequest::new(config),
242            state: RwLock::new(AuthenticatorState {
243                last_token: None,
244                current_token_expiry: Instant::now(),
245            }),
246        }
247    }
248
249    async fn request_token(
250        &self,
251        client: &ClientWithMiddleware,
252    ) -> Result<AuthenticatorResult, AuthenticatorError> {
253        let response = client
254            .post(&self.token_url)
255            .form(&self.req)
256            .send()
257            .await
258            .map_err(|e| {
259                AuthenticatorError::internal_error(
260                    "Something went wrong when sending the request".to_string(),
261                    Some(e.to_string()),
262                )
263            })?;
264
265        let status = response.status();
266
267        let start = Instant::now();
268
269        let response = response.text().await.map_err(|e| {
270            AuthenticatorError::internal_error(
271                "Failed to receive response contents".to_owned(),
272                Some(e.to_string()),
273            )
274        })?;
275
276        if status != StatusCode::OK {
277            return match serde_json::from_str(&response) {
278                Ok(e) => Err(e),
279                Err(e) => Err(AuthenticatorError::internal_error(
280                    format!("Something went wrong (status: {status}), but the response error couldn't be deserialized. Raw response: {response}")
281                    , Some(e.to_string())))
282            };
283        }
284
285        let response: AuthenticatorResponse = serde_json::from_str(&response).map_err(|e| {
286            AuthenticatorError::internal_error(
287                "Failed to deserialize response from OAuth endpoint".to_string(),
288                Some(e.to_string()),
289            )
290        })?;
291
292        let token = response.access_token;
293        let Some(expires_in) = response
294            .expires_in
295            // Subtract 60 as a buffer. We do retry on 401s, but it's best to renew the
296            // token before it expires. If for whatever reason expires_in is less than 60,
297            // we will just always renew before sending a request. We won't (hopefully)
298            // get an infinite loop.
299            .map(|m| Duration::from_secs(m.0.saturating_sub(60)))
300            .or(self.default_expires_in)
301        else {
302            return Err(AuthenticatorError::internal_error(
303                "Missing expires_in in response, and no default expiration configured".to_owned(),
304                None,
305            ));
306        };
307
308        Ok(AuthenticatorResult {
309            token,
310            expiry: start + expires_in,
311        })
312    }
313
314    /// Get a token. This will only fetch a new token if it is about
315    /// to expire (will expire in the next 60 seconds). This also
316    /// returns when the next token will be requested. This is the time
317    /// when the authenticator will refresh the token, so the actual
318    /// expiry time minus 60 seconds.
319    ///
320    /// # Arguments
321    ///
322    /// * `client` - Reqwest client to use for requests to the IdP.
323    pub async fn get_token_with_expiry(
324        &self,
325        client: &ClientWithMiddleware,
326    ) -> Result<AuthenticatorResult, AuthenticatorError> {
327        let now = Instant::now();
328        {
329            let state = &*self.state.read().await;
330            if let Some(last) = &state.last_token {
331                if state.current_token_expiry > now {
332                    return Ok(AuthenticatorResult {
333                        token: last.clone(),
334                        expiry: state.current_token_expiry,
335                    });
336                }
337            }
338        }
339
340        // If the token is expired, release the read lock and try to acquire a write lock.
341        let mut write = self.state.write().await;
342
343        // Need to check here too, in case we were blocked in this write lock by another thread
344        // fetching the token.
345        if let Some(last) = &write.last_token {
346            if write.current_token_expiry > now {
347                return Ok(AuthenticatorResult {
348                    token: last.clone(),
349                    expiry: write.current_token_expiry,
350                });
351            }
352        }
353
354        match self.request_token(client).await {
355            Ok(response) => {
356                write.current_token_expiry = response.expiry;
357                write.last_token = Some(response.token.clone());
358                Ok(response)
359            }
360            Err(e) => Err(e),
361        }
362    }
363
364    /// Get a token. This will only fetch a new token if it is about
365    /// to expire (will expire in the next 60 seconds).
366    ///
367    /// # Arguments
368    ///
369    /// * `client` - Reqwest client to use for requests to the IdP.
370    pub async fn get_token(
371        &self,
372        client: &ClientWithMiddleware,
373    ) -> Result<String, AuthenticatorError> {
374        Ok(self.get_token_with_expiry(client).await?.token)
375    }
376}