1#![warn(missing_docs)]
15
16use async_trait::async_trait;
17use authkestra_core::{
18 AuthError, CredentialsProvider, Identity, OAuthProvider, OAuthToken, SessionConfig,
19 SessionStore, UserMapper,
20};
21use authkestra_token::TokenManager;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25pub mod client_credentials_flow;
27pub mod device_flow;
29
30pub use client_credentials_flow::ClientCredentialsFlow;
31pub use device_flow::{DeviceAuthorizationResponse, DeviceFlow};
32
33#[async_trait]
35pub trait ErasedOAuthFlow: Send + Sync {
36 fn provider_id(&self) -> String;
38 fn initiate_login(&self, scopes: &[&str], pkce_challenge: Option<&str>) -> (String, String);
40 async fn finalize_login(
42 &self,
43 code: &str,
44 received_state: &str,
45 expected_state: &str,
46 pkce_verifier: Option<&str>,
47 ) -> Result<(authkestra_core::Identity, authkestra_core::OAuthToken), authkestra_core::AuthError>;
48}
49
50pub struct OAuth2Flow<P: OAuthProvider, M: UserMapper = ()> {
52 provider: P,
53 mapper: Option<M>,
54}
55
56#[async_trait]
57impl<P: OAuthProvider, M: UserMapper> ErasedOAuthFlow for OAuth2Flow<P, M> {
58 fn provider_id(&self) -> String {
59 self.provider.provider_id().to_string()
60 }
61
62 fn initiate_login(&self, scopes: &[&str], pkce_challenge: Option<&str>) -> (String, String) {
63 self.initiate_login(scopes, pkce_challenge)
64 }
65
66 async fn finalize_login(
67 &self,
68 code: &str,
69 received_state: &str,
70 expected_state: &str,
71 pkce_verifier: Option<&str>,
72 ) -> Result<(authkestra_core::Identity, authkestra_core::OAuthToken), authkestra_core::AuthError>
73 {
74 let (identity, token, _) = self
75 .finalize_login(code, received_state, expected_state, pkce_verifier)
76 .await?;
77 Ok((identity, token))
78 }
79}
80
81impl<P: OAuthProvider> OAuth2Flow<P, ()> {
82 pub fn new(provider: P) -> Self {
84 Self {
85 provider,
86 mapper: None,
87 }
88 }
89}
90
91impl<P: OAuthProvider, M: UserMapper> OAuth2Flow<P, M> {
92 pub fn with_mapper(provider: P, mapper: M) -> Self {
94 Self {
95 provider,
96 mapper: Some(mapper),
97 }
98 }
99
100 pub fn initiate_login(
102 &self,
103 scopes: &[&str],
104 pkce_challenge: Option<&str>,
105 ) -> (String, String) {
106 let state = uuid::Uuid::new_v4().to_string();
107 let url = self
108 .provider
109 .get_authorization_url(&state, scopes, pkce_challenge);
110 (url, state)
111 }
112
113 pub async fn finalize_login(
116 &self,
117 code: &str,
118 received_state: &str,
119 expected_state: &str,
120 pkce_verifier: Option<&str>,
121 ) -> Result<(Identity, OAuthToken, Option<M::LocalUser>), AuthError> {
122 if received_state != expected_state {
123 return Err(AuthError::CsrfMismatch);
124 }
125 let (identity, token) = self
126 .provider
127 .exchange_code_for_identity(code, pkce_verifier)
128 .await?;
129
130 let local_user = if let Some(mapper) = &self.mapper {
131 Some(mapper.map_user(&identity).await?)
132 } else {
133 None
134 };
135
136 Ok((identity, token, local_user))
137 }
138
139 pub async fn refresh_access_token(&self, refresh_token: &str) -> Result<OAuthToken, AuthError> {
141 self.provider.refresh_token(refresh_token).await
142 }
143
144 pub async fn revoke_token(&self, token: &str) -> Result<(), AuthError> {
146 self.provider.revoke_token(token).await
147 }
148}
149
150#[derive(Clone)]
152pub struct Authkestra {
153 pub providers: HashMap<String, Arc<dyn ErasedOAuthFlow>>,
155 pub session_store: Arc<dyn SessionStore>,
157 pub session_config: SessionConfig,
159 pub token_manager: Arc<TokenManager>,
161}
162
163impl Authkestra {
164 pub fn builder() -> AuthkestraBuilder {
166 AuthkestraBuilder::default()
167 }
168}
169
170#[derive(Default)]
172pub struct AuthkestraBuilder {
173 providers: HashMap<String, Arc<dyn ErasedOAuthFlow>>,
174 session_store: Option<Arc<dyn SessionStore>>,
175 session_config: SessionConfig,
176 token_manager: Option<Arc<TokenManager>>,
177}
178
179impl AuthkestraBuilder {
180 pub fn provider<P, M>(mut self, flow: OAuth2Flow<P, M>) -> Self
182 where
183 P: OAuthProvider + 'static,
184 M: UserMapper + 'static,
185 {
186 let id = flow.provider_id();
187 self.providers.insert(id, Arc::new(flow));
188 self
189 }
190
191 pub fn session_store(mut self, store: Arc<dyn SessionStore>) -> Self {
193 self.session_store = Some(store);
194 self
195 }
196
197 pub fn session_config(mut self, config: SessionConfig) -> Self {
199 self.session_config = config;
200 self
201 }
202
203 pub fn token_manager(mut self, manager: Arc<TokenManager>) -> Self {
205 self.token_manager = Some(manager);
206 self
207 }
208
209 pub fn build(self) -> Authkestra {
211 Authkestra {
212 providers: self.providers,
213 session_store: self
214 .session_store
215 .unwrap_or_else(|| Arc::new(authkestra_core::MemoryStore::default())), session_config: self.session_config,
217 token_manager: self
218 .token_manager
219 .unwrap_or_else(|| Arc::new(TokenManager::new(b"secret", None))),
220 }
221 }
222}
223
224pub struct CredentialsFlow<P: CredentialsProvider, M: UserMapper = ()> {
226 provider: P,
227 mapper: Option<M>,
228}
229
230impl<P: CredentialsProvider> CredentialsFlow<P, ()> {
231 pub fn new(provider: P) -> Self {
233 Self {
234 provider,
235 mapper: None,
236 }
237 }
238}
239
240impl<P: CredentialsProvider, M: UserMapper> CredentialsFlow<P, M> {
241 pub fn with_mapper(provider: P, mapper: M) -> Self {
243 Self {
244 provider,
245 mapper: Some(mapper),
246 }
247 }
248
249 pub async fn authenticate(
251 &self,
252 creds: P::Credentials,
253 ) -> Result<(Identity, Option<M::LocalUser>), AuthError> {
254 let identity = self.provider.authenticate(creds).await?;
255
256 let local_user = if let Some(mapper) = &self.mapper {
257 Some(mapper.map_user(&identity).await?)
258 } else {
259 None
260 };
261
262 Ok((identity, local_user))
263 }
264}