enigma_relay/
server.rs

1use std::collections::HashMap;
2use std::net::{IpAddr, Ipv4Addr, SocketAddr};
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use actix_web::middleware::Compress;
7use actix_web::web::{self, Data, Json};
8use actix_web::{App, HttpRequest, HttpResponse, HttpServer};
9use base64::Engine;
10use futures::FutureExt;
11use tokio::sync::oneshot;
12
13use crate::config::{RelayConfig, RelayLimits, RelayMode, StorageKind};
14use crate::error::{EnigmaRelayError, Result};
15#[cfg(feature = "metrics")]
16use crate::metrics::RelayMetrics;
17use crate::model::{
18    AckRequest, AckResponse, DeliveryItem, PullRequest, PullResponse, PushRequest, PushResponse,
19};
20use crate::store::{AckItem, DynRelayStore, InboundMessage, QueueMessage};
21use crate::store_mem::MemStore;
22#[cfg(feature = "persistence")]
23use crate::store_sled::SledStore;
24use crate::ttl::{now_millis, run_gc, ShutdownSignal};
25
26pub struct RunningRelay {
27    pub base_url: String,
28    pub shutdown: oneshot::Sender<()>,
29    pub handle: tokio::task::JoinHandle<Result<()>>,
30}
31
32#[derive(Clone)]
33struct AppState {
34    store: DynRelayStore,
35    config: RelayConfig,
36    rate_limiter: RateLimiter,
37    started: Instant,
38    #[cfg(feature = "metrics")]
39    metrics: Arc<RelayMetrics>,
40}
41
42#[derive(Clone, Eq, PartialEq, Hash)]
43struct RateKey {
44    ip: IpAddr,
45    label: &'static str,
46}
47
48#[derive(Clone)]
49struct Bucket {
50    tokens: f64,
51    last: Instant,
52    burst: f64,
53    rate: f64,
54    ban_until: Option<Instant>,
55}
56
57#[derive(Clone)]
58struct RateLimiter {
59    enabled: bool,
60    ban_seconds: u64,
61    buckets: Arc<parking_lot::Mutex<HashMap<RateKey, Bucket>>>,
62}
63
64impl RateLimiter {
65    fn new(enabled: bool, ban_seconds: u64) -> Self {
66        RateLimiter {
67            enabled,
68            ban_seconds,
69            buckets: Arc::new(parking_lot::Mutex::new(HashMap::new())),
70        }
71    }
72
73    fn allow(&self, ip: IpAddr, label: &'static str, rate: f64, burst: f64) -> bool {
74        if !self.enabled {
75            return true;
76        }
77        let now = Instant::now();
78        let mut guard = self.buckets.lock();
79        let key = RateKey { ip, label };
80        let entry = guard.entry(key).or_insert(Bucket {
81            tokens: burst,
82            last: now,
83            burst,
84            rate,
85            ban_until: None,
86        });
87        if let Some(until) = entry.ban_until {
88            if until > now {
89                return false;
90            }
91            entry.ban_until = None;
92        }
93        let elapsed = now.saturating_duration_since(entry.last).as_secs_f64();
94        entry.tokens = (entry.tokens + elapsed * entry.rate).min(entry.burst);
95        entry.last = now;
96        if entry.tokens >= 1.0 {
97            entry.tokens -= 1.0;
98            true
99        } else {
100            entry.ban_until = Some(now + Duration::from_secs(self.ban_seconds));
101            false
102        }
103    }
104}
105
106pub async fn start(cfg: RelayConfig) -> Result<RunningRelay> {
107    let store: DynRelayStore = match cfg.storage.kind {
108        StorageKind::Memory => Arc::new(MemStore::new()),
109        StorageKind::Sled => {
110            #[cfg(feature = "persistence")]
111            {
112                Arc::new(SledStore::new(&cfg.storage.path)?)
113            }
114            #[cfg(not(feature = "persistence"))]
115            {
116                return Err(EnigmaRelayError::Disabled(
117                    "persistence feature not enabled".to_string(),
118                ));
119            }
120        }
121    };
122    start_with_store(store, cfg).await
123}
124
125pub async fn start_with_store(store: DynRelayStore, cfg: RelayConfig) -> Result<RunningRelay> {
126    cfg.validate()?;
127    let (shutdown_tx, shutdown_rx) = oneshot::channel();
128    let shutdown_signal: ShutdownSignal = shutdown_rx.shared();
129    let rate_limiter = RateLimiter::new(cfg.rate_limit.enabled, cfg.rate_limit.ban_seconds);
130    #[cfg(feature = "metrics")]
131    let metrics = Arc::new(RelayMetrics::default());
132    let state = AppState {
133        store: store.clone(),
134        config: cfg.clone(),
135        rate_limiter: rate_limiter.clone(),
136        started: Instant::now(),
137        #[cfg(feature = "metrics")]
138        metrics: metrics.clone(),
139    };
140    let gc_store = store.clone();
141    let gc_limits = cfg.relay.clone();
142    let gc_signal = shutdown_signal.clone();
143    #[cfg(feature = "metrics")]
144    let gc_task = {
145        let metrics = metrics.clone();
146        tokio::spawn(async move { run_gc(gc_store, gc_limits, gc_signal, Some(metrics)).await })
147    };
148    #[cfg(not(feature = "metrics"))]
149    let gc_task = tokio::spawn(async move { run_gc(gc_store, gc_limits, gc_signal).await });
150    let (server, addr) = build_server(state, &cfg)?;
151    let handle = server.handle();
152    let server_task = tokio::spawn(server);
153    let joined = tokio::spawn(async move {
154        let _ = shutdown_signal.await;
155        handle.stop(true).await;
156        let _ = gc_task.await;
157        let srv = server_task
158            .await
159            .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
160        srv.map_err(|e: std::io::Error| EnigmaRelayError::Internal(e.to_string()))
161    });
162    let scheme = match cfg.mode {
163        RelayMode::Http => "http",
164        RelayMode::Tls => "https",
165    };
166    Ok(RunningRelay {
167        base_url: format!("{}://{}", scheme, addr),
168        shutdown: shutdown_tx,
169        handle: joined,
170    })
171}
172
173fn build_server(
174    state: AppState,
175    cfg: &RelayConfig,
176) -> Result<(actix_web::dev::Server, SocketAddr)> {
177    let state_data = state.clone();
178    let builder = HttpServer::new(move || {
179        App::new()
180            .app_data(Data::new(state_data.clone()))
181            .wrap(Compress::default())
182            .route("/push", web::post().to(push))
183            .route("/pull", web::post().to(pull))
184            .route("/ack", web::post().to(ack))
185            .route("/health", web::get().to(health))
186            .route("/stats", web::get().to(stats))
187    });
188    match cfg.mode {
189        RelayMode::Http => {
190            let listener = std::net::TcpListener::bind(&cfg.address)
191                .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
192            listener
193                .set_nonblocking(true)
194                .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
195            let addr = listener
196                .local_addr()
197                .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
198            let server = builder
199                .listen(listener)
200                .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?
201                .disable_signals()
202                .run();
203            Ok((server, addr))
204        }
205        RelayMode::Tls => {
206            #[cfg(feature = "tls")]
207            {
208                let tls_cfg = cfg
209                    .tls
210                    .clone()
211                    .ok_or_else(|| EnigmaRelayError::Config("missing tls config".to_string()))?;
212                let server_config = build_rustls_config(tls_cfg)?;
213                let listener = std::net::TcpListener::bind(&cfg.address)
214                    .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
215                listener
216                    .set_nonblocking(true)
217                    .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
218                let addr = listener
219                    .local_addr()
220                    .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
221                let server = builder
222                    .listen_rustls_0_23(listener, server_config)
223                    .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?
224                    .disable_signals()
225                    .run();
226                Ok((server, addr))
227            }
228            #[cfg(not(feature = "tls"))]
229            {
230                Err(EnigmaRelayError::Disabled(
231                    "tls feature not enabled".to_string(),
232                ))
233            }
234        }
235    }
236}
237
238async fn push(
239    req: HttpRequest,
240    state: Data<AppState>,
241    body: Json<PushRequest>,
242) -> std::result::Result<HttpResponse, EnigmaRelayError> {
243    enforce_rate(&state, &req, "push")?;
244    #[cfg(feature = "metrics")]
245    state
246        .metrics
247        .push_total
248        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
249    let msg = build_inbound_message(body.into_inner(), &state.config.relay)?;
250    let result = state.store.push(msg, &state.config.relay).await?;
251    #[cfg(feature = "metrics")]
252    {
253        if result.stored {
254            state
255                .metrics
256                .push_stored
257                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
258        }
259        if result.duplicate {
260            state
261                .metrics
262                .duplicates
263                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
264        }
265    }
266    let response = PushResponse {
267        stored: result.stored,
268        duplicate: result.duplicate,
269        queue_len: Some(result.queue_len),
270        queue_bytes: Some(result.queue_bytes),
271    };
272    Ok(HttpResponse::Ok().json(response))
273}
274
275async fn pull(
276    req: HttpRequest,
277    state: Data<AppState>,
278    body: Json<PullRequest>,
279) -> std::result::Result<HttpResponse, EnigmaRelayError> {
280    enforce_rate(&state, &req, "pull")?;
281    #[cfg(feature = "metrics")]
282    state
283        .metrics
284        .pull_total
285        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
286    let pull = body.into_inner();
287    if pull.recipient.trim().is_empty() {
288        return Err(EnigmaRelayError::InvalidInput(
289            "recipient required".to_string(),
290        ));
291    }
292    let max = pull.max.unwrap_or(state.config.relay.pull_batch_max);
293    if max == 0 {
294        return Err(EnigmaRelayError::InvalidInput(
295            "max must be positive".to_string(),
296        ));
297    }
298    let clamped = max.min(state.config.relay.pull_batch_max);
299    let batch = state
300        .store
301        .pull(
302            pull.recipient.as_str(),
303            pull.cursor.clone(),
304            clamped,
305            now_millis(),
306        )
307        .await?;
308    let items: Vec<DeliveryItem> = batch.items.into_iter().map(to_delivery).collect();
309    let response = PullResponse {
310        items,
311        next_cursor: batch.next_cursor,
312        remaining_estimate: batch.remaining_estimate,
313    };
314    Ok(HttpResponse::Ok().json(response))
315}
316
317async fn ack(
318    req: HttpRequest,
319    state: Data<AppState>,
320    body: Json<AckRequest>,
321) -> std::result::Result<HttpResponse, EnigmaRelayError> {
322    enforce_rate(&state, &req, "ack")?;
323    #[cfg(feature = "metrics")]
324    state
325        .metrics
326        .ack_total
327        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
328    let ack = body.into_inner();
329    if ack.recipient.trim().is_empty() {
330        return Err(EnigmaRelayError::InvalidInput(
331            "recipient required".to_string(),
332        ));
333    }
334    let mut items = Vec::new();
335    for entry in ack.ack {
336        items.push(AckItem {
337            message_id: entry.message_id,
338            chunk_index: entry.chunk_index,
339        });
340    }
341    let outcome = state.store.ack(ack.recipient.as_str(), items).await?;
342    #[cfg(feature = "metrics")]
343    state
344        .metrics
345        .ack_deleted
346        .fetch_add(outcome.deleted, std::sync::atomic::Ordering::Relaxed);
347    let response = AckResponse {
348        deleted: outcome.deleted,
349        missing: outcome.missing,
350        remaining: outcome.remaining,
351    };
352    Ok(HttpResponse::Ok().json(response))
353}
354
355async fn health() -> std::result::Result<HttpResponse, EnigmaRelayError> {
356    Ok(HttpResponse::Ok().json(serde_json::json!({"status": "ok"})))
357}
358
359async fn stats(state: Data<AppState>) -> std::result::Result<HttpResponse, EnigmaRelayError> {
360    #[cfg(feature = "metrics")]
361    {
362        let snapshot = state.metrics.snapshot();
363        let uptime_ms = state.started.elapsed().as_millis() as u64;
364        return Ok(HttpResponse::Ok().json(serde_json::json!({
365            "status": "ok",
366            "uptime_ms": uptime_ms,
367            "metrics": snapshot
368        })));
369    }
370    #[cfg(not(feature = "metrics"))]
371    {
372        let uptime_ms = state.started.elapsed().as_millis() as u64;
373        Ok(HttpResponse::Ok().json(serde_json::json!({
374            "status": "ok",
375            "uptime_ms": uptime_ms
376        })))
377    }
378}
379
380fn enforce_rate(state: &AppState, req: &HttpRequest, label: &'static str) -> Result<()> {
381    if !state.rate_limiter.enabled {
382        return Ok(());
383    }
384    let cfg = &state.config.rate_limit;
385    let ip = peer_ip(req);
386    let burst = cfg.burst as f64;
387    let global_ok = state
388        .rate_limiter
389        .allow(ip, "global", cfg.per_ip_rps as f64, burst);
390    let endpoint_rps = match label {
391        "push" => cfg.endpoints.push_rps,
392        "pull" => cfg.endpoints.pull_rps,
393        "ack" => cfg.endpoints.ack_rps,
394        _ => cfg.per_ip_rps,
395    };
396    let endpoint_ok = state
397        .rate_limiter
398        .allow(ip, label, endpoint_rps as f64, burst);
399    if global_ok && endpoint_ok {
400        return Ok(());
401    }
402    #[cfg(feature = "metrics")]
403    state
404        .metrics
405        .rate_limited
406        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
407    Err(EnigmaRelayError::RateLimited)
408}
409
410fn peer_ip(req: &HttpRequest) -> IpAddr {
411    req.peer_addr()
412        .map(|s| s.ip())
413        .unwrap_or(Ipv4Addr::LOCALHOST.into())
414}
415
416fn build_inbound_message(body: PushRequest, limits: &RelayLimits) -> Result<InboundMessage> {
417    if body.recipient.trim().is_empty() {
418        return Err(EnigmaRelayError::InvalidInput(
419            "recipient required".to_string(),
420        ));
421    }
422    if body.meta.chunk_count == 0 {
423        return Err(EnigmaRelayError::InvalidInput(
424            "chunk_count must be positive".to_string(),
425        ));
426    }
427    if body.meta.chunk_index >= body.meta.chunk_count {
428        return Err(EnigmaRelayError::InvalidInput(
429            "chunk_index out of range".to_string(),
430        ));
431    }
432    if body.meta.kind.trim().is_empty() {
433        return Err(EnigmaRelayError::InvalidInput("kind required".to_string()));
434    }
435    let decoded = base64::engine::general_purpose::STANDARD
436        .decode(body.ciphertext_b64.as_bytes())
437        .map_err(|_| EnigmaRelayError::InvalidInput("invalid base64".to_string()))?;
438    let payload_bytes = decoded.len() as u64;
439    if payload_bytes > limits.max_message_bytes {
440        return Err(EnigmaRelayError::InvalidInput(
441            "message too large".to_string(),
442        ));
443    }
444    let arrival_ms = now_millis();
445    let ttl_ms = limits.message_ttl_seconds.saturating_mul(1000);
446    let deadline_ms = arrival_ms.saturating_add(ttl_ms);
447    Ok(InboundMessage {
448        recipient: body.recipient,
449        message_id: body.message_id,
450        ciphertext_b64: body.ciphertext_b64,
451        meta: body.meta,
452        payload_bytes,
453        arrival_ms,
454        deadline_ms,
455    })
456}
457
458fn to_delivery(msg: QueueMessage) -> DeliveryItem {
459    DeliveryItem {
460        recipient: msg.recipient,
461        message_id: msg.message_id,
462        ciphertext_b64: msg.ciphertext_b64,
463        meta: msg.meta,
464        arrival_ms: msg.arrival_ms,
465        deadline_ms: msg.deadline_ms,
466    }
467}
468
469#[cfg(feature = "tls")]
470fn build_rustls_config(tls: crate::config::TlsConfig) -> Result<rustls::ServerConfig> {
471    use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
472    #[cfg(feature = "mtls")]
473    use rustls::RootCertStore;
474    use rustls_pemfile::{certs, pkcs8_private_keys};
475    use std::fs::File;
476    use std::io::BufReader;
477
478    let mut cert_reader = BufReader::new(
479        File::open(&tls.cert_pem_path).map_err(|e| EnigmaRelayError::Tls(e.to_string()))?,
480    );
481    let mut key_reader = BufReader::new(
482        File::open(&tls.key_pem_path).map_err(|e| EnigmaRelayError::Tls(e.to_string()))?,
483    );
484    let cert_chain: Vec<CertificateDer<'static>> = certs(&mut cert_reader)
485        .collect::<std::result::Result<Vec<_>, _>>()
486        .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?;
487    let mut keys = pkcs8_private_keys(&mut key_reader)
488        .collect::<std::result::Result<Vec<PrivatePkcs8KeyDer<'static>>, _>>()
489        .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?;
490    let key_bytes = keys
491        .pop()
492        .ok_or_else(|| EnigmaRelayError::Tls("no private key found".to_string()))?;
493    let key = PrivateKeyDer::Pkcs8(key_bytes);
494    let builder = rustls::ServerConfig::builder();
495    let cfg = {
496        #[cfg(feature = "mtls")]
497        {
498            if let Some(ca_path) = tls.client_ca_pem_path.clone() {
499                let mut ca_reader = BufReader::new(
500                    File::open(ca_path).map_err(|e| EnigmaRelayError::Tls(e.to_string()))?,
501                );
502                let mut store = RootCertStore::empty();
503                let cas = certs(&mut ca_reader)
504                    .collect::<std::result::Result<Vec<_>, _>>()
505                    .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?;
506                for ca in cas {
507                    store
508                        .add(ca.into())
509                        .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?;
510                }
511                let verifier = rustls::server::WebPkiClientVerifier::builder(store.into())
512                    .build()
513                    .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?;
514                builder
515                    .with_client_cert_verifier(verifier)
516                    .with_single_cert(cert_chain, key)
517                    .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?
518            } else {
519                builder
520                    .with_no_client_auth()
521                    .with_single_cert(cert_chain, key)
522                    .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?
523            }
524        }
525        #[cfg(not(feature = "mtls"))]
526        {
527            if tls.client_ca_pem_path.is_some() {
528                return Err(EnigmaRelayError::Disabled(
529                    "mtls feature not enabled".to_string(),
530                ));
531            }
532            builder
533                .with_no_client_auth()
534                .with_single_cert(cert_chain, key)
535                .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?
536        }
537    };
538    Ok(cfg)
539}