1use serde::Deserialize;
2use thiserror::Error;
3use tokio_tungstenite::tungstenite::{self, http::Response};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct SocketIssue {
7 pub error: String,
8 pub message: String,
9 pub code: Option<AuthErrorCode>,
10 pub retryable: bool,
11 pub retry_after: Option<u64>,
12 pub suggested_action: Option<String>,
13 pub docs_url: Option<String>,
14 pub fatal: bool,
15}
16
17impl std::fmt::Display for SocketIssue {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.write_str(&self.message)
20 }
21}
22
23#[derive(Debug, Deserialize)]
24pub(crate) struct SocketIssuePayload {
25 #[serde(rename = "type")]
26 pub kind: String,
27 pub error: String,
28 pub message: String,
29 pub code: String,
30 pub retryable: bool,
31 #[serde(default)]
32 pub retry_after: Option<u64>,
33 #[serde(default)]
34 pub suggested_action: Option<String>,
35 #[serde(default)]
36 pub docs_url: Option<String>,
37 pub fatal: bool,
38}
39
40impl SocketIssuePayload {
41 pub fn is_socket_issue(&self) -> bool {
42 self.kind == "error"
43 }
44
45 pub fn into_socket_issue(self) -> SocketIssue {
46 SocketIssue {
47 error: self.error,
48 message: self.message,
49 code: AuthErrorCode::from_wire(&self.code),
50 retryable: self.retryable,
51 retry_after: self.retry_after,
52 suggested_action: self.suggested_action,
53 docs_url: self.docs_url,
54 fatal: self.fatal,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum AuthErrorCode {
61 TokenMissing,
62 TokenExpired,
63 TokenInvalidSignature,
64 TokenInvalidFormat,
65 TokenInvalidIssuer,
66 TokenInvalidAudience,
67 TokenMissingClaim,
68 TokenKeyNotFound,
69 OriginMismatch,
70 OriginRequired,
71 OriginNotAllowed,
72 AuthRequired,
73 MissingAuthorizationHeader,
74 InvalidAuthorizationFormat,
75 InvalidApiKey,
76 ExpiredApiKey,
77 UserNotFound,
78 SecretKeyRequired,
79 DeploymentAccessDenied,
80 RateLimitExceeded,
81 WebSocketSessionRateLimitExceeded,
82 ConnectionLimitExceeded,
83 SubscriptionLimitExceeded,
84 SnapshotLimitExceeded,
85 EgressLimitExceeded,
86 QuotaExceeded,
87 InvalidStaticToken,
88 InternalError,
89}
90
91impl AuthErrorCode {
92 pub fn from_wire(code: &str) -> Option<Self> {
93 Some(match code.trim().to_ascii_lowercase().as_str() {
94 "token-missing" => Self::TokenMissing,
95 "token-expired" => Self::TokenExpired,
96 "token-invalid-signature" => Self::TokenInvalidSignature,
97 "token-invalid-format" => Self::TokenInvalidFormat,
98 "token-invalid-issuer" => Self::TokenInvalidIssuer,
99 "token-invalid-audience" => Self::TokenInvalidAudience,
100 "token-missing-claim" => Self::TokenMissingClaim,
101 "token-key-not-found" => Self::TokenKeyNotFound,
102 "origin-mismatch" => Self::OriginMismatch,
103 "origin-required" => Self::OriginRequired,
104 "origin-not-allowed" => Self::OriginNotAllowed,
105 "auth-required" => Self::AuthRequired,
106 "missing-authorization-header" => Self::MissingAuthorizationHeader,
107 "invalid-authorization-format" => Self::InvalidAuthorizationFormat,
108 "invalid-api-key" => Self::InvalidApiKey,
109 "expired-api-key" => Self::ExpiredApiKey,
110 "user-not-found" => Self::UserNotFound,
111 "secret-key-required" => Self::SecretKeyRequired,
112 "deployment-access-denied" => Self::DeploymentAccessDenied,
113 "rate-limit-exceeded" => Self::RateLimitExceeded,
114 "websocket-session-rate-limit-exceeded" => Self::WebSocketSessionRateLimitExceeded,
115 "connection-limit-exceeded" => Self::ConnectionLimitExceeded,
116 "subscription-limit-exceeded" => Self::SubscriptionLimitExceeded,
117 "snapshot-limit-exceeded" => Self::SnapshotLimitExceeded,
118 "egress-limit-exceeded" => Self::EgressLimitExceeded,
119 "quota-exceeded" => Self::QuotaExceeded,
120 "invalid-static-token" => Self::InvalidStaticToken,
121 "internal-error" => Self::InternalError,
122 _ => return None,
123 })
124 }
125
126 pub fn as_wire(self) -> &'static str {
127 match self {
128 Self::TokenMissing => "token-missing",
129 Self::TokenExpired => "token-expired",
130 Self::TokenInvalidSignature => "token-invalid-signature",
131 Self::TokenInvalidFormat => "token-invalid-format",
132 Self::TokenInvalidIssuer => "token-invalid-issuer",
133 Self::TokenInvalidAudience => "token-invalid-audience",
134 Self::TokenMissingClaim => "token-missing-claim",
135 Self::TokenKeyNotFound => "token-key-not-found",
136 Self::OriginMismatch => "origin-mismatch",
137 Self::OriginRequired => "origin-required",
138 Self::OriginNotAllowed => "origin-not-allowed",
139 Self::AuthRequired => "auth-required",
140 Self::MissingAuthorizationHeader => "missing-authorization-header",
141 Self::InvalidAuthorizationFormat => "invalid-authorization-format",
142 Self::InvalidApiKey => "invalid-api-key",
143 Self::ExpiredApiKey => "expired-api-key",
144 Self::UserNotFound => "user-not-found",
145 Self::SecretKeyRequired => "secret-key-required",
146 Self::DeploymentAccessDenied => "deployment-access-denied",
147 Self::RateLimitExceeded => "rate-limit-exceeded",
148 Self::WebSocketSessionRateLimitExceeded => "websocket-session-rate-limit-exceeded",
149 Self::ConnectionLimitExceeded => "connection-limit-exceeded",
150 Self::SubscriptionLimitExceeded => "subscription-limit-exceeded",
151 Self::SnapshotLimitExceeded => "snapshot-limit-exceeded",
152 Self::EgressLimitExceeded => "egress-limit-exceeded",
153 Self::QuotaExceeded => "quota-exceeded",
154 Self::InvalidStaticToken => "invalid-static-token",
155 Self::InternalError => "internal-error",
156 }
157 }
158
159 pub fn should_retry(self) -> bool {
160 matches!(self, Self::InternalError)
161 }
162
163 pub fn should_refresh_token(self) -> bool {
164 matches!(
165 self,
166 Self::TokenExpired
167 | Self::TokenInvalidSignature
168 | Self::TokenInvalidFormat
169 | Self::TokenInvalidIssuer
170 | Self::TokenInvalidAudience
171 | Self::TokenKeyNotFound
172 )
173 }
174}
175
176impl std::fmt::Display for AuthErrorCode {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 f.write_str(self.as_wire())
179 }
180}
181
182#[derive(Error, Debug, Clone)]
183pub enum HyperStackError {
184 #[error("Missing WebSocket URL")]
185 MissingUrl,
186
187 #[error("Connection failed: {0}")]
188 ConnectionFailed(String),
189
190 #[error("WebSocket error: {message}")]
191 WebSocket {
192 message: String,
193 code: Option<AuthErrorCode>,
194 },
195
196 #[error("WebSocket handshake rejected ({status}): {message}")]
197 HandshakeRejected {
198 status: u16,
199 message: String,
200 code: Option<AuthErrorCode>,
201 },
202
203 #[error("Authentication request failed ({status}): {message}")]
204 AuthRequestFailed {
205 status: u16,
206 message: String,
207 code: Option<AuthErrorCode>,
208 },
209
210 #[error("WebSocket closed by server: {message}")]
211 ServerClosed {
212 message: String,
213 code: Option<AuthErrorCode>,
214 },
215
216 #[error("Socket issue: {0}")]
217 SocketIssue(SocketIssue),
218
219 #[error("JSON serialization error: {0}")]
220 Serialization(String),
221
222 #[error("Max reconnection attempts reached ({0})")]
223 MaxReconnectAttempts(u32),
224
225 #[error("Connection closed")]
226 ConnectionClosed,
227
228 #[error("Subscription failed: {0}")]
229 SubscriptionFailed(String),
230
231 #[error("Channel error: {0}")]
232 ChannelError(String),
233}
234
235#[derive(Debug, Deserialize)]
236struct ErrorPayload {
237 error: Option<String>,
238 code: Option<String>,
239}
240
241impl HyperStackError {
242 pub fn auth_code(&self) -> Option<AuthErrorCode> {
243 match self {
244 Self::WebSocket { code, .. }
245 | Self::HandshakeRejected { code, .. }
246 | Self::AuthRequestFailed { code, .. }
247 | Self::ServerClosed { code, .. } => *code,
248 Self::SocketIssue(issue) => issue.code,
249 _ => None,
250 }
251 }
252
253 pub fn socket_issue(&self) -> Option<&SocketIssue> {
254 match self {
255 Self::SocketIssue(issue) => Some(issue),
256 _ => None,
257 }
258 }
259
260 pub fn should_retry(&self) -> bool {
261 match self {
262 Self::HandshakeRejected { status, code, .. }
263 | Self::AuthRequestFailed { status, code, .. } => code
264 .map(AuthErrorCode::should_retry)
265 .unwrap_or(*status >= 500),
266 Self::ServerClosed { code, .. } | Self::WebSocket { code, .. } => {
267 code.map(AuthErrorCode::should_retry).unwrap_or(true)
268 }
269 Self::SocketIssue(issue) => issue.retryable,
270 Self::ConnectionFailed(_) | Self::ConnectionClosed => true,
271 Self::MissingUrl
272 | Self::Serialization(_)
273 | Self::MaxReconnectAttempts(_)
274 | Self::SubscriptionFailed(_)
275 | Self::ChannelError(_) => false,
276 }
277 }
278
279 pub fn should_refresh_token(&self) -> bool {
280 self.auth_code()
281 .map(AuthErrorCode::should_refresh_token)
282 .unwrap_or(false)
283 }
284
285 pub(crate) fn from_tungstenite(error: tungstenite::Error) -> Self {
286 match error {
287 tungstenite::Error::Http(response) => Self::from_http_response(response),
288 other => Self::WebSocket {
289 message: other.to_string(),
290 code: None,
291 },
292 }
293 }
294
295 pub(crate) fn from_http_response(response: Response<Option<Vec<u8>>>) -> Self {
296 let status = response.status().as_u16();
297 let header_code = response
298 .headers()
299 .get("X-Error-Code")
300 .and_then(|value| value.to_str().ok())
301 .and_then(AuthErrorCode::from_wire);
302 let (body_message, body_code) = parse_error_payload(response.body().as_deref());
303 let code = header_code.or(body_code);
304
305 let message = body_message.unwrap_or_else(|| {
306 response
307 .status()
308 .canonical_reason()
309 .unwrap_or("WebSocket handshake rejected")
310 .to_string()
311 });
312
313 Self::HandshakeRejected {
314 status,
315 message,
316 code,
317 }
318 }
319
320 pub(crate) fn from_auth_response(
321 status: u16,
322 header_code: Option<&str>,
323 body: Option<&[u8]>,
324 fallback_message: Option<&str>,
325 ) -> Self {
326 let header_code = header_code.and_then(AuthErrorCode::from_wire);
327 let (body_message, body_code) = parse_error_payload(body);
328 let code = header_code.or(body_code);
329 let message = body_message.unwrap_or_else(|| {
330 fallback_message
331 .unwrap_or("Authentication request failed")
332 .to_string()
333 });
334
335 Self::AuthRequestFailed {
336 status,
337 message,
338 code,
339 }
340 }
341
342 pub(crate) fn from_close_reason(reason: &str) -> Option<Self> {
343 let trimmed = reason.trim();
344 if trimmed.is_empty() {
345 return None;
346 }
347
348 let (code, message) = parse_close_reason(trimmed);
349 Some(Self::ServerClosed { code, message })
350 }
351
352 pub(crate) fn from_socket_issue(issue: SocketIssue) -> Self {
353 Self::SocketIssue(issue)
354 }
355}
356
357impl From<serde_json::Error> for HyperStackError {
358 fn from(value: serde_json::Error) -> Self {
359 Self::Serialization(value.to_string())
360 }
361}
362
363impl From<tungstenite::Error> for HyperStackError {
364 fn from(value: tungstenite::Error) -> Self {
365 Self::from_tungstenite(value)
366 }
367}
368
369fn parse_error_payload(body: Option<&[u8]>) -> (Option<String>, Option<AuthErrorCode>) {
370 let Some(body) = body.filter(|value| !value.is_empty()) else {
371 return (None, None);
372 };
373
374 if let Ok(payload) = serde_json::from_slice::<ErrorPayload>(body) {
375 let code = payload.code.as_deref().and_then(AuthErrorCode::from_wire);
376 let message = payload.error.map(|value| value.trim().to_string());
377 return (message.filter(|value| !value.is_empty()), code);
378 }
379
380 let message = String::from_utf8_lossy(body).trim().to_string();
381 if message.is_empty() {
382 (None, None)
383 } else {
384 (Some(message), None)
385 }
386}
387
388fn parse_close_reason(reason: &str) -> (Option<AuthErrorCode>, String) {
389 if let Some((wire_code, message)) = reason.split_once(':') {
390 let code = AuthErrorCode::from_wire(wire_code);
391 let message = message.trim();
392
393 if code.is_some() && !message.is_empty() {
394 return (code, message.to_string());
395 }
396 }
397
398 (None, reason.trim().to_string())
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn parses_platform_handshake_rejection() {
407 let response = Response::builder()
408 .status(403)
409 .header("X-Error-Code", "origin-required")
410 .body(Some(
411 br#"{"error":"Publishable key requires Origin header","code":"origin-required"}"#
412 .to_vec(),
413 ))
414 .expect("response should build");
415
416 let error = HyperStackError::from_http_response(response);
417 assert!(matches!(
418 error,
419 HyperStackError::HandshakeRejected {
420 status: 403,
421 code: Some(AuthErrorCode::OriginRequired),
422 ..
423 }
424 ));
425 assert!(!error.should_retry());
426 }
427
428 #[test]
429 fn parses_token_endpoint_error_response() {
430 let error = HyperStackError::from_auth_response(
431 429,
432 Some("websocket-session-rate-limit-exceeded"),
433 Some(
434 br#"{"error":"WebSocket session mint rate limit exceeded","code":"websocket-session-rate-limit-exceeded"}"#,
435 ),
436 Some("Too Many Requests"),
437 );
438
439 assert!(matches!(
440 error,
441 HyperStackError::AuthRequestFailed {
442 status: 429,
443 code: Some(AuthErrorCode::WebSocketSessionRateLimitExceeded),
444 ..
445 }
446 ));
447 assert!(!error.should_retry());
448 }
449
450 #[test]
451 fn parses_rate_limit_close_reason() {
452 let error = HyperStackError::from_close_reason(
453 "websocket-session-rate-limit-exceeded: WebSocket session mint rate limit exceeded",
454 )
455 .expect("close reason should parse");
456
457 assert!(matches!(
458 error,
459 HyperStackError::ServerClosed {
460 code: Some(AuthErrorCode::WebSocketSessionRateLimitExceeded),
461 ..
462 }
463 ));
464 assert!(!error.should_retry());
465 }
466
467 #[test]
468 fn parses_unknown_close_reason_without_code() {
469 let error = HyperStackError::from_close_reason("server maintenance")
470 .expect("non-empty reason should be preserved");
471
472 assert!(matches!(
473 error,
474 HyperStackError::ServerClosed {
475 code: None,
476 ref message,
477 } if message == "server maintenance"
478 ));
479 }
480
481 #[test]
482 fn socket_issue_error_uses_issue_retryability() {
483 let error = HyperStackError::from_socket_issue(SocketIssue {
484 error: "subscription-limit-exceeded".to_string(),
485 message: "subscription limit exceeded".to_string(),
486 code: Some(AuthErrorCode::SubscriptionLimitExceeded),
487 retryable: false,
488 retry_after: None,
489 suggested_action: Some("unsubscribe first".to_string()),
490 docs_url: None,
491 fatal: false,
492 });
493
494 assert!(!error.should_retry());
495 assert!(
496 matches!(error.socket_issue(), Some(issue) if issue.message == "subscription limit exceeded")
497 );
498 }
499}