1use crate::transport::{TransportError, ConnectionState};
6use crate::rpc::RpcError;
7use crate::codec::CodecError;
8use std::time::{Duration, Instant};
10use std::collections::HashMap;
11use serde::{Serialize, Deserialize};
12use thiserror::Error;
13
14#[derive(Debug, Error)]
16pub enum LeptosWsError {
17 #[error("Transport error: {source}")]
18 Transport {
19 source: TransportError,
20 context: ErrorContext,
21 recovery: RecoveryStrategy,
22 },
23
24 #[error("RPC error: {source}")]
25 Rpc {
26 source: RpcError,
27 context: ErrorContext,
28 recovery: RecoveryStrategy,
29 },
30
31 #[error("Codec error: {source}")]
32 Codec {
33 source: CodecError,
34 context: ErrorContext,
35 recovery: RecoveryStrategy,
36 },
37
38 #[error("Configuration error: {message}")]
39 Configuration {
40 message: String,
41 field: String,
42 expected: String,
43 actual: String,
44 },
45
46 #[error("Security error: {message}")]
47 Security {
48 message: String,
49 threat_level: ThreatLevel,
50 context: ErrorContext,
51 },
52
53 #[error("Rate limit exceeded: {message}")]
54 RateLimit {
55 message: String,
56 retry_after: Option<Duration>,
57 context: ErrorContext,
58 },
59
60 #[error("Internal error: {message}")]
61 Internal {
62 message: String,
63 context: ErrorContext,
64 should_report: bool,
65 },
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct ErrorContext {
71 pub timestamp: u64,
72 pub operation: String,
73 pub component: String,
74 pub connection_state: Option<ConnectionState>,
75 pub attempt_number: u32,
76 pub user_data: Option<serde_json::Value>,
77 pub session_id: Option<String>,
78 pub trace_id: Option<String>,
79 pub error_type: Option<ErrorType>,
80 pub message: Option<String>,
81 pub service: Option<String>,
82 pub correlation_id: Option<String>,
83 pub metadata: Option<HashMap<String, String>>,
84}
85
86#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
88pub enum ErrorType {
89 Network,
90 Timeout,
91 Authentication,
92 Authorization,
93 Validation,
94 Serialization,
95 Deserialization,
96 RateLimit,
97 CircuitBreaker,
98 ServiceUnavailable,
99 Internal,
100 Unknown,
101 Transport,
103 Rpc,
104 Codec,
105}
106
107impl ErrorContext {
108 pub fn new(operation: &str, component: &str) -> Self {
109 Self {
110 timestamp: std::time::SystemTime::now()
111 .duration_since(std::time::UNIX_EPOCH)
112 .unwrap_or_default()
113 .as_secs(),
114 operation: operation.to_string(),
115 component: component.to_string(),
116 connection_state: None,
117 attempt_number: 1,
118 user_data: None,
119 session_id: None,
120 trace_id: None,
121 error_type: None,
122 message: None,
123 service: None,
124 correlation_id: None,
125 metadata: None,
126 }
127 }
128
129 pub fn with_connection_state(mut self, state: ConnectionState) -> Self {
130 self.connection_state = Some(state);
131 self
132 }
133
134 pub fn with_attempt(mut self, attempt: u32) -> Self {
135 self.attempt_number = attempt;
136 self
137 }
138
139 pub fn with_trace_id(mut self, trace_id: String) -> Self {
140 self.trace_id = Some(trace_id);
141 self
142 }
143
144 pub fn with_session_id(mut self, session_id: String) -> Self {
145 self.session_id = Some(session_id);
146 self
147 }
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub enum RecoveryStrategy {
153 Retry {
155 max_attempts: u32,
156 base_delay: Duration,
157 max_delay: Duration,
158 jitter: bool,
159 },
160
161 Reconnect {
163 max_attempts: u32,
164 delay: Duration,
165 },
166
167 Fallback {
169 alternatives: Vec<String>,
170 },
171
172 Degrade {
174 reduced_functionality: Vec<String>,
175 duration: Duration,
176 },
177
178 Manual {
180 instructions: String,
181 support_contact: Option<String>,
182 },
183
184 Automatic {
186 estimated_time: Duration,
187 progress_callback: Option<String>,
188 },
189}
190
191#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd)]
193pub enum ThreatLevel {
194 Low,
195 Medium,
196 High,
197 Critical,
198}
199
200pub struct ErrorRecoveryHandler {
202 max_retry_attempts: u32,
203 base_retry_delay: Duration,
204 max_retry_delay: Duration,
205 jitter_enabled: bool,
206 circuit_breaker: CircuitBreaker,
207}
208
209impl ErrorRecoveryHandler {
210 pub fn new() -> Self {
211 Self {
212 max_retry_attempts: 3,
213 base_retry_delay: Duration::from_millis(100),
214 max_retry_delay: Duration::from_secs(30),
215 jitter_enabled: true,
216 circuit_breaker: CircuitBreaker::new(),
217 }
218 }
219
220 pub async fn handle_error<F, R>(&mut self,
222 error: LeptosWsError,
223 operation: F
224 ) -> Result<R, LeptosWsError>
225 where
226 F: Fn() -> Result<R, LeptosWsError> + Send + Sync,
227 R: Send + Sync,
228 {
229 match &error {
230 LeptosWsError::Transport { recovery, .. } => {
231 self.handle_transport_recovery(recovery, operation).await
232 },
233 LeptosWsError::Rpc { recovery, .. } => {
234 self.handle_rpc_recovery(recovery, operation).await
235 },
236 LeptosWsError::RateLimit { retry_after, .. } => {
237 self.handle_rate_limit(*retry_after, operation).await
238 },
239 _ => Err(error),
240 }
241 }
242
243 async fn handle_transport_recovery<F, R>(&mut self,
244 strategy: &RecoveryStrategy,
245 operation: F
246 ) -> Result<R, LeptosWsError>
247 where
248 F: Fn() -> Result<R, LeptosWsError> + Send + Sync,
249 R: Send + Sync,
250 {
251 match strategy {
252 RecoveryStrategy::Retry { max_attempts, base_delay, max_delay, jitter } => {
253 self.retry_with_backoff(*max_attempts, *base_delay, *max_delay, *jitter, operation).await
254 },
255 RecoveryStrategy::Reconnect { max_attempts, delay } => {
256 self.retry_with_reconnect(*max_attempts, *delay, operation).await
257 },
258 _ => Err(LeptosWsError::Internal {
259 message: "Recovery strategy not implemented".to_string(),
260 context: ErrorContext::new("recovery", "error_handler"),
261 should_report: true,
262 }),
263 }
264 }
265
266 async fn handle_rpc_recovery<F, R>(&mut self,
267 strategy: &RecoveryStrategy,
268 operation: F
269 ) -> Result<R, LeptosWsError>
270 where
271 F: Fn() -> Result<R, LeptosWsError> + Send + Sync,
272 R: Send + Sync,
273 {
274 self.handle_transport_recovery(strategy, operation).await
276 }
277
278 async fn handle_rate_limit<F, R>(&mut self,
279 retry_after: Option<Duration>,
280 operation: F
281 ) -> Result<R, LeptosWsError>
282 where
283 F: Fn() -> Result<R, LeptosWsError> + Send + Sync,
284 R: Send + Sync,
285 {
286 let delay = retry_after.unwrap_or(Duration::from_secs(1));
287 tokio::time::sleep(delay).await;
288 operation()
289 }
290
291 async fn retry_with_backoff<F, R>(&mut self,
292 max_attempts: u32,
293 base_delay: Duration,
294 max_delay: Duration,
295 jitter: bool,
296 operation: F
297 ) -> Result<R, LeptosWsError>
298 where
299 F: Fn() -> Result<R, LeptosWsError> + Send + Sync,
300 R: Send + Sync,
301 {
302 let mut attempt = 1;
303 let mut delay = base_delay;
304
305 loop {
306 if !self.circuit_breaker.allow_request() {
308 return Err(LeptosWsError::Internal {
309 message: "Circuit breaker is open".to_string(),
310 context: ErrorContext::new("retry", "error_handler"),
311 should_report: false,
312 });
313 }
314
315 match operation() {
316 Ok(result) => {
317 self.circuit_breaker.record_success();
318 return Ok(result);
319 },
320 Err(error) => {
321 self.circuit_breaker.record_failure();
322
323 if attempt >= max_attempts {
324 return Err(error);
325 }
326
327 let actual_delay = if jitter {
329 let jitter_amount = delay.as_millis() as f64 * 0.1;
330 let jitter_offset = (rand::random::<f64>() - 0.5) * 2.0 * jitter_amount;
331 Duration::from_millis((delay.as_millis() as f64 + jitter_offset) as u64)
332 } else {
333 delay
334 };
335
336 tokio::time::sleep(actual_delay).await;
337
338 delay = std::cmp::min(delay * 2, max_delay);
340 attempt += 1;
341 }
342 }
343 }
344 }
345
346 async fn retry_with_reconnect<F, R>(&mut self,
347 max_attempts: u32,
348 delay: Duration,
349 operation: F
350 ) -> Result<R, LeptosWsError>
351 where
352 F: Fn() -> Result<R, LeptosWsError> + Send + Sync,
353 R: Send + Sync,
354 {
355 for attempt in 1..=max_attempts {
356 tokio::time::sleep(delay).await;
358
359 match operation() {
360 Ok(result) => return Ok(result),
361 Err(error) => {
362 if attempt == max_attempts {
363 return Err(error);
364 }
365 }
366 }
367 }
368
369 unreachable!()
370 }
371}
372
373impl Default for ErrorRecoveryHandler {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379pub struct CircuitBreaker {
381 failure_count: u32,
382 success_count: u32,
383 last_failure_time: Option<Instant>,
384 state: CircuitBreakerState,
385 failure_threshold: u32,
386 timeout: Duration,
387}
388
389#[derive(Debug, Clone, Copy, PartialEq)]
390enum CircuitBreakerState {
391 Closed,
392 Open,
393 HalfOpen,
394}
395
396impl CircuitBreaker {
397 pub fn new() -> Self {
398 Self {
399 failure_count: 0,
400 success_count: 0,
401 last_failure_time: None,
402 state: CircuitBreakerState::Closed,
403 failure_threshold: 5,
404 timeout: Duration::from_secs(60),
405 }
406 }
407
408 pub fn allow_request(&mut self) -> bool {
409 match self.state {
410 CircuitBreakerState::Closed => true,
411 CircuitBreakerState::Open => {
412 if let Some(last_failure) = self.last_failure_time {
413 if Instant::now() - last_failure > self.timeout {
414 self.state = CircuitBreakerState::HalfOpen;
415 true
416 } else {
417 false
418 }
419 } else {
420 false
421 }
422 },
423 CircuitBreakerState::HalfOpen => true,
424 }
425 }
426
427 pub fn record_success(&mut self) {
428 self.success_count += 1;
429 self.failure_count = 0;
430
431 if self.state == CircuitBreakerState::HalfOpen {
432 self.state = CircuitBreakerState::Closed;
433 }
434 }
435
436 pub fn record_failure(&mut self) {
437 self.failure_count += 1;
438 self.last_failure_time = Some(Instant::now());
439
440 if self.failure_count >= self.failure_threshold {
441 self.state = CircuitBreakerState::Open;
442 }
443 }
444
445 pub fn get_state(&self) -> &str {
446 match self.state {
447 CircuitBreakerState::Closed => "closed",
448 CircuitBreakerState::Open => "open",
449 CircuitBreakerState::HalfOpen => "half-open",
450 }
451 }
452}
453
454impl Default for CircuitBreaker {
455 fn default() -> Self {
456 Self::new()
457 }
458}
459
460pub struct ErrorReporter {
462 enabled: bool,
463 endpoint: Option<String>,
464 api_key: Option<String>,
465}
466
467impl ErrorReporter {
468 pub fn new() -> Self {
469 Self {
470 enabled: false,
471 endpoint: None,
472 api_key: None,
473 }
474 }
475
476 pub fn configure(&mut self, endpoint: String, api_key: String) {
477 self.endpoint = Some(endpoint);
478 self.api_key = Some(api_key);
479 self.enabled = true;
480 }
481
482 pub async fn report_error(&self, error: &LeptosWsError) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
483 if !self.enabled {
484 return Ok(());
485 }
486
487 let error_data = serde_json::json!({
489 "error_type": self.get_error_type(error),
490 "message": error.to_string(),
491 "timestamp": std::time::SystemTime::now()
492 .duration_since(std::time::UNIX_EPOCH)
493 .unwrap_or_default()
494 .as_secs(),
495 "context": self.extract_context(error),
496 });
497
498 tracing::error!("Would report error: {}", error_data);
500
501 Ok(())
502 }
503
504 fn get_error_type(&self, error: &LeptosWsError) -> &'static str {
505 match error {
506 LeptosWsError::Transport { .. } => "transport",
507 LeptosWsError::Rpc { .. } => "rpc",
508 LeptosWsError::Codec { .. } => "codec",
509 LeptosWsError::Configuration { .. } => "configuration",
510 LeptosWsError::Security { .. } => "security",
511 LeptosWsError::RateLimit { .. } => "rate_limit",
512 LeptosWsError::Internal { .. } => "internal",
513 }
514 }
515
516 fn extract_context<'a>(&self, error: &'a LeptosWsError) -> Option<&'a ErrorContext> {
517 match error {
518 LeptosWsError::Transport { context, .. } => Some(context),
519 LeptosWsError::Rpc { context, .. } => Some(context),
520 LeptosWsError::Codec { context, .. } => Some(context),
521 LeptosWsError::Security { context, .. } => Some(context),
522 LeptosWsError::RateLimit { context, .. } => Some(context),
523 LeptosWsError::Internal { context, .. } => Some(context),
524 _ => None,
525 }
526 }
527}
528
529impl Default for ErrorReporter {
530 fn default() -> Self {
531 Self::new()
532 }
533}
534
535impl From<TransportError> for LeptosWsError {
537 fn from(source: TransportError) -> Self {
538 LeptosWsError::Transport {
539 source,
540 context: ErrorContext::new("transport", "transport"),
541 recovery: RecoveryStrategy::Retry {
542 max_attempts: 3,
543 base_delay: Duration::from_millis(100),
544 max_delay: Duration::from_secs(10),
545 jitter: true,
546 },
547 }
548 }
549}
550
551impl From<RpcError> for LeptosWsError {
552 fn from(source: RpcError) -> Self {
553 LeptosWsError::Rpc {
554 source,
555 context: ErrorContext::new("rpc", "rpc"),
556 recovery: RecoveryStrategy::Retry {
557 max_attempts: 2,
558 base_delay: Duration::from_millis(50),
559 max_delay: Duration::from_secs(5),
560 jitter: false,
561 },
562 }
563 }
564}
565
566impl From<CodecError> for LeptosWsError {
567 fn from(source: CodecError) -> Self {
568 LeptosWsError::Codec {
569 source,
570 context: ErrorContext::new("codec", "codec"),
571 recovery: RecoveryStrategy::Manual {
572 instructions: "Check message format and codec configuration".to_string(),
573 support_contact: None,
574 },
575 }
576 }
577}
578
579#[macro_export]
581macro_rules! transport_error {
582 ($source:expr, $operation:expr, $component:expr) => {
583 LeptosWsError::Transport {
584 source: $source,
585 context: ErrorContext::new($operation, $component),
586 recovery: RecoveryStrategy::Retry {
587 max_attempts: 3,
588 base_delay: Duration::from_millis(100),
589 max_delay: Duration::from_secs(10),
590 jitter: true,
591 },
592 }
593 };
594}
595
596#[macro_export]
597macro_rules! rpc_error {
598 ($source:expr, $operation:expr) => {
599 LeptosWsError::Rpc {
600 source: $source,
601 context: ErrorContext::new($operation, "rpc"),
602 recovery: RecoveryStrategy::Retry {
603 max_attempts: 2,
604 base_delay: Duration::from_millis(50),
605 max_delay: Duration::from_secs(5),
606 jitter: false,
607 },
608 }
609 };
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615
616 #[test]
617 fn test_error_context_creation() {
618 let context = ErrorContext::new("test_operation", "test_component");
619 assert_eq!(context.operation, "test_operation");
620 assert_eq!(context.component, "test_component");
621 assert_eq!(context.attempt_number, 1);
622 assert!(context.timestamp > 0);
623 }
624
625 #[test]
626 fn test_circuit_breaker() {
627 let mut cb = CircuitBreaker::new();
628
629 assert!(cb.allow_request());
631 assert_eq!(cb.get_state(), "closed");
632
633 for _ in 0..5 {
635 cb.record_failure();
636 }
637
638 assert_eq!(cb.get_state(), "open");
640 assert!(!cb.allow_request());
641
642 cb.record_success();
645 assert_eq!(cb.get_state(), "open");
647 assert!(!cb.allow_request());
648 }
649
650 #[tokio::test]
651 async fn test_error_recovery_basic() {
652 let mut handler = ErrorRecoveryHandler::new();
653 let attempt_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
654 let attempt_count_clone = attempt_count.clone();
655
656 let operation = move || {
657 let count = attempt_count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
658 if count < 2 {
659 Err(LeptosWsError::Internal {
660 message: "Temporary failure".to_string(),
661 context: ErrorContext::new("test", "test"),
662 should_report: false,
663 })
664 } else {
665 Ok("Success!")
666 }
667 };
668
669 let error = LeptosWsError::Internal {
670 message: "Initial failure".to_string(),
671 context: ErrorContext::new("test", "test"),
672 should_report: false,
673 };
674
675 let result = handler.handle_error(error, operation).await;
678 assert!(result.is_err());
679 }
680}