Skip to main content

selium_net_hyper/
driver.rs

1//! Driver wiring for Hyper-backed HTTP connections.
2//! Uses HTTP/1.1 for `NetProtocol::Http` and HTTP/2 for `NetProtocol::Https`.
3
4use std::{
5    collections::{HashMap, VecDeque},
6    fmt,
7    net::SocketAddr,
8    sync::{
9        Arc, Mutex,
10        atomic::{AtomicBool, Ordering},
11    },
12};
13
14use futures_util::future::BoxFuture;
15use http_body_util::Full;
16use hyper::{
17    body::Bytes,
18    client::conn::{http1, http2},
19    http::{
20        header::{InvalidHeaderName, InvalidHeaderValue},
21        method::InvalidMethod,
22        uri::InvalidUri,
23    },
24};
25use hyper_util::rt::{TokioExecutor, TokioIo};
26use rustls::{ClientConfig, ServerConfig, sign};
27use selium_abi::{IoFrame, NetProtocol};
28use selium_kernel::{
29    drivers::{
30        io::IoCapability,
31        net::{NetCapability, TlsClientConfig, TlsServerConfig},
32    },
33    guest_data::GuestError,
34};
35use thiserror::Error;
36use tokio::{
37    io::{AsyncRead, AsyncWrite},
38    net::TcpListener as TokioTcpListener,
39    sync::{Notify, mpsc, oneshot},
40};
41use tracing::debug;
42
43use crate::{
44    client::{connect_stream, read_outbound, write_outbound},
45    server::{read_inbound, run_listener, write_inbound},
46    tls::{
47        build_client_config, build_client_verifier, build_server_config, certified_key_from_config,
48        resolve_alpn,
49    },
50};
51
52/// Body type used for outbound Hyper messages.
53pub(crate) type HyperBody = Full<Bytes>;
54/// IO type used for Hyper streams.
55pub(crate) type HyperStream = Box<dyn HyperIo + 'static>;
56
57const PENDING_QUEUE: usize = 64;
58
59pub(crate) trait HyperIo: AsyncRead + AsyncWrite + Unpin + Send {}
60
61impl<T> HyperIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
62
63pub(crate) enum OutboundSender {
64    Http1(http1::SendRequest<HyperBody>),
65    Http2(http2::SendRequest<HyperBody>),
66}
67
68/// Errors produced by the Hyper-backed listener driver.
69#[derive(Error, Debug)]
70pub enum HyperError {
71    #[error("listener closed before any request arrived")]
72    ListenerClosed,
73    #[error("port out of range")]
74    PortRange,
75    #[error("failed to bind listener: {0}")]
76    Bind(#[source] std::io::Error),
77    #[error("failed to mark listener non-blocking: {0}")]
78    NonBlocking(#[source] std::io::Error),
79    #[error("failed to connect: {0}")]
80    Connect(#[source] std::io::Error),
81    #[error("TLS handshake failed: {0}")]
82    Tls(#[source] std::io::Error),
83    #[error("failed to build TLS config: {0}")]
84    Rustls(#[source] rustls::Error),
85    #[error("failed to parse certificate chain: {0}")]
86    Certificate(String),
87    #[error("failed to parse private key: {0}")]
88    PrivateKey(String),
89    #[error("client certificate provided without private key")]
90    ClientKeyMissing,
91    #[error("client authentication requires a CA bundle")]
92    ClientAuthMissing,
93    #[error("failed to configure client authentication: {0}")]
94    ClientAuth(String),
95    #[error("HTTP connection failed: {0}")]
96    Hyper(#[source] hyper::Error),
97    #[error("failed to build HTTP message: {0}")]
98    Http(#[source] hyper::http::Error),
99    #[error("HTTP parse error: {0}")]
100    HttpParse(String),
101    #[error("HTTP message incomplete")]
102    HttpIncomplete,
103    #[error("invalid header name: {0}")]
104    InvalidHeaderName(#[source] InvalidHeaderName),
105    #[error("invalid header value: {0}")]
106    InvalidHeaderValue(#[source] InvalidHeaderValue),
107    #[error("invalid method: {0}")]
108    InvalidMethod(#[source] InvalidMethod),
109    #[error("invalid URI: {0}")]
110    InvalidUri(#[source] InvalidUri),
111    #[error("invalid status code")]
112    InvalidStatus,
113    #[error("unsupported transfer encoding")]
114    TransferEncoding,
115    #[error("content length mismatch (expected {expected}, got {actual})")]
116    ContentLengthMismatch { expected: usize, actual: usize },
117    #[error("host header does not match target domain")]
118    HostMismatch,
119    #[error("response channel closed")]
120    ResponseChannelClosed,
121    #[error("mutex poisoned")]
122    Lock,
123    #[error("operation unsupported")]
124    Unsupported,
125    #[error("TLS configuration does not match existing listener")]
126    TlsConfigMismatch,
127    #[error("unsupported protocol: {protocol:?}")]
128    UnsupportedProtocol { protocol: NetProtocol },
129}
130
131struct ListenerRegistry {
132    listeners: Mutex<HashMap<u16, Arc<Listener>>>,
133}
134
135#[derive(Clone, Debug, PartialEq, Eq)]
136struct ListenerTlsProfile {
137    cert_chain: Vec<Vec<u8>>,
138    alpn: Vec<Vec<u8>>,
139    client_ca_pem: Option<Vec<u8>>,
140    require_client_auth: bool,
141}
142
143struct Listener {
144    protocol: NetProtocol,
145    domain: String,
146    pending_rx: tokio::sync::Mutex<mpsc::Receiver<PendingRequest>>,
147    tls_profile: ListenerTlsProfile,
148}
149
150pub(crate) struct PendingRequest {
151    pub(crate) request_bytes: Vec<u8>,
152    pub(crate) responder: oneshot::Sender<Vec<u8>>,
153    pub(crate) remote_addr: String,
154}
155
156pub(crate) struct OutboundState {
157    pub(crate) protocol: NetProtocol,
158    pub(crate) domain: String,
159    pub(crate) port: u16,
160    pub(crate) sender: tokio::sync::Mutex<OutboundSender>,
161    pub(crate) response: tokio::sync::Mutex<VecDeque<u8>>,
162    pub(crate) response_notify: Notify,
163    pub(crate) closed: AtomicBool,
164}
165
166pub(crate) struct InboundState {
167    pub(crate) protocol: NetProtocol,
168    pub(crate) request: Mutex<VecDeque<u8>>,
169    pub(crate) response: Mutex<Vec<u8>>,
170    pub(crate) responder: Mutex<Option<oneshot::Sender<Vec<u8>>>>,
171}
172
173pub(crate) enum ConnectionKind {
174    Outbound(Arc<OutboundState>),
175    Inbound(Arc<InboundState>),
176}
177
178/// Handle for a bound HTTP listener.
179#[derive(Clone)]
180pub struct ListenerHandle {
181    listener: Arc<Listener>,
182}
183
184/// Hyper-backed network driver.
185pub struct HyperDriver {
186    registry: Arc<ListenerRegistry>,
187    default_cert_chain: Vec<Vec<u8>>,
188    default_server_config: Arc<ServerConfig>,
189    default_client_config: Arc<ClientConfig>,
190}
191
192/// Reader side of an HTTP connection.
193pub struct HttpReader {
194    state: ConnectionKind,
195}
196
197/// Writer side of an HTTP connection.
198pub struct HttpWriter {
199    state: ConnectionKind,
200}
201
202impl fmt::Debug for ListenerHandle {
203    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204        f.debug_struct("ListenerHandle")
205            .field("protocol", &self.listener.protocol)
206            .field("domain", &self.listener.domain)
207            .finish()
208    }
209}
210
211impl ListenerRegistry {
212    fn new() -> Self {
213        Self {
214            listeners: Mutex::new(HashMap::new()),
215        }
216    }
217
218    fn get_or_try_init(
219        &self,
220        protocol: NetProtocol,
221        domain: &str,
222        port: u16,
223        tls_profile: ListenerTlsProfile,
224        server_config: Arc<ServerConfig>,
225    ) -> Result<Arc<Listener>, HyperError> {
226        let mut guard = self.listeners.lock().map_err(|_| HyperError::Lock)?;
227        if let Some(listener) = guard.get(&port) {
228            if listener.protocol != protocol {
229                return Err(HyperError::UnsupportedProtocol { protocol });
230            }
231            if !listener.domain.eq_ignore_ascii_case(domain) {
232                return Err(HyperError::HostMismatch);
233            }
234            if listener.tls_profile != tls_profile {
235                return Err(HyperError::TlsConfigMismatch);
236            }
237            return Ok(Arc::clone(listener));
238        }
239
240        let listener = Arc::new(Listener::new(
241            protocol,
242            domain.to_string(),
243            port,
244            tls_profile,
245            server_config,
246        )?);
247        guard.insert(port, Arc::clone(&listener));
248        Ok(listener)
249    }
250}
251
252impl Listener {
253    fn new(
254        protocol: NetProtocol,
255        domain: String,
256        port: u16,
257        tls_profile: ListenerTlsProfile,
258        server_config: Arc<ServerConfig>,
259    ) -> Result<Self, HyperError> {
260        ensure_http_protocol(protocol)?;
261        let addr = SocketAddr::from(([0, 0, 0, 0], port));
262        let std_listener = std::net::TcpListener::bind(addr).map_err(HyperError::Bind)?;
263        std_listener
264            .set_nonblocking(true)
265            .map_err(HyperError::NonBlocking)?;
266        let listener = TokioTcpListener::from_std(std_listener).map_err(HyperError::Bind)?;
267
268        let (pending_tx, pending_rx) = mpsc::channel(PENDING_QUEUE);
269        tokio::spawn(run_listener(
270            listener,
271            protocol,
272            domain.clone(),
273            server_config,
274            pending_tx,
275        ));
276
277        Ok(Self {
278            protocol,
279            domain,
280            pending_rx: tokio::sync::Mutex::new(pending_rx),
281            tls_profile,
282        })
283    }
284}
285
286impl ListenerHandle {
287    fn new(listener: Arc<Listener>) -> Self {
288        Self { listener }
289    }
290
291    /// Return the domain bound by the listener.
292    pub fn domain(&self) -> &str {
293        &self.listener.domain
294    }
295
296    /// Return the protocol bound by the listener.
297    pub fn protocol(&self) -> NetProtocol {
298        self.listener.protocol
299    }
300}
301
302impl HyperDriver {
303    /// Create a new driver instance with an already validated certificate and private key.
304    pub fn new(certified_key: Arc<sign::CertifiedKey>) -> Result<Arc<Self>, HyperError> {
305        let default_cert_chain = certified_key
306            .cert
307            .iter()
308            .map(|cert| cert.as_ref().to_vec())
309            .collect::<Vec<_>>();
310        let client_verifier = build_client_verifier(None, false)?;
311        let default_server_config = build_server_config(
312            Arc::clone(&certified_key),
313            resolve_alpn(NetProtocol::Https, None),
314            client_verifier,
315        )?;
316        let default_client_config = build_client_config(NetProtocol::Https, None)?;
317        Ok(Arc::new(Self {
318            registry: Arc::new(ListenerRegistry::new()),
319            default_cert_chain,
320            default_server_config,
321            default_client_config,
322        }))
323    }
324}
325
326impl HttpReader {
327    fn outbound(state: Arc<OutboundState>) -> Self {
328        Self {
329            state: ConnectionKind::Outbound(state),
330        }
331    }
332
333    fn inbound(state: Arc<InboundState>) -> Self {
334        Self {
335            state: ConnectionKind::Inbound(state),
336        }
337    }
338}
339
340impl HttpWriter {
341    fn outbound(state: Arc<OutboundState>) -> Self {
342        Self {
343            state: ConnectionKind::Outbound(state),
344        }
345    }
346
347    fn inbound(state: Arc<InboundState>) -> Self {
348        Self {
349            state: ConnectionKind::Inbound(state),
350        }
351    }
352}
353
354impl Drop for HttpWriter {
355    fn drop(&mut self) {
356        match &self.state {
357            ConnectionKind::Outbound(state) => {
358                state.closed.store(true, Ordering::SeqCst);
359                state.response_notify.notify_waiters();
360            }
361            ConnectionKind::Inbound(state) => {
362                let response = match state.response.lock() {
363                    Ok(mut guard) => std::mem::take(&mut *guard),
364                    Err(err) => {
365                        debug!(err = %err, "response buffer lock poisoned");
366                        Vec::new()
367                    }
368                };
369                let responder = match state.responder.lock() {
370                    Ok(mut guard) => guard.take(),
371                    Err(err) => {
372                        debug!(err = %err, "response channel lock poisoned");
373                        None
374                    }
375                };
376                if let Some(responder) = responder
377                    && responder.send(response).is_err()
378                {
379                    debug!("response receiver dropped before completion");
380                }
381            }
382        }
383    }
384}
385
386impl NetCapability for HyperDriver {
387    type Handle = ListenerHandle;
388    type Reader = HttpReader;
389    type Writer = HttpWriter;
390    type Error = HyperError;
391
392    fn create(
393        &self,
394        protocol: NetProtocol,
395        domain: &str,
396        port: u16,
397        tls: Option<Arc<TlsServerConfig>>,
398    ) -> BoxFuture<'_, Result<Self::Handle, Self::Error>> {
399        let registry = Arc::clone(&self.registry);
400        let domain = domain.to_string();
401        let default_cert_chain = self.default_cert_chain.clone();
402        let default_server_config = Arc::clone(&self.default_server_config);
403
404        Box::pin(async move {
405            ensure_http_protocol(protocol)?;
406            let (server_config, tls_profile) = match tls.as_ref() {
407                Some(config) => {
408                    let alpn = resolve_alpn(protocol, config.alpn.as_ref());
409                    let client_verifier = build_client_verifier(
410                        config.client_ca_pem.as_ref(),
411                        config.require_client_auth,
412                    )?;
413                    let (certified_key, cert_chain) = certified_key_from_config(config)?;
414                    let server_config =
415                        build_server_config(certified_key, alpn.clone(), client_verifier)?;
416                    let profile = ListenerTlsProfile {
417                        cert_chain,
418                        alpn,
419                        client_ca_pem: config.client_ca_pem.clone(),
420                        require_client_auth: config.require_client_auth,
421                    };
422                    (server_config, profile)
423                }
424                None => {
425                    let alpn = resolve_alpn(protocol, None);
426                    let profile = ListenerTlsProfile {
427                        cert_chain: default_cert_chain,
428                        alpn,
429                        client_ca_pem: None,
430                        require_client_auth: false,
431                    };
432                    (default_server_config, profile)
433                }
434            };
435            let listener =
436                registry.get_or_try_init(protocol, &domain, port, tls_profile, server_config)?;
437            Ok(ListenerHandle::new(listener))
438        })
439    }
440
441    fn connect(
442        &self,
443        protocol: NetProtocol,
444        domain: &str,
445        port: u16,
446        tls: Option<Arc<TlsClientConfig>>,
447    ) -> BoxFuture<'_, Result<(Self::Reader, Self::Writer, String), Self::Error>> {
448        let domain = domain.to_string();
449        let default_client_config = Arc::clone(&self.default_client_config);
450
451        Box::pin(async move {
452            ensure_http_protocol(protocol)?;
453            let tls = tls.as_deref();
454            let client_config = match tls {
455                Some(config) => build_client_config(protocol, Some(config))?,
456                None => default_client_config,
457            };
458            let stream = connect_stream(protocol, &domain, port, client_config).await?;
459            let stream = TokioIo::new(stream);
460            let sender = match protocol {
461                NetProtocol::Http => {
462                    let (sender, connection) =
463                        http1::handshake(stream).await.map_err(HyperError::Hyper)?;
464                    tokio::spawn(async move {
465                        if let Err(err) = connection.await {
466                            debug!(err = %err, "http connection terminated");
467                        }
468                    });
469                    OutboundSender::Http1(sender)
470                }
471                NetProtocol::Https => {
472                    let (sender, connection) = http2::handshake(TokioExecutor::new(), stream)
473                        .await
474                        .map_err(HyperError::Hyper)?;
475                    tokio::spawn(async move {
476                        if let Err(err) = connection.await {
477                            debug!(err = %err, "http connection terminated");
478                        }
479                    });
480                    OutboundSender::Http2(sender)
481                }
482                _ => return Err(HyperError::UnsupportedProtocol { protocol }),
483            };
484
485            let state = Arc::new(OutboundState {
486                protocol,
487                domain: domain.clone(),
488                port,
489                sender: tokio::sync::Mutex::new(sender),
490                response: tokio::sync::Mutex::new(VecDeque::new()),
491                response_notify: Notify::new(),
492                closed: AtomicBool::new(false),
493            });
494
495            let reader = HttpReader::outbound(Arc::clone(&state));
496            let writer = HttpWriter::outbound(state);
497            Ok((reader, writer, format!("{domain}:{port}")))
498        })
499    }
500
501    fn accept(
502        &self,
503        handle: &Self::Handle,
504    ) -> BoxFuture<'_, Result<(Self::Reader, Self::Writer, String), Self::Error>> {
505        let listener = Arc::clone(&handle.listener);
506
507        Box::pin(async move {
508            let pending = {
509                let mut guard = listener.pending_rx.lock().await;
510                guard.recv().await
511            }
512            .ok_or(HyperError::ListenerClosed)?;
513
514            let state = Arc::new(InboundState {
515                protocol: listener.protocol,
516                request: Mutex::new(pending.request_bytes.into()),
517                response: Mutex::new(Vec::new()),
518                responder: Mutex::new(Some(pending.responder)),
519            });
520
521            let reader = HttpReader::inbound(Arc::clone(&state));
522            let writer = HttpWriter::inbound(state);
523            Ok((reader, writer, pending.remote_addr))
524        })
525    }
526}
527
528impl IoCapability for HyperDriver {
529    type Handle = ();
530    type Reader = HttpReader;
531    type Writer = HttpWriter;
532    type Error = HyperError;
533
534    fn new_writer(&self, _handle: &Self::Handle) -> Result<Self::Writer, Self::Error> {
535        Err(HyperError::Unsupported)
536    }
537
538    fn new_reader(&self, _handle: &Self::Handle) -> Result<Self::Reader, Self::Error> {
539        Err(HyperError::Unsupported)
540    }
541
542    async fn read(&self, reader: &mut Self::Reader, len: usize) -> Result<IoFrame, Self::Error> {
543        match &reader.state {
544            ConnectionKind::Outbound(state) => read_outbound(state, len).await,
545            ConnectionKind::Inbound(state) => read_inbound(state, len),
546        }
547    }
548
549    async fn write(&self, writer: &mut Self::Writer, bytes: &[u8]) -> Result<(), Self::Error> {
550        match &writer.state {
551            ConnectionKind::Outbound(state) => write_outbound(state, bytes).await,
552            ConnectionKind::Inbound(state) => write_inbound(state, bytes),
553        }
554    }
555}
556
557impl From<HyperError> for GuestError {
558    fn from(value: HyperError) -> Self {
559        match value {
560            HyperError::HttpParse(_) => GuestError::InvalidArgument,
561            HyperError::HttpIncomplete => GuestError::InvalidArgument,
562            HyperError::Certificate(_) => GuestError::InvalidArgument,
563            HyperError::PrivateKey(_) => GuestError::InvalidArgument,
564            HyperError::ClientKeyMissing => GuestError::InvalidArgument,
565            HyperError::ClientAuthMissing => GuestError::InvalidArgument,
566            HyperError::ClientAuth(_) => GuestError::InvalidArgument,
567            HyperError::InvalidHeaderName(_) => GuestError::InvalidArgument,
568            HyperError::InvalidHeaderValue(_) => GuestError::InvalidArgument,
569            HyperError::InvalidMethod(_) => GuestError::InvalidArgument,
570            HyperError::InvalidUri(_) => GuestError::InvalidArgument,
571            HyperError::InvalidStatus => GuestError::InvalidArgument,
572            HyperError::ContentLengthMismatch { .. } => GuestError::InvalidArgument,
573            HyperError::HostMismatch => GuestError::InvalidArgument,
574            HyperError::TlsConfigMismatch => GuestError::InvalidArgument,
575            HyperError::UnsupportedProtocol { .. } => GuestError::InvalidArgument,
576            HyperError::TransferEncoding => GuestError::InvalidArgument,
577            _ => GuestError::Subsystem(value.to_string()),
578        }
579    }
580}
581
582fn ensure_http_protocol(protocol: NetProtocol) -> Result<(), HyperError> {
583    match protocol {
584        NetProtocol::Http | NetProtocol::Https => Ok(()),
585        _ => Err(HyperError::UnsupportedProtocol { protocol }),
586    }
587}