1use std::collections::HashMap;
2use std::net::{IpAddr, Ipv4Addr, SocketAddr};
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use actix_web::middleware::Compress;
7use actix_web::web::{self, Data, Json};
8use actix_web::{App, HttpRequest, HttpResponse, HttpServer};
9use base64::Engine;
10use futures::FutureExt;
11use tokio::sync::oneshot;
12
13use crate::config::{RelayConfig, RelayLimits, RelayMode, StorageKind};
14use crate::error::{EnigmaRelayError, Result};
15#[cfg(feature = "metrics")]
16use crate::metrics::RelayMetrics;
17use crate::model::{
18 AckRequest, AckResponse, DeliveryItem, PullRequest, PullResponse, PushRequest, PushResponse,
19};
20use crate::store::{AckItem, DynRelayStore, InboundMessage, QueueMessage};
21use crate::store_mem::MemStore;
22#[cfg(feature = "persistence")]
23use crate::store_sled::SledStore;
24use crate::ttl::{now_millis, run_gc, ShutdownSignal};
25
26pub struct RunningRelay {
27 pub base_url: String,
28 pub shutdown: oneshot::Sender<()>,
29 pub handle: tokio::task::JoinHandle<Result<()>>,
30}
31
32#[derive(Clone)]
33struct AppState {
34 store: DynRelayStore,
35 config: RelayConfig,
36 rate_limiter: RateLimiter,
37 started: Instant,
38 #[cfg(feature = "metrics")]
39 metrics: Arc<RelayMetrics>,
40}
41
42#[derive(Clone, Eq, PartialEq, Hash)]
43struct RateKey {
44 ip: IpAddr,
45 label: &'static str,
46}
47
48#[derive(Clone)]
49struct Bucket {
50 tokens: f64,
51 last: Instant,
52 burst: f64,
53 rate: f64,
54 ban_until: Option<Instant>,
55}
56
57#[derive(Clone)]
58struct RateLimiter {
59 enabled: bool,
60 ban_seconds: u64,
61 buckets: Arc<parking_lot::Mutex<HashMap<RateKey, Bucket>>>,
62}
63
64impl RateLimiter {
65 fn new(enabled: bool, ban_seconds: u64) -> Self {
66 RateLimiter {
67 enabled,
68 ban_seconds,
69 buckets: Arc::new(parking_lot::Mutex::new(HashMap::new())),
70 }
71 }
72
73 fn allow(&self, ip: IpAddr, label: &'static str, rate: f64, burst: f64) -> bool {
74 if !self.enabled {
75 return true;
76 }
77 let now = Instant::now();
78 let mut guard = self.buckets.lock();
79 let key = RateKey { ip, label };
80 let entry = guard.entry(key).or_insert(Bucket {
81 tokens: burst,
82 last: now,
83 burst,
84 rate,
85 ban_until: None,
86 });
87 if let Some(until) = entry.ban_until {
88 if until > now {
89 return false;
90 }
91 entry.ban_until = None;
92 }
93 let elapsed = now.saturating_duration_since(entry.last).as_secs_f64();
94 entry.tokens = (entry.tokens + elapsed * entry.rate).min(entry.burst);
95 entry.last = now;
96 if entry.tokens >= 1.0 {
97 entry.tokens -= 1.0;
98 true
99 } else {
100 entry.ban_until = Some(now + Duration::from_secs(self.ban_seconds));
101 false
102 }
103 }
104}
105
106pub async fn start(cfg: RelayConfig) -> Result<RunningRelay> {
107 let store: DynRelayStore = match cfg.storage.kind {
108 StorageKind::Memory => Arc::new(MemStore::new()),
109 StorageKind::Sled => {
110 #[cfg(feature = "persistence")]
111 {
112 Arc::new(SledStore::new(&cfg.storage.path)?)
113 }
114 #[cfg(not(feature = "persistence"))]
115 {
116 return Err(EnigmaRelayError::Disabled(
117 "persistence feature not enabled".to_string(),
118 ));
119 }
120 }
121 };
122 start_with_store(store, cfg).await
123}
124
125pub async fn start_with_store(store: DynRelayStore, cfg: RelayConfig) -> Result<RunningRelay> {
126 cfg.validate()?;
127 let (shutdown_tx, shutdown_rx) = oneshot::channel();
128 let shutdown_signal: ShutdownSignal = shutdown_rx.shared();
129 let rate_limiter = RateLimiter::new(cfg.rate_limit.enabled, cfg.rate_limit.ban_seconds);
130 #[cfg(feature = "metrics")]
131 let metrics = Arc::new(RelayMetrics::default());
132 let state = AppState {
133 store: store.clone(),
134 config: cfg.clone(),
135 rate_limiter: rate_limiter.clone(),
136 started: Instant::now(),
137 #[cfg(feature = "metrics")]
138 metrics: metrics.clone(),
139 };
140 let gc_store = store.clone();
141 let gc_limits = cfg.relay.clone();
142 let gc_signal = shutdown_signal.clone();
143 #[cfg(feature = "metrics")]
144 let gc_task = {
145 let metrics = metrics.clone();
146 tokio::spawn(async move { run_gc(gc_store, gc_limits, gc_signal, Some(metrics)).await })
147 };
148 #[cfg(not(feature = "metrics"))]
149 let gc_task = tokio::spawn(async move { run_gc(gc_store, gc_limits, gc_signal).await });
150 let (server, addr) = build_server(state, &cfg)?;
151 let handle = server.handle();
152 let server_task = tokio::spawn(server);
153 let joined = tokio::spawn(async move {
154 let _ = shutdown_signal.await;
155 handle.stop(true).await;
156 let _ = gc_task.await;
157 let srv = server_task
158 .await
159 .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
160 srv.map_err(|e: std::io::Error| EnigmaRelayError::Internal(e.to_string()))
161 });
162 let scheme = match cfg.mode {
163 RelayMode::Http => "http",
164 RelayMode::Tls => "https",
165 };
166 Ok(RunningRelay {
167 base_url: format!("{}://{}", scheme, addr),
168 shutdown: shutdown_tx,
169 handle: joined,
170 })
171}
172
173fn build_server(
174 state: AppState,
175 cfg: &RelayConfig,
176) -> Result<(actix_web::dev::Server, SocketAddr)> {
177 let state_data = state.clone();
178 let builder = HttpServer::new(move || {
179 App::new()
180 .app_data(Data::new(state_data.clone()))
181 .wrap(Compress::default())
182 .route("/push", web::post().to(push))
183 .route("/pull", web::post().to(pull))
184 .route("/ack", web::post().to(ack))
185 .route("/health", web::get().to(health))
186 .route("/stats", web::get().to(stats))
187 });
188 match cfg.mode {
189 RelayMode::Http => {
190 let listener = std::net::TcpListener::bind(&cfg.address)
191 .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
192 listener
193 .set_nonblocking(true)
194 .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
195 let addr = listener
196 .local_addr()
197 .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
198 let server = builder
199 .listen(listener)
200 .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?
201 .disable_signals()
202 .run();
203 Ok((server, addr))
204 }
205 RelayMode::Tls => {
206 #[cfg(feature = "tls")]
207 {
208 let tls_cfg = cfg
209 .tls
210 .clone()
211 .ok_or_else(|| EnigmaRelayError::Config("missing tls config".to_string()))?;
212 let server_config = build_rustls_config(tls_cfg)?;
213 let listener = std::net::TcpListener::bind(&cfg.address)
214 .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
215 listener
216 .set_nonblocking(true)
217 .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
218 let addr = listener
219 .local_addr()
220 .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?;
221 let server = builder
222 .listen_rustls_0_23(listener, server_config)
223 .map_err(|e| EnigmaRelayError::Internal(e.to_string()))?
224 .disable_signals()
225 .run();
226 Ok((server, addr))
227 }
228 #[cfg(not(feature = "tls"))]
229 {
230 Err(EnigmaRelayError::Disabled(
231 "tls feature not enabled".to_string(),
232 ))
233 }
234 }
235 }
236}
237
238async fn push(
239 req: HttpRequest,
240 state: Data<AppState>,
241 body: Json<PushRequest>,
242) -> std::result::Result<HttpResponse, EnigmaRelayError> {
243 enforce_rate(&state, &req, "push")?;
244 #[cfg(feature = "metrics")]
245 state
246 .metrics
247 .push_total
248 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
249 let msg = build_inbound_message(body.into_inner(), &state.config.relay)?;
250 let result = state.store.push(msg, &state.config.relay).await?;
251 #[cfg(feature = "metrics")]
252 {
253 if result.stored {
254 state
255 .metrics
256 .push_stored
257 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
258 }
259 if result.duplicate {
260 state
261 .metrics
262 .duplicates
263 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
264 }
265 }
266 let response = PushResponse {
267 stored: result.stored,
268 duplicate: result.duplicate,
269 queue_len: Some(result.queue_len),
270 queue_bytes: Some(result.queue_bytes),
271 };
272 Ok(HttpResponse::Ok().json(response))
273}
274
275async fn pull(
276 req: HttpRequest,
277 state: Data<AppState>,
278 body: Json<PullRequest>,
279) -> std::result::Result<HttpResponse, EnigmaRelayError> {
280 enforce_rate(&state, &req, "pull")?;
281 #[cfg(feature = "metrics")]
282 state
283 .metrics
284 .pull_total
285 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
286 let pull = body.into_inner();
287 if pull.recipient.trim().is_empty() {
288 return Err(EnigmaRelayError::InvalidInput(
289 "recipient required".to_string(),
290 ));
291 }
292 let max = pull.max.unwrap_or(state.config.relay.pull_batch_max);
293 if max == 0 {
294 return Err(EnigmaRelayError::InvalidInput(
295 "max must be positive".to_string(),
296 ));
297 }
298 let clamped = max.min(state.config.relay.pull_batch_max);
299 let batch = state
300 .store
301 .pull(
302 pull.recipient.as_str(),
303 pull.cursor.clone(),
304 clamped,
305 now_millis(),
306 )
307 .await?;
308 let items: Vec<DeliveryItem> = batch.items.into_iter().map(to_delivery).collect();
309 let response = PullResponse {
310 items,
311 next_cursor: batch.next_cursor,
312 remaining_estimate: batch.remaining_estimate,
313 };
314 Ok(HttpResponse::Ok().json(response))
315}
316
317async fn ack(
318 req: HttpRequest,
319 state: Data<AppState>,
320 body: Json<AckRequest>,
321) -> std::result::Result<HttpResponse, EnigmaRelayError> {
322 enforce_rate(&state, &req, "ack")?;
323 #[cfg(feature = "metrics")]
324 state
325 .metrics
326 .ack_total
327 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
328 let ack = body.into_inner();
329 if ack.recipient.trim().is_empty() {
330 return Err(EnigmaRelayError::InvalidInput(
331 "recipient required".to_string(),
332 ));
333 }
334 let mut items = Vec::new();
335 for entry in ack.ack {
336 items.push(AckItem {
337 message_id: entry.message_id,
338 chunk_index: entry.chunk_index,
339 });
340 }
341 let outcome = state.store.ack(ack.recipient.as_str(), items).await?;
342 #[cfg(feature = "metrics")]
343 state
344 .metrics
345 .ack_deleted
346 .fetch_add(outcome.deleted, std::sync::atomic::Ordering::Relaxed);
347 let response = AckResponse {
348 deleted: outcome.deleted,
349 missing: outcome.missing,
350 remaining: outcome.remaining,
351 };
352 Ok(HttpResponse::Ok().json(response))
353}
354
355async fn health() -> std::result::Result<HttpResponse, EnigmaRelayError> {
356 Ok(HttpResponse::Ok().json(serde_json::json!({"status": "ok"})))
357}
358
359async fn stats(state: Data<AppState>) -> std::result::Result<HttpResponse, EnigmaRelayError> {
360 #[cfg(feature = "metrics")]
361 {
362 let snapshot = state.metrics.snapshot();
363 let uptime_ms = state.started.elapsed().as_millis() as u64;
364 return Ok(HttpResponse::Ok().json(serde_json::json!({
365 "status": "ok",
366 "uptime_ms": uptime_ms,
367 "metrics": snapshot
368 })));
369 }
370 #[cfg(not(feature = "metrics"))]
371 {
372 let uptime_ms = state.started.elapsed().as_millis() as u64;
373 Ok(HttpResponse::Ok().json(serde_json::json!({
374 "status": "ok",
375 "uptime_ms": uptime_ms
376 })))
377 }
378}
379
380fn enforce_rate(state: &AppState, req: &HttpRequest, label: &'static str) -> Result<()> {
381 if !state.rate_limiter.enabled {
382 return Ok(());
383 }
384 let cfg = &state.config.rate_limit;
385 let ip = peer_ip(req);
386 let burst = cfg.burst as f64;
387 let global_ok = state
388 .rate_limiter
389 .allow(ip, "global", cfg.per_ip_rps as f64, burst);
390 let endpoint_rps = match label {
391 "push" => cfg.endpoints.push_rps,
392 "pull" => cfg.endpoints.pull_rps,
393 "ack" => cfg.endpoints.ack_rps,
394 _ => cfg.per_ip_rps,
395 };
396 let endpoint_ok = state
397 .rate_limiter
398 .allow(ip, label, endpoint_rps as f64, burst);
399 if global_ok && endpoint_ok {
400 return Ok(());
401 }
402 #[cfg(feature = "metrics")]
403 state
404 .metrics
405 .rate_limited
406 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
407 Err(EnigmaRelayError::RateLimited)
408}
409
410fn peer_ip(req: &HttpRequest) -> IpAddr {
411 req.peer_addr()
412 .map(|s| s.ip())
413 .unwrap_or(Ipv4Addr::LOCALHOST.into())
414}
415
416fn build_inbound_message(body: PushRequest, limits: &RelayLimits) -> Result<InboundMessage> {
417 if body.recipient.trim().is_empty() {
418 return Err(EnigmaRelayError::InvalidInput(
419 "recipient required".to_string(),
420 ));
421 }
422 if body.meta.chunk_count == 0 {
423 return Err(EnigmaRelayError::InvalidInput(
424 "chunk_count must be positive".to_string(),
425 ));
426 }
427 if body.meta.chunk_index >= body.meta.chunk_count {
428 return Err(EnigmaRelayError::InvalidInput(
429 "chunk_index out of range".to_string(),
430 ));
431 }
432 if body.meta.kind.trim().is_empty() {
433 return Err(EnigmaRelayError::InvalidInput("kind required".to_string()));
434 }
435 let decoded = base64::engine::general_purpose::STANDARD
436 .decode(body.ciphertext_b64.as_bytes())
437 .map_err(|_| EnigmaRelayError::InvalidInput("invalid base64".to_string()))?;
438 let payload_bytes = decoded.len() as u64;
439 if payload_bytes > limits.max_message_bytes {
440 return Err(EnigmaRelayError::InvalidInput(
441 "message too large".to_string(),
442 ));
443 }
444 let arrival_ms = now_millis();
445 let ttl_ms = limits.message_ttl_seconds.saturating_mul(1000);
446 let deadline_ms = arrival_ms.saturating_add(ttl_ms);
447 Ok(InboundMessage {
448 recipient: body.recipient,
449 message_id: body.message_id,
450 ciphertext_b64: body.ciphertext_b64,
451 meta: body.meta,
452 payload_bytes,
453 arrival_ms,
454 deadline_ms,
455 })
456}
457
458fn to_delivery(msg: QueueMessage) -> DeliveryItem {
459 DeliveryItem {
460 recipient: msg.recipient,
461 message_id: msg.message_id,
462 ciphertext_b64: msg.ciphertext_b64,
463 meta: msg.meta,
464 arrival_ms: msg.arrival_ms,
465 deadline_ms: msg.deadline_ms,
466 }
467}
468
469#[cfg(feature = "tls")]
470fn build_rustls_config(tls: crate::config::TlsConfig) -> Result<rustls::ServerConfig> {
471 use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
472 #[cfg(feature = "mtls")]
473 use rustls::RootCertStore;
474 use rustls_pemfile::{certs, pkcs8_private_keys};
475 use std::fs::File;
476 use std::io::BufReader;
477
478 let mut cert_reader = BufReader::new(
479 File::open(&tls.cert_pem_path).map_err(|e| EnigmaRelayError::Tls(e.to_string()))?,
480 );
481 let mut key_reader = BufReader::new(
482 File::open(&tls.key_pem_path).map_err(|e| EnigmaRelayError::Tls(e.to_string()))?,
483 );
484 let cert_chain: Vec<CertificateDer<'static>> = certs(&mut cert_reader)
485 .collect::<std::result::Result<Vec<_>, _>>()
486 .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?;
487 let mut keys = pkcs8_private_keys(&mut key_reader)
488 .collect::<std::result::Result<Vec<PrivatePkcs8KeyDer<'static>>, _>>()
489 .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?;
490 let key_bytes = keys
491 .pop()
492 .ok_or_else(|| EnigmaRelayError::Tls("no private key found".to_string()))?;
493 let key = PrivateKeyDer::Pkcs8(key_bytes);
494 let builder = rustls::ServerConfig::builder();
495 let cfg = {
496 #[cfg(feature = "mtls")]
497 {
498 if let Some(ca_path) = tls.client_ca_pem_path.clone() {
499 let mut ca_reader = BufReader::new(
500 File::open(ca_path).map_err(|e| EnigmaRelayError::Tls(e.to_string()))?,
501 );
502 let mut store = RootCertStore::empty();
503 let cas = certs(&mut ca_reader)
504 .collect::<std::result::Result<Vec<_>, _>>()
505 .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?;
506 for ca in cas {
507 store
508 .add(ca.into())
509 .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?;
510 }
511 let verifier = rustls::server::WebPkiClientVerifier::builder(store.into())
512 .build()
513 .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?;
514 builder
515 .with_client_cert_verifier(verifier)
516 .with_single_cert(cert_chain, key)
517 .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?
518 } else {
519 builder
520 .with_no_client_auth()
521 .with_single_cert(cert_chain, key)
522 .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?
523 }
524 }
525 #[cfg(not(feature = "mtls"))]
526 {
527 if tls.client_ca_pem_path.is_some() {
528 return Err(EnigmaRelayError::Disabled(
529 "mtls feature not enabled".to_string(),
530 ));
531 }
532 builder
533 .with_no_client_auth()
534 .with_single_cert(cert_chain, key)
535 .map_err(|e| EnigmaRelayError::Tls(e.to_string()))?
536 }
537 };
538 Ok(cfg)
539}