forge-runtime 0.10.0

Runtime executors and gateway for the Forge framework
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
//! TLS configuration and listener for the gateway.
//!
//! PEM-encoded certificate and key are loaded from disk at startup.
//! [`bind_listener`] returns a [`GatewayListener`] that implements
//! [`axum::serve::Listener`], so the gateway's single
//! `axum::serve(listener, service).await` hotpath handles both HTTP and HTTPS.

use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Once;
use std::task::{Context, Poll};

use forge_core::error::{ForgeError, Result};
use rustls::ServerConfig;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsAcceptor;
use tokio_rustls::server::TlsStream;

/// Resolved TLS source for the gateway listener.
#[derive(Debug, Clone)]
pub struct TlsListenConfig {
    /// Path to the PEM-encoded certificate chain.
    pub cert_path: String,
    /// Path to the PEM-encoded private key.
    pub key_path: String,
}

impl TlsListenConfig {
    /// Build a [`TlsListenConfig`] from `forge_core::config::TlsConfig`,
    /// validating that both `cert_path` and `key_path` are set together.
    ///
    /// Returns `Ok(Some(_))` when TLS is configured, `Ok(None)` when both
    /// fields are unset, and `Err(_)` when only one is set. Programmatic
    /// construction of `ForgeConfig` (tests, embedded use) skips the
    /// parse-time validator, so this conversion is the second line of defense
    /// against the half-set case slipping through silently.
    ///
    /// (We can't impl `TryFrom<...> for Option<TlsListenConfig>` directly —
    /// the orphan rule rejects implementing a foreign trait for a foreign
    /// type, even with a local type parameter.)
    pub fn from_core(cfg: &forge_core::config::TlsConfig) -> Result<Option<Self>> {
        match (cfg.cert_path.as_ref(), cfg.key_path.as_ref()) {
            (Some(cert), Some(key)) => Ok(Some(TlsListenConfig {
                cert_path: cert.clone(),
                key_path: key.clone(),
            })),
            (None, None) => Ok(None),
            (Some(_), None) => Err(ForgeError::config(
                "gateway.tls.cert_path is set but gateway.tls.key_path is missing. \
                 Set both to enable TLS, or neither to serve plain HTTP.",
            )),
            (None, Some(_)) => Err(ForgeError::config(
                "gateway.tls.key_path is set but gateway.tls.cert_path is missing. \
                 Set both to enable TLS, or neither to serve plain HTTP.",
            )),
        }
    }
}

/// The TLS listener type produced by [`bind_listener`] when TLS is configured.
type TlsListener = tls_listener::TlsListener<TcpListener, TlsAcceptor>;

/// Listener handed to `axum::serve`. Single type for both HTTP and HTTPS so
/// the gateway hot path stays one line at every call site.
pub enum GatewayListener {
    Plain(TcpListener),
    Tls(TlsListener),
}

/// Connection IO returned by [`GatewayListener::accept`]. Each variant maps
/// directly to a listener variant; both halves implement `AsyncRead +
/// AsyncWrite`, so axum can drive them uniformly.
pub enum GatewayConn {
    Plain(TcpStream),
    Tls(Box<TlsStream<TcpStream>>),
}

// `TcpStream` and `TlsStream<TcpStream>` are both `Unpin`, so `GatewayConn`
// is `Unpin` via auto-trait inheritance and `Pin::get_mut` is sound here.

impl AsyncRead for GatewayConn {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<std::io::Result<()>> {
        match self.get_mut() {
            GatewayConn::Plain(s) => Pin::new(s).poll_read(cx, buf),
            GatewayConn::Tls(s) => Pin::new(s.as_mut()).poll_read(cx, buf),
        }
    }
}

