authkestra_core/
strategy.rs1use crate::error::AuthError;
2use async_trait::async_trait;
3use http::request::Parts;
4use std::marker::PhantomData;
5
6#[async_trait]
11pub trait AuthenticationStrategy<I>: Send + Sync {
12 async fn authenticate(&self, parts: &Parts) -> Result<Option<I>, AuthError>;
19}
20
21#[async_trait]
23pub trait BasicAuthenticator: Send + Sync {
24 type Identity;
26 async fn authenticate(
28 &self,
29 username: &str,
30 password: &str,
31 ) -> Result<Option<Self::Identity>, AuthError>;
32}
33
34pub struct BasicStrategy<P, I> {
36 authenticator: P,
37 _marker: PhantomData<I>,
38}
39
40impl<P, I> BasicStrategy<P, I> {
41 pub fn new(authenticator: P) -> Self {
43 Self {
44 authenticator,
45 _marker: PhantomData,
46 }
47 }
48}
49
50#[async_trait]
51impl<P, I> AuthenticationStrategy<I> for BasicStrategy<P, I>
52where
53 P: BasicAuthenticator<Identity = I> + Send + Sync,
54 I: Send + Sync + 'static,
55{
56 async fn authenticate(&self, parts: &Parts) -> Result<Option<I>, AuthError> {
57 if let Some((username, password)) = utils::extract_basic_credentials(&parts.headers) {
58 self.authenticator.authenticate(&username, &password).await
59 } else {
60 Ok(None)
61 }
62 }
63}
64
65#[async_trait]
67pub trait TokenValidator: Send + Sync {
68 type Identity;
70 async fn validate(&self, token: &str) -> Result<Option<Self::Identity>, AuthError>;
72}
73
74pub struct TokenStrategy<V, I> {
76 validator: V,
77 _marker: PhantomData<I>,
78}
79
80impl<V, I> TokenStrategy<V, I> {
81 pub fn new(validator: V) -> Self {
83 Self {
84 validator,
85 _marker: PhantomData,
86 }
87 }
88}
89
90#[async_trait]
91impl<V, I> AuthenticationStrategy<I> for TokenStrategy<V, I>
92where
93 V: TokenValidator<Identity = I> + Send + Sync,
94 I: Send + Sync + 'static,
95{
96 async fn authenticate(&self, parts: &Parts) -> Result<Option<I>, AuthError> {
97 if let Some(token) = utils::extract_bearer_token(&parts.headers) {
98 self.validator.validate(token).await
99 } else {
100 Ok(None)
101 }
102 }
103}
104
105pub struct HeaderStrategy<F, I> {
107 header_name: http::header::HeaderName,
108 validator: F,
109 _marker: PhantomData<I>,
110}
111
112impl<F, I> HeaderStrategy<F, I> {
113 pub fn new(header_name: http::header::HeaderName, validator: F) -> Self {
115 Self {
116 header_name,
117 validator,
118 _marker: PhantomData,
119 }
120 }
121}
122
123#[async_trait]
124impl<F, I, Fut> AuthenticationStrategy<I> for HeaderStrategy<F, I>
125where
126 F: Fn(String) -> Fut + Send + Sync,
127 Fut: std::future::Future<Output = Result<Option<I>, AuthError>> + Send,
128 I: Send + Sync + 'static,
129{
130 async fn authenticate(&self, parts: &Parts) -> Result<Option<I>, AuthError> {
131 if let Some(value) = parts.headers.get(&self.header_name) {
132 if let Ok(value_str) = value.to_str() {
133 return (self.validator)(value_str.to_string()).await;
134 }
135 }
136 Ok(None)
137 }
138}
139
140#[async_trait]
142pub trait SessionProvider: Send + Sync {
143 type Identity;
145 async fn load_session(&self, session_id: &str) -> Result<Option<Self::Identity>, AuthError>;
147}
148
149pub struct SessionStrategy<P, I> {
151 provider: P,
152 cookie_name: String,
153 _marker: PhantomData<I>,
154}
155
156impl<P, I> SessionStrategy<P, I> {
157 pub fn new(provider: P, cookie_name: impl Into<String>) -> Self {
159 Self {
160 provider,
161 cookie_name: cookie_name.into(),
162 _marker: PhantomData,
163 }
164 }
165}
166
167#[async_trait]
168impl<P, I> AuthenticationStrategy<I> for SessionStrategy<P, I>
169where
170 P: SessionProvider<Identity = I> + Send + Sync,
171 I: Send + Sync + 'static,
172{
173 async fn authenticate(&self, parts: &Parts) -> Result<Option<I>, AuthError> {
174 if let Some(session_id) = utils::extract_cookie(&parts.headers, &self.cookie_name) {
175 self.provider.load_session(session_id).await
176 } else {
177 Ok(None)
178 }
179 }
180}
181
182pub mod utils {
184 use http::header::{HeaderMap, AUTHORIZATION};
185
186 pub fn extract_bearer_token(headers: &HeaderMap) -> Option<&str> {
188 headers
189 .get(AUTHORIZATION)?
190 .to_str()
191 .ok()?
192 .strip_prefix("Bearer ")
193 .map(|s| s.trim())
194 }
195
196 pub fn extract_basic_credentials(headers: &HeaderMap) -> Option<(String, String)> {
198 let auth_header = headers.get(AUTHORIZATION)?.to_str().ok()?;
199 if !auth_header.starts_with("Basic ") {
200 return None;
201 }
202 let encoded = auth_header.strip_prefix("Basic ")?.trim();
203 let decoded =
204 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, encoded).ok()?;
205 let decoded_str = String::from_utf8(decoded).ok()?;
206 let mut parts = decoded_str.splitn(2, ':');
207 let username = parts.next()?.to_string();
208 let password = parts.next()?.to_string();
209 Some((username, password))
210 }
211
212 pub fn extract_cookie<'a>(headers: &'a http::HeaderMap, name: &str) -> Option<&'a str> {
214 let cookie_header = headers.get(http::header::COOKIE)?.to_str().ok()?;
215 for cookie in cookie_header.split(';') {
216 let mut parts = cookie.splitn(2, '=');
217 let k = parts.next()?.trim();
218 let v = parts.next()?.trim();
219 if k == name {
220 return Some(v);
221 }
222 }
223 None
224 }
225}