axum_supabase_auth/auth/api/
mod.rs1mod types;
2
3use super::types::*;
4use crate::api::types::HealthCheckResponse;
5use bon::bon;
6use either::Either;
7use oauth2::{PkceCodeChallenge, PkceCodeVerifier};
8use reqwest::header::{HeaderMap, HeaderValue};
9use reqwest::{Client, Method, StatusCode, Url};
10use serde::de::DeserializeOwned;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13use std::sync::Arc;
14use std::time::Duration;
15use thiserror::Error;
16use tracing::{error, instrument, warn};
17
18#[derive(Clone)]
19pub struct Api {
20 url: Url,
21 client: Client,
22 headers: Arc<HeaderMap>,
23}
24
25#[bon]
26impl Api {
27 #[builder(finish_fn=send)]
28 async fn send_request<T, B>(
29 &self,
30 #[builder(start_fn)] method: Method,
31 #[builder(start_fn)] endpoint: &str,
32 query: Option<&[(&str, &str)]>,
33 body: Option<&B>,
34 access_token: Option<&str>,
35 ) -> Result<T, ApiError>
36 where
37 T: DeserializeOwned,
38 B: Serialize + ?Sized,
39 {
40 let url = self.url.join(endpoint)?;
41
42 let mut request = self
43 .client
44 .request(method.clone(), url)
45 .headers((*self.headers).clone());
46
47 if let Some(q) = query {
48 if !q.is_empty() {
49 request = request.query(q);
50 }
51 }
52
53 if let Some(b) = body {
54 request = request.json(b);
55 }
56
57 if let Some(token) = access_token {
58 request = request.bearer_auth(token);
59 }
60
61 let response = request.send().await?;
62
63 match response.error_for_status_ref() {
64 Ok(_) => {}
65 Err(err) => {
66 return if let Some(status) = err.status() {
67 let body: ApiErrorResponse = response.json().await.unwrap_or_default();
68
69 warn!(%err,
70 method = %method,
71 endpoint = %endpoint,
72 body = ?body,
73 status = status.as_u16(),
74 "Request failed");
75
76 Err(ApiError::HttpError(err, status))
77 } else {
78 Err(ApiError::Unknown(err))
79 }
80 }
81 };
82
83 let result = response.json::<T>().await?;
84 Ok(result)
85 }
86}
87
88impl Api {
89 pub fn new(url: Url, timeout: Duration, api_key: &str) -> Self {
90 let client = Client::builder()
91 .timeout(timeout)
92 .user_agent("portal")
93 .build()
94 .unwrap();
95
96 let mut headers = HeaderMap::new();
97 headers.insert("apiKey", HeaderValue::from_str(api_key).unwrap());
98 let headers = Arc::new(headers);
99
100 Self {
101 url,
102 client,
103 headers,
104 }
105 }
106
107 #[instrument(skip(self, password), fields(user_id))]
118 pub async fn sign_up(
119 &self,
120 email_or_phone: EmailOrPhone,
121 password: impl AsRef<str> + Sized,
122 ) -> Result<SignUpResponse, ApiError> {
123 self.send_request(Method::POST, "signup")
124 .body(&self.sign_in_up_body(&email_or_phone, &password))
125 .send()
126 .await
127 }
128
129 #[instrument(skip(self, password))]
130 pub async fn sign_in(
131 &self,
132 email_or_phone: EmailOrPhone,
133 password: impl AsRef<str>,
134 ) -> Result<Session, ApiError> {
135 self.send_request(Method::POST, "token")
136 .query(&[("grant_type", "password")])
137 .body(&self.sign_in_up_body(&email_or_phone, &password))
138 .send()
139 .await
140 }
141
142 fn sign_in_up_body<'a>(
143 &'a self,
144 email_or_phone: &'a EmailOrPhone,
145 password: &'a impl AsRef<str>,
146 ) -> SignInUpBody<'a> {
147 match email_or_phone {
148 EmailOrPhone::Email(email) => SignInUpBody {
149 email: Some(email),
150 phone: None,
151 password: password.as_ref(),
152 },
153 EmailOrPhone::Phone(phone) => SignInUpBody {
154 email: None,
155 phone: Some(phone.as_str()),
156 password: password.as_ref(),
157 },
158 }
159 }
160
161 #[instrument(skip(self, access_token))]
162 pub async fn logout(&self, access_token: impl AsRef<str>) -> Result<(), ApiError> {
163 let endpoint = self.url.join("logout")?;
164
165 self.client
166 .post(endpoint)
167 .headers((*self.headers).clone())
168 .bearer_auth(access_token.as_ref())
169 .send()
170 .await?
171 .error_for_status()?;
172
173 Ok(())
174 }
175
176 #[instrument(skip(self, access_token))]
177 pub async fn get_user(&self, access_token: impl AsRef<str>) -> Result<User, ApiError> {
178 self.send_request::<_, ()>(Method::GET, "user")
179 .access_token(access_token.as_ref())
180 .send()
181 .await
182 }
183
184 #[instrument(skip(self))]
185 pub async fn health_check(&self) -> Result<HealthCheckResponse, ApiError> {
186 self.send_request::<_, ()>(Method::GET, "health")
187 .send()
188 .await
189 }
190
191 #[instrument(skip(self, refresh_token))]
192 pub async fn refresh_access_token(
193 &self,
194 refresh_token: impl AsRef<str>,
195 ) -> Result<Session, ApiError> {
196 self.send_request(Method::POST, "token")
197 .query(&[("grant_type", "refresh_token")])
198 .body(&json!({
199 "refresh_token": refresh_token.as_ref(),
200 }))
201 .send()
202 .await
203 }
204
205 pub async fn list_users(&self, access_token: impl AsRef<str>) -> Result<UserList, ApiError> {
206 self.list_users_query(access_token, &[]).await
207 }
208
209 #[instrument(skip(self, access_token, query))]
210 pub async fn list_users_query(
211 &self,
212 access_token: impl AsRef<str>,
213 query: &[(&str, &str)],
214 ) -> Result<UserList, ApiError> {
215 self.send_request::<_, ()>(Method::GET, "admin/users")
216 .query(query)
217 .access_token(access_token.as_ref())
218 .send()
219 .await
220 }
221
222 pub fn create_pkce_oauth_url(&self, req: OAuthRequest, challenge: PkceCodeChallenge) -> Url {
223 let query = format!(
224 "provider={}&redirect_to={}&code_challenge={}&code_challenge_method={}",
225 req.provider,
226 req.redirect_to,
227 challenge.as_str(),
228 challenge.method().as_str()
229 );
230
231 let mut endpoint = self.url.join("authorize").unwrap();
232 endpoint.set_query(Some(&query));
233
234 endpoint
235 }
236
237 pub async fn exchange_code_for_session(
238 &self,
239 code: &str,
240 verifier: &PkceCodeVerifier,
241 ) -> Result<Session, ApiError> {
242 self.send_request(Method::POST, "token")
243 .query(&[("grant_type", "pkce")])
244 .body(&json!({
245 "auth_code": code,
246 "code_verifier": verifier.secret(),
247 }))
248 .send()
249 .await
250 }
251}
252
253#[derive(Serialize)]
254struct SignInUpBody<'a> {
255 #[serde(skip_serializing_if = "Option::is_none")]
256 email: Option<&'a str>,
257 #[serde(skip_serializing_if = "Option::is_none")]
258 phone: Option<&'a str>,
259 password: &'a str,
260}
261
262#[derive(Deserialize, Debug)]
263#[serde(transparent)]
264pub struct SignUpResponse {
265 #[serde(with = "either::serde_untagged")]
266 inner: Either<User, Session>,
267}
268
269impl SignUpResponse {
270 pub fn session(self) -> Option<Session> {
271 self.into()
272 }
273
274 pub fn user(self) -> Option<User> {
275 self.into()
276 }
277}
278
279impl AsRef<User> for SignUpResponse {
280 fn as_ref(&self) -> &User {
281 match self.inner {
282 Either::Left(ref user) => user,
283 Either::Right(ref session) => &session.user,
284 }
285 }
286}
287
288impl From<SignUpResponse> for Option<User> {
289 fn from(val: SignUpResponse) -> Self {
290 val.inner.left()
291 }
292}
293
294impl From<SignUpResponse> for Option<Session> {
295 fn from(val: SignUpResponse) -> Self {
296 val.inner.right()
297 }
298}
299
300#[derive(Debug, Error)]
302pub enum ApiError {
303 #[error("Http error with status {1}: {0}")]
304 HttpError(#[source] reqwest::Error, StatusCode),
305 #[error(transparent)]
306 Unknown(#[from] reqwest::Error),
307 #[error("URL parsing error: {0}")]
308 UrlError(#[from] url::ParseError),
309}
310
311#[derive(Debug, Deserialize, Default)]
312pub struct ApiErrorResponse {
313 pub code: u16,
314 pub error_code: ApiErrorCode,
315 pub msg: String,
316}
317
318#[derive(Debug, Deserialize)]
319#[serde(rename_all = "snake_case")]
320pub enum ApiErrorCode {
321 Unknown,
322 SignupDisabled,
323 UserAlreadyExists,
324}
325
326impl Default for ApiErrorCode {
327 fn default() -> Self {
328 Self::Unknown
329 }
330}