sockudo 3.1.0

A simple, fast, and secure WebSocket server for real-time applications.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
// src/rate_limiter/middleware.rs
use crate::rate_limiter::{RateLimitResult, RateLimiter};
use axum::{
    body::Body as AxumBody,
    extract::ConnectInfo,
    http::{HeaderMap, HeaderName, HeaderValue, Request as AxumRequest, StatusCode},
    response::{IntoResponse, Response as AxumResponse},
};
use futures_util::future::BoxFuture;
use hyper::Request as HyperRequest;
use sonic_rs::json;
use std::{
    fmt,
    net::SocketAddr,
    sync::Arc,
    task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;
use tracing::{debug, error, warn};

#[derive(Debug)]
pub enum RateLimitMiddlewareError {
    InvalidHeaderName(String),
    ExtractionFailed(String),
}

impl fmt::Display for RateLimitMiddlewareError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            RateLimitMiddlewareError::InvalidHeaderName(e) => {
                write!(f, "Invalid header name for key extraction: {e}")
            }
            RateLimitMiddlewareError::ExtractionFailed(e) => {
                write!(f, "Key extraction failed: {e}")
            }
        }
    }
}

impl std::error::Error for RateLimitMiddlewareError {}

// Define header names as constants
const HEADER_LIMIT: HeaderName = HeaderName::from_static("x-ratelimit-limit");
const HEADER_REMAINING: HeaderName = HeaderName::from_static("x-ratelimit-remaining");
const HEADER_RESET: HeaderName = HeaderName::from_static("x-ratelimit-reset");
const HEADER_RETRY_AFTER: HeaderName = HeaderName::from_static("retry-after");

#[derive(Debug, Clone)]
pub struct RateLimitOptions {
    pub include_headers: bool,
    pub fail_open: bool,
    pub key_prefix: Option<String>,
}

impl Default for RateLimitOptions {
    fn default() -> Self {
        Self {
            include_headers: true,
            fail_open: true,
            key_prefix: None,
        }
    }
}

#[derive(Clone)]
pub struct RateLimitLayer<K> {
    limiter: Arc<dyn RateLimiter>,
    key_extractor: Arc<K>,
    options: RateLimitOptions,
    metrics: Option<Arc<tokio::sync::Mutex<dyn crate::metrics::MetricsInterface + Send + Sync>>>,
    config_name: String, // Track which rate limit config this is using
}

impl<K> RateLimitLayer<K>
where
    K: KeyExtractor + Clone + Send + Sync + 'static,
{
    pub fn new(limiter: Arc<dyn RateLimiter>, key_extractor: K) -> Self {
        Self::with_options(limiter, key_extractor, RateLimitOptions::default())
    }

    #[allow(dead_code)]
    pub fn with_options(
        limiter: Arc<dyn RateLimiter>,
        key_extractor: K,
        options: RateLimitOptions,
    ) -> Self {
        Self {
            limiter,
            key_extractor: Arc::new(key_extractor),
            options,
            metrics: None,
            config_name: "unknown".to_string(),
        }
    }

    pub fn with_config_name(mut self, config_name: String) -> Self {
        self.config_name = config_name;
        self
    }

    pub fn with_metrics(
        mut self,
        metrics: Arc<tokio::sync::Mutex<dyn crate::metrics::MetricsInterface + Send + Sync>>,
    ) -> Self {
        self.metrics = Some(metrics);
        self
    }
}

impl<S, K> Layer<S> for RateLimitLayer<K>
where
    S: Clone + Send + 'static,
    S: Service<AxumRequest<AxumBody>, Response = AxumResponse>,
    S::Future: Send + 'static,
    K: KeyExtractor + Clone + Send + Sync + 'static,
{
    type Service = RateLimitService<S, K>;

    fn layer(&self, inner: S) -> Self::Service {
        RateLimitService {
            inner,
            limiter: self.limiter.clone(),
            key_extractor: self.key_extractor.clone(),
            options: self.options.clone(),
            metrics: self.metrics.clone(),
            config_name: self.config_name.clone(),
        }
    }
}

#[derive(Clone)]
pub struct RateLimitService<S, K> {
    inner: S,
    limiter: Arc<dyn RateLimiter>,
    key_extractor: Arc<K>,
    options: RateLimitOptions,
    metrics: Option<Arc<tokio::sync::Mutex<dyn crate::metrics::MetricsInterface + Send + Sync>>>,
    config_name: String,
}

