1#[cfg(feature = "access-control")]
2use axum::extract::connect_info::Connected;
3use axum::extract::State;
4use axum::http::Method;
5use axum::response::{IntoResponse, Response};
6#[cfg(feature = "access-control")]
7use axum::serve::IncomingStream;
8use axum::Router;
9use config::Config;
10use opentelemetry_sdk::metrics::SdkMeterProvider;
11use rand::Rng;
12use tokio_listener::{Listener, SystemOptions, UserOptions};
13use tower::{Service, ServiceBuilder};
14use tracing::info;
15
16use crate::ohttp_relay::SentinelTag;
17
18#[cfg(feature = "access-control")]
19pub mod access_control;
20pub mod cli;
21pub mod config;
22pub mod db;
23pub mod directory;
24pub mod key_config;
25pub mod metrics;
26pub mod middleware;
27pub mod ohttp_relay;
28
29use crate::metrics::MetricsService;
30use crate::middleware::{track_connections, track_metrics};
31
32#[derive(Clone)]
33struct Services {
34 directory: crate::directory::Service<crate::db::DbServiceAdapter>,
35 relay: crate::ohttp_relay::Service,
36 metrics: MetricsService,
37 #[cfg(feature = "access-control")]
38 geoip: Option<std::sync::Arc<access_control::IpFilter>>,
39}
40
41pub async fn serve(config: Config, meter_provider: Option<SdkMeterProvider>) -> anyhow::Result<()> {
42 let sentinel_tag = generate_sentinel_tag();
43
44 #[cfg(feature = "access-control")]
45 let geoip = init_geoip(&config).await?;
46
47 let directory = init_directory(&config, sentinel_tag).await?;
48
49 let services = Services {
50 directory,
51 relay: crate::ohttp_relay::Service::new(sentinel_tag).await,
52 metrics: MetricsService::new(meter_provider),
53 #[cfg(feature = "access-control")]
54 geoip,
55 };
56
57 let app = build_app(services);
58 #[cfg(feature = "access-control")]
59 let app = app.into_make_service_with_connect_info::<middleware::MaybePeerIp>();
60
61 let listener =
62 Listener::bind(&config.listener, &SystemOptions::default(), &UserOptions::default())
63 .await?;
64 info!("Payjoin service listening on {:?}", listener.local_addr());
65 axum::serve(listener, app).await?;
66
67 Ok(())
68}
69
70#[cfg(feature = "_manual-tls")]
78pub async fn serve_manual_tls(
79 config: Config,
80 tls_config: Option<axum_server::tls_rustls::RustlsConfig>,
81 root_store: rustls::RootCertStore,
82 default_gateway: Option<crate::ohttp_relay::GatewayUri>,
83) -> anyhow::Result<(u16, tokio::task::JoinHandle<anyhow::Result<()>>)> {
84 use std::net::SocketAddr;
85
86 let sentinel_tag = generate_sentinel_tag();
87
88 #[cfg(feature = "access-control")]
89 let geoip = init_geoip(&config).await?;
90
91 let directory = init_directory(&config, sentinel_tag).await?;
92
93 let services = Services {
94 directory,
95 relay: crate::ohttp_relay::Service::new_with_roots(
96 sentinel_tag,
97 root_store,
98 default_gateway,
99 )
100 .await,
101 metrics: MetricsService::new(None),
102 #[cfg(feature = "access-control")]
103 geoip,
104 };
105 let app = build_app(services);
106
107 let addr: SocketAddr = config
108 .listener
109 .to_string()
110 .parse()
111 .map_err(|_| anyhow::anyhow!("TLS mode requires a TCP address (e.g., '[::]:8080')"))?;
112 let listener = tokio::net::TcpListener::bind(addr).await?;
113 let port = listener.local_addr()?.port();
114
115 let handle = match tls_config {
116 Some(tls) => {
117 info!("Payjoin service listening on port {} with TLS", port);
118 tokio::spawn(async move {
119 axum_server::from_tcp_rustls(listener.into_std()?, tls)?
120 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
121 .await
122 .map_err(Into::into)
123 })
124 }
125 None => {
126 info!("Payjoin service listening on port {} without TLS", port);
127 tokio::spawn(async move {
128 axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>())
129 .await
130 .map_err(Into::into)
131 })
132 }
133 };
134
135 Ok((port, handle))
136}
137
138#[cfg(feature = "acme")]
143pub async fn serve_acme(
144 config: Config,
145 meter_provider: Option<SdkMeterProvider>,
146) -> anyhow::Result<()> {
147 use std::net::SocketAddr;
148 use std::sync::Arc;
149
150 let acme_config = config
151 .acme
152 .clone()
153 .ok_or_else(|| anyhow::anyhow!("ACME configuration is required for serve_acme"))?;
154
155 let sentinel_tag = generate_sentinel_tag();
156
157 #[cfg(feature = "access-control")]
158 let geoip = init_geoip(&config).await?;
159
160 let directory = init_directory(&config, sentinel_tag).await?;
161
162 let services = Services {
163 directory,
164 relay: crate::ohttp_relay::Service::new(sentinel_tag).await,
165 metrics: MetricsService::new(meter_provider),
166 #[cfg(feature = "access-control")]
167 geoip,
168 };
169 let app = build_app(services);
170
171 let addr: SocketAddr = config
172 .listener
173 .to_string()
174 .parse()
175 .map_err(|_| anyhow::anyhow!("ACME mode requires a TCP address (e.g., '[::]:443')"))?;
176
177 let acme = acme_config.into_rustls_config(&config.storage_dir);
178 let mut state = acme.state();
179 let rustls_config = Arc::new(
180 rustls::ServerConfig::builder().with_no_client_auth().with_cert_resolver(state.resolver()),
181 );
182 let acceptor = state.axum_acceptor(rustls_config);
183
184 tokio::spawn(async move {
186 use tokio_stream::StreamExt;
187 loop {
188 match state.next().await {
189 Some(Ok(ok)) => info!("ACME event: {:?}", ok),
190 Some(Err(err)) => tracing::error!("ACME error: {:?}", err),
191 None => break,
192 }
193 }
194 });
195
196 info!("Payjoin service listening on {} with ACME TLS", addr);
197 axum_server::bind(addr)
198 .acceptor(acceptor)
199 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
200 .await?;
201 Ok(())
202}
203
204fn generate_sentinel_tag() -> SentinelTag { SentinelTag::new(rand::thread_rng().gen()) }
208
209#[cfg(feature = "access-control")]
210impl Connected<IncomingStream<'_, Listener>> for middleware::MaybePeerIp {
211 fn connect_info(stream: IncomingStream<'_, Listener>) -> Self {
212 let ip = match stream.remote_addr() {
213 tokio_listener::SomeSocketAddr::Tcp(addr) => Some(addr.ip()),
214 _ => None,
215 };
216 Self(ip)
217 }
218}
219
220async fn init_directory(
221 config: &Config,
222 sentinel_tag: SentinelTag,
223) -> anyhow::Result<crate::directory::Service<crate::db::DbServiceAdapter>> {
224 let files_db = crate::db::FilesDb::init(config.timeout, config.storage_dir.clone()).await?;
225 files_db.spawn_background_prune().await;
226 let db = crate::db::DbServiceAdapter::new(files_db);
227
228 let ohttp_keys_dir = config.storage_dir.join("ohttp-keys");
229 let ohttp_config = init_ohttp_config(&ohttp_keys_dir)?;
230
231 let v1 = if config.v1.is_some() {
232 #[cfg(feature = "access-control")]
233 let blocked = init_blocked_addresses(config).await?;
234 #[cfg(not(feature = "access-control"))]
235 let blocked = None;
236 Some(crate::directory::V1::new(blocked))
237 } else {
238 None
239 };
240 Ok(crate::directory::Service::new(db, ohttp_config.into(), sentinel_tag, v1))
241}
242
243#[cfg(feature = "access-control")]
244async fn init_geoip(
245 config: &Config,
246) -> anyhow::Result<Option<std::sync::Arc<access_control::IpFilter>>> {
247 match &config.access_control {
248 Some(ac_config) => {
249 let gi = access_control::IpFilter::from_config(ac_config, &config.storage_dir).await?;
250 info!("GeoIP access control enabled");
251 Ok(Some(std::sync::Arc::new(gi)))
252 }
253 None => Ok(None),
254 }
255}
256
257#[cfg(feature = "access-control")]
258async fn init_blocked_addresses(
259 config: &Config,
260) -> anyhow::Result<Option<crate::directory::BlockedAddresses>> {
261 let v1_config = match &config.v1 {
262 Some(c) => c,
263 None => return Ok(None),
264 };
265
266 if v1_config.blocked_addresses_path.is_none() && v1_config.blocked_addresses_url.is_none() {
268 return Ok(None);
269 }
270
271 let blocked = match &v1_config.blocked_addresses_path {
273 Some(path) => {
274 let text = access_control::load_blocked_address_text(path)?;
275 let ba = crate::directory::BlockedAddresses::from_address_lines(&text);
276 info!("Loaded blocked addresses from {}", path.display());
277 ba
278 }
279 None => crate::directory::BlockedAddresses::empty(),
280 };
281
282 if let Some(url) = &v1_config.blocked_addresses_url {
284 let cache_path = config.storage_dir.join("blocked_addresses_cache.txt");
285 let refresh = std::time::Duration::from_secs(
286 v1_config.blocked_addresses_refresh_secs.unwrap_or(86400),
287 );
288
289 match reqwest::get(url).await.and_then(|r| r.error_for_status()) {
291 Ok(resp) => match resp.text().await {
292 Ok(body) => {
293 if let Err(e) = std::fs::write(&cache_path, &body) {
294 tracing::warn!("Failed to write address cache: {e}");
295 }
296 let count = blocked.update_from_lines(&body).await;
297 info!("Fetched {count} blocked addresses from URL");
298 }
299 Err(e) => {
300 tracing::warn!("Failed to read address list response: {e}");
301 load_address_cache(&cache_path, &blocked).await;
302 }
303 },
304 Err(e) => {
305 tracing::warn!("Failed to fetch address list: {e}");
306 load_address_cache(&cache_path, &blocked).await;
307 }
308 }
309
310 access_control::spawn_address_list_updater(
311 url.clone(),
312 refresh,
313 cache_path,
314 blocked.clone(),
315 );
316 }
317
318 Ok(Some(blocked))
319}
320
321#[cfg(feature = "access-control")]
322async fn load_address_cache(
323 cache_path: &std::path::Path,
324 blocked: &crate::directory::BlockedAddresses,
325) {
326 if cache_path.exists() {
327 match access_control::load_blocked_address_text(cache_path) {
328 Ok(text) => {
329 let count = blocked.update_from_lines(&text).await;
330 info!("Loaded {count} blocked addresses from cache");
331 }
332 Err(e) => tracing::warn!("Failed to load address cache: {e}"),
333 }
334 }
335}
336
337fn init_ohttp_config(
338 ohttp_keys_dir: &std::path::Path,
339) -> anyhow::Result<crate::key_config::ServerKeyConfig> {
340 std::fs::create_dir_all(ohttp_keys_dir)?;
341 match crate::key_config::read_server_config(ohttp_keys_dir) {
342 Ok(config) => Ok(config),
343 Err(_) => {
344 let config = crate::key_config::gen_ohttp_server_config()?;
345 crate::key_config::persist_new_key_config(config.clone(), ohttp_keys_dir)?;
346 Ok(config)
347 }
348 }
349}
350
351fn build_app(services: Services) -> Router {
352 let metrics = services.metrics.clone();
353
354 #[cfg(feature = "access-control")]
355 let geoip = services.geoip.clone();
356
357 #[allow(unused_mut)]
358 let mut router = Router::new()
359 .fallback(route_request)
360 .layer(
361 ServiceBuilder::new()
362 .layer(axum::middleware::from_fn_with_state(metrics.clone(), track_metrics))
363 .layer(axum::middleware::from_fn_with_state(metrics, track_connections)),
364 )
365 .with_state(services);
366
367 #[cfg(feature = "access-control")]
368 {
369 router = router
370 .layer(axum::middleware::from_fn(middleware::check_geoip))
371 .layer(axum::Extension(geoip));
372 }
373
374 router
375}
376
377async fn route_request(
378 State(mut services): State<Services>,
379 req: axum::extract::Request,
380) -> Response {
381 if is_relay_request(&req) {
382 match services.relay.call(req).await {
383 Ok(res) => res.into_response(),
384 Err(e) => (axum::http::StatusCode::BAD_GATEWAY, e.to_string()).into_response(),
385 }
386 } else {
387 match services.directory.call(req).await {
389 Ok(res) => res.into_response(),
390 Err(e) =>
391 (axum::http::StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
392 }
393 }
394}
395
396fn is_relay_request(req: &axum::extract::Request) -> bool {
405 let method = req.method();
406 let path = req.uri().path();
407
408 match (method, path) {
409 (&Method::OPTIONS, _) | (&Method::CONNECT, _) | (&Method::POST, "/") => true,
410 (&Method::POST, p) | (&Method::GET, p)
411 if p.starts_with("/http://") || p.starts_with("/https://") =>
412 true,
413 _ => false,
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use std::sync::Arc;
420 use std::time::Duration;
421
422 use axum_server::tls_rustls::RustlsConfig;
423 use opentelemetry_sdk::metrics::{InMemoryMetricExporter, PeriodicReader, SdkMeterProvider};
424 use payjoin_test_utils::{http_agent, local_cert_key, wait_for_service_ready};
425 use rustls::pki_types::CertificateDer;
426 use rustls::RootCertStore;
427 use tempfile::tempdir;
428
429 use super::*;
430 use crate::metrics::{ACTIVE_CONNECTIONS, HTTP_REQUESTS, TOTAL_CONNECTIONS};
431
432 async fn start_service(
433 cert_der: Vec<u8>,
434 key_der: Vec<u8>,
435 ) -> (u16, tokio::task::JoinHandle<anyhow::Result<()>>, tempfile::TempDir) {
436 let tempdir = tempdir().unwrap();
437 let config = Config::new(
438 "[::]:0".parse().expect("valid listener address"),
439 tempdir.path().to_path_buf(),
440 Duration::from_secs(2),
441 None,
442 );
443
444 let mut root_store = RootCertStore::empty();
445 root_store.add(CertificateDer::from(cert_der.clone())).unwrap();
446 let tls_config = RustlsConfig::from_der(vec![cert_der], key_der).await.unwrap();
447
448 let (port, handle) =
449 serve_manual_tls(config, Some(tls_config), root_store, None).await.unwrap();
450 (port, handle, tempdir)
451 }
452
453 #[tokio::test]
454 async fn self_loop_request_is_rejected() {
455 let cert = local_cert_key();
456 let cert_der = cert.cert.der().to_vec();
457 let key_der = cert.signing_key.serialize_der();
458
459 let (port, _handle, _tempdir) = start_service(cert_der.clone(), key_der).await;
460
461 let client = Arc::new(http_agent(cert_der.clone()).unwrap());
462 let base_url = format!("https://localhost:{}", port);
463 wait_for_service_ready(&base_url, client.clone()).await.unwrap();
464
465 let ohttp_req_url = format!("{base_url}/{base_url}");
468
469 let response = client
470 .post(&ohttp_req_url)
471 .header("Content-Type", "message/ohttp-req")
472 .body(vec![0u8; 100])
473 .send()
474 .await
475 .expect("request should complete");
476
477 assert_eq!(
478 response.status(),
479 axum::http::StatusCode::FORBIDDEN,
480 "self-loop request should be rejected with 403 Forbidden"
481 );
482 }
483
484 #[tokio::test]
485 async fn cross_instance_request_is_accepted() {
486 let cert = local_cert_key();
487 let cert_der = cert.cert.der().to_vec();
488 let key_der = cert.signing_key.serialize_der();
489
490 let (relay_port, _relay_handle, _relay_tempdir) =
491 start_service(cert_der.clone(), key_der.clone()).await;
492 let (directory_port, _directory_handle, _directory_tempdir) =
493 start_service(cert_der.clone(), key_der).await;
494
495 let client = Arc::new(http_agent(cert_der).unwrap());
496 let relay_url = format!("https://localhost:{}", relay_port);
497 let directory_url = format!("https://localhost:{}", directory_port);
498
499 wait_for_service_ready(&relay_url, client.clone()).await.unwrap();
500 wait_for_service_ready(&directory_url, client.clone()).await.unwrap();
501
502 let ohttp_req_url = format!("{}/{}", relay_url, directory_url);
505
506 let response = client
507 .post(&ohttp_req_url)
508 .header("Content-Type", "message/ohttp-req")
509 .body(vec![0u8; 100])
510 .send()
511 .await
512 .expect("request should complete");
513
514 assert_ne!(
516 response.status(),
517 axum::http::StatusCode::FORBIDDEN,
518 "cross-instance request should not be rejected as forbidden"
519 );
520 }
521
522 #[tokio::test]
523 async fn middleware_records_metrics() {
524 use axum::body::Body;
525 use axum::http::Request;
526 use tower::ServiceExt;
527
528 let exporter = InMemoryMetricExporter::default();
529 let reader = PeriodicReader::builder(exporter.clone()).build();
530 let provider = SdkMeterProvider::builder().with_reader(reader).build();
531
532 let tempdir = tempdir().unwrap();
533 let config = Config::new(
534 "[::]:0".parse().expect("valid listener address"),
535 tempdir.path().to_path_buf(),
536 Duration::from_secs(2),
537 None,
538 );
539
540 let sentinel_tag = generate_sentinel_tag();
541 let services = Services {
542 directory: init_directory(&config, sentinel_tag).await.unwrap(),
543 relay: crate::ohttp_relay::Service::new(sentinel_tag).await,
544 metrics: MetricsService::new(Some(provider.clone())),
545 #[cfg(feature = "access-control")]
546 geoip: None,
547 };
548
549 let app = build_app(services);
550
551 let request = Request::builder().method("GET").uri("/health").body(Body::empty()).unwrap();
552 let response = ServiceExt::<Request<Body>>::oneshot(app, request).await.unwrap();
553 assert_eq!(response.status(), 200);
554
555 provider.force_flush().expect("flush failed");
556
557 let finished = exporter.get_finished_metrics().expect("metrics");
558 let metric_names: Vec<&str> = finished
559 .iter()
560 .flat_map(|rm| rm.scope_metrics())
561 .flat_map(|sm| sm.metrics())
562 .map(|m| m.name())
563 .collect();
564 assert!(metric_names.contains(&HTTP_REQUESTS), "missing http_request_total");
565 assert!(metric_names.contains(&TOTAL_CONNECTIONS), "missing total_connections");
566 assert!(metric_names.contains(&ACTIVE_CONNECTIONS), "missing active_connections");
567 }
568}