rp_supabase_auth/
jwt_stream.rs

1use core::ops::Div as _;
2use core::pin::Pin;
3use core::task::{Context, Poll};
4use core::time::Duration;
5
6use futures::Stream;
7use reqwest::header::InvalidHeaderValue;
8use thiserror::Error;
9use tokio::task::JoinSet;
10
11use crate::auth_client::requests::{GrantType, TokenRequest};
12use crate::auth_client::{ApiClient, Request};
13use crate::error::AuthError;
14use crate::types::{AccessTokenResponseSchema, ErrorSchema, LoginCredentials, TokenRequestBody};
15
16#[derive(Clone, Debug, PartialEq, Eq, typed_builder::TypedBuilder)]
17pub struct SupabaseAuthConfig {
18    pub api_key: String,
19    pub max_reconnect_attempts: u8,
20    pub reconnect_interval: core::time::Duration,
21    pub url: url::Url,
22}
23
24pub struct JwtStream {
25    config: SupabaseAuthConfig,
26}
27
28impl JwtStream {
29    /// Creates a new [`SupabaseAuth`].
30    #[must_use]
31    pub const fn new(config: SupabaseAuthConfig) -> Self {
32        Self { config }
33    }
34
35    /// Creates a Stream that will attempt to log in to supabase and periodically refresh the JWT
36    ///
37    /// # Errors
38    ///
39    /// This function will return an error if the provided supabase url cannot be joined with the
40    /// expected suffix or if the client cannot be created.
41    #[tracing::instrument(skip_all, err)]
42    pub fn sign_in(&self, params: LoginCredentials) -> Result<JwtRefreshStream, AuthError> {
43        let client = ApiClient::new_unauthenticated(&self.config.url, &self.config.api_key)?;
44        let max_reconnect_attempts = usize::from(self.config.max_reconnect_attempts);
45        Ok(JwtRefreshStream {
46            api_key: self.config.api_key.clone(),
47            client,
48            token_body: params,
49            max_reconnect_attempts,
50            current_reconnect_attempts: 0,
51            background_tasks: JoinSet::new(),
52            reconnect_interval: self.config.reconnect_interval,
53        })
54    }
55}
56
57pub struct JwtRefreshStream {
58    pub api_key: String,
59    pub client: ApiClient,
60    pub token_body: LoginCredentials,
61    pub max_reconnect_attempts: usize,
62    pub current_reconnect_attempts: usize,
63    pub background_tasks: JoinSet<Result<AccessTokenResponseSchema, RefreshStreamError>>,
64    pub reconnect_interval: Duration,
65}
66
67impl JwtRefreshStream {
68    fn login_request(
69        &self,
70    ) -> Result<Request<AccessTokenResponseSchema, ErrorSchema>, RefreshStreamError> {
71        let req = self.client.build_request(
72            &TokenRequest::builder()
73                .grant_type(GrantType::Password)
74                .payload(
75                    TokenRequestBody::builder()
76                        .email(self.token_body.email.clone())
77                        .password(self.token_body.password.clone())
78                        .phone(self.token_body.phone.clone())
79                        .build(),
80                )
81                .build(),
82        )?;
83        Ok(req)
84    }
85
86    fn spawn_login_task(&mut self, delay: Option<core::time::Duration>) {
87        let request = match self.login_request() {
88            Ok(req) => req,
89            Err(err) => {
90                tracing::error!(?err, "Failed to build login request");
91                return;
92            }
93        };
94        let task = async move {
95            if let Some(duration) = delay {
96                tokio::time::sleep(duration).await;
97            }
98            auth_request(request).await
99        };
100        self.background_tasks.spawn(task);
101    }
102
103    fn spawn_refresh_task(&mut self, access_token: &AccessTokenResponseSchema) {
104        // Attempt to extract refresh_token
105        let Some(refresh_token) = access_token.refresh_token.clone() else {
106            tracing::warn!("`refresh_token` not present");
107            return;
108        };
109
110        // Attempt to extract expires_in
111        let Some(expires_in) = access_token.expires_in else {
112            tracing::warn!("`expires_in` not present");
113            return;
114        };
115
116        // Build the TokenRequestBody
117        let token_request_body = TokenRequestBody::builder()
118            .refresh_token(refresh_token)
119            .build();
120
121        // Build the TokenRequest
122        let token_request = TokenRequest::builder()
123            .grant_type(GrantType::RefreshToken)
124            .payload(token_request_body)
125            .build();
126
127        // Attempt to build the request
128        let Ok(request) = self.client.build_request(&token_request) else {
129            tracing::warn!("could not build refresh task request");
130            return;
131        };
132
133        // Create the asynchronous task
134        let task = async move {
135            let refresh_in =
136                calculate_refresh_sleep_duration(u64::try_from(expires_in).unwrap_or(0));
137            tokio::time::sleep(refresh_in).await;
138            auth_request(request).await
139        };
140
141        // Spawn the background task
142        self.background_tasks.spawn(task);
143    }
144}
145
146impl Stream for JwtRefreshStream {
147    type Item = Result<AccessTokenResponseSchema, RefreshStreamError>;
148
149    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150        match self.background_tasks.poll_join_next(cx) {
151            Poll::Ready(Some(Ok(item))) => {
152                match &item {
153                    Ok(access_token) => {
154                        // Reset reconnect attempts on success
155                        self.current_reconnect_attempts = 0;
156                        // Spawn a task to refresh the token before it expires
157                        self.spawn_refresh_task(access_token);
158                        cx.waker().wake_by_ref();
159                    }
160                    Err(err) => {
161                        if self.current_reconnect_attempts >= self.max_reconnect_attempts {
162                            tracing::error!(
163                                ?err,
164                                "Max reconnect attempts exceeded; terminating stream"
165                            );
166                            return Poll::Ready(None);
167                        }
168                        tracing::warn!(
169                            attempts = self.current_reconnect_attempts,
170                            max_attempts = self.max_reconnect_attempts,
171                            "Login failed; retrying"
172                        );
173                        self.current_reconnect_attempts =
174                            self.current_reconnect_attempts.saturating_add(1);
175                        // Spawn a login task with a delay
176                        let duration = self.reconnect_interval;
177                        self.spawn_login_task(Some(duration));
178                        cx.waker().wake_by_ref();
179                    }
180                }
181                Poll::Ready(Some(item))
182            }
183            Poll::Ready(Some(Err(join_err))) => {
184                tracing::error!(?join_err, "Task panicked; terminating stream");
185                cx.waker().wake_by_ref();
186                Poll::Ready(None)
187            }
188            Poll::Ready(None) => {
189                // No tasks left; start the initial login attempt
190                if self.current_reconnect_attempts >= self.max_reconnect_attempts {
191                    tracing::error!("Max reconnect attempts exceeded; terminating stream");
192                    return Poll::Ready(None);
193                }
194                tracing::debug!("No tasks running; attempting initial login");
195                self.current_reconnect_attempts = self.current_reconnect_attempts.saturating_add(1);
196                self.spawn_login_task(None);
197                // Yield control to allow the task to start
198                cx.waker().wake_by_ref();
199                Poll::Pending
200            }
201            Poll::Pending => Poll::Pending,
202        }
203    }
204}
205
206async fn auth_request(
207    request: Request<AccessTokenResponseSchema, ErrorSchema>,
208) -> Result<AccessTokenResponseSchema, RefreshStreamError> {
209    let res = request.execute().await?.json().await??;
210    Ok(res)
211}
212
213fn calculate_refresh_sleep_duration(expires_in: u64) -> Duration {
214    Duration::from_secs(expires_in).div(2)
215}
216
217#[derive(Debug, Error)]
218pub enum RefreshStreamError {
219    #[error("Request error: {0}")]
220    Reqwest(#[from] reqwest::Error),
221    #[error("JSON parse error: {0}")]
222    JsonParse(#[from] simd_json::Error),
223    #[error("Supabase API error: {0}")]
224    SupabaseApiError(String),
225    #[error("Auth error: {0}")]
226    AuthError(#[from] AuthError),
227    #[error("Auth error: {0}")]
228    ErrorResponse(#[from] ErrorSchema),
229}
230
231#[derive(Debug, Error)]
232pub enum SignInError {
233    #[error(transparent)]
234    InvalidHeaderValue(#[from] InvalidHeaderValue),
235
236    #[error(transparent)]
237    ReqwestError(#[from] reqwest::Error),
238
239    #[error(transparent)]
240    UrlParseError(#[from] url::ParseError),
241}
242
243#[cfg(test)]
244#[expect(clippy::unwrap_used, reason = "allow for tests")]
245mod auth_tests {
246    use core::time::Duration;
247
248    use futures::StreamExt as _;
249    use mockito::Matcher;
250    use pretty_assertions::assert_eq;
251    use rp_supabase_mock::{SupabaseMockServer, make_jwt};
252    use rstest::rstest;
253    use test_log::test;
254    use tokio::time::timeout;
255
256    use super::*;
257
258    fn ms(ms: u32) -> Duration {
259        Duration::from_millis(ms.into())
260    }
261
262    #[rstest]
263    #[test(tokio::test)]
264    #[timeout(ms(5_000))]
265    async fn test_successful_password_login() {
266        let access_token = make_jwt(Duration::from_secs(3600)).unwrap();
267        let mut ms = SupabaseMockServer::new().await;
268        ms.register_jwt_password(&access_token).unwrap();
269        let config = SupabaseAuthConfig {
270            url: ms.server_url().unwrap(),
271            api_key: "api-key".to_owned(),
272            max_reconnect_attempts: 1,
273            reconnect_interval: Duration::from_secs(1),
274        };
275        let supabase_auth = JwtStream::new(config);
276        let token_body = LoginCredentials::builder()
277            .email("user@example.com".to_owned())
278            .password("password".to_owned())
279            .build();
280
281        let mut stream = supabase_auth.sign_in(token_body).unwrap();
282
283        let response = timeout(Duration::from_secs(5), stream.next())
284            .await
285            .unwrap()
286            .unwrap();
287
288        dbg!(&response);
289        assert!(response.is_ok());
290        let auth_response = response.unwrap();
291        assert_eq!(auth_response.access_token.unwrap(), access_token);
292        assert_eq!(auth_response.refresh_token.unwrap(), "some-refresh-token");
293        assert_eq!(
294            auth_response.user.unwrap().email.unwrap(),
295            "user@example.com"
296        );
297    }
298
299    #[rstest]
300    #[test(tokio::test)]
301    #[timeout(ms(100))]
302    async fn test_password_login_error() {
303        let mut ms = SupabaseMockServer::new().await;
304        let _m1 = ms
305            .mockito_server
306            .mock("POST", "/auth/v1/token")
307            .match_query(Matcher::Regex("grant_type=password".to_owned()))
308            .with_status(400)
309            .create();
310
311        let config = SupabaseAuthConfig {
312            url: ms.server_url().unwrap(),
313            api_key: "api-key".to_owned(),
314            max_reconnect_attempts: 2,
315            reconnect_interval: Duration::from_secs(1),
316        };
317        let supabase_auth = JwtStream::new(config);
318        let token_body = LoginCredentials::builder()
319            .email("user@example.com".to_owned())
320            .password("password".to_owned())
321            .build();
322
323        let mut stream = supabase_auth.sign_in(token_body).unwrap();
324
325        let response = timeout(Duration::from_secs(5), stream.next())
326            .await
327            .unwrap()
328            .unwrap();
329
330        response.unwrap_err();
331    }
332    #[rstest]
333    #[test(tokio::test)]
334    #[timeout(ms(100))]
335    async fn test_password_login_error_no_retries() {
336        let mut ms = SupabaseMockServer::new().await;
337        let _m1 = ms
338            .mockito_server
339            .mock("POST", "/auth/v1/token")
340            .match_query(Matcher::Regex("grant_type=password".to_owned()))
341            .with_status(400)
342            .create();
343
344        let config = SupabaseAuthConfig {
345            url: ms.server_url().unwrap(),
346            api_key: "api-key".to_owned(),
347            max_reconnect_attempts: 1,
348            reconnect_interval: Duration::from_secs(1),
349        };
350        let supabase_auth = JwtStream::new(config);
351        let token_body = LoginCredentials::builder()
352            .email("user@example.com".to_owned())
353            .password("password".to_owned())
354            .build();
355
356        let mut stream = supabase_auth.sign_in(token_body).unwrap();
357
358        let response = timeout(Duration::from_secs(5), stream.next())
359            .await
360            .unwrap();
361
362        assert!(response.is_none());
363    }
364
365    #[rstest]
366    #[test(tokio::test)]
367    #[timeout(ms(100))]
368    async fn test_retry_on_login_error() {
369        let mut ms = SupabaseMockServer::new().await;
370        let _m1 = ms
371            .mockito_server
372            .mock("POST", "/auth/v1/token")
373            .match_query(Matcher::Regex("grant_type=password".to_owned()))
374            .with_status(500)
375            .create();
376        let config = SupabaseAuthConfig {
377            url: ms.server_url().unwrap(),
378            api_key: "api-key".to_owned(),
379            max_reconnect_attempts: 2,
380            reconnect_interval: Duration::from_millis(20),
381        };
382        let supabase_auth = JwtStream::new(config);
383        let token_body = LoginCredentials::builder()
384            .email("user@example.com".to_owned())
385            .password("password".to_owned())
386            .build();
387
388        let mut stream = supabase_auth.sign_in(token_body).unwrap();
389
390        let response = stream.next().await.unwrap();
391        response.unwrap_err();
392        ms.register_jwt_password(&make_jwt(Duration::from_secs(3600)).unwrap())
393            .unwrap();
394        let response = timeout(Duration::from_secs(10), stream.next())
395            .await
396            .unwrap()
397            .unwrap();
398
399        dbg!(&response);
400        assert!(response.is_ok());
401        let auth_response = response.unwrap();
402        assert_eq!(auth_response.refresh_token.unwrap(), "some-refresh-token");
403        assert_eq!(
404            auth_response.user.unwrap().email.unwrap(),
405            "user@example.com"
406        );
407    }
408
409    #[rstest]
410    #[test_log::test(tokio::test)]
411    #[timeout(ms(3_000))]
412    async fn test_use_refresh_token_on_expiry() {
413        // setup
414        let mut ms = SupabaseMockServer::new().await;
415        let first_access_token = make_jwt(Duration::from_millis(5)).unwrap();
416        ms.register_jwt_password(&first_access_token).unwrap();
417
418        let new_access_token = make_jwt(Duration::from_secs(3600)).unwrap();
419        ms.register_jwt_refresh(&new_access_token).unwrap();
420        let config = SupabaseAuthConfig {
421            url: ms.server_url().unwrap(),
422            api_key: "api-key".to_owned(),
423            max_reconnect_attempts: 1,
424            reconnect_interval: Duration::from_millis(20),
425        };
426        let supabase_auth = JwtStream::new(config);
427
428        // action
429        let token_body = LoginCredentials::builder()
430            .email("user@example.com".to_owned())
431            .password("password".to_owned())
432            .build();
433        let mut stream = supabase_auth.sign_in(token_body).unwrap();
434
435        // Get the initial token
436        let response1 = timeout(Duration::from_secs(5), stream.next())
437            .await
438            .unwrap()
439            .unwrap();
440        dbg!(&response1);
441        assert!(response1.is_ok());
442        let auth_response1 = response1.unwrap();
443        assert_eq!(auth_response1.access_token.unwrap(), first_access_token);
444        assert_eq!(
445            auth_response1.user.unwrap().email.unwrap(),
446            "user@example.com"
447        );
448
449        // Wait for token to expire and refresh
450        let response2 = timeout(Duration::from_secs(5), stream.next())
451            .await
452            .unwrap()
453            .unwrap();
454        dbg!(&response2);
455        assert!(response2.is_ok());
456        let auth_response2 = response2.unwrap();
457        assert_eq!(auth_response2.access_token.unwrap(), new_access_token);
458        assert_eq!(
459            auth_response2.user.unwrap().email.unwrap(),
460            "user@example.com"
461        );
462    }
463}