impl<S, K> Service<AxumRequest<AxumBody>> for RateLimitService<S, K>
where
    S: Clone + Send + 'static,
    S: Service<AxumRequest<AxumBody>, Response = AxumResponse>,
    S::Future: Send + 'static,
    S::Error: IntoResponse + Send,
    K: KeyExtractor + Send + Sync + 'static,
{
    type Response = AxumResponse;
    type Error = S::Error;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: AxumRequest<AxumBody>) -> Self::Future {
        let limiter = self.limiter.clone();
        let key_extractor = self.key_extractor.clone();
        let options = self.options.clone();
        let metrics = self.metrics.clone();
        let config_name = self.config_name.clone();
        let mut inner = self.inner.clone();

        Box::pin(async move {
            let key = match key_extractor.extract(&req) {
                Ok(k) => k,
                Err(e) => {
                    error!("Failed to extract key for rate limiting: {}", e);
                    return Ok(internal_server_error_response_with_message(
                        "Key extraction failed for rate limiting.",
                    ));
                }
            };

            debug!(key = %key, "Extracted rate limit key");

            let final_key = if let Some(prefix) = &options.key_prefix {
                format!("{prefix}:{key}")
            } else {
                key
            };
            debug!(final_key = %final_key, "Final rate limit key");

            // Use the config name as the primary limiter type
            let primary_limiter_type = &config_name;

            // Track request context for additional granularity
            let path = req.uri().path();
            let request_context = if path.starts_with("/app/") {
                "websocket_upgrade"
            } else if path.starts_with("/apps/") {
                "http_api"
            } else if path.starts_with("/up/") {
                "health_check"
            } else {
                "other"
            };

            // Track rate limit check with config name
            if let Some(ref metrics) = metrics {
                let metrics_locked = metrics.lock().await;
                // Use "global" as app_id for IP-based rate limiting
                metrics_locked.mark_rate_limit_check_with_context(
                    "global",
                    primary_limiter_type,
                    request_context,
                );
            }

            let rate_limit_result = match limiter.increment(&final_key).await {
                Ok(result) => result,
                Err(e) => {
                    error!("Rate limiter backend error for key '{}': {}", final_key, e);
                    if options.fail_open {
                        warn!("{}", "Rate limiter failed open");
                        RateLimitResult {
                            allowed: true,
                            remaining: 0,
                            reset_after: 0,
                            limit: 0,
                        }
                    } else {
                        error!(key = %final_key, "Rate limiter failed closed");
                        return Ok(internal_server_error_response_with_message(
                            "Rate limiter backend unavailable.",
                        ));
                    }
                }
            };

            if !rate_limit_result.allowed {
                debug!(key = %final_key, "Rate limit exceeded for config: {}", config_name);

                // Track rate limit triggered with config name
                if let Some(ref metrics) = metrics {
                    let metrics_locked = metrics.lock().await;
                    metrics_locked.mark_rate_limit_triggered_with_context(
                        "global",
                        primary_limiter_type,
                        request_context,
                    );
                }

                return Ok(rate_limit_error_response(Some(&rate_limit_result)));
            }

            debug!(key = %final_key, "Rate limit check passed");
            let result = inner.call(req).await;

            match result {
                Ok(mut response) => {
                    // Only add headers if include_headers is true AND the limiter didn't fail open with dummy values
                    if options.include_headers && rate_limit_result.limit > 0 {
                        add_rate_limit_headers(response.headers_mut(), &rate_limit_result, false);
                    }
                    Ok(response)
                }
                Err(err) => Err(err),
            }
        })
    }
}

// --- Key Extractors ---

pub trait KeyExtractor: Send + Sync {
    fn extract<B>(&self, req: &HyperRequest<B>) -> Result<String, RateLimitMiddlewareError>;
}

#[derive(Clone, Debug)]
pub struct IpKeyExtractor {
    trust_hops: usize,
}

impl IpKeyExtractor {
    pub fn new(trust_hops: usize) -> Self {
        Self { trust_hops }
    }

