1use std::{
5 collections::{HashMap, VecDeque},
6 fmt,
7 net::SocketAddr,
8 sync::{
9 Arc, Mutex,
10 atomic::{AtomicBool, Ordering},
11 },
12};
13
14use futures_util::future::BoxFuture;
15use http_body_util::Full;
16use hyper::{
17 body::Bytes,
18 client::conn::{http1, http2},
19 http::{
20 header::{InvalidHeaderName, InvalidHeaderValue},
21 method::InvalidMethod,
22 uri::InvalidUri,
23 },
24};
25use hyper_util::rt::{TokioExecutor, TokioIo};
26use rustls::{ClientConfig, ServerConfig, sign};
27use selium_abi::{IoFrame, NetProtocol};
28use selium_kernel::{
29 drivers::{
30 io::IoCapability,
31 net::{NetCapability, TlsClientConfig, TlsServerConfig},
32 },
33 guest_data::GuestError,
34};
35use thiserror::Error;
36use tokio::{
37 io::{AsyncRead, AsyncWrite},
38 net::TcpListener as TokioTcpListener,
39 sync::{Notify, mpsc, oneshot},
40};
41use tracing::debug;
42
43use crate::{
44 client::{connect_stream, read_outbound, write_outbound},
45 server::{read_inbound, run_listener, write_inbound},
46 tls::{
47 build_client_config, build_client_verifier, build_server_config, certified_key_from_config,
48 resolve_alpn,
49 },
50};
51
52pub(crate) type HyperBody = Full<Bytes>;
54pub(crate) type HyperStream = Box<dyn HyperIo + 'static>;
56
57const PENDING_QUEUE: usize = 64;
58
59pub(crate) trait HyperIo: AsyncRead + AsyncWrite + Unpin + Send {}
60
61impl<T> HyperIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
62
63pub(crate) enum OutboundSender {
64 Http1(http1::SendRequest<HyperBody>),
65 Http2(http2::SendRequest<HyperBody>),
66}
67
68#[derive(Error, Debug)]
70pub enum HyperError {
71 #[error("listener closed before any request arrived")]
72 ListenerClosed,
73 #[error("port out of range")]
74 PortRange,
75 #[error("failed to bind listener: {0}")]
76 Bind(#[source] std::io::Error),
77 #[error("failed to mark listener non-blocking: {0}")]
78 NonBlocking(#[source] std::io::Error),
79 #[error("failed to connect: {0}")]
80 Connect(#[source] std::io::Error),
81 #[error("TLS handshake failed: {0}")]
82 Tls(#[source] std::io::Error),
83 #[error("failed to build TLS config: {0}")]
84 Rustls(#[source] rustls::Error),
85 #[error("failed to parse certificate chain: {0}")]
86 Certificate(String),
87 #[error("failed to parse private key: {0}")]
88 PrivateKey(String),
89 #[error("client certificate provided without private key")]
90 ClientKeyMissing,
91 #[error("client authentication requires a CA bundle")]
92 ClientAuthMissing,
93 #[error("failed to configure client authentication: {0}")]
94 ClientAuth(String),
95 #[error("HTTP connection failed: {0}")]
96 Hyper(#[source] hyper::Error),
97 #[error("failed to build HTTP message: {0}")]
98 Http(#[source] hyper::http::Error),
99 #[error("HTTP parse error: {0}")]
100 HttpParse(String),
101 #[error("HTTP message incomplete")]
102 HttpIncomplete,
103 #[error("invalid header name: {0}")]
104 InvalidHeaderName(#[source] InvalidHeaderName),
105 #[error("invalid header value: {0}")]
106 InvalidHeaderValue(#[source] InvalidHeaderValue),
107 #[error("invalid method: {0}")]
108 InvalidMethod(#[source] InvalidMethod),
109 #[error("invalid URI: {0}")]
110 InvalidUri(#[source] InvalidUri),
111 #[error("invalid status code")]
112 InvalidStatus,
113 #[error("unsupported transfer encoding")]
114 TransferEncoding,
115 #[error("content length mismatch (expected {expected}, got {actual})")]
116 ContentLengthMismatch { expected: usize, actual: usize },
117 #[error("host header does not match target domain")]
118 HostMismatch,
119 #[error("response channel closed")]
120 ResponseChannelClosed,
121 #[error("mutex poisoned")]
122 Lock,
123 #[error("operation unsupported")]
124 Unsupported,
125 #[error("TLS configuration does not match existing listener")]
126 TlsConfigMismatch,
127 #[error("unsupported protocol: {protocol:?}")]
128 UnsupportedProtocol { protocol: NetProtocol },
129}
130
131struct ListenerRegistry {
132 listeners: Mutex<HashMap<u16, Arc<Listener>>>,
133}
134
135#[derive(Clone, Debug, PartialEq, Eq)]
136struct ListenerTlsProfile {
137 cert_chain: Vec<Vec<u8>>,
138 alpn: Vec<Vec<u8>>,
139 client_ca_pem: Option<Vec<u8>>,
140 require_client_auth: bool,
141}
142
143struct Listener {
144 protocol: NetProtocol,
145 domain: String,
146 pending_rx: tokio::sync::Mutex<mpsc::Receiver<PendingRequest>>,
147 tls_profile: ListenerTlsProfile,
148}
149
150pub(crate) struct PendingRequest {
151 pub(crate) request_bytes: Vec<u8>,
152 pub(crate) responder: oneshot::Sender<Vec<u8>>,
153 pub(crate) remote_addr: String,
154}
155
156pub(crate) struct OutboundState {
157 pub(crate) protocol: NetProtocol,
158 pub(crate) domain: String,
159 pub(crate) port: u16,
160 pub(crate) sender: tokio::sync::Mutex<OutboundSender>,
161 pub(crate) response: tokio::sync::Mutex<VecDeque<u8>>,
162 pub(crate) response_notify: Notify,
163 pub(crate) closed: AtomicBool,
164}
165
166pub(crate) struct InboundState {
167 pub(crate) protocol: NetProtocol,
168 pub(crate) request: Mutex<VecDeque<u8>>,
169 pub(crate) response: Mutex<Vec<u8>>,
170 pub(crate) responder: Mutex<Option<oneshot::Sender<Vec<u8>>>>,
171}
172
173pub(crate) enum ConnectionKind {
174 Outbound(Arc<OutboundState>),
175 Inbound(Arc<InboundState>),
176}
177
178#[derive(Clone)]
180pub struct ListenerHandle {
181 listener: Arc<Listener>,
182}
183
184pub struct HyperDriver {
186 registry: Arc<ListenerRegistry>,
187 default_cert_chain: Vec<Vec<u8>>,
188 default_server_config: Arc<ServerConfig>,
189 default_client_config: Arc<ClientConfig>,
190}
191
192pub struct HttpReader {
194 state: ConnectionKind,
195}
196
197pub struct HttpWriter {
199 state: ConnectionKind,
200}
201
202impl fmt::Debug for ListenerHandle {
203 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204 f.debug_struct("ListenerHandle")
205 .field("protocol", &self.listener.protocol)
206 .field("domain", &self.listener.domain)
207 .finish()
208 }
209}
210
211impl ListenerRegistry {
212 fn new() -> Self {
213 Self {
214 listeners: Mutex::new(HashMap::new()),
215 }
216 }
217
218 fn get_or_try_init(
219 &self,
220 protocol: NetProtocol,
221 domain: &str,
222 port: u16,
223 tls_profile: ListenerTlsProfile,
224 server_config: Arc<ServerConfig>,
225 ) -> Result<Arc<Listener>, HyperError> {
226 let mut guard = self.listeners.lock().map_err(|_| HyperError::Lock)?;
227 if let Some(listener) = guard.get(&port) {
228 if listener.protocol != protocol {
229 return Err(HyperError::UnsupportedProtocol { protocol });
230 }
231 if !listener.domain.eq_ignore_ascii_case(domain) {
232 return Err(HyperError::HostMismatch);
233 }
234 if listener.tls_profile != tls_profile {
235 return Err(HyperError::TlsConfigMismatch);
236 }
237 return Ok(Arc::clone(listener));
238 }
239
240 let listener = Arc::new(Listener::new(
241 protocol,
242 domain.to_string(),
243 port,
244 tls_profile,
245 server_config,
246 )?);
247 guard.insert(port, Arc::clone(&listener));
248 Ok(listener)
249 }
250}
251
252impl Listener {
253 fn new(
254 protocol: NetProtocol,
255 domain: String,
256 port: u16,
257 tls_profile: ListenerTlsProfile,
258 server_config: Arc<ServerConfig>,
259 ) -> Result<Self, HyperError> {
260 ensure_http_protocol(protocol)?;
261 let addr = SocketAddr::from(([0, 0, 0, 0], port));
262 let std_listener = std::net::TcpListener::bind(addr).map_err(HyperError::Bind)?;
263 std_listener
264 .set_nonblocking(true)
265 .map_err(HyperError::NonBlocking)?;
266 let listener = TokioTcpListener::from_std(std_listener).map_err(HyperError::Bind)?;
267
268 let (pending_tx, pending_rx) = mpsc::channel(PENDING_QUEUE);
269 tokio::spawn(run_listener(
270 listener,
271 protocol,
272 domain.clone(),
273 server_config,
274 pending_tx,
275 ));
276
277 Ok(Self {
278 protocol,
279 domain,
280 pending_rx: tokio::sync::Mutex::new(pending_rx),
281 tls_profile,
282 })
283 }
284}
285
286impl ListenerHandle {
287 fn new(listener: Arc<Listener>) -> Self {
288 Self { listener }
289 }
290
291 pub fn domain(&self) -> &str {
293 &self.listener.domain
294 }
295
296 pub fn protocol(&self) -> NetProtocol {
298 self.listener.protocol
299 }
300}
301
302impl HyperDriver {
303 pub fn new(certified_key: Arc<sign::CertifiedKey>) -> Result<Arc<Self>, HyperError> {
305 let default_cert_chain = certified_key
306 .cert
307 .iter()
308 .map(|cert| cert.as_ref().to_vec())
309 .collect::<Vec<_>>();
310 let client_verifier = build_client_verifier(None, false)?;
311 let default_server_config = build_server_config(
312 Arc::clone(&certified_key),
313 resolve_alpn(NetProtocol::Https, None),
314 client_verifier,
315 )?;
316 let default_client_config = build_client_config(NetProtocol::Https, None)?;
317 Ok(Arc::new(Self {
318 registry: Arc::new(ListenerRegistry::new()),
319 default_cert_chain,
320 default_server_config,
321 default_client_config,
322 }))
323 }
324}
325
326impl HttpReader {
327 fn outbound(state: Arc<OutboundState>) -> Self {
328 Self {
329 state: ConnectionKind::Outbound(state),
330 }
331 }
332
333 fn inbound(state: Arc<InboundState>) -> Self {
334 Self {
335 state: ConnectionKind::Inbound(state),
336 }
337 }
338}
339
340impl HttpWriter {
341 fn outbound(state: Arc<OutboundState>) -> Self {
342 Self {
343 state: ConnectionKind::Outbound(state),
344 }
345 }
346
347 fn inbound(state: Arc<InboundState>) -> Self {
348 Self {
349 state: ConnectionKind::Inbound(state),
350 }
351 }
352}
353
354impl Drop for HttpWriter {
355 fn drop(&mut self) {
356 match &self.state {
357 ConnectionKind::Outbound(state) => {
358 state.closed.store(true, Ordering::SeqCst);
359 state.response_notify.notify_waiters();
360 }
361 ConnectionKind::Inbound(state) => {
362 let response = match state.response.lock() {
363 Ok(mut guard) => std::mem::take(&mut *guard),
364 Err(err) => {
365 debug!(err = %err, "response buffer lock poisoned");
366 Vec::new()
367 }
368 };
369 let responder = match state.responder.lock() {
370 Ok(mut guard) => guard.take(),
371 Err(err) => {
372 debug!(err = %err, "response channel lock poisoned");
373 None
374 }
375 };
376 if let Some(responder) = responder
377 && responder.send(response).is_err()
378 {
379 debug!("response receiver dropped before completion");
380 }
381 }
382 }
383 }
384}
385
386impl NetCapability for HyperDriver {
387 type Handle = ListenerHandle;
388 type Reader = HttpReader;
389 type Writer = HttpWriter;
390 type Error = HyperError;
391
392 fn create(
393 &self,
394 protocol: NetProtocol,
395 domain: &str,
396 port: u16,
397 tls: Option<Arc<TlsServerConfig>>,
398 ) -> BoxFuture<'_, Result<Self::Handle, Self::Error>> {
399 let registry = Arc::clone(&self.registry);
400 let domain = domain.to_string();
401 let default_cert_chain = self.default_cert_chain.clone();
402 let default_server_config = Arc::clone(&self.default_server_config);
403
404 Box::pin(async move {
405 ensure_http_protocol(protocol)?;
406 let (server_config, tls_profile) = match tls.as_ref() {
407 Some(config) => {
408 let alpn = resolve_alpn(protocol, config.alpn.as_ref());
409 let client_verifier = build_client_verifier(
410 config.client_ca_pem.as_ref(),
411 config.require_client_auth,
412 )?;
413 let (certified_key, cert_chain) = certified_key_from_config(config)?;
414 let server_config =
415 build_server_config(certified_key, alpn.clone(), client_verifier)?;
416 let profile = ListenerTlsProfile {
417 cert_chain,
418 alpn,
419 client_ca_pem: config.client_ca_pem.clone(),
420 require_client_auth: config.require_client_auth,
421 };
422 (server_config, profile)
423 }
424 None => {
425 let alpn = resolve_alpn(protocol, None);
426 let profile = ListenerTlsProfile {
427 cert_chain: default_cert_chain,
428 alpn,
429 client_ca_pem: None,
430 require_client_auth: false,
431 };
432 (default_server_config, profile)
433 }
434 };
435 let listener =
436 registry.get_or_try_init(protocol, &domain, port, tls_profile, server_config)?;
437 Ok(ListenerHandle::new(listener))
438 })
439 }
440
441 fn connect(
442 &self,
443 protocol: NetProtocol,
444 domain: &str,
445 port: u16,
446 tls: Option<Arc<TlsClientConfig>>,
447 ) -> BoxFuture<'_, Result<(Self::Reader, Self::Writer, String), Self::Error>> {
448 let domain = domain.to_string();
449 let default_client_config = Arc::clone(&self.default_client_config);
450
451 Box::pin(async move {
452 ensure_http_protocol(protocol)?;
453 let tls = tls.as_deref();
454 let client_config = match tls {
455 Some(config) => build_client_config(protocol, Some(config))?,
456 None => default_client_config,
457 };
458 let stream = connect_stream(protocol, &domain, port, client_config).await?;
459 let stream = TokioIo::new(stream);
460 let sender = match protocol {
461 NetProtocol::Http => {
462 let (sender, connection) =
463 http1::handshake(stream).await.map_err(HyperError::Hyper)?;
464 tokio::spawn(async move {
465 if let Err(err) = connection.await {
466 debug!(err = %err, "http connection terminated");
467 }
468 });
469 OutboundSender::Http1(sender)
470 }
471 NetProtocol::Https => {
472 let (sender, connection) = http2::handshake(TokioExecutor::new(), stream)
473 .await
474 .map_err(HyperError::Hyper)?;
475 tokio::spawn(async move {
476 if let Err(err) = connection.await {
477 debug!(err = %err, "http connection terminated");
478 }
479 });
480 OutboundSender::Http2(sender)
481 }
482 _ => return Err(HyperError::UnsupportedProtocol { protocol }),
483 };
484
485 let state = Arc::new(OutboundState {
486 protocol,
487 domain: domain.clone(),
488 port,
489 sender: tokio::sync::Mutex::new(sender),
490 response: tokio::sync::Mutex::new(VecDeque::new()),
491 response_notify: Notify::new(),
492 closed: AtomicBool::new(false),
493 });
494
495 let reader = HttpReader::outbound(Arc::clone(&state));
496 let writer = HttpWriter::outbound(state);
497 Ok((reader, writer, format!("{domain}:{port}")))
498 })
499 }
500
501 fn accept(
502 &self,
503 handle: &Self::Handle,
504 ) -> BoxFuture<'_, Result<(Self::Reader, Self::Writer, String), Self::Error>> {
505 let listener = Arc::clone(&handle.listener);
506
507 Box::pin(async move {
508 let pending = {
509 let mut guard = listener.pending_rx.lock().await;
510 guard.recv().await
511 }
512 .ok_or(HyperError::ListenerClosed)?;
513
514 let state = Arc::new(InboundState {
515 protocol: listener.protocol,
516 request: Mutex::new(pending.request_bytes.into()),
517 response: Mutex::new(Vec::new()),
518 responder: Mutex::new(Some(pending.responder)),
519 });
520
521 let reader = HttpReader::inbound(Arc::clone(&state));
522 let writer = HttpWriter::inbound(state);
523 Ok((reader, writer, pending.remote_addr))
524 })
525 }
526}
527
528impl IoCapability for HyperDriver {
529 type Handle = ();
530 type Reader = HttpReader;
531 type Writer = HttpWriter;
532 type Error = HyperError;
533
534 fn new_writer(&self, _handle: &Self::Handle) -> Result<Self::Writer, Self::Error> {
535 Err(HyperError::Unsupported)
536 }
537
538 fn new_reader(&self, _handle: &Self::Handle) -> Result<Self::Reader, Self::Error> {
539 Err(HyperError::Unsupported)
540 }
541
542 async fn read(&self, reader: &mut Self::Reader, len: usize) -> Result<IoFrame, Self::Error> {
543 match &reader.state {
544 ConnectionKind::Outbound(state) => read_outbound(state, len).await,
545 ConnectionKind::Inbound(state) => read_inbound(state, len),
546 }
547 }
548
549 async fn write(&self, writer: &mut Self::Writer, bytes: &[u8]) -> Result<(), Self::Error> {
550 match &writer.state {
551 ConnectionKind::Outbound(state) => write_outbound(state, bytes).await,
552 ConnectionKind::Inbound(state) => write_inbound(state, bytes),
553 }
554 }
555}
556
557impl From<HyperError> for GuestError {
558 fn from(value: HyperError) -> Self {
559 match value {
560 HyperError::HttpParse(_) => GuestError::InvalidArgument,
561 HyperError::HttpIncomplete => GuestError::InvalidArgument,
562 HyperError::Certificate(_) => GuestError::InvalidArgument,
563 HyperError::PrivateKey(_) => GuestError::InvalidArgument,
564 HyperError::ClientKeyMissing => GuestError::InvalidArgument,
565 HyperError::ClientAuthMissing => GuestError::InvalidArgument,
566 HyperError::ClientAuth(_) => GuestError::InvalidArgument,
567 HyperError::InvalidHeaderName(_) => GuestError::InvalidArgument,
568 HyperError::InvalidHeaderValue(_) => GuestError::InvalidArgument,
569 HyperError::InvalidMethod(_) => GuestError::InvalidArgument,
570 HyperError::InvalidUri(_) => GuestError::InvalidArgument,
571 HyperError::InvalidStatus => GuestError::InvalidArgument,
572 HyperError::ContentLengthMismatch { .. } => GuestError::InvalidArgument,
573 HyperError::HostMismatch => GuestError::InvalidArgument,
574 HyperError::TlsConfigMismatch => GuestError::InvalidArgument,
575 HyperError::UnsupportedProtocol { .. } => GuestError::InvalidArgument,
576 HyperError::TransferEncoding => GuestError::InvalidArgument,
577 _ => GuestError::Subsystem(value.to_string()),
578 }
579 }
580}
581
582fn ensure_http_protocol(protocol: NetProtocol) -> Result<(), HyperError> {
583 match protocol {
584 NetProtocol::Http | NetProtocol::Https => Ok(()),
585 _ => Err(HyperError::UnsupportedProtocol { protocol }),
586 }
587}