oxidite_core/
server.rs

1use std::net::SocketAddr;
2use tokio::net::TcpListener;
3use hyper::server::conn::http1;
4use hyper_util::rt::TokioIo;
5use hyper_util::service::TowerToHyperService;
6use crate::error::{Error, Result};
7use crate::types::{OxiditeRequest, OxiditeResponse};
8use tower_service::Service;
9
10use http_body_util::BodyExt;
11
12use std::task::{Context, Poll};
13
14#[cfg(feature = "http3")]
15pub mod http3_server;
16
17#[cfg(feature = "http3")]
18pub use http3_server::Http3Server;
19
20/// Adapter to convert hyper::Request<Incoming> to OxiditeRequest
21#[derive(Clone)]
22pub struct BodyAdapter<S>(S);
23
24impl<S> BodyAdapter<S> {
25    pub fn new(service: S) -> Self {
26        Self(service)
27    }
28}
29
30use futures_util::future::Map;
31use futures_util::FutureExt;
32
33impl<S> Service<hyper::Request<hyper::body::Incoming>> for BodyAdapter<S>
34where
35    S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone,
36{
37    type Response = hyper::Response<crate::types::BoxBody>;
38    type Error = Error;
39    type Future = Map<S::Future, fn(std::result::Result<OxiditeResponse, Error>) -> std::result::Result<hyper::Response<crate::types::BoxBody>, Error>>;
40
41    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
42        self.0.poll_ready(cx)
43    }
44
45    fn call(&mut self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
46        let req = req.map(|b| b.map_err(|e| e.into()).boxed());
47        fn map_response(res: std::result::Result<OxiditeResponse, Error>) -> std::result::Result<hyper::Response<crate::types::BoxBody>, Error> {
48            res.map(|r| r.into())
49        }
50        self.0.call(req).map(map_response)
51    }
52}
53
54
55pub struct Server<S> {
56    service: S,
57}
58
59impl<S> Server<S>
60where
61    S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone + Send + Sync + 'static,
62    S::Future: Send + 'static,
63{
64    pub fn new(service: S) -> Self {
65        Self {
66            service,
67        }
68    }
69
70    pub async fn listen(self, addr: SocketAddr) -> Result<()> {
71        let listener = TcpListener::bind(addr).await?;
72        println!("Listening on http://{}", addr);
73
74        loop {
75            let (stream, _) = listener.accept().await?;
76            let io = TokioIo::new(stream);
77            let service = self.service.clone();
78
79            tokio::task::spawn(async move {
80                let service = BodyAdapter::new(service);
81                let hyper_service = TowerToHyperService::new(service);
82                
83                if let Err(err) = http1::Builder::new()
84                    .serve_connection(io, hyper_service)
85                    .await
86                {
87                    // This `err` is a `hyper::Error`, not `crate::error::Error`.
88                    // The user's requested logging for `crate::error::Error` types
89                    // is now handled within the `hyper_compatible_service` wrapper.
90                    // This `eprintln` now only catches connection-level `hyper::Error`s.
91                    eprintln!("Error serving connection: {:?}", err);
92                }
93            });
94        }
95    }
96
97    /// Listen with both HTTP/1.1 and HTTP/3 support
98    #[cfg(feature = "http3")]
99    pub async fn listen_h3(self, addr: SocketAddr, cert_pem: &str, key_pem: &str) -> Result<()> {
100        use rustls::ServerConfig;  
101        use rustls_pemfile::{certs, pkcs8_private_keys};
102        use std::io::Cursor;
103        
104        // Setup HTTP/1.1 server in background
105        let http1_addr = addr;
106        let http1_service = self.service.clone();
107        
108        tokio::spawn(async move {
109            let listener = TcpListener::bind(http1_addr).await.unwrap();
110            println!("HTTP/1.1 server listening on http://{}", http1_addr);
111            
112            loop {
113                let (stream, _) = listener.accept().await.unwrap();
114                let io = TokioIo::new(stream);
115                let service = http1_service.clone();
116                
117                tokio::task::spawn(async move {
118                    let service = BodyAdapter::new(service);
119                    let hyper_service = TowerToHyperService::new(service);
120                    
121                    if let Err(err) = http1::Builder::new()
122                        .serve_connection(io, hyper_service)
123                        .await
124                    {
125                        eprintln!("HTTP/1.1 connection error: {:?}", err);
126                    }
127                });
128            }
129        });
130        
131        // Setup HTTP/3 server
132        let cert_chain = certs(&mut Cursor::new(cert_pem))
133            .collect::<std::result::Result<Vec<_>, _>>()
134            .map_err(|e| crate::error::Error::InternalServerError(e.to_string()))?;
135        
136        let mut keys = pkcs8_private_keys(&mut Cursor::new(key_pem))
137            .collect::<std::result::Result<Vec<_>, _>>()?;
138        
139        if keys.is_empty() {
140            return Err(crate::error::Error::InternalServerError("No private keys found".to_string()));
141        }
142        
143        let tls_config = ServerConfig::builder()
144            .with_no_client_auth()
145            .with_single_cert(cert_chain, rustls::pki_types::PrivateKeyDer::Pkcs8(keys.remove(0)))
146            .map_err(|e| crate::error::Error::InternalServerError(e.to_string()))?;
147        
148        let http3_server = Http3Server::new(self.service);
149        http3_server.listen(addr, tls_config).await?;
150        
151        Ok(())
152    }
153}