1use core::net::SocketAddr;
2use core::time::Duration;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
6use jsonwebtoken::{DecodingKey, Validation, decode};
7pub use mockito;
8use mockito::{Matcher, ServerGuard};
9use serde::{Deserialize, Serialize};
10use simd_json::json;
11
12pub struct SupabaseMockServer {
13 pub api_mock: Vec<mockito::Mock>,
14 pub mockito_server: ServerGuard,
15}
16
17impl SupabaseMockServer {
18 #[must_use]
19 pub async fn new() -> Self {
20 let server = mockito::Server::new_async().await;
21 Self {
22 mockito_server: server,
23 api_mock: vec![],
24 }
25 }
26
27 #[must_use]
28 pub fn server_address(&self) -> SocketAddr {
29 self.mockito_server.socket_address()
30 }
31
32 pub fn server_url(&self) -> Result<url::Url, url::ParseError> {
38 self.mockito_server.url().parse()
39 }
40
41 pub fn register_jwt(&mut self, jwt: &str) -> Result<&mut Self, JwtParseError> {
47 self.register_jwt_password(jwt)?.register_jwt_refresh(jwt)
48 }
49
50 pub fn register_jwt_password(&mut self, jwt: &str) -> Result<&mut Self, JwtParseError> {
56 let parsed_jwt = parse_jwt(jwt)?;
57 let current_ts = current_ts();
58 let expires_at = parsed_jwt.exp;
59 let expires_in = expires_at.abs_diff(current_ts.as_secs());
60 self.register_jwt_custom_grant_type(jwt, "password", Duration::from_millis(expires_in));
61 Ok(self)
62 }
63
64 pub fn register_jwt_refresh(&mut self, jwt: &str) -> Result<&mut Self, JwtParseError> {
70 let parsed_jwt = parse_jwt(jwt)?;
71 let current_ts = current_ts();
72 let expires_at = parsed_jwt.exp;
73 let expires_in = expires_at.abs_diff(current_ts.as_secs());
74 self.register_jwt_custom_grant_type(
75 jwt,
76 "refresh_token",
77 Duration::from_millis(expires_in),
78 );
79 Ok(self)
80 }
81
82 fn register_jwt_custom_grant_type(
83 &mut self,
84 jwt: &str,
85 grant_type: &str,
86 expires_in: Duration,
87 ) {
88 let body = json!({
89 "access_token": jwt,
90 "refresh_token": "some-refresh-token",
91 "expires_in": expires_in.as_secs(),
92 "token_type": "bearer",
93 "user": {
94 "id": "user-id",
95 "email": "user@example.com"
96 }
97 });
98 let body = simd_json::to_string(&body).unwrap_or_else(|_| "{}".to_owned());
99 let mock = self
100 .mockito_server
101 .mock("POST", "/auth/v1/token")
102 .match_query(Matcher::Regex(format!("grant_type={grant_type}")))
103 .with_status(200)
104 .with_header("content-type", "application/json")
105 .with_body(body)
106 .create();
107 self.api_mock.push(mock);
108 }
109}
110
111pub fn make_jwt(expires_in: Duration) -> Result<String, JwtParseError> {
117 let issued_at = current_ts().as_secs();
119
120 let exp = issued_at
121 .checked_add(expires_in.as_secs())
122 .ok_or(JwtParseError::InvalidJwt)?;
123
124 let claims = Claims {
125 iat: issued_at,
126 exp,
127 };
128
129 let mut header = Header::new(Algorithm::HS256);
131 header.kid = Some("secret".to_owned());
132
133 encode(&header, &claims, &EncodingKey::from_secret(SECRET))
134 .map_err(|_err| JwtParseError::InvalidJwt)
135}
136
137fn current_ts() -> Duration {
143 SystemTime::now()
144 .duration_since(UNIX_EPOCH)
145 .unwrap_or(Duration::from_secs(0))
146}
147
148const SECRET: &[u8] = b"SECRET";
149
150#[derive(Debug, Serialize, Deserialize)]
152pub struct Claims {
153 iat: u64,
154 exp: u64,
155}
156
157pub fn parse_jwt(token: &str) -> Result<Claims, JwtParseError> {
162 let mut validation = Validation::new(Algorithm::HS256);
164 validation.required_spec_claims = ["exp".to_owned(), "iat".to_owned()].into_iter().collect();
165
166 let data = decode::<Claims>(token, &DecodingKey::from_secret(SECRET), &validation)
168 .map_err(|_err| JwtParseError::InvalidJwt)?;
169
170 if data.header.kid.as_deref() != Some("secret") {
172 return Err(JwtParseError::InvalidJwt);
173 }
174
175 Ok(data.claims)
176}
177
178#[derive(Debug, thiserror::Error)]
179pub enum JwtParseError {
180 #[error("Base64 decode error: {0}")]
181 Base64Decode(#[from] base64::DecodeError),
182
183 #[error("Invalid JWT")]
184 InvalidJwt,
185
186 #[error("JSON parse error: {0}")]
187 JsonParse(#[from] simd_json::Error),
188}