rp_supabase_mock/
lib.rs

1use core::net::SocketAddr;
2use core::time::Duration;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
6use jsonwebtoken::{DecodingKey, Validation, decode};
7pub use mockito;
8use mockito::{Matcher, ServerGuard};
9use serde::{Deserialize, Serialize};
10use simd_json::json;
11
12pub struct SupabaseMockServer {
13    pub api_mock: Vec<mockito::Mock>,
14    pub mockito_server: ServerGuard,
15}
16
17impl SupabaseMockServer {
18    #[must_use]
19    pub async fn new() -> Self {
20        let server = mockito::Server::new_async().await;
21        Self {
22            mockito_server: server,
23            api_mock: vec![],
24        }
25    }
26
27    #[must_use]
28    pub fn server_address(&self) -> SocketAddr {
29        self.mockito_server.socket_address()
30    }
31
32    /// Returns the server URL.
33    ///
34    /// # Errors
35    ///
36    /// Returns an error if the server URL cannot be parsed.
37    pub fn server_url(&self) -> Result<url::Url, url::ParseError> {
38        self.mockito_server.url().parse()
39    }
40
41    /// Registers a JWT token for both password and refresh grant types.
42    ///
43    /// # Errors
44    ///
45    /// Returns an error if the JWT token cannot be parsed or does not have an expiration time.
46    pub fn register_jwt(&mut self, jwt: &str) -> Result<&mut Self, JwtParseError> {
47        self.register_jwt_password(jwt)?.register_jwt_refresh(jwt)
48    }
49
50    /// Registers a JWT token for password authentication.
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if the JWT token cannot be parsed or does not have an expiration time.
55    pub fn register_jwt_password(&mut self, jwt: &str) -> Result<&mut Self, JwtParseError> {
56        let parsed_jwt = parse_jwt(jwt)?;
57        let current_ts = current_ts();
58        let expires_at = parsed_jwt.exp;
59        let expires_in = expires_at.abs_diff(current_ts.as_secs());
60        self.register_jwt_custom_grant_type(jwt, "password", Duration::from_millis(expires_in));
61        Ok(self)
62    }
63
64    /// Registers a JWT token for refresh token authentication.
65    ///
66    /// # Errors
67    ///
68    /// Returns an error if the JWT token cannot be parsed or does not have an expiration time.
69    pub fn register_jwt_refresh(&mut self, jwt: &str) -> Result<&mut Self, JwtParseError> {
70        let parsed_jwt = parse_jwt(jwt)?;
71        let current_ts = current_ts();
72        let expires_at = parsed_jwt.exp;
73        let expires_in = expires_at.abs_diff(current_ts.as_secs());
74        self.register_jwt_custom_grant_type(
75            jwt,
76            "refresh_token",
77            Duration::from_millis(expires_in),
78        );
79        Ok(self)
80    }
81
82    fn register_jwt_custom_grant_type(
83        &mut self,
84        jwt: &str,
85        grant_type: &str,
86        expires_in: Duration,
87    ) {
88        let body = json!({
89            "access_token": jwt,
90            "refresh_token": "some-refresh-token",
91            "expires_in": expires_in.as_secs(),
92            "token_type": "bearer",
93            "user": {
94                "id": "user-id",
95                "email": "user@example.com"
96            }
97        });
98        let body = simd_json::to_string(&body).unwrap_or_else(|_| "{}".to_owned());
99        let mock = self
100            .mockito_server
101            .mock("POST", "/auth/v1/token")
102            .match_query(Matcher::Regex(format!("grant_type={grant_type}")))
103            .with_status(200)
104            .with_header("content-type", "application/json")
105            .with_body(body)
106            .create();
107        self.api_mock.push(mock);
108    }
109}
110
111/// Creates a new JWT token with the specified expiration time.
112///
113/// # Errors
114///
115/// Returns an error if the JWT key pair cannot be generated or the JWT token cannot be signed.
116pub fn make_jwt(expires_in: Duration) -> Result<String, JwtParseError> {
117    // `iat` and `exp` must be *seconds* since the epoch for JWTs.
118    let issued_at = current_ts().as_secs();
119
120    let exp = issued_at
121        .checked_add(expires_in.as_secs())
122        .ok_or(JwtParseError::InvalidJwt)?;
123
124    let claims = Claims {
125        iat: issued_at,
126        exp,
127    };
128
129    // Build an explicit header so we can keep the `"kid": "secret"` you had.
130    let mut header = Header::new(Algorithm::HS256);
131    header.kid = Some("secret".to_owned());
132
133    encode(&header, &claims, &EncodingKey::from_secret(SECRET))
134        .map_err(|_err| JwtParseError::InvalidJwt)
135}
136
137/// Returns the current timestamp.
138///
139/// # Panics
140///
141/// This function will panic if the system time is before the Unix epoch.
142fn current_ts() -> Duration {
143    SystemTime::now()
144        .duration_since(UNIX_EPOCH)
145        .unwrap_or(Duration::from_secs(0))
146}
147
148const SECRET: &[u8] = b"SECRET";
149
150/// Same struct we used for encoding; now we just add `Deserialize`.
151#[derive(Debug, Serialize, Deserialize)]
152pub struct Claims {
153    iat: u64,
154    exp: u64,
155}
156
157/// Parse JWT
158///
159/// # Errors
160/// if the JWT cannot be parsed or the claims are invalid
161pub fn parse_jwt(token: &str) -> Result<Claims, JwtParseError> {
162    // Accept only HS256 and require exp to be in the future.
163    let mut validation = Validation::new(Algorithm::HS256);
164    validation.required_spec_claims = ["exp".to_owned(), "iat".to_owned()].into_iter().collect();
165
166    // Perform the decode + signature check
167    let data = decode::<Claims>(token, &DecodingKey::from_secret(SECRET), &validation)
168        .map_err(|_err| JwtParseError::InvalidJwt)?;
169
170    // Optional defense-in-depth: ensure the kid is what we expect.
171    if data.header.kid.as_deref() != Some("secret") {
172        return Err(JwtParseError::InvalidJwt);
173    }
174
175    Ok(data.claims)
176}
177
178#[derive(Debug, thiserror::Error)]
179pub enum JwtParseError {
180    #[error("Base64 decode error: {0}")]
181    Base64Decode(#[from] base64::DecodeError),
182
183    #[error("Invalid JWT")]
184    InvalidJwt,
185
186    #[error("JSON parse error: {0}")]
187    JsonParse(#[from] simd_json::Error),
188}