Skip to main content

authkestra_core/
strategy.rs

1use crate::error::AuthError;
2use async_trait::async_trait;
3use http::request::Parts;
4use std::marker::PhantomData;
5
6/// Trait for an authentication strategy.
7///
8/// A strategy is responsible for extracting credentials from a request
9/// and validating them to produce an identity.
10#[async_trait]
11pub trait AuthenticationStrategy<I>: Send + Sync {
12    /// Attempt to authenticate the request.
13    ///
14    /// Returns:
15    /// - `Ok(Some(identity))` if authentication was successful.
16    /// - `Ok(None)` if the strategy did not find relevant credentials (e.g., missing header).
17    /// - `Err(AuthError)` if authentication failed (e.g., invalid token, DB error).
18    async fn authenticate(&self, parts: &Parts) -> Result<Option<I>, AuthError>;
19}
20
21/// Trait for a provider that validates username and password (Basic Auth).
22#[async_trait]
23pub trait BasicAuthenticator: Send + Sync {
24    /// The type of identity returned by this authenticator.
25    type Identity;
26    /// Validate the credentials.
27    async fn authenticate(
28        &self,
29        username: &str,
30        password: &str,
31    ) -> Result<Option<Self::Identity>, AuthError>;
32}
33
34/// Strategy for Basic authentication.
35pub struct BasicStrategy<P, I> {
36    authenticator: P,
37    _marker: PhantomData<I>,
38}
39
40impl<P, I> BasicStrategy<P, I> {
41    /// Create a new BasicStrategy with the given authenticator.
42    pub fn new(authenticator: P) -> Self {
43        Self {
44            authenticator,
45            _marker: PhantomData,
46        }
47    }
48}
49
50#[async_trait]
51impl<P, I> AuthenticationStrategy<I> for BasicStrategy<P, I>
52where
53    P: BasicAuthenticator<Identity = I> + Send + Sync,
54    I: Send + Sync + 'static,
55{
56    async fn authenticate(&self, parts: &Parts) -> Result<Option<I>, AuthError> {
57        if let Some((username, password)) = utils::extract_basic_credentials(&parts.headers) {
58            self.authenticator.authenticate(&username, &password).await
59        } else {
60            Ok(None)
61        }
62    }
63}
64
65/// Trait for a validator that verifies a token.
66#[async_trait]
67pub trait TokenValidator: Send + Sync {
68    /// The type of identity returned by this validator.
69    type Identity;
70    /// Validate the token.
71    async fn validate(&self, token: &str) -> Result<Option<Self::Identity>, AuthError>;
72}
73
74/// Strategy for Token (Bearer) authentication.
75pub struct TokenStrategy<V, I> {
76    validator: V,
77    _marker: PhantomData<I>,
78}
79
80impl<V, I> TokenStrategy<V, I> {
81    /// Create a new TokenStrategy with the given validator.
82    pub fn new(validator: V) -> Self {
83        Self {
84            validator,
85            _marker: PhantomData,
86        }
87    }
88}
89
90#[async_trait]
91impl<V, I> AuthenticationStrategy<I> for TokenStrategy<V, I>
92where
93    V: TokenValidator<Identity = I> + Send + Sync,
94    I: Send + Sync + 'static,
95{
96    async fn authenticate(&self, parts: &Parts) -> Result<Option<I>, AuthError> {
97        if let Some(token) = utils::extract_bearer_token(&parts.headers) {
98            self.validator.validate(token).await
99        } else {
100            Ok(None)
101        }
102    }
103}
104
105/// Strategy for custom header authentication.
106pub struct HeaderStrategy<F, I> {
107    header_name: http::header::HeaderName,
108    validator: F,
109    _marker: PhantomData<I>,
110}
111
112impl<F, I> HeaderStrategy<F, I> {
113    /// Create a new HeaderStrategy.
114    pub fn new(header_name: http::header::HeaderName, validator: F) -> Self {
115        Self {
116            header_name,
117            validator,
118            _marker: PhantomData,
119        }
120    }
121}
122
123#[async_trait]
124impl<F, I, Fut> AuthenticationStrategy<I> for HeaderStrategy<F, I>
125where
126    F: Fn(String) -> Fut + Send + Sync,
127    Fut: std::future::Future<Output = Result<Option<I>, AuthError>> + Send,
128    I: Send + Sync + 'static,
129{
130    async fn authenticate(&self, parts: &Parts) -> Result<Option<I>, AuthError> {
131        if let Some(value) = parts.headers.get(&self.header_name) {
132            if let Ok(value_str) = value.to_str() {
133                return (self.validator)(value_str.to_string()).await;
134            }
135        }
136        Ok(None)
137    }
138}
139
140/// Trait for a session store that can load an identity.
141#[async_trait]
142pub trait SessionProvider: Send + Sync {
143    /// The type of identity returned by this provider.
144    type Identity;
145    /// Load the identity associated with the session ID.
146    async fn load_session(&self, session_id: &str) -> Result<Option<Self::Identity>, AuthError>;
147}
148
149/// Strategy for Session authentication.
150pub struct SessionStrategy<P, I> {
151    provider: P,
152    cookie_name: String,
153    _marker: PhantomData<I>,
154}
155
156impl<P, I> SessionStrategy<P, I> {
157    /// Create a new SessionStrategy.
158    pub fn new(provider: P, cookie_name: impl Into<String>) -> Self {
159        Self {
160            provider,
161            cookie_name: cookie_name.into(),
162            _marker: PhantomData,
163        }
164    }
165}
166
167#[async_trait]
168impl<P, I> AuthenticationStrategy<I> for SessionStrategy<P, I>
169where
170    P: SessionProvider<Identity = I> + Send + Sync,
171    I: Send + Sync + 'static,
172{
173    async fn authenticate(&self, parts: &Parts) -> Result<Option<I>, AuthError> {
174        if let Some(session_id) = utils::extract_cookie(&parts.headers, &self.cookie_name) {
175            self.provider.load_session(session_id).await
176        } else {
177            Ok(None)
178        }
179    }
180}
181
182/// Utility functions for common authentication tasks.
183pub mod utils {
184    use http::header::{HeaderMap, AUTHORIZATION};
185
186    /// Extract the Bearer token from the Authorization header.
187    pub fn extract_bearer_token(headers: &HeaderMap) -> Option<&str> {
188        headers
189            .get(AUTHORIZATION)?
190            .to_str()
191            .ok()?
192            .strip_prefix("Bearer ")
193            .map(|s| s.trim())
194    }
195
196    /// Extract Basic credentials from the Authorization header.
197    pub fn extract_basic_credentials(headers: &HeaderMap) -> Option<(String, String)> {
198        let auth_header = headers.get(AUTHORIZATION)?.to_str().ok()?;
199        if !auth_header.starts_with("Basic ") {
200            return None;
201        }
202        let encoded = auth_header.strip_prefix("Basic ")?.trim();
203        let decoded =
204            base64::Engine::decode(&base64::engine::general_purpose::STANDARD, encoded).ok()?;
205        let decoded_str = String::from_utf8(decoded).ok()?;
206        let mut parts = decoded_str.splitn(2, ':');
207        let username = parts.next()?.to_string();
208        let password = parts.next()?.to_string();
209        Some((username, password))
210    }
211
212    /// Extract a cookie value by name.
213    pub fn extract_cookie<'a>(headers: &'a http::HeaderMap, name: &str) -> Option<&'a str> {
214        let cookie_header = headers.get(http::header::COOKIE)?.to_str().ok()?;
215        for cookie in cookie_header.split(';') {
216            let mut parts = cookie.splitn(2, '=');
217            let k = parts.next()?.trim();
218            let v = parts.next()?.trim();
219            if k == name {
220                return Some(v);
221            }
222        }
223        None
224    }
225}