cognite/
auth.rs

1use anyhow::Context;
2use async_trait::async_trait;
3use http::Extensions;
4use reqwest::{Request, Response};
5use reqwest_middleware::{ClientWithMiddleware, Middleware, Next, Result};
6
7use crate::AuthHeaderManager;
8
9/// Middleware for token authentication.
10///
11/// Note that in order to use this, you need to add `ClientWithMiddleware` as an extension
12/// to your requests.
13pub struct AuthenticatorMiddleware {
14    authenticator: AuthHeaderManager,
15}
16
17#[derive(Clone)]
18struct AuthenticatorFlag;
19
20#[derive(Clone)]
21/// This indicates whether or not the API call should skip authentication.
22pub struct SkipAuthentication;
23
24impl AuthenticatorMiddleware {
25    /// Create a new authenticator middleware from an authenticator.
26    ///
27    /// # Arguments
28    ///
29    /// * `authenticator` - Header manager.
30    pub fn new(authenticator: AuthHeaderManager) -> crate::Result<Self> {
31        Ok(Self { authenticator })
32    }
33}
34
35#[async_trait]
36impl Middleware for AuthenticatorMiddleware {
37    async fn handle(
38        &self,
39        mut req: Request,
40        extensions: &mut Extensions,
41        next: Next<'_>,
42    ) -> Result<Response> {
43        // Since we are reusing the client, we want to avoid infinitely calling the authenticator recursively,
44        // so we add a flag indicating that we have already called the authenticator in this chain.
45        if extensions.get::<AuthenticatorFlag>().is_none()
46            && extensions.get::<SkipAuthentication>().is_none()
47        {
48            // Add the flag before we call the authenticator, this prevents the authenticator from
49            // attempting to add headers to its own request, which would deadlock.
50            extensions.insert(AuthenticatorFlag);
51            // This is all a little hacky, we add the client itself as an extension to the request
52            // so that we can use it from in here. The deadlocky-ness of this is exactly why it isn't
53            // possible by default.
54            // If it isn't in there we assume that it isn't supposed to be there, and skip the whole layer.
55            if let Some(client) = extensions.get::<ClientWithMiddleware>() {
56                self.authenticator
57                    .set_headers(req.headers_mut(), client)
58                    .await
59                    .map_err(|e| reqwest_middleware::Error::Middleware(e.into()))
60                    .context("Failed to authenticate request")?;
61            }
62            // Once we're done, remove the flag
63            extensions.remove::<AuthenticatorFlag>();
64        }
65        next.run(req, extensions).await
66    }
67}