axum_supabase_auth/auth/
service.rs

1use crate::api::{ApiError, SignUpResponse};
2use crate::auth::api::Api;
3use crate::auth::ClientError;
4use crate::{Auth, EmailOrPhone, OAuthRequest, OAuthResponse, Session, SessionAuth, User};
5use axum::http::StatusCode;
6use base64::prelude::{Engine as _, BASE64_STANDARD};
7use oauth2::{PkceCodeChallenge, PkceCodeVerifier};
8use reqwest::Url;
9use std::sync::Arc;
10use std::time::Duration;
11use tracing::error;
12
13#[derive(Clone)]
14pub struct AuthService {
15    api: Arc<Api>,
16}
17
18impl AuthService {
19    pub fn new(url: Url, api_key: &str) -> Self {
20        Self::new_with_timeout(url, api_key, Duration::from_secs(2))
21    }
22
23    pub fn new_with_timeout(url: Url, api_key: &str, timeout: Duration) -> Self {
24        Self {
25            api: Arc::new(Api::new(url, timeout, api_key)),
26        }
27    }
28}
29
30impl Auth for AuthService {
31    async fn sign_up(
32        &self,
33        email_or_phone: EmailOrPhone,
34        password: impl AsRef<str>,
35    ) -> Result<SignUpResponse, ClientError> {
36        match self.api.sign_up(email_or_phone, password).await {
37            Ok(session) => Ok(session),
38            Err(ApiError::HttpError(_, StatusCode::UNPROCESSABLE_ENTITY)) => {
39                Err(ClientError::AlreadySignedUp)
40            }
41            Err(e) => {
42                error!("Error signing up: {:?}", e);
43                Err(ClientError::InternalError)
44            }
45        }
46    }
47
48    async fn sign_in(
49        &self,
50        email_or_phone: EmailOrPhone,
51        password: impl AsRef<str>,
52    ) -> Result<Session, ClientError> {
53        match self.api.sign_in(email_or_phone, password).await {
54            Ok(session) => Ok(session),
55            Err(ApiError::HttpError(_, StatusCode::BAD_REQUEST)) => {
56                Err(ClientError::WrongCredentials)
57            }
58            Err(e) => {
59                error!("Error signing in: {:?}", e);
60                Err(ClientError::InternalError)
61            }
62        }
63    }
64
65    async fn exchange_code_for_session(
66        &self,
67        code: &str,
68        csrf_token_b64: &str,
69    ) -> Result<Session, ClientError> {
70        let csrf_token = BASE64_STANDARD
71            .decode(csrf_token_b64)
72            .map_err(|_| ClientError::WrongToken)?;
73        let verifier = PkceCodeVerifier::new(
74            String::from_utf8(csrf_token).map_err(|_| ClientError::WrongToken)?,
75        );
76
77        match self.api.exchange_code_for_session(code, &verifier).await {
78            Ok(session) => Ok(session),
79            Err(e) => {
80                error!("Error exchanging code for session: {:?}", e);
81                Err(ClientError::InternalError)
82            }
83        }
84    }
85
86    fn create_oauth_url(&self, req: OAuthRequest) -> Result<OAuthResponse, ClientError> {
87        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
88
89        let url = self.api.create_pkce_oauth_url(req, pkce_challenge);
90        let csrf_token = BASE64_STANDARD.encode(pkce_verifier.secret());
91
92        Ok(OAuthResponse {
93            supabase_url: url.to_string(),
94            csrf_token,
95        })
96    }
97
98    fn with_token(&self, access_token: String) -> impl SessionAuth {
99        SessionAuthService::with_token(self.clone(), access_token)
100    }
101
102    fn with_refresh_token(&self, access_token: String, refresh_token: String) -> impl SessionAuth {
103        SessionAuthService::with_refresh_token(self.clone(), access_token, refresh_token)
104    }
105}
106
107#[derive(Clone)]
108pub struct SessionAuthService {
109    auth: AuthService,
110    access_token: String,
111    refresh_token: Option<String>,
112}
113
114impl AsRef<AuthService> for SessionAuthService {
115    fn as_ref(&self) -> &AuthService {
116        &self.auth
117    }
118}
119
120impl SessionAuthService {
121    fn with_token(auth: AuthService, access_token: String) -> Self {
122        Self {
123            auth,
124            access_token,
125            refresh_token: None,
126        }
127    }
128
129    fn with_refresh_token(auth: AuthService, access_token: String, refresh_token: String) -> Self {
130        Self {
131            auth,
132            access_token,
133            refresh_token: Some(refresh_token),
134        }
135    }
136}
137
138impl SessionAuth for SessionAuthService {
139    async fn logout(&self) -> Result<(), ClientError> {
140        match self.auth.api.logout(&self.access_token).await {
141            Ok(_) => Ok(()),
142            Err(e) => {
143                error!("Error logging out: {:?}", e);
144                Err(ClientError::InternalError)
145            }
146        }
147    }
148
149    async fn list_users(&self) -> Result<Vec<User>, ClientError> {
150        match self.auth.api.list_users(&self.access_token).await {
151            Ok(users) => Ok(users.users),
152            Err(ApiError::HttpError(_, StatusCode::FORBIDDEN)) => {
153                Err(ClientError::WrongCredentials)
154            }
155            Err(e) => {
156                error!("Error listing users: {:?}", e);
157                Err(ClientError::InternalError)
158            }
159        }
160    }
161
162    async fn refresh(&mut self) -> Result<Session, ClientError> {
163        let refresh_token = match self.refresh_token {
164            Some(ref refresh_token) => refresh_token,
165            None => return Err(ClientError::MissingRefreshToken),
166        };
167
168        let session = match self.auth.api.refresh_access_token(refresh_token).await {
169            Ok(session) => session,
170            Err(e) => {
171                error!("Error refreshing token: {:?}", e);
172                return Err(ClientError::InternalError);
173            }
174        };
175
176        self.access_token = session.access_token.clone();
177        self.refresh_token = Some(session.refresh_token.clone());
178        Ok(session)
179    }
180}