Skip to main content

better_fetch/
auth.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use base64::Engine;
6use http::header::{HeaderValue, AUTHORIZATION};
7use http::HeaderMap;
8
9/// Authentication configuration for a client or request.
10#[derive(Clone)]
11pub enum Auth {
12    Bearer {
13        token: TokenSource,
14    },
15    Basic {
16        username: TokenSource,
17        password: TokenSource,
18    },
19    Custom {
20        prefix: String,
21        value: TokenSource,
22    },
23}
24
25/// Source for credential values (static, sync, or async).
26#[derive(Clone)]
27pub enum TokenSource {
28    Static(String),
29    Fn(Arc<dyn Fn() -> Option<String> + Send + Sync>),
30    AsyncFn(Arc<dyn AsyncTokenProvider>),
31}
32
33/// Async token resolver.
34pub trait AsyncTokenProvider: Send + Sync {
35    fn resolve(&self) -> Pin<Box<dyn Future<Output = Option<String>> + Send + '_>>;
36}
37
38impl<F, Fut> AsyncTokenProvider for F
39where
40    F: Send + Sync,
41    F: Fn() -> Fut,
42    Fut: Future<Output = Option<String>> + Send + 'static,
43{
44    fn resolve(&self) -> Pin<Box<dyn Future<Output = Option<String>> + Send + '_>> {
45        Box::pin((self)())
46    }
47}
48
49impl Auth {
50    pub fn bearer(token: impl Into<String>) -> Self {
51        Self::Bearer {
52            token: TokenSource::Static(token.into()),
53        }
54    }
55
56    pub fn bearer_fn(f: impl Fn() -> Option<String> + Send + Sync + 'static) -> Self {
57        Self::Bearer {
58            token: TokenSource::Fn(Arc::new(f)),
59        }
60    }
61
62    pub fn basic(username: impl Into<String>, password: impl Into<String>) -> Self {
63        Self::Basic {
64            username: TokenSource::Static(username.into()),
65            password: TokenSource::Static(password.into()),
66        }
67    }
68
69    pub async fn apply(&self, headers: &mut HeaderMap) -> crate::Result<()> {
70        match self {
71            Self::Bearer { token } => {
72                if let Some(value) = resolve_token(token).await? {
73                    set_authorization(headers, format!("Bearer {value}"))?;
74                }
75            }
76            Self::Basic { username, password } => {
77                let user = resolve_token(username).await?;
78                let pass = resolve_token(password).await?;
79                if let (Some(u), Some(p)) = (user, pass) {
80                    let encoded =
81                        base64::engine::general_purpose::STANDARD.encode(format!("{u}:{p}"));
82                    set_authorization(headers, format!("Basic {encoded}"))?;
83                }
84            }
85            Self::Custom { prefix, value } => {
86                if let Some(v) = resolve_token(value).await? {
87                    set_authorization(headers, format!("{prefix} {v}"))?;
88                }
89            }
90        }
91        Ok(())
92    }
93}
94
95async fn resolve_token(source: &TokenSource) -> crate::Result<Option<String>> {
96    match source {
97        TokenSource::Static(s) => Ok(Some(s.clone())),
98        TokenSource::Fn(f) => Ok(f()),
99        TokenSource::AsyncFn(f) => Ok(f.resolve().await),
100    }
101}
102
103fn set_authorization(headers: &mut HeaderMap, value: String) -> crate::Result<()> {
104    let header_value = HeaderValue::from_str(&value)
105        .map_err(|e| crate::error::Error::Other(format!("invalid authorization header: {e}")))?;
106    headers.insert(AUTHORIZATION, header_value);
107    Ok(())
108}