    fn get_ip<B>(&self, req: &HyperRequest<B>) -> Option<String> {
        if self.trust_hops > 0
            && let Some(value) = req.headers().get("x-forwarded-for")
            && let Ok(forwarded_str) = value.to_str()
        {
            let ips: Vec<&str> = forwarded_str.split(',').map(str::trim).collect();
            let client_ip_index = ips.len().saturating_sub(self.trust_hops);
            if let Some(ip_str) = ips.get(client_ip_index) {
                if ip_str.parse::<std::net::IpAddr>().is_ok() {
                    return Some(ip_str.to_string());
                }
            } else if let Some(ip_str) = ips.first()
                && ip_str.parse::<std::net::IpAddr>().is_ok()
            {
                return Some(ip_str.to_string());
            }
        }

        if let Some(value) = req.headers().get("x-real-ip")
            && let Ok(real_ip_str) = value.to_str()
        {
            let real_ip = real_ip_str.trim();
            if real_ip.parse::<std::net::IpAddr>().is_ok() {
                return Some(real_ip.to_string());
            }
        }

        req.extensions()
            .get::<ConnectInfo<SocketAddr>>()
            .map(|ConnectInfo(addr)| addr.ip().to_string())
    }
}

impl Default for IpKeyExtractor {
    fn default() -> Self {
        Self::new(0)
    }
}

impl KeyExtractor for IpKeyExtractor {
    fn extract<B>(&self, req: &HyperRequest<B>) -> Result<String, RateLimitMiddlewareError> {
        Ok(self.get_ip(req).unwrap_or_else(|| {
            warn!(
                "{}",
                "Could not extract IP address for rate limiting, falling back to 'unknown_ip'"
            );
            "unknown_ip".to_string()
        }))
    }
}

// --- Helper Functions ---

fn rate_limit_error_response(result: Option<&RateLimitResult>) -> AxumResponse {
    let mut response = axum::response::Response::builder()
        .status(StatusCode::TOO_MANY_REQUESTS)
        .header(axum::http::header::CONTENT_TYPE, "application/json")
        .body(AxumBody::from(
            sonic_rs::to_string(&json!({
                "error": "Too Many Requests",
                "message": "Rate limit exceeded. Please try again later.",
            }))
            .expect("Failed to serialize rate limit error response"),
        ))
        .expect("Failed to build rate limit error response");

    // For 429 responses, always include headers if result is available
    if let Some(res) = result {
        add_rate_limit_headers(response.headers_mut(), res, true);
    }
    response
}

fn internal_server_error_response_with_message(message: &str) -> AxumResponse {
    axum::response::Response::builder()
        .status(StatusCode::INTERNAL_SERVER_ERROR)
        .header(axum::http::header::CONTENT_TYPE, "application/json")
        .body(AxumBody::from(
            sonic_rs::to_string(&json!({
                "error": "Internal Server Error",
                "message": message,
            }))
            .expect("Failed to serialize internal server error response"),
        ))
        .expect("Failed to build internal server error response")
}

// Updated add_rate_limit_headers as per user's working version
fn add_rate_limit_headers(
    headers: &mut HeaderMap,
    result: &RateLimitResult,
    is_rate_limited: bool,
) {
    if let Ok(value) = HeaderValue::try_from(result.limit.to_string()) {
        headers.insert(HEADER_LIMIT, value);
    } else {
        warn!(
            value = result.limit,
            "Failed to convert rate limit limit value for header X-RateLimit-Limit"
        );
    }

    // Conditionally add X-RateLimit-Remaining and X-RateLimit-Reset if they were not the cause of the panic
    // For now, let's include them as per a standard implementation, but be mindful if panic returns.
    if let Ok(value) = HeaderValue::try_from(result.remaining.to_string()) {
        headers.insert(HEADER_REMAINING, value);
    } else {
        warn!(
            value = result.remaining,
            "Failed to convert rate limit remaining value for header X-RateLimit-Remaining"
        );
    }

    if let Ok(value) = HeaderValue::try_from(result.reset_after.to_string()) {
        headers.insert(HEADER_RESET, value.clone()); // Clone for Retry-After
        if is_rate_limited {
            headers.insert(HEADER_RETRY_AFTER, value);
        }
    } else {
        warn!(
            value = result.reset_after,
            "Failed to convert rate limit reset_after value for header X-RateLimit-Reset/Retry-After"
        );
    }
}