Skip to main content

hyperstack_sdk/
auth.rs

1use crate::error::{AuthErrorCode, HyperStackError};
2use base64::Engine as _;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use url::Url;
10
11pub const TOKEN_REFRESH_BUFFER_SECONDS: u64 = 60;
12pub const MIN_REFRESH_DELAY_SECONDS: u64 = 1;
13pub const DEFAULT_QUERY_PARAMETER: &str = "hs_token";
14pub const DEFAULT_HOSTED_TOKEN_ENDPOINT: &str = "https://api.usehyperstack.com/ws/sessions";
15pub const HOSTED_WEBSOCKET_SUFFIX: &str = ".stack.usehyperstack.com";
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct AuthToken {
19    pub token: String,
20    pub expires_at: Option<u64>,
21}
22
23impl AuthToken {
24    pub fn new(token: impl Into<String>) -> Self {
25        Self {
26            token: token.into(),
27            expires_at: None,
28        }
29    }
30
31    pub fn with_expiry(mut self, expires_at: u64) -> Self {
32        self.expires_at = Some(expires_at);
33        self
34    }
35}
36
37impl From<String> for AuthToken {
38    fn from(value: String) -> Self {
39        Self::new(value)
40    }
41}
42
43impl From<&str> for AuthToken {
44    fn from(value: &str) -> Self {
45        Self::new(value)
46    }
47}
48
49pub type TokenProviderFuture =
50    Pin<Box<dyn Future<Output = Result<AuthToken, HyperStackError>> + Send>>;
51pub type TokenProvider = dyn Fn() -> TokenProviderFuture + Send + Sync;
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
54pub enum TokenTransport {
55    #[default]
56    QueryParameter,
57    Bearer,
58}
59
60#[derive(Clone, Default)]
61pub struct AuthConfig {
62    pub(crate) token: Option<String>,
63    pub(crate) get_token: Option<Arc<TokenProvider>>,
64    pub(crate) token_endpoint: Option<String>,
65    pub(crate) publishable_key: Option<String>,
66    pub(crate) token_endpoint_headers: HashMap<String, String>,
67    pub(crate) token_transport: TokenTransport,
68}
69
70impl fmt::Debug for AuthConfig {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        f.debug_struct("AuthConfig")
73            .field("has_token", &self.token.is_some())
74            .field("has_get_token", &self.get_token.is_some())
75            .field("token_endpoint", &self.token_endpoint)
76            .field(
77                "publishable_key",
78                &self.publishable_key.as_ref().map(|_| "***"),
79            )
80            .field(
81                "token_endpoint_headers",
82                &self.token_endpoint_headers.keys().collect::<Vec<_>>(),
83            )
84            .field("token_transport", &self.token_transport)
85            .finish()
86    }
87}
88
89impl AuthConfig {
90    pub fn with_token(mut self, token: impl Into<String>) -> Self {
91        self.token = Some(token.into());
92        self
93    }
94
95    pub fn with_publishable_key(mut self, publishable_key: impl Into<String>) -> Self {
96        self.publishable_key = Some(publishable_key.into());
97        self
98    }
99
100    pub fn with_token_endpoint(mut self, token_endpoint: impl Into<String>) -> Self {
101        self.token_endpoint = Some(token_endpoint.into());
102        self
103    }
104
105    pub fn with_token_endpoint_header(
106        mut self,
107        key: impl Into<String>,
108        value: impl Into<String>,
109    ) -> Self {
110        self.token_endpoint_headers.insert(key.into(), value.into());
111        self
112    }
113
114    pub fn with_token_transport(mut self, transport: TokenTransport) -> Self {
115        self.token_transport = transport;
116        self
117    }
118
119    pub fn with_token_provider<F, Fut>(mut self, provider: F) -> Self
120    where
121        F: Fn() -> Fut + Send + Sync + 'static,
122        Fut: Future<Output = Result<AuthToken, HyperStackError>> + Send + 'static,
123    {
124        self.get_token = Some(Arc::new(move || Box::pin(provider())));
125        self
126    }
127
128    pub(crate) fn resolve_strategy(&self, websocket_url: &str) -> ResolvedAuthStrategy {
129        if let Some(token) = self.token.clone() {
130            return ResolvedAuthStrategy::StaticToken(token);
131        }
132
133        if let Some(get_token) = self.get_token.clone() {
134            return ResolvedAuthStrategy::TokenProvider(get_token);
135        }
136
137        if let Some(token_endpoint) = self.token_endpoint.clone() {
138            return ResolvedAuthStrategy::TokenEndpoint(token_endpoint);
139        }
140
141        if self.publishable_key.is_some() && is_hosted_hyperstack_websocket_url(websocket_url) {
142            return ResolvedAuthStrategy::TokenEndpoint(DEFAULT_HOSTED_TOKEN_ENDPOINT.to_string());
143        }
144
145        ResolvedAuthStrategy::None
146    }
147
148    pub(crate) fn has_refreshable_auth(&self, websocket_url: &str) -> bool {
149        matches!(
150            self.resolve_strategy(websocket_url),
151            ResolvedAuthStrategy::TokenProvider(_) | ResolvedAuthStrategy::TokenEndpoint(_)
152        )
153    }
154}
155
156#[derive(Clone)]
157pub(crate) enum ResolvedAuthStrategy {
158    None,
159    StaticToken(String),
160    TokenProvider(Arc<TokenProvider>),
161    TokenEndpoint(String),
162}
163
164#[derive(Debug, Deserialize)]
165pub(crate) struct TokenEndpointResponse {
166    pub token: String,
167    #[serde(default)]
168    pub expires_at: Option<u64>,
169    #[serde(default, rename = "expiresAt")]
170    pub expires_at_camel: Option<u64>,
171}
172
173impl TokenEndpointResponse {
174    pub fn into_auth_token(self) -> AuthToken {
175        AuthToken {
176            token: self.token,
177            expires_at: self.expires_at.or(self.expires_at_camel),
178        }
179    }
180}
181
182#[derive(Debug, Serialize)]
183pub(crate) struct TokenEndpointRequest<'a> {
184    pub websocket_url: &'a str,
185}
186
187pub(crate) fn parse_jwt_expiry(token: &str) -> Option<u64> {
188    let mut parts = token.split('.');
189    let _header = parts.next()?;
190    let payload = parts.next()?;
191    let _signature = parts.next()?;
192
193    let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
194        .decode(payload.as_bytes())
195        .ok()?;
196    let payload: JwtPayload = serde_json::from_slice(&decoded).ok()?;
197    payload.exp
198}
199
200pub(crate) fn token_is_expiring(expires_at: Option<u64>, now_epoch_seconds: u64) -> bool {
201    match expires_at {
202        Some(exp) => now_epoch_seconds >= exp.saturating_sub(TOKEN_REFRESH_BUFFER_SECONDS),
203        None => false,
204    }
205}
206
207pub(crate) fn token_refresh_delay(expires_at: Option<u64>, now_epoch_seconds: u64) -> Option<u64> {
208    let expires_at = expires_at?;
209    let refresh_at = expires_at.saturating_sub(TOKEN_REFRESH_BUFFER_SECONDS);
210    Some(
211        refresh_at
212            .saturating_sub(now_epoch_seconds)
213            .max(MIN_REFRESH_DELAY_SECONDS),
214    )
215}
216
217pub(crate) fn is_hosted_hyperstack_websocket_url(websocket_url: &str) -> bool {
218    Url::parse(websocket_url)
219        .ok()
220        .and_then(|url| url.host_str().map(str::to_ascii_lowercase))
221        .is_some_and(|host| host.ends_with(HOSTED_WEBSOCKET_SUFFIX))
222}
223
224pub(crate) fn build_websocket_url(
225    websocket_url: &str,
226    token: Option<&str>,
227    transport: TokenTransport,
228) -> Result<String, HyperStackError> {
229    if transport == TokenTransport::Bearer || token.is_none() {
230        return Ok(websocket_url.to_string());
231    }
232
233    let mut url = Url::parse(websocket_url)
234        .map_err(|error| HyperStackError::ConnectionFailed(error.to_string()))?;
235    url.query_pairs_mut()
236        .append_pair(DEFAULT_QUERY_PARAMETER, token.expect("checked is_some"));
237    Ok(url.to_string())
238}
239
240pub(crate) fn hosted_auth_required_error() -> HyperStackError {
241    HyperStackError::WebSocket {
242        message: "Hosted Hyperstack websocket connections require auth.publishable_key, auth.get_token, auth.token_endpoint, or auth.token".to_string(),
243        code: Some(AuthErrorCode::AuthRequired),
244    }
245}
246
247#[derive(Debug, Deserialize)]
248struct JwtPayload {
249    exp: Option<u64>,
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    fn encode_base64url(input: &str) -> String {
257        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(input.as_bytes())
258    }
259
260    #[test]
261    fn publishable_key_on_hosted_url_uses_default_token_endpoint() {
262        let auth = AuthConfig::default().with_publishable_key("hspk_test");
263        let strategy = auth.resolve_strategy("wss://demo.stack.usehyperstack.com");
264
265        assert!(matches!(
266            strategy,
267            ResolvedAuthStrategy::TokenEndpoint(ref endpoint)
268                if endpoint == DEFAULT_HOSTED_TOKEN_ENDPOINT
269        ));
270    }
271
272    #[test]
273    fn static_token_takes_precedence_over_endpoint_flow() {
274        let auth = AuthConfig::default()
275            .with_publishable_key("hspk_test")
276            .with_token_endpoint("https://custom.example/ws/sessions")
277            .with_token("static-token");
278
279        assert!(matches!(
280            auth.resolve_strategy("wss://demo.stack.usehyperstack.com"),
281            ResolvedAuthStrategy::StaticToken(ref token) if token == "static-token"
282        ));
283    }
284
285    #[test]
286    fn build_websocket_url_adds_query_token_for_query_transport() {
287        let url = build_websocket_url(
288            "wss://demo.stack.usehyperstack.com/socket",
289            Some("abc123"),
290            TokenTransport::QueryParameter,
291        )
292        .expect("query auth url should build");
293
294        assert!(url.contains("hs_token=abc123"));
295    }
296
297    #[test]
298    fn parse_jwt_expiry_reads_exp_claim() {
299        let header = encode_base64url(r#"{"alg":"none","typ":"JWT"}"#);
300        let payload = encode_base64url(r#"{"exp":12345}"#);
301        let token = format!("{}.{}.sig", header, payload);
302
303        assert_eq!(parse_jwt_expiry(&token), Some(12345));
304    }
305
306    #[test]
307    fn token_refresh_delay_respects_refresh_buffer() {
308        let now = 1_000;
309        let expires_at = Some(now + TOKEN_REFRESH_BUFFER_SECONDS + 15);
310
311        assert_eq!(token_refresh_delay(expires_at, now), Some(15));
312        assert_eq!(
313            token_refresh_delay(Some(now + 10), now),
314            Some(MIN_REFRESH_DELAY_SECONDS)
315        );
316    }
317}