rustapi-extras 0.1.470

Production-ready middleware collection for RustAPI. Includes JWT auth, CORS, Rate Limiting, SQLx integration, and OpenTelemetry observability.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
//! Circuit breaker middleware for resilient service calls
//!
//! This module implements the circuit breaker pattern to prevent cascading failures
//! and give failing services time to recover.
//!
//! # States
//!
//! - **Closed**: Normal operation, requests pass through
//! - **Open**: Too many failures, requests fail fast
//! - **HalfOpen**: Testing if service recovered
//!
//! # Example
//!
//! ```rust,no_run
//! use rustapi_core::RustApi;
//! use rustapi_extras::CircuitBreakerLayer;
//! use std::time::Duration;
//!
//! #[tokio::main]
//! async fn main() {
//!     let app = RustApi::new()
//!         .layer(
//!             CircuitBreakerLayer::new()
//!                 .failure_threshold(5)
//!                 .timeout(Duration::from_secs(30))
//!         )
//!         .run("0.0.0.0:3000")
//!         .await
//!         .unwrap();
//! }
//! ```

use rustapi_core::{
    middleware::{BoxedNext, MiddlewareLayer},
    Request, Response, ResponseBody,
};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;

/// Circuit breaker state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
    /// Circuit is closed, requests pass through normally
    Closed,
    /// Circuit is open, requests fail fast
    Open,
    /// Circuit is half-open, testing if service recovered
    HalfOpen,
}

/// Circuit breaker configuration
#[derive(Clone)]
pub struct CircuitBreakerConfig {
    /// Number of failures before opening the circuit
    pub failure_threshold: usize,
    /// Duration to wait before transitioning from Open to HalfOpen
    pub timeout: Duration,
    /// Number of successful requests in HalfOpen state before closing
    pub success_threshold: usize,
}

impl Default for CircuitBreakerConfig {
    fn default() -> Self {
        Self {
            failure_threshold: 5,
            timeout: Duration::from_secs(60),
            success_threshold: 2,
        }
    }
}

/// Circuit breaker state tracker
struct CircuitBreakerState {
    state: CircuitState,
    failure_count: usize,
    success_count: usize,
    last_failure_time: Option<Instant>,
    total_requests: u64,
    total_failures: u64,
    total_successes: u64,
}

impl Default for CircuitBreakerState {
    fn default() -> Self {
        Self {
            state: CircuitState::Closed,
            failure_count: 0,
            success_count: 0,
            last_failure_time: None,
            total_requests: 0,
            total_failures: 0,
            total_successes: 0,
        }
    }
}

/// Circuit break middleware layer
#[derive(Clone)]
pub struct CircuitBreakerLayer {
    config: CircuitBreakerConfig,
    state: Arc<RwLock<CircuitBreakerState>>,
}

impl CircuitBreakerLayer {
    /// Create a new circuit breaker with default configuration
    pub fn new() -> Self {
        Self {
            config: CircuitBreakerConfig::default(),
            state: Arc::new(RwLock::new(CircuitBreakerState::default())),
        }
    }

    /// Set the failure threshold
    pub fn failure_threshold(mut self, threshold: usize) -> Self {
        self.config.failure_threshold = threshold;
        self
    }

    /// Set the timeout before transitioning to half-open
    pub fn timeout(mut self, timeout: Duration) -> Self {
        self.config.timeout = timeout;
        self
    }

    /// Set the success threshold in half-open state
    pub fn success_threshold(mut self, threshold: usize) -> Self {
        self.config.success_threshold = threshold;
        self
    }

    /// Get the current circuit state
    pub async fn get_state(&self) -> CircuitState {
        self.state.read().await.state
    }

    /// Get circuit breaker statistics
    pub async fn get_stats(&self) -> CircuitBreakerStats {
        let state = self.state.read().await;
        CircuitBreakerStats {
            state: state.state,
            total_requests: state.total_requests,
            total_failures: state.total_failures,
            total_successes: state.total_successes,
            failure_count: state.failure_count,
            success_count: state.success_count,
        }
    }

    /// Reset the circuit breaker
    pub async fn reset(&self) {
        let mut state = self.state.write().await;
        *state = CircuitBreakerState::default();
    }
}

impl Default for CircuitBreakerLayer {
    fn default() -> Self {
        Self::new()
    }
}

/// Circuit breaker statistics
#[derive(Debug, Clone)]
pub struct CircuitBreakerStats {
    /// Current state
    pub state: CircuitState,
    /// Total requests processed
    pub total_requests: u64,
    /// Total failures
    pub total_failures: u64,
    /// Total successes
    pub total_successes: u64,
    /// Current failure count
    pub failure_count: usize,
    /// Current success count (in half-open state)
    pub success_count: usize,
}