impl AsyncWrite for GatewayConn {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<std::io::Result<usize>> {
        match self.get_mut() {
            GatewayConn::Plain(s) => Pin::new(s).poll_write(cx, buf),
            GatewayConn::Tls(s) => Pin::new(s.as_mut()).poll_write(cx, buf),
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        match self.get_mut() {
            GatewayConn::Plain(s) => Pin::new(s).poll_flush(cx),
            GatewayConn::Tls(s) => Pin::new(s.as_mut()).poll_flush(cx),
        }
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        match self.get_mut() {
            GatewayConn::Plain(s) => Pin::new(s).poll_shutdown(cx),
            GatewayConn::Tls(s) => Pin::new(s.as_mut()).poll_shutdown(cx),
        }
    }

    fn poll_write_vectored(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        bufs: &[std::io::IoSlice<'_>],
    ) -> Poll<std::io::Result<usize>> {
        match self.get_mut() {
            GatewayConn::Plain(s) => Pin::new(s).poll_write_vectored(cx, bufs),
            GatewayConn::Tls(s) => Pin::new(s.as_mut()).poll_write_vectored(cx, bufs),
        }
    }

    fn is_write_vectored(&self) -> bool {
        match self {
            GatewayConn::Plain(s) => s.is_write_vectored(),
            GatewayConn::Tls(s) => s.is_write_vectored(),
        }
    }
}

impl axum::serve::Listener for GatewayListener {
    type Io = GatewayConn;
    type Addr = SocketAddr;

    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
        match self {
            GatewayListener::Plain(l) => {
                let (io, addr) = axum::serve::Listener::accept(l).await;
                (GatewayConn::Plain(io), addr)
            }
            GatewayListener::Tls(l) => {
                let (io, addr) = axum::serve::Listener::accept(l).await;
                (GatewayConn::Tls(Box::new(io)), addr)
            }
        }
    }

    fn local_addr(&self) -> std::io::Result<Self::Addr> {
        match self {
            GatewayListener::Plain(l) => l.local_addr(),
            GatewayListener::Tls(l) => l.local_addr(),
        }
    }
}

// Local newtype around `SocketAddr` so the orphan rule lets us implement
// `axum::extract::connect_info::Connected` on it for our `GatewayListener`.
// Threaded through `into_make_service_with_connect_info::<PeerAddr>()` so
// `trusted_proxies` extraction works behind both plain TCP and TLS.
#[derive(Debug, Clone, Copy)]
pub struct PeerAddr(pub SocketAddr);

impl PeerAddr {
    /// Extract the IP address from the peer socket address.
    pub fn ip(&self) -> std::net::IpAddr {
        self.0.ip()
    }
}

impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, GatewayListener>>
    for PeerAddr
{
    fn connect_info(stream: axum::serve::IncomingStream<'_, GatewayListener>) -> Self {
        PeerAddr(*stream.remote_addr())
    }
}

static CRYPTO_PROVIDER_INIT: Once = Once::new();

/// Install the `ring` default crypto provider for rustls, exactly once.
///
/// `rustls` 0.23+ requires an explicit default provider; calling this before
/// building a `ServerConfig` ensures cipher suites are wired up. Safe to call
/// repeatedly (idempotent via [`Once`]).
fn install_default_crypto_provider() {
    CRYPTO_PROVIDER_INIT.call_once(|| {
        let _ = rustls::crypto::ring::default_provider().install_default();
    });
}

/// Build a [`rustls::ServerConfig`] from a [`TlsListenConfig`].
///
/// Returns a [`ForgeError::Config`] for any I/O or parse failure. Failures
/// are surfaced at server startup so operators see them immediately, not at
/// the first HTTPS request.
pub fn load_rustls_config(cfg: &TlsListenConfig) -> Result<Arc<ServerConfig>> {
    install_default_crypto_provider();
    let server_config = build_from_files(&cfg.cert_path, &cfg.key_path)?;
    Ok(Arc::new(server_config))
}

/// Bind the gateway listener on `addr`. With `Some(cfg)` the result is a
/// TLS-terminating listener; with `None` it is a plain TCP listener. Both
/// variants implement [`axum::serve::Listener`], so the caller writes
/// `axum::serve(listener, service).await` once.
///
/// Config errors (I/O, parse, invalid key pair) are mapped to
/// [`std::io::Error`] so callers can propagate startup failures uniformly.
pub async fn bind_listener(
    addr: SocketAddr,
    tls: Option<&TlsListenConfig>,
) -> std::io::Result<GatewayListener> {
    match tls {
        Some(cfg) => {
            let rustls_config = load_rustls_config(cfg).map_err(std::io::Error::other)?;
            tracing::info!(
                addr = %addr,
                cert_path = %cfg.cert_path,
                key_path = %cfg.key_path,
                "Gateway listening with TLS"
            );
            let tcp = TcpListener::bind(addr).await?;
            Ok(GatewayListener::Tls(
                tls_listener::builder(TlsAcceptor::from(rustls_config)).listen(tcp),
            ))
        }
        None => {
            tracing::info!(addr = %addr, "Gateway listening (HTTP)");
            Ok(GatewayListener::Plain(TcpListener::bind(addr).await?))
        }
    }
}

fn build_from_files(cert_path: &str, key_path: &str) -> Result<ServerConfig> {
    let cert_chain = read_pem_certs(cert_path)?;
    let key = read_pem_key(key_path)?;

    ServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(cert_chain, key)
        .map_err(|e| ForgeError::config_with("invalid TLS certificate or key", e))
}

fn read_pem_certs(path: &str) -> Result<Vec<CertificateDer<'static>>> {
    let certs: Vec<_> = CertificateDer::pem_file_iter(path)
        .map_err(|e| {
            ForgeError::config(format!(
                "failed to read PEM certificates from '{path}': {e}"
            ))
        })?
        .collect::<std::result::Result<_, _>>()
        .map_err(|e| {
            ForgeError::config(format!("failed to parse PEM certificates in '{path}': {e}"))
        })?;

