axum_supabase_auth/auth/api/
mod.rs

1mod types;
2
3use super::types::*;
4use crate::api::types::HealthCheckResponse;
5use bon::bon;
6use either::Either;
7use oauth2::{PkceCodeChallenge, PkceCodeVerifier};
8use reqwest::header::{HeaderMap, HeaderValue};
9use reqwest::{Client, Method, StatusCode, Url};
10use serde::de::DeserializeOwned;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::sync::Arc;
14use std::time::Duration;
15use thiserror::Error;
16use tracing::{error, instrument, warn};
17
18#[derive(Clone)]
19pub struct Api {
20    url: Url,
21    client: Client,
22    headers: Arc<HeaderMap>,
23}
24
25#[bon]
26impl Api {
27    #[builder(finish_fn=send)]
28    async fn send_request<T, B>(
29        &self,
30        #[builder(start_fn)] method: Method,
31        #[builder(start_fn)] endpoint: &str,
32        query: Option<&[(&str, &str)]>,
33        body: Option<&B>,
34        access_token: Option<&str>,
35    ) -> Result<T, ApiError>
36    where
37        T: DeserializeOwned,
38        B: Serialize + ?Sized,
39    {
40        let url = self.url.join(endpoint)?;
41
42        let mut request = self
43            .client
44            .request(method.clone(), url)
45            .headers((*self.headers).clone());
46
47        if let Some(q) = query {
48            if !q.is_empty() {
49                request = request.query(q);
50            }
51        }
52
53        if let Some(b) = body {
54            request = request.json(b);
55        }
56
57        if let Some(token) = access_token {
58            request = request.bearer_auth(token);
59        }
60
61        let response = request.send().await?;
62
63        match response.error_for_status_ref() {
64            Ok(_) => {}
65            Err(err) => {
66                return if let Some(status) = err.status() {
67                    let body: ApiErrorResponse = response.json().await.unwrap_or_default();
68
69                    warn!(%err,
70                        method = %method,
71                        endpoint = %endpoint,
72                        body = ?body,
73                        status = status.as_u16(),
74                        "Request failed");
75
76                    Err(ApiError::HttpError(err, status))
77                } else {
78                    Err(ApiError::Unknown(err))
79                }
80            }
81        };
82
83        let result = response.json::<T>().await?;
84        Ok(result)
85    }
86}
87
88impl Api {
89    pub fn new(url: Url, timeout: Duration, api_key: &str) -> Self {
90        let client = Client::builder()
91            .timeout(timeout)
92            .user_agent("portal")
93            .build()
94            .unwrap();
95
96        let mut headers = HeaderMap::new();
97        headers.insert("apiKey", HeaderValue::from_str(api_key).unwrap());
98        let headers = Arc::new(headers);
99
100        Self {
101            url,
102            client,
103            headers,
104        }
105    }
106
107    /// Signs up a new user.
108    ///
109    /// # Arguments
110    ///
111    /// * `email_or_phone` - The user's email or phone number.
112    /// * `password` - The user's password.
113    ///
114    /// # Returns
115    ///
116    /// A `SignUpResponse` which may contain either a `User` or a `Session`, depending on the server configuration.
117    #[instrument(skip(self, password), fields(user_id))]
118    pub async fn sign_up(
119        &self,
120        email_or_phone: EmailOrPhone,
121        password: impl AsRef<str> + Sized,
122    ) -> Result<SignUpResponse, ApiError> {
123        self.send_request(Method::POST, "signup")
124            .body(&self.sign_in_up_body(&email_or_phone, &password))
125            .send()
126            .await
127    }
128
129    #[instrument(skip(self, password))]
130    pub async fn sign_in(
131        &self,
132        email_or_phone: EmailOrPhone,
133        password: impl AsRef<str>,
134    ) -> Result<Session, ApiError> {
135        self.send_request(Method::POST, "token")
136            .query(&[("grant_type", "password")])
137            .body(&self.sign_in_up_body(&email_or_phone, &password))
138            .send()
139            .await
140    }
141
142    fn sign_in_up_body<'a>(
143        &'a self,
144        email_or_phone: &'a EmailOrPhone,
145        password: &'a impl AsRef<str>,
146    ) -> SignInUpBody<'a> {
147        match email_or_phone {
148            EmailOrPhone::Email(email) => SignInUpBody {
149                email: Some(email),
150                phone: None,
151                password: password.as_ref(),
152            },
153            EmailOrPhone::Phone(phone) => SignInUpBody {
154                email: None,
155                phone: Some(phone.as_str()),
156                password: password.as_ref(),
157            },
158        }
159    }
160
161    #[instrument(skip(self, access_token))]
162    pub async fn logout(&self, access_token: impl AsRef<str>) -> Result<(), ApiError> {
163        let endpoint = self.url.join("logout")?;
164
165        self.client
166            .post(endpoint)
167            .headers((*self.headers).clone())
168            .bearer_auth(access_token.as_ref())
169            .send()
170            .await?
171            .error_for_status()?;
172
173        Ok(())
174    }
175
176    #[instrument(skip(self, access_token))]
177    pub async fn get_user(&self, access_token: impl AsRef<str>) -> Result<User, ApiError> {
178        self.send_request::<_, ()>(Method::GET, "user")
179            .access_token(access_token.as_ref())
180            .send()
181            .await
182    }
183
184    #[instrument(skip(self))]
185    pub async fn health_check(&self) -> Result<HealthCheckResponse, ApiError> {
186        self.send_request::<_, ()>(Method::GET, "health")
187            .send()
188            .await
189    }
190
191    #[instrument(skip(self, refresh_token))]
192    pub async fn refresh_access_token(
193        &self,
194        refresh_token: impl AsRef<str>,
195    ) -> Result<Session, ApiError> {
196        self.send_request(Method::POST, "token")
197            .query(&[("grant_type", "refresh_token")])
198            .body(&json!({
199                "refresh_token": refresh_token.as_ref(),
200            }))
201            .send()
202            .await
203    }
204
205    pub async fn list_users(&self, access_token: impl AsRef<str>) -> Result<UserList, ApiError> {
206        self.list_users_query(access_token, &[]).await
207    }
208
209    #[instrument(skip(self, access_token, query))]
210    pub async fn list_users_query(
211        &self,
212        access_token: impl AsRef<str>,
213        query: &[(&str, &str)],
214    ) -> Result<UserList, ApiError> {
215        self.send_request::<_, ()>(Method::GET, "admin/users")
216            .query(query)
217            .access_token(access_token.as_ref())
218            .send()
219            .await
220    }
221
222    pub fn create_pkce_oauth_url(&self, req: OAuthRequest, challenge: PkceCodeChallenge) -> Url {
223        let query = format!(
224            "provider={}&redirect_to={}&code_challenge={}&code_challenge_method={}",
225            req.provider,
226            req.redirect_to,
227            challenge.as_str(),
228            challenge.method().as_str()
229        );
230
231        let mut endpoint = self.url.join("authorize").unwrap();
232        endpoint.set_query(Some(&query));
233
234        endpoint
235    }
236
237    pub async fn exchange_code_for_session(
238        &self,
239        code: &str,
240        verifier: &PkceCodeVerifier,
241    ) -> Result<Session, ApiError> {
242        self.send_request(Method::POST, "token")
243            .query(&[("grant_type", "pkce")])
244            .body(&json!({
245                "auth_code": code,
246                "code_verifier": verifier.secret(),
247            }))
248            .send()
249            .await
250    }
251}
252
253#[derive(Serialize)]
254struct SignInUpBody<'a> {
255    #[serde(skip_serializing_if = "Option::is_none")]
256    email: Option<&'a str>,
257    #[serde(skip_serializing_if = "Option::is_none")]
258    phone: Option<&'a str>,
259    password: &'a str,
260}
261
262#[derive(Deserialize, Debug)]
263#[serde(transparent)]
264pub struct SignUpResponse {
265    #[serde(with = "either::serde_untagged")]
266    inner: Either<User, Session>,
267}
268
269impl SignUpResponse {
270    pub fn session(self) -> Option<Session> {
271        self.into()
272    }
273
274    pub fn user(self) -> Option<User> {
275        self.into()
276    }
277}
278
279impl AsRef<User> for SignUpResponse {
280    fn as_ref(&self) -> &User {
281        match self.inner {
282            Either::Left(ref user) => user,
283            Either::Right(ref session) => &session.user,
284        }
285    }
286}
287
288impl From<SignUpResponse> for Option<User> {
289    fn from(val: SignUpResponse) -> Self {
290        val.inner.left()
291    }
292}
293
294impl From<SignUpResponse> for Option<Session> {
295    fn from(val: SignUpResponse) -> Self {
296        val.inner.right()
297    }
298}
299
300/// https://github.com/supabase/auth/blob/master/internal/api/errors.go
301#[derive(Debug, Error)]
302pub enum ApiError {
303    #[error("Http error with status {1}: {0}")]
304    HttpError(#[source] reqwest::Error, StatusCode),
305    #[error(transparent)]
306    Unknown(#[from] reqwest::Error),
307    #[error("URL parsing error: {0}")]
308    UrlError(#[from] url::ParseError),
309}
310
311#[derive(Debug, Deserialize, Default)]
312pub struct ApiErrorResponse {
313    pub code: u16,
314    pub error_code: ApiErrorCode,
315    pub msg: String,
316}
317
318#[derive(Debug, Deserialize)]
319#[serde(rename_all = "snake_case")]
320pub enum ApiErrorCode {
321    Unknown,
322    SignupDisabled,
323    UserAlreadyExists,
324}
325
326impl Default for ApiErrorCode {
327    fn default() -> Self {
328        Self::Unknown
329    }
330}