1use core::ops::Div as _;
2use core::pin::Pin;
3use core::task::{Context, Poll};
4use core::time::Duration;
5
6use futures::Stream;
7use reqwest::header::InvalidHeaderValue;
8use thiserror::Error;
9use tokio::task::JoinSet;
10
11use crate::auth_client::requests::{GrantType, TokenRequest};
12use crate::auth_client::{ApiClient, Request};
13use crate::error::AuthError;
14use crate::types::{AccessTokenResponseSchema, ErrorSchema, LoginCredentials, TokenRequestBody};
15
16#[derive(Clone, Debug, PartialEq, Eq, typed_builder::TypedBuilder)]
17pub struct SupabaseAuthConfig {
18 pub api_key: String,
19 pub max_reconnect_attempts: u8,
20 pub reconnect_interval: core::time::Duration,
21 pub url: url::Url,
22}
23
24pub struct JwtStream {
25 config: SupabaseAuthConfig,
26}
27
28impl JwtStream {
29 #[must_use]
31 pub const fn new(config: SupabaseAuthConfig) -> Self {
32 Self { config }
33 }
34
35 #[tracing::instrument(skip_all, err)]
42 pub fn sign_in(&self, params: LoginCredentials) -> Result<JwtRefreshStream, AuthError> {
43 let client = ApiClient::new_unauthenticated(&self.config.url, &self.config.api_key)?;
44 let max_reconnect_attempts = usize::from(self.config.max_reconnect_attempts);
45 Ok(JwtRefreshStream {
46 api_key: self.config.api_key.clone(),
47 client,
48 token_body: params,
49 max_reconnect_attempts,
50 current_reconnect_attempts: 0,
51 background_tasks: JoinSet::new(),
52 reconnect_interval: self.config.reconnect_interval,
53 })
54 }
55}
56
57pub struct JwtRefreshStream {
58 pub api_key: String,
59 pub client: ApiClient,
60 pub token_body: LoginCredentials,
61 pub max_reconnect_attempts: usize,
62 pub current_reconnect_attempts: usize,
63 pub background_tasks: JoinSet<Result<AccessTokenResponseSchema, RefreshStreamError>>,
64 pub reconnect_interval: Duration,
65}
66
67impl JwtRefreshStream {
68 fn login_request(
69 &self,
70 ) -> Result<Request<AccessTokenResponseSchema, ErrorSchema>, RefreshStreamError> {
71 let req = self.client.build_request(
72 &TokenRequest::builder()
73 .grant_type(GrantType::Password)
74 .payload(
75 TokenRequestBody::builder()
76 .email(self.token_body.email.clone())
77 .password(self.token_body.password.clone())
78 .phone(self.token_body.phone.clone())
79 .build(),
80 )
81 .build(),
82 )?;
83 Ok(req)
84 }
85
86 fn spawn_login_task(&mut self, delay: Option<core::time::Duration>) {
87 let request = match self.login_request() {
88 Ok(req) => req,
89 Err(err) => {
90 tracing::error!(?err, "Failed to build login request");
91 return;
92 }
93 };
94 let task = async move {
95 if let Some(duration) = delay {
96 tokio::time::sleep(duration).await;
97 }
98 auth_request(request).await
99 };
100 self.background_tasks.spawn(task);
101 }
102
103 fn spawn_refresh_task(&mut self, access_token: &AccessTokenResponseSchema) {
104 let Some(refresh_token) = access_token.refresh_token.clone() else {
106 tracing::warn!("`refresh_token` not present");
107 return;
108 };
109
110 let Some(expires_in) = access_token.expires_in else {
112 tracing::warn!("`expires_in` not present");
113 return;
114 };
115
116 let token_request_body = TokenRequestBody::builder()
118 .refresh_token(refresh_token)
119 .build();
120
121 let token_request = TokenRequest::builder()
123 .grant_type(GrantType::RefreshToken)
124 .payload(token_request_body)
125 .build();
126
127 let Ok(request) = self.client.build_request(&token_request) else {
129 tracing::warn!("could not build refresh task request");
130 return;
131 };
132
133 let task = async move {
135 let refresh_in =
136 calculate_refresh_sleep_duration(u64::try_from(expires_in).unwrap_or(0));
137 tokio::time::sleep(refresh_in).await;
138 auth_request(request).await
139 };
140
141 self.background_tasks.spawn(task);
143 }
144}
145
146impl Stream for JwtRefreshStream {
147 type Item = Result<AccessTokenResponseSchema, RefreshStreamError>;
148
149 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150 match self.background_tasks.poll_join_next(cx) {
151 Poll::Ready(Some(Ok(item))) => {
152 match &item {
153 Ok(access_token) => {
154 self.current_reconnect_attempts = 0;
156 self.spawn_refresh_task(access_token);
158 cx.waker().wake_by_ref();
159 }
160 Err(err) => {
161 if self.current_reconnect_attempts >= self.max_reconnect_attempts {
162 tracing::error!(
163 ?err,
164 "Max reconnect attempts exceeded; terminating stream"
165 );
166 return Poll::Ready(None);
167 }
168 tracing::warn!(
169 attempts = self.current_reconnect_attempts,
170 max_attempts = self.max_reconnect_attempts,
171 "Login failed; retrying"
172 );
173 self.current_reconnect_attempts =
174 self.current_reconnect_attempts.saturating_add(1);
175 let duration = self.reconnect_interval;
177 self.spawn_login_task(Some(duration));
178 cx.waker().wake_by_ref();
179 }
180 }
181 Poll::Ready(Some(item))
182 }
183 Poll::Ready(Some(Err(join_err))) => {
184 tracing::error!(?join_err, "Task panicked; terminating stream");
185 cx.waker().wake_by_ref();
186 Poll::Ready(None)
187 }
188 Poll::Ready(None) => {
189 if self.current_reconnect_attempts >= self.max_reconnect_attempts {
191 tracing::error!("Max reconnect attempts exceeded; terminating stream");
192 return Poll::Ready(None);
193 }
194 tracing::debug!("No tasks running; attempting initial login");
195 self.current_reconnect_attempts = self.current_reconnect_attempts.saturating_add(1);
196 self.spawn_login_task(None);
197 cx.waker().wake_by_ref();
199 Poll::Pending
200 }
201 Poll::Pending => Poll::Pending,
202 }
203 }
204}
205
206async fn auth_request(
207 request: Request<AccessTokenResponseSchema, ErrorSchema>,
208) -> Result<AccessTokenResponseSchema, RefreshStreamError> {
209 let res = request.execute().await?.json().await??;
210 Ok(res)
211}
212
213fn calculate_refresh_sleep_duration(expires_in: u64) -> Duration {
214 Duration::from_secs(expires_in).div(2)
215}
216
217#[derive(Debug, Error)]
218pub enum RefreshStreamError {
219 #[error("Request error: {0}")]
220 Reqwest(#[from] reqwest::Error),
221 #[error("JSON parse error: {0}")]
222 JsonParse(#[from] simd_json::Error),
223 #[error("Supabase API error: {0}")]
224 SupabaseApiError(String),
225 #[error("Auth error: {0}")]
226 AuthError(#[from] AuthError),
227 #[error("Auth error: {0}")]
228 ErrorResponse(#[from] ErrorSchema),
229}
230
231#[derive(Debug, Error)]
232pub enum SignInError {
233 #[error(transparent)]
234 InvalidHeaderValue(#[from] InvalidHeaderValue),
235
236 #[error(transparent)]
237 ReqwestError(#[from] reqwest::Error),
238
239 #[error(transparent)]
240 UrlParseError(#[from] url::ParseError),
241}
242
243#[cfg(test)]
244#[expect(clippy::unwrap_used, reason = "allow for tests")]
245mod auth_tests {
246 use core::time::Duration;
247
248 use futures::StreamExt as _;
249 use mockito::Matcher;
250 use pretty_assertions::assert_eq;
251 use rp_supabase_mock::{SupabaseMockServer, make_jwt};
252 use rstest::rstest;
253 use test_log::test;
254 use tokio::time::timeout;
255
256 use super::*;
257
258 fn ms(ms: u32) -> Duration {
259 Duration::from_millis(ms.into())
260 }
261
262 #[rstest]
263 #[test(tokio::test)]
264 #[timeout(ms(5_000))]
265 async fn test_successful_password_login() {
266 let access_token = make_jwt(Duration::from_secs(3600)).unwrap();
267 let mut ms = SupabaseMockServer::new().await;
268 ms.register_jwt_password(&access_token).unwrap();
269 let config = SupabaseAuthConfig {
270 url: ms.server_url().unwrap(),
271 api_key: "api-key".to_owned(),
272 max_reconnect_attempts: 1,
273 reconnect_interval: Duration::from_secs(1),
274 };
275 let supabase_auth = JwtStream::new(config);
276 let token_body = LoginCredentials::builder()
277 .email("user@example.com".to_owned())
278 .password("password".to_owned())
279 .build();
280
281 let mut stream = supabase_auth.sign_in(token_body).unwrap();
282
283 let response = timeout(Duration::from_secs(5), stream.next())
284 .await
285 .unwrap()
286 .unwrap();
287
288 dbg!(&response);
289 assert!(response.is_ok());
290 let auth_response = response.unwrap();
291 assert_eq!(auth_response.access_token.unwrap(), access_token);
292 assert_eq!(auth_response.refresh_token.unwrap(), "some-refresh-token");
293 assert_eq!(
294 auth_response.user.unwrap().email.unwrap(),
295 "user@example.com"
296 );
297 }
298
299 #[rstest]
300 #[test(tokio::test)]
301 #[timeout(ms(100))]
302 async fn test_password_login_error() {
303 let mut ms = SupabaseMockServer::new().await;
304 let _m1 = ms
305 .mockito_server
306 .mock("POST", "/auth/v1/token")
307 .match_query(Matcher::Regex("grant_type=password".to_owned()))
308 .with_status(400)
309 .create();
310
311 let config = SupabaseAuthConfig {
312 url: ms.server_url().unwrap(),
313 api_key: "api-key".to_owned(),
314 max_reconnect_attempts: 2,
315 reconnect_interval: Duration::from_secs(1),
316 };
317 let supabase_auth = JwtStream::new(config);
318 let token_body = LoginCredentials::builder()
319 .email("user@example.com".to_owned())
320 .password("password".to_owned())
321 .build();
322
323 let mut stream = supabase_auth.sign_in(token_body).unwrap();
324
325 let response = timeout(Duration::from_secs(5), stream.next())
326 .await
327 .unwrap()
328 .unwrap();
329
330 response.unwrap_err();
331 }
332 #[rstest]
333 #[test(tokio::test)]
334 #[timeout(ms(100))]
335 async fn test_password_login_error_no_retries() {
336 let mut ms = SupabaseMockServer::new().await;
337 let _m1 = ms
338 .mockito_server
339 .mock("POST", "/auth/v1/token")
340 .match_query(Matcher::Regex("grant_type=password".to_owned()))
341 .with_status(400)
342 .create();
343
344 let config = SupabaseAuthConfig {
345 url: ms.server_url().unwrap(),
346 api_key: "api-key".to_owned(),
347 max_reconnect_attempts: 1,
348 reconnect_interval: Duration::from_secs(1),
349 };
350 let supabase_auth = JwtStream::new(config);
351 let token_body = LoginCredentials::builder()
352 .email("user@example.com".to_owned())
353 .password("password".to_owned())
354 .build();
355
356 let mut stream = supabase_auth.sign_in(token_body).unwrap();
357
358 let response = timeout(Duration::from_secs(5), stream.next())
359 .await
360 .unwrap();
361
362 assert!(response.is_none());
363 }
364
365 #[rstest]
366 #[test(tokio::test)]
367 #[timeout(ms(100))]
368 async fn test_retry_on_login_error() {
369 let mut ms = SupabaseMockServer::new().await;
370 let _m1 = ms
371 .mockito_server
372 .mock("POST", "/auth/v1/token")
373 .match_query(Matcher::Regex("grant_type=password".to_owned()))
374 .with_status(500)
375 .create();
376 let config = SupabaseAuthConfig {
377 url: ms.server_url().unwrap(),
378 api_key: "api-key".to_owned(),
379 max_reconnect_attempts: 2,
380 reconnect_interval: Duration::from_millis(20),
381 };
382 let supabase_auth = JwtStream::new(config);
383 let token_body = LoginCredentials::builder()
384 .email("user@example.com".to_owned())
385 .password("password".to_owned())
386 .build();
387
388 let mut stream = supabase_auth.sign_in(token_body).unwrap();
389
390 let response = stream.next().await.unwrap();
391 response.unwrap_err();
392 ms.register_jwt_password(&make_jwt(Duration::from_secs(3600)).unwrap())
393 .unwrap();
394 let response = timeout(Duration::from_secs(10), stream.next())
395 .await
396 .unwrap()
397 .unwrap();
398
399 dbg!(&response);
400 assert!(response.is_ok());
401 let auth_response = response.unwrap();
402 assert_eq!(auth_response.refresh_token.unwrap(), "some-refresh-token");
403 assert_eq!(
404 auth_response.user.unwrap().email.unwrap(),
405 "user@example.com"
406 );
407 }
408
409 #[rstest]
410 #[test_log::test(tokio::test)]
411 #[timeout(ms(3_000))]
412 async fn test_use_refresh_token_on_expiry() {
413 let mut ms = SupabaseMockServer::new().await;
415 let first_access_token = make_jwt(Duration::from_millis(5)).unwrap();
416 ms.register_jwt_password(&first_access_token).unwrap();
417
418 let new_access_token = make_jwt(Duration::from_secs(3600)).unwrap();
419 ms.register_jwt_refresh(&new_access_token).unwrap();
420 let config = SupabaseAuthConfig {
421 url: ms.server_url().unwrap(),
422 api_key: "api-key".to_owned(),
423 max_reconnect_attempts: 1,
424 reconnect_interval: Duration::from_millis(20),
425 };
426 let supabase_auth = JwtStream::new(config);
427
428 let token_body = LoginCredentials::builder()
430 .email("user@example.com".to_owned())
431 .password("password".to_owned())
432 .build();
433 let mut stream = supabase_auth.sign_in(token_body).unwrap();
434
435 let response1 = timeout(Duration::from_secs(5), stream.next())
437 .await
438 .unwrap()
439 .unwrap();
440 dbg!(&response1);
441 assert!(response1.is_ok());
442 let auth_response1 = response1.unwrap();
443 assert_eq!(auth_response1.access_token.unwrap(), first_access_token);
444 assert_eq!(
445 auth_response1.user.unwrap().email.unwrap(),
446 "user@example.com"
447 );
448
449 let response2 = timeout(Duration::from_secs(5), stream.next())
451 .await
452 .unwrap()
453 .unwrap();
454 dbg!(&response2);
455 assert!(response2.is_ok());
456 let auth_response2 = response2.unwrap();
457 assert_eq!(auth_response2.access_token.unwrap(), new_access_token);
458 assert_eq!(
459 auth_response2.user.unwrap().email.unwrap(),
460 "user@example.com"
461 );
462 }
463}