    if certs.is_empty() {
        return Err(ForgeError::config(format!(
            "no PEM certificates found in '{path}'"
        )));
    }

    Ok(certs)
}

fn read_pem_key(path: &str) -> Result<PrivateKeyDer<'static>> {
    PrivateKeyDer::from_pem_file(path).map_err(|e| {
        ForgeError::config(format!("failed to read PEM private key from '{path}': {e}"))
    })
}

#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
    use super::*;
    use std::io::Write;
    use tempfile::NamedTempFile;

    #[tokio::test]
    async fn from_files_missing_cert_path_errors() {
        let cfg = TlsListenConfig {
            cert_path: "/nonexistent/cert.pem".to_string(),
            key_path: "/nonexistent/key.pem".to_string(),
        };
        let err = load_rustls_config(&cfg).unwrap_err();
        let msg = err.to_string();
        assert!(
            msg.contains("failed to read PEM certificates from '/nonexistent/cert.pem'"),
            "unexpected error: {msg}"
        );
    }

    #[tokio::test]
    async fn from_files_malformed_cert_errors() {
        let mut cert_file = NamedTempFile::new().unwrap();
        cert_file.write_all(b"not a certificate").unwrap();

        let mut key_file = NamedTempFile::new().unwrap();
        key_file.write_all(b"not a key").unwrap();

        let cfg = TlsListenConfig {
            cert_path: cert_file.path().to_string_lossy().into_owned(),
            key_path: key_file.path().to_string_lossy().into_owned(),
        };
        let err = load_rustls_config(&cfg).unwrap_err();
        let msg = err.to_string();
        assert!(
            msg.contains("no PEM certificates found"),
            "unexpected error: {msg}"
        );
    }

    #[test]
    fn from_core_both_set_returns_some() {
        let core_cfg = forge_core::config::TlsConfig {
            cert_path: Some("/cert.pem".into()),
            key_path: Some("/key.pem".into()),
        };
        let listen = TlsListenConfig::from_core(&core_cfg).unwrap().unwrap();
        assert_eq!(listen.cert_path, "/cert.pem");
        assert_eq!(listen.key_path, "/key.pem");
    }

    #[test]
    fn from_core_neither_set_returns_none() {
        let core_cfg = forge_core::config::TlsConfig::default();
        assert!(TlsListenConfig::from_core(&core_cfg).unwrap().is_none());
    }

    #[test]
    fn from_core_only_cert_errors() {
        let core_cfg = forge_core::config::TlsConfig {
            cert_path: Some("/cert.pem".into()),
            key_path: None,
        };
        let err = TlsListenConfig::from_core(&core_cfg).unwrap_err();
        assert!(
            err.to_string().contains("key_path is missing"),
            "unexpected error: {err}"
        );
    }

    #[test]
    fn from_core_only_key_errors() {
        let core_cfg = forge_core::config::TlsConfig {
            cert_path: None,
            key_path: Some("/key.pem".into()),
        };
        let err = TlsListenConfig::from_core(&core_cfg).unwrap_err();
        assert!(
            err.to_string().contains("cert_path is missing"),
            "unexpected error: {err}"
        );
    }
}

/// End-to-end test that performs a real TLS handshake against an
/// `axum::serve` listener built from [`bind_listener`]. Catches regressions
/// when bumping rustls / tls-listener / axum that the unit tests miss.
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tls_handshake_e2e {
    use super::*;
    use axum::{Router, routing::get};
    use std::io::Write;

    fn write_cert_and_key() -> (tempfile::NamedTempFile, tempfile::NamedTempFile) {
        let rcgen::CertifiedKey { cert, key_pair } =
            rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).expect("rcgen");
        let mut cert_file = tempfile::NamedTempFile::new().expect("cert tempfile");
        cert_file
            .write_all(cert.pem().as_bytes())
            .expect("write cert");
        let mut key_file = tempfile::NamedTempFile::new().expect("key tempfile");
        key_file
            .write_all(key_pair.serialize_pem().as_bytes())
            .expect("write key");
        (cert_file, key_file)
    }

    async fn serve(addr: SocketAddr, tls: Option<TlsListenConfig>) -> SocketAddr {
        let app = Router::new().route("/_api/health", get(|| async { "ok" }));
        let listener = bind_listener(addr, tls.as_ref()).await.expect("bind");
        let bound = axum::serve::Listener::local_addr(&listener).expect("local_addr");
        tokio::spawn(async move {
            axum::serve(listener, app.into_make_service())
                .await
                .expect("serve");
        });
        bound
    }

    #[tokio::test]
    async fn https_handshake_returns_ok() {
        let (cert_file, key_file) = write_cert_and_key();
        let cfg = TlsListenConfig {
            cert_path: cert_file.path().to_string_lossy().into_owned(),
            key_path: key_file.path().to_string_lossy().into_owned(),
        };

        let bound = serve("127.0.0.1:0".parse().unwrap(), Some(cfg)).await;

        let client = reqwest::Client::builder()
            .danger_accept_invalid_certs(true)
            .use_rustls_tls()
            .build()
            .expect("client");
        let url = format!("https://{}/_api/health", bound);
        let resp = client.get(&url).send().await.expect("request");
        assert_eq!(resp.status(), 200);
        assert_eq!(resp.text().await.expect("body"), "ok");
    }

    #[tokio::test]
    async fn http_path_through_same_helper_returns_ok() {
        let bound = serve("127.0.0.1:0".parse().unwrap(), None).await;

        let client = reqwest::Client::builder().build().expect("client");
        let url = format!("http://{}/_api/health", bound);
        let resp = client.get(&url).send().await.expect("request");
        assert_eq!(resp.status(), 200);
        assert_eq!(resp.text().await.expect("body"), "ok");
    }
}