1use crate::error::{AuthErrorCode, HyperStackError};
2use base64::Engine as _;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use url::Url;
10
11pub const TOKEN_REFRESH_BUFFER_SECONDS: u64 = 60;
12pub const MIN_REFRESH_DELAY_SECONDS: u64 = 1;
13pub const DEFAULT_QUERY_PARAMETER: &str = "hs_token";
14pub const DEFAULT_HOSTED_TOKEN_ENDPOINT: &str = "https://api.usehyperstack.com/ws/sessions";
15pub const HOSTED_WEBSOCKET_SUFFIX: &str = ".stack.usehyperstack.com";
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct AuthToken {
19 pub token: String,
20 pub expires_at: Option<u64>,
21}
22
23impl AuthToken {
24 pub fn new(token: impl Into<String>) -> Self {
25 Self {
26 token: token.into(),
27 expires_at: None,
28 }
29 }
30
31 pub fn with_expiry(mut self, expires_at: u64) -> Self {
32 self.expires_at = Some(expires_at);
33 self
34 }
35}
36
37impl From<String> for AuthToken {
38 fn from(value: String) -> Self {
39 Self::new(value)
40 }
41}
42
43impl From<&str> for AuthToken {
44 fn from(value: &str) -> Self {
45 Self::new(value)
46 }
47}
48
49pub type TokenProviderFuture =
50 Pin<Box<dyn Future<Output = Result<AuthToken, HyperStackError>> + Send>>;
51pub type TokenProvider = dyn Fn() -> TokenProviderFuture + Send + Sync;
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
54pub enum TokenTransport {
55 #[default]
56 QueryParameter,
57 Bearer,
58}
59
60#[derive(Clone, Default)]
61pub struct AuthConfig {
62 pub(crate) token: Option<String>,
63 pub(crate) get_token: Option<Arc<TokenProvider>>,
64 pub(crate) token_endpoint: Option<String>,
65 pub(crate) publishable_key: Option<String>,
66 pub(crate) token_endpoint_headers: HashMap<String, String>,
67 pub(crate) token_transport: TokenTransport,
68}
69
70impl fmt::Debug for AuthConfig {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 f.debug_struct("AuthConfig")
73 .field("has_token", &self.token.is_some())
74 .field("has_get_token", &self.get_token.is_some())
75 .field("token_endpoint", &self.token_endpoint)
76 .field(
77 "publishable_key",
78 &self.publishable_key.as_ref().map(|_| "***"),
79 )
80 .field(
81 "token_endpoint_headers",
82 &self.token_endpoint_headers.keys().collect::<Vec<_>>(),
83 )
84 .field("token_transport", &self.token_transport)
85 .finish()
86 }
87}
88
89impl AuthConfig {
90 pub fn with_token(mut self, token: impl Into<String>) -> Self {
91 self.token = Some(token.into());
92 self
93 }
94
95 pub fn with_publishable_key(mut self, publishable_key: impl Into<String>) -> Self {
96 self.publishable_key = Some(publishable_key.into());
97 self
98 }
99
100 pub fn with_token_endpoint(mut self, token_endpoint: impl Into<String>) -> Self {
101 self.token_endpoint = Some(token_endpoint.into());
102 self
103 }
104
105 pub fn with_token_endpoint_header(
106 mut self,
107 key: impl Into<String>,
108 value: impl Into<String>,
109 ) -> Self {
110 self.token_endpoint_headers.insert(key.into(), value.into());
111 self
112 }
113
114 pub fn with_token_transport(mut self, transport: TokenTransport) -> Self {
115 self.token_transport = transport;
116 self
117 }
118
119 pub fn with_token_provider<F, Fut>(mut self, provider: F) -> Self
120 where
121 F: Fn() -> Fut + Send + Sync + 'static,
122 Fut: Future<Output = Result<AuthToken, HyperStackError>> + Send + 'static,
123 {
124 self.get_token = Some(Arc::new(move || Box::pin(provider())));
125 self
126 }
127
128 pub(crate) fn resolve_strategy(&self, websocket_url: &str) -> ResolvedAuthStrategy {
129 if let Some(token) = self.token.clone() {
130 return ResolvedAuthStrategy::StaticToken(token);
131 }
132
133 if let Some(get_token) = self.get_token.clone() {
134 return ResolvedAuthStrategy::TokenProvider(get_token);
135 }
136
137 if let Some(token_endpoint) = self.token_endpoint.clone() {
138 return ResolvedAuthStrategy::TokenEndpoint(token_endpoint);
139 }
140
141 if self.publishable_key.is_some() && is_hosted_hyperstack_websocket_url(websocket_url) {
142 return ResolvedAuthStrategy::TokenEndpoint(DEFAULT_HOSTED_TOKEN_ENDPOINT.to_string());
143 }
144
145 ResolvedAuthStrategy::None
146 }
147
148 pub(crate) fn has_refreshable_auth(&self, websocket_url: &str) -> bool {
149 matches!(
150 self.resolve_strategy(websocket_url),
151 ResolvedAuthStrategy::TokenProvider(_) | ResolvedAuthStrategy::TokenEndpoint(_)
152 )
153 }
154}
155
156#[derive(Clone)]
157pub(crate) enum ResolvedAuthStrategy {
158 None,
159 StaticToken(String),
160 TokenProvider(Arc<TokenProvider>),
161 TokenEndpoint(String),
162}
163
164#[derive(Debug, Deserialize)]
165pub(crate) struct TokenEndpointResponse {
166 pub token: String,
167 #[serde(default)]
168 pub expires_at: Option<u64>,
169 #[serde(default, rename = "expiresAt")]
170 pub expires_at_camel: Option<u64>,
171}
172
173impl TokenEndpointResponse {
174 pub fn into_auth_token(self) -> AuthToken {
175 AuthToken {
176 token: self.token,
177 expires_at: self.expires_at.or(self.expires_at_camel),
178 }
179 }
180}
181
182#[derive(Debug, Serialize)]
183pub(crate) struct TokenEndpointRequest<'a> {
184 pub websocket_url: &'a str,
185}
186
187pub(crate) fn parse_jwt_expiry(token: &str) -> Option<u64> {
188 let mut parts = token.split('.');
189 let _header = parts.next()?;
190 let payload = parts.next()?;
191 let _signature = parts.next()?;
192
193 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
194 .decode(payload.as_bytes())
195 .ok()?;
196 let payload: JwtPayload = serde_json::from_slice(&decoded).ok()?;
197 payload.exp
198}
199
200pub(crate) fn token_is_expiring(expires_at: Option<u64>, now_epoch_seconds: u64) -> bool {
201 match expires_at {
202 Some(exp) => now_epoch_seconds >= exp.saturating_sub(TOKEN_REFRESH_BUFFER_SECONDS),
203 None => false,
204 }
205}
206
207pub(crate) fn token_refresh_delay(expires_at: Option<u64>, now_epoch_seconds: u64) -> Option<u64> {
208 let expires_at = expires_at?;
209 let refresh_at = expires_at.saturating_sub(TOKEN_REFRESH_BUFFER_SECONDS);
210 Some(
211 refresh_at
212 .saturating_sub(now_epoch_seconds)
213 .max(MIN_REFRESH_DELAY_SECONDS),
214 )
215}
216
217pub(crate) fn is_hosted_hyperstack_websocket_url(websocket_url: &str) -> bool {
218 Url::parse(websocket_url)
219 .ok()
220 .and_then(|url| url.host_str().map(str::to_ascii_lowercase))
221 .is_some_and(|host| host.ends_with(HOSTED_WEBSOCKET_SUFFIX))
222}
223
224pub(crate) fn build_websocket_url(
225 websocket_url: &str,
226 token: Option<&str>,
227 transport: TokenTransport,
228) -> Result<String, HyperStackError> {
229 if transport == TokenTransport::Bearer || token.is_none() {
230 return Ok(websocket_url.to_string());
231 }
232
233 let mut url = Url::parse(websocket_url)
234 .map_err(|error| HyperStackError::ConnectionFailed(error.to_string()))?;
235 url.query_pairs_mut()
236 .append_pair(DEFAULT_QUERY_PARAMETER, token.expect("checked is_some"));
237 Ok(url.to_string())
238}
239
240pub(crate) fn hosted_auth_required_error() -> HyperStackError {
241 HyperStackError::WebSocket {
242 message: "Hosted Hyperstack websocket connections require auth.publishable_key, auth.get_token, auth.token_endpoint, or auth.token".to_string(),
243 code: Some(AuthErrorCode::AuthRequired),
244 }
245}
246
247#[derive(Debug, Deserialize)]
248struct JwtPayload {
249 exp: Option<u64>,
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 fn encode_base64url(input: &str) -> String {
257 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(input.as_bytes())
258 }
259
260 #[test]
261 fn publishable_key_on_hosted_url_uses_default_token_endpoint() {
262 let auth = AuthConfig::default().with_publishable_key("hspk_test");
263 let strategy = auth.resolve_strategy("wss://demo.stack.usehyperstack.com");
264
265 assert!(matches!(
266 strategy,
267 ResolvedAuthStrategy::TokenEndpoint(ref endpoint)
268 if endpoint == DEFAULT_HOSTED_TOKEN_ENDPOINT
269 ));
270 }
271
272 #[test]
273 fn static_token_takes_precedence_over_endpoint_flow() {
274 let auth = AuthConfig::default()
275 .with_publishable_key("hspk_test")
276 .with_token_endpoint("https://custom.example/ws/sessions")
277 .with_token("static-token");
278
279 assert!(matches!(
280 auth.resolve_strategy("wss://demo.stack.usehyperstack.com"),
281 ResolvedAuthStrategy::StaticToken(ref token) if token == "static-token"
282 ));
283 }
284
285 #[test]
286 fn build_websocket_url_adds_query_token_for_query_transport() {
287 let url = build_websocket_url(
288 "wss://demo.stack.usehyperstack.com/socket",
289 Some("abc123"),
290 TokenTransport::QueryParameter,
291 )
292 .expect("query auth url should build");
293
294 assert!(url.contains("hs_token=abc123"));
295 }
296
297 #[test]
298 fn parse_jwt_expiry_reads_exp_claim() {
299 let header = encode_base64url(r#"{"alg":"none","typ":"JWT"}"#);
300 let payload = encode_base64url(r#"{"exp":12345}"#);
301 let token = format!("{}.{}.sig", header, payload);
302
303 assert_eq!(parse_jwt_expiry(&token), Some(12345));
304 }
305
306 #[test]
307 fn token_refresh_delay_respects_refresh_buffer() {
308 let now = 1_000;
309 let expires_at = Some(now + TOKEN_REFRESH_BUFFER_SECONDS + 15);
310
311 assert_eq!(token_refresh_delay(expires_at, now), Some(15));
312 assert_eq!(
313 token_refresh_delay(Some(now + 10), now),
314 Some(MIN_REFRESH_DELAY_SECONDS)
315 );
316 }
317}