Skip to main content

payjoin_mailroom/
lib.rs

1#[cfg(feature = "access-control")]
2use axum::extract::connect_info::Connected;
3use axum::extract::State;
4use axum::http::Method;
5use axum::response::{IntoResponse, Response};
6#[cfg(feature = "access-control")]
7use axum::serve::IncomingStream;
8use axum::Router;
9use config::Config;
10use opentelemetry_sdk::metrics::SdkMeterProvider;
11use rand::Rng;
12use tokio_listener::{Listener, SystemOptions, UserOptions};
13use tower::{Service, ServiceBuilder};
14use tracing::info;
15
16use crate::ohttp_relay::SentinelTag;
17
18#[cfg(feature = "access-control")]
19pub mod access_control;
20pub mod cli;
21pub mod config;
22pub mod db;
23pub mod directory;
24pub mod key_config;
25pub mod metrics;
26pub mod middleware;
27pub mod ohttp_relay;
28
29use crate::metrics::MetricsService;
30use crate::middleware::{track_connections, track_metrics};
31
32#[derive(Clone)]
33struct Services {
34    directory: crate::directory::Service<crate::db::DbServiceAdapter>,
35    relay: crate::ohttp_relay::Service,
36    metrics: MetricsService,
37    #[cfg(feature = "access-control")]
38    geoip: Option<std::sync::Arc<access_control::IpFilter>>,
39}
40
41pub async fn serve(config: Config, meter_provider: Option<SdkMeterProvider>) -> anyhow::Result<()> {
42    let sentinel_tag = generate_sentinel_tag();
43
44    #[cfg(feature = "access-control")]
45    let geoip = init_geoip(&config).await?;
46
47    let directory = init_directory(&config, sentinel_tag).await?;
48
49    let services = Services {
50        directory,
51        relay: crate::ohttp_relay::Service::new(sentinel_tag).await,
52        metrics: MetricsService::new(meter_provider),
53        #[cfg(feature = "access-control")]
54        geoip,
55    };
56
57    let app = build_app(services);
58    #[cfg(feature = "access-control")]
59    let app = app.into_make_service_with_connect_info::<middleware::MaybePeerIp>();
60
61    let listener =
62        Listener::bind(&config.listener, &SystemOptions::default(), &UserOptions::default())
63            .await?;
64    info!("Payjoin service listening on {:?}", listener.local_addr());
65    axum::serve(listener, app).await?;
66
67    Ok(())
68}
69
70/// Serves payjoin-mailroom with manual TLS configuration.
71///
72/// Binds to `config.listener` (use port 0 to let the OS assign a free port) and returns
73/// the actual bound port and a task handle.
74///
75/// If `tls_config` is provided, the server will use TLS for incoming connections.
76/// The `root_store` is used for outgoing relay connections to the gateway.
77#[cfg(feature = "_manual-tls")]
78pub async fn serve_manual_tls(
79    config: Config,
80    tls_config: Option<axum_server::tls_rustls::RustlsConfig>,
81    root_store: rustls::RootCertStore,
82    default_gateway: Option<crate::ohttp_relay::GatewayUri>,
83) -> anyhow::Result<(u16, tokio::task::JoinHandle<anyhow::Result<()>>)> {
84    use std::net::SocketAddr;
85
86    let sentinel_tag = generate_sentinel_tag();
87
88    #[cfg(feature = "access-control")]
89    let geoip = init_geoip(&config).await?;
90
91    let directory = init_directory(&config, sentinel_tag).await?;
92
93    let services = Services {
94        directory,
95        relay: crate::ohttp_relay::Service::new_with_roots(
96            sentinel_tag,
97            root_store,
98            default_gateway,
99        )
100        .await,
101        metrics: MetricsService::new(None),
102        #[cfg(feature = "access-control")]
103        geoip,
104    };
105    let app = build_app(services);
106
107    let addr: SocketAddr = config
108        .listener
109        .to_string()
110        .parse()
111        .map_err(|_| anyhow::anyhow!("TLS mode requires a TCP address (e.g., '[::]:8080')"))?;
112    let listener = tokio::net::TcpListener::bind(addr).await?;
113    let port = listener.local_addr()?.port();
114
115    let handle = match tls_config {
116        Some(tls) => {
117            info!("Payjoin service listening on port {} with TLS", port);
118            tokio::spawn(async move {
119                axum_server::from_tcp_rustls(listener.into_std()?, tls)?
120                    .serve(app.into_make_service_with_connect_info::<SocketAddr>())
121                    .await
122                    .map_err(Into::into)
123            })
124        }
125        None => {
126            info!("Payjoin service listening on port {} without TLS", port);
127            tokio::spawn(async move {
128                axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>())
129                    .await
130                    .map_err(Into::into)
131            })
132        }
133    };
134
135    Ok((port, handle))
136}
137
138/// Serves payjoin-mailroom with ACME-managed TLS certificates.
139///
140/// Uses `tokio-rustls-acme` to automatically obtain and renew TLS
141/// certificates from Let's Encrypt via the TLS-ALPN-01 challenge.
142#[cfg(feature = "acme")]
143pub async fn serve_acme(
144    config: Config,
145    meter_provider: Option<SdkMeterProvider>,
146) -> anyhow::Result<()> {
147    use std::net::SocketAddr;
148    use std::sync::Arc;
149
150    let acme_config = config
151        .acme
152        .clone()
153        .ok_or_else(|| anyhow::anyhow!("ACME configuration is required for serve_acme"))?;
154
155    let sentinel_tag = generate_sentinel_tag();
156
157    #[cfg(feature = "access-control")]
158    let geoip = init_geoip(&config).await?;
159
160    let directory = init_directory(&config, sentinel_tag).await?;
161
162    let services = Services {
163        directory,
164        relay: crate::ohttp_relay::Service::new(sentinel_tag).await,
165        metrics: MetricsService::new(meter_provider),
166        #[cfg(feature = "access-control")]
167        geoip,
168    };
169    let app = build_app(services);
170
171    let addr: SocketAddr = config
172        .listener
173        .to_string()
174        .parse()
175        .map_err(|_| anyhow::anyhow!("ACME mode requires a TCP address (e.g., '[::]:443')"))?;
176
177    let acme = acme_config.into_rustls_config(&config.storage_dir);
178    let mut state = acme.state();
179    let rustls_config = Arc::new(
180        rustls::ServerConfig::builder().with_no_client_auth().with_cert_resolver(state.resolver()),
181    );
182    let acceptor = state.axum_acceptor(rustls_config);
183
184    // Drive ACME cert renewal in background
185    tokio::spawn(async move {
186        use tokio_stream::StreamExt;
187        loop {
188            match state.next().await {
189                Some(Ok(ok)) => info!("ACME event: {:?}", ok),
190                Some(Err(err)) => tracing::error!("ACME error: {:?}", err),
191                None => break,
192            }
193        }
194    });
195
196    info!("Payjoin service listening on {} with ACME TLS", addr);
197    axum_server::bind(addr)
198        .acceptor(acceptor)
199        .serve(app.into_make_service_with_connect_info::<SocketAddr>())
200        .await?;
201    Ok(())
202}
203
204/// Generate random sentinel tag at startup.
205/// The relay and directory share this tag in a best-effort attempt
206/// at detecting self loops.
207fn generate_sentinel_tag() -> SentinelTag { SentinelTag::new(rand::thread_rng().gen()) }
208
209#[cfg(feature = "access-control")]
210impl Connected<IncomingStream<'_, Listener>> for middleware::MaybePeerIp {
211    fn connect_info(stream: IncomingStream<'_, Listener>) -> Self {
212        let ip = match stream.remote_addr() {
213            tokio_listener::SomeSocketAddr::Tcp(addr) => Some(addr.ip()),
214            _ => None,
215        };
216        Self(ip)
217    }
218}
219
220async fn init_directory(
221    config: &Config,
222    sentinel_tag: SentinelTag,
223) -> anyhow::Result<crate::directory::Service<crate::db::DbServiceAdapter>> {
224    let files_db = crate::db::FilesDb::init(config.timeout, config.storage_dir.clone()).await?;
225    files_db.spawn_background_prune().await;
226    let db = crate::db::DbServiceAdapter::new(files_db);
227
228    let ohttp_keys_dir = config.storage_dir.join("ohttp-keys");
229    let ohttp_config = init_ohttp_config(&ohttp_keys_dir)?;
230
231    let v1 = if config.v1.is_some() {
232        #[cfg(feature = "access-control")]
233        let blocked = init_blocked_addresses(config).await?;
234        #[cfg(not(feature = "access-control"))]
235        let blocked = None;
236        Some(crate::directory::V1::new(blocked))
237    } else {
238        None
239    };
240    Ok(crate::directory::Service::new(db, ohttp_config.into(), sentinel_tag, v1))
241}
242
243#[cfg(feature = "access-control")]
244async fn init_geoip(
245    config: &Config,
246) -> anyhow::Result<Option<std::sync::Arc<access_control::IpFilter>>> {
247    match &config.access_control {
248        Some(ac_config) => {
249            let gi = access_control::IpFilter::from_config(ac_config, &config.storage_dir).await?;
250            info!("GeoIP access control enabled");
251            Ok(Some(std::sync::Arc::new(gi)))
252        }
253        None => Ok(None),
254    }
255}
256
257#[cfg(feature = "access-control")]
258async fn init_blocked_addresses(
259    config: &Config,
260) -> anyhow::Result<Option<crate::directory::BlockedAddresses>> {
261    let v1_config = match &config.v1 {
262        Some(c) => c,
263        None => return Ok(None),
264    };
265
266    // Neither file nor URL configured
267    if v1_config.blocked_addresses_path.is_none() && v1_config.blocked_addresses_url.is_none() {
268        return Ok(None);
269    }
270
271    // Load initial addresses from file if available
272    let blocked = match &v1_config.blocked_addresses_path {
273        Some(path) => {
274            let text = access_control::load_blocked_address_text(path)?;
275            let ba = crate::directory::BlockedAddresses::from_address_lines(&text);
276            info!("Loaded blocked addresses from {}", path.display());
277            ba
278        }
279        None => crate::directory::BlockedAddresses::empty(),
280    };
281
282    // If URL configured, try initial fetch and spawn background updater
283    if let Some(url) = &v1_config.blocked_addresses_url {
284        let cache_path = config.storage_dir.join("blocked_addresses_cache.txt");
285        let refresh = std::time::Duration::from_secs(
286            v1_config.blocked_addresses_refresh_secs.unwrap_or(86400),
287        );
288
289        // Try initial fetch; fall back to cache on failure
290        match reqwest::get(url).await.and_then(|r| r.error_for_status()) {
291            Ok(resp) => match resp.text().await {
292                Ok(body) => {
293                    if let Err(e) = std::fs::write(&cache_path, &body) {
294                        tracing::warn!("Failed to write address cache: {e}");
295                    }
296                    let count = blocked.update_from_lines(&body).await;
297                    info!("Fetched {count} blocked addresses from URL");
298                }
299                Err(e) => {
300                    tracing::warn!("Failed to read address list response: {e}");
301                    load_address_cache(&cache_path, &blocked).await;
302                }
303            },
304            Err(e) => {
305                tracing::warn!("Failed to fetch address list: {e}");
306                load_address_cache(&cache_path, &blocked).await;
307            }
308        }
309
310        access_control::spawn_address_list_updater(
311            url.clone(),
312            refresh,
313            cache_path,
314            blocked.clone(),
315        );
316    }
317
318    Ok(Some(blocked))
319}
320
321#[cfg(feature = "access-control")]
322async fn load_address_cache(
323    cache_path: &std::path::Path,
324    blocked: &crate::directory::BlockedAddresses,
325) {
326    if cache_path.exists() {
327        match access_control::load_blocked_address_text(cache_path) {
328            Ok(text) => {
329                let count = blocked.update_from_lines(&text).await;
330                info!("Loaded {count} blocked addresses from cache");
331            }
332            Err(e) => tracing::warn!("Failed to load address cache: {e}"),
333        }
334    }
335}
336
337fn init_ohttp_config(
338    ohttp_keys_dir: &std::path::Path,
339) -> anyhow::Result<crate::key_config::ServerKeyConfig> {
340    std::fs::create_dir_all(ohttp_keys_dir)?;
341    match crate::key_config::read_server_config(ohttp_keys_dir) {
342        Ok(config) => Ok(config),
343        Err(_) => {
344            let config = crate::key_config::gen_ohttp_server_config()?;
345            crate::key_config::persist_new_key_config(config.clone(), ohttp_keys_dir)?;
346            Ok(config)
347        }
348    }
349}
350
351fn build_app(services: Services) -> Router {
352    let metrics = services.metrics.clone();
353
354    #[cfg(feature = "access-control")]
355    let geoip = services.geoip.clone();
356
357    #[allow(unused_mut)]
358    let mut router = Router::new()
359        .fallback(route_request)
360        .layer(
361            ServiceBuilder::new()
362                .layer(axum::middleware::from_fn_with_state(metrics.clone(), track_metrics))
363                .layer(axum::middleware::from_fn_with_state(metrics, track_connections)),
364        )
365        .with_state(services);
366
367    #[cfg(feature = "access-control")]
368    {
369        router = router
370            .layer(axum::middleware::from_fn(middleware::check_geoip))
371            .layer(axum::Extension(geoip));
372    }
373
374    router
375}
376
377async fn route_request(
378    State(mut services): State<Services>,
379    req: axum::extract::Request,
380) -> Response {
381    if is_relay_request(&req) {
382        match services.relay.call(req).await {
383            Ok(res) => res.into_response(),
384            Err(e) => (axum::http::StatusCode::BAD_GATEWAY, e.to_string()).into_response(),
385        }
386    } else {
387        // The directory service handles all other requests (including 404)
388        match services.directory.call(req).await {
389            Ok(res) => res.into_response(),
390            Err(e) =>
391                (axum::http::StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
392        }
393    }
394}
395
396/// Determines if a request should be routed to the OHTTP relay service.
397///
398/// Routing rules:
399/// - `(OPTIONS, _)` => CORS preflight handling
400/// - `(CONNECT, _)` => OHTTP bootstrap tunneling
401/// - `(POST, "/")` => relay to default gateway (needed for backwards-compatibility only)
402/// - `(POST, /http(s)://...)` => RFC 9540 opt-in gateway specified in path
403/// - `(GET, /http(s)://...)` => OHTTP bootstrap via WebSocket with opt-in gateway
404fn is_relay_request(req: &axum::extract::Request) -> bool {
405    let method = req.method();
406    let path = req.uri().path();
407
408    match (method, path) {
409        (&Method::OPTIONS, _) | (&Method::CONNECT, _) | (&Method::POST, "/") => true,
410        (&Method::POST, p) | (&Method::GET, p)
411            if p.starts_with("/http://") || p.starts_with("/https://") =>
412            true,
413        _ => false,
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use std::sync::Arc;
420    use std::time::Duration;
421
422    use axum_server::tls_rustls::RustlsConfig;
423    use opentelemetry_sdk::metrics::{InMemoryMetricExporter, PeriodicReader, SdkMeterProvider};
424    use payjoin_test_utils::{http_agent, local_cert_key, wait_for_service_ready};
425    use rustls::pki_types::CertificateDer;
426    use rustls::RootCertStore;
427    use tempfile::tempdir;
428
429    use super::*;
430    use crate::metrics::{ACTIVE_CONNECTIONS, HTTP_REQUESTS, TOTAL_CONNECTIONS};
431
432    async fn start_service(
433        cert_der: Vec<u8>,
434        key_der: Vec<u8>,
435    ) -> (u16, tokio::task::JoinHandle<anyhow::Result<()>>, tempfile::TempDir) {
436        let tempdir = tempdir().unwrap();
437        let config = Config::new(
438            "[::]:0".parse().expect("valid listener address"),
439            tempdir.path().to_path_buf(),
440            Duration::from_secs(2),
441            None,
442        );
443
444        let mut root_store = RootCertStore::empty();
445        root_store.add(CertificateDer::from(cert_der.clone())).unwrap();
446        let tls_config = RustlsConfig::from_der(vec![cert_der], key_der).await.unwrap();
447
448        let (port, handle) =
449            serve_manual_tls(config, Some(tls_config), root_store, None).await.unwrap();
450        (port, handle, tempdir)
451    }
452
453    #[tokio::test]
454    async fn self_loop_request_is_rejected() {
455        let cert = local_cert_key();
456        let cert_der = cert.cert.der().to_vec();
457        let key_der = cert.signing_key.serialize_der();
458
459        let (port, _handle, _tempdir) = start_service(cert_der.clone(), key_der).await;
460
461        let client = Arc::new(http_agent(cert_der.clone()).unwrap());
462        let base_url = format!("https://localhost:{}", port);
463        wait_for_service_ready(&base_url, client.clone()).await.unwrap();
464
465        // Make a request through the relay that targets this same instance's directory.
466        // The path format is /{gateway_url} where gateway_url points back to ourselves.
467        let ohttp_req_url = format!("{base_url}/{base_url}");
468
469        let response = client
470            .post(&ohttp_req_url)
471            .header("Content-Type", "message/ohttp-req")
472            .body(vec![0u8; 100])
473            .send()
474            .await
475            .expect("request should complete");
476
477        assert_eq!(
478            response.status(),
479            axum::http::StatusCode::FORBIDDEN,
480            "self-loop request should be rejected with 403 Forbidden"
481        );
482    }
483
484    #[tokio::test]
485    async fn cross_instance_request_is_accepted() {
486        let cert = local_cert_key();
487        let cert_der = cert.cert.der().to_vec();
488        let key_der = cert.signing_key.serialize_der();
489
490        let (relay_port, _relay_handle, _relay_tempdir) =
491            start_service(cert_der.clone(), key_der.clone()).await;
492        let (directory_port, _directory_handle, _directory_tempdir) =
493            start_service(cert_der.clone(), key_der).await;
494
495        let client = Arc::new(http_agent(cert_der).unwrap());
496        let relay_url = format!("https://localhost:{}", relay_port);
497        let directory_url = format!("https://localhost:{}", directory_port);
498
499        wait_for_service_ready(&relay_url, client.clone()).await.unwrap();
500        wait_for_service_ready(&directory_url, client.clone()).await.unwrap();
501
502        // Make a request through the relay instance to the directory instance.
503        // Since they're different instances with different sentinel tags, this should work.
504        let ohttp_req_url = format!("{}/{}", relay_url, directory_url);
505
506        let response = client
507            .post(&ohttp_req_url)
508            .header("Content-Type", "message/ohttp-req")
509            .body(vec![0u8; 100])
510            .send()
511            .await
512            .expect("request should complete");
513
514        // The request may fail for other reasons (invalid OHTTP body), but not due to self-loop.
515        assert_ne!(
516            response.status(),
517            axum::http::StatusCode::FORBIDDEN,
518            "cross-instance request should not be rejected as forbidden"
519        );
520    }
521
522    #[tokio::test]
523    async fn middleware_records_metrics() {
524        use axum::body::Body;
525        use axum::http::Request;
526        use tower::ServiceExt;
527
528        let exporter = InMemoryMetricExporter::default();
529        let reader = PeriodicReader::builder(exporter.clone()).build();
530        let provider = SdkMeterProvider::builder().with_reader(reader).build();
531
532        let tempdir = tempdir().unwrap();
533        let config = Config::new(
534            "[::]:0".parse().expect("valid listener address"),
535            tempdir.path().to_path_buf(),
536            Duration::from_secs(2),
537            None,
538        );
539
540        let sentinel_tag = generate_sentinel_tag();
541        let services = Services {
542            directory: init_directory(&config, sentinel_tag).await.unwrap(),
543            relay: crate::ohttp_relay::Service::new(sentinel_tag).await,
544            metrics: MetricsService::new(Some(provider.clone())),
545            #[cfg(feature = "access-control")]
546            geoip: None,
547        };
548
549        let app = build_app(services);
550
551        let request = Request::builder().method("GET").uri("/health").body(Body::empty()).unwrap();
552        let response = ServiceExt::<Request<Body>>::oneshot(app, request).await.unwrap();
553        assert_eq!(response.status(), 200);
554
555        provider.force_flush().expect("flush failed");
556
557        let finished = exporter.get_finished_metrics().expect("metrics");
558        let metric_names: Vec<&str> = finished
559            .iter()
560            .flat_map(|rm| rm.scope_metrics())
561            .flat_map(|sm| sm.metrics())
562            .map(|m| m.name())
563            .collect();
564        assert!(metric_names.contains(&HTTP_REQUESTS), "missing http_request_total");
565        assert!(metric_names.contains(&TOTAL_CONNECTIONS), "missing total_connections");
566        assert!(metric_names.contains(&ACTIVE_CONNECTIONS), "missing active_connections");
567    }
568}