impl MiddlewareLayer for CircuitBreakerLayer {
    fn call(
        &self,
        req: Request,
        next: BoxedNext,
    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
        let config = self.config.clone();
        let state = self.state.clone();

        Box::pin(async move {
            // Check current state
            let mut state_guard = state.write().await;
            state_guard.total_requests += 1;

            match state_guard.state {
                CircuitState::Open => {
                    // Check if timeout has elapsed
                    if let Some(last_failure) = state_guard.last_failure_time {
                        if last_failure.elapsed() >= config.timeout {
                            // Transition to half-open
                            tracing::info!("Circuit breaker transitioning to HalfOpen");
                            state_guard.state = CircuitState::HalfOpen;
                            state_guard.success_count = 0;
                        } else {
                            // Still open, fail fast
                            drop(state_guard);
                            return http::Response::builder()
                                .status(503)
                                .header("Content-Type", "application/json")
                                .body(ResponseBody::Full(http_body_util::Full::new(
                                    bytes::Bytes::from(
                                        serde_json::json!({
                                            "error": {
                                                "type": "service_unavailable",
                                                "message": "Circuit breaker is OPEN"
                                            }
                                        })
                                        .to_string(),
                                    ),
                                )))
                                .unwrap();
                        }
                    }
                }
                CircuitState::HalfOpen => {
                    // Allow request but monitor closely
                }
                CircuitState::Closed => {
                    // Normal operation
                }
            }

            drop(state_guard);

            // Execute request
            let response = next(req).await;

            // Update state based on result
            let mut state_guard = state.write().await;

            // Check if response indicates success (2xx status)
            if response.status().is_success() {
                state_guard.total_successes += 1;

                match state_guard.state {
                    CircuitState::HalfOpen => {
                        state_guard.success_count += 1;
                        if state_guard.success_count >= config.success_threshold {
                            // Transition to closed
                            tracing::info!("Circuit breaker transitioning to Closed");
                            state_guard.state = CircuitState::Closed;
                            state_guard.failure_count = 0;
                            state_guard.success_count = 0;
                        }
                    }
                    CircuitState::Closed => {
                        // Reset failure count on success
                        state_guard.failure_count = 0;
                    }
                    _ => {}
                }
            } else {
                // Non-2xx status is treated as failure
                record_failure(&mut state_guard, &config);
            }

            response
        })
    }

    fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
        Box::new(self.clone())
    }
}

fn record_failure(state: &mut CircuitBreakerState, config: &CircuitBreakerConfig) {
    state.total_failures += 1;
    state.failure_count += 1;
    state.last_failure_time = Some(Instant::now());

    match state.state {
        CircuitState::Closed => {
            if state.failure_count >= config.failure_threshold {
                // Open the circuit
                tracing::warn!(
                    "Circuit breaker OPENING after {} failures",
                    state.failure_count
                );
                state.state = CircuitState::Open;
            }
        }
        CircuitState::HalfOpen => {
            // Failed in half-open, go back to open
            tracing::warn!("Circuit breaker returning to OPEN state");
            state.state = CircuitState::Open;
            state.success_count = 0;
        }
        _ => {}
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use bytes::Bytes;
    use std::sync::Arc;

    #[tokio::test]
    async fn circuit_breaker_opens_after_threshold() {
        let breaker = CircuitBreakerLayer::new()
            .failure_threshold(3)
            .timeout(Duration::from_secs(1));

        // Create a handler that always fails
        let next: BoxedNext = Arc::new(|_req: Request| {
            Box::pin(async {
                http::Response::builder()
                    .status(500)
                    .body(ResponseBody::Full(http_body_util::Full::new(
                        bytes::Bytes::from("Error"),
                    )))
                    .unwrap()
            }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
        });

        // Make requests that fail
        for _ in 0..3 {
            let req = http::Request::builder()
                .method("GET")
                .uri("/")
                .body(())
                .unwrap();
            let req = Request::from_http_request(req, Bytes::new());

            let _ = breaker.call(req, next.clone()).await;
        }

        // Circuit should be open now
        let state = breaker.get_state().await;
        assert_eq!(state, CircuitState::Open);

        // Next request should fail fast
        let req = http::Request::builder()
            .method("GET")
            .uri("/")
            .body(())
            .unwrap();
        let req = Request::from_http_request(req, Bytes::new());

        let response = breaker.call(req, next.clone()).await;
        assert_eq!(response.status(), 503);
    }

    #[tokio::test]
    async fn circuit_breaker_recovers() {
        let breaker = CircuitBreakerLayer::new()
            .failure_threshold(2)
            .timeout(Duration::from_millis(100))
            .success_threshold(2);

        // Fail requests to open circuit
        let fail_next: BoxedNext = Arc::new(|_req: Request| {
            Box::pin(async {
                http::Response::builder()
                    .status(500)
                    .body(ResponseBody::Full(http_body_util::Full::new(
                        bytes::Bytes::from("Error"),
                    )))
                    .unwrap()
            }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
        });

        for _ in 0..2 {
            let req = http::Request::builder()
                .method("GET")
                .uri("/")
                .body(())
                .unwrap();
            let req = Request::from_http_request(req, Bytes::new());
            let _ = breaker.call(req, fail_next.clone()).await;
        }

        assert_eq!(breaker.get_state().await, CircuitState::Open);

        // Wait for timeout
        tokio::time::sleep(Duration::from_millis(150)).await;

        // Make successful requests
        let success_next: BoxedNext = Arc::new(|_req: Request| {
            Box::pin(async {
                http::Response::builder()
                    .status(200)
                    .body(ResponseBody::Full(http_body_util::Full::new(
                        bytes::Bytes::from("OK"),
                    )))
                    .unwrap()
            }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
        });

        for _ in 0..2 {
            let req = http::Request::builder()
                .method("GET")
                .uri("/")
                .body(())
                .unwrap();
            let req = Request::from_http_request(req, Bytes::new());
            let result = breaker.call(req, success_next.clone()).await;
            assert!(result.status().is_success());
        }

        // Circuit should be closed now
        let state = breaker.get_state().await;
        assert_eq!(state, CircuitState::Closed);
    }
}