1use std::net::SocketAddr;
2use tokio::net::TcpListener;
3use hyper::server::conn::http1;
4use hyper_util::rt::TokioIo;
5use hyper_util::service::TowerToHyperService;
6use crate::error::{Error, Result};
7use crate::types::{OxiditeRequest, OxiditeResponse};
8use tower_service::Service;
9use std::error::Error as StdError;
10
11use http_body_util::BodyExt;
12
13use std::task::{Context, Poll};
14
15#[cfg(feature = "http3")]
16pub mod http3_server;
17
18#[cfg(feature = "http3")]
19pub use http3_server::Http3Server;
20
21#[derive(Clone)]
23pub struct BodyAdapter<S>(S);
24
25impl<S> BodyAdapter<S> {
26 pub fn new(service: S) -> Self {
27 Self(service)
28 }
29}
30
31use futures_util::future::Map;
32use futures_util::FutureExt;
33
34impl<S> Service<hyper::Request<hyper::body::Incoming>> for BodyAdapter<S>
35where
36 S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone,
37{
38 type Response = hyper::Response<crate::types::BoxBody>;
39 type Error = Error;
40 type Future = Map<S::Future, fn(std::result::Result<OxiditeResponse, Error>) -> std::result::Result<hyper::Response<crate::types::BoxBody>, Error>>;
41
42 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
43 self.0.poll_ready(cx)
44 }
45
46 fn call(&mut self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
47 let req = req.map(|b| b.map_err(|e| e.into()).boxed());
48 fn map_response(res: std::result::Result<OxiditeResponse, Error>) -> std::result::Result<hyper::Response<crate::types::BoxBody>, Error> {
49 res.map(|r| r.into())
50 }
51 self.0.call(req).map(map_response)
52 }
53}
54
55
56pub struct Server<S> {
57 service: S,
58}
59
60impl<S> Server<S>
61where
62 S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone + Send + Sync + 'static,
63 S::Future: Send + 'static,
64{
65 pub fn new(service: S) -> Self {
66 Self {
67 service,
68 }
69 }
70
71 pub async fn listen(self, addr: SocketAddr) -> Result<()> {
72 let listener = TcpListener::bind(addr).await?;
73 println!("Listening on http://{}", addr);
74
75 loop {
76 let (stream, _) = listener.accept().await?;
77 let io = TokioIo::new(stream);
78 let service = self.service.clone();
79
80 tokio::task::spawn(async move {
81 let service = BodyAdapter::new(service);
82 let hyper_service = TowerToHyperService::new(service);
83
84 if let Err(err) = http1::Builder::new()
85 .serve_connection(io, hyper_service)
86 .await
87 {
88 if let Some(service_err) = err.source().and_then(|e| e.downcast_ref::<Error>()) {
91 if service_err.is_server_error() {
93 eprintln!("Server error: {}", service_err);
94 }
95 } else {
97 let err_msg = err.to_string();
100 if !err_msg.contains("NotFound") && !err_msg.contains("connection closed") {
102 eprintln!("Connection error: {}", err);
103 }
104 }
105 }
106 });
107 }
108 }
109
110 #[cfg(feature = "http3")]
112 pub async fn listen_h3(self, addr: SocketAddr, cert_pem: &str, key_pem: &str) -> Result<()> {
113 use rustls::ServerConfig;
114 use rustls_pemfile::{certs, pkcs8_private_keys};
115 use std::io::Cursor;
116
117 let http1_addr = addr;
119 let http1_service = self.service.clone();
120
121 tokio::spawn(async move {
122 let listener = TcpListener::bind(http1_addr).await.unwrap();
123 println!("HTTP/1.1 server listening on http://{}", http1_addr);
124
125 loop {
126 let (stream, _) = listener.accept().await.unwrap();
127 let io = TokioIo::new(stream);
128 let service = http1_service.clone();
129
130 tokio::task::spawn(async move {
131 let service = BodyAdapter::new(service);
132 let hyper_service = TowerToHyperService::new(service);
133
134 if let Err(err) = http1::Builder::new()
135 .serve_connection(io, hyper_service)
136 .await
137 {
138 if let Some(service_err) = err.source().and_then(|e| e.downcast_ref::<Error>()) {
140 if service_err.is_server_error() {
141 eprintln!("HTTP/1.1 server error: {}", service_err);
142 }
143 } else {
144 let err_msg = err.to_string();
145 if !err_msg.contains("NotFound") && !err_msg.contains("connection closed") {
146 eprintln!("HTTP/1.1 connection error: {}", err);
147 }
148 }
149 }
150 });
151 }
152 });
153
154 let cert_chain = certs(&mut Cursor::new(cert_pem))
156 .collect::<std::result::Result<Vec<_>, _>>()
157 .map_err(|e| crate::error::Error::InternalServerError(e.to_string()))?;
158
159 let mut keys = pkcs8_private_keys(&mut Cursor::new(key_pem))
160 .collect::<std::result::Result<Vec<_>, _>>()?;
161
162 if keys.is_empty() {
163 return Err(crate::error::Error::InternalServerError("No private keys found".to_string()));
164 }
165
166 let tls_config = ServerConfig::builder()
167 .with_no_client_auth()
168 .with_single_cert(cert_chain, rustls::pki_types::PrivateKeyDer::Pkcs8(keys.remove(0)))
169 .map_err(|e| crate::error::Error::InternalServerError(e.to_string()))?;
170
171 let http3_server = Http3Server::new(self.service);
172 http3_server.listen(addr, tls_config).await?;
173
174 Ok(())
175 }
176}