1use std::net::SocketAddr;
2use tokio::net::TcpListener;
3use hyper::server::conn::http1;
4use hyper_util::rt::TokioIo;
5use hyper_util::service::TowerToHyperService;
6use crate::error::{Error, Result};
7use crate::types::{OxiditeRequest, OxiditeResponse};
8use tower_service::Service;
9use std::error::Error as StdError;
10
11use http_body_util::BodyExt;
12
13use std::task::{Context, Poll};
14
15#[cfg(feature = "http3")]
16pub mod http3_server;
17
18#[cfg(feature = "http3")]
19pub use http3_server::Http3Server;
20
21#[derive(Clone)]
23pub struct BodyAdapter<S>(S);
24
25impl<S> BodyAdapter<S> {
26 pub fn new(service: S) -> Self {
27 Self(service)
28 }
29}
30
31use futures_util::future::Map;
32use futures_util::FutureExt;
33
34impl<S> Service<hyper::Request<hyper::body::Incoming>> for BodyAdapter<S>
35where
36 S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone,
37{
38 type Response = hyper::Response<crate::types::BoxBody>;
39 type Error = Error;
40 type Future = Map<S::Future, fn(std::result::Result<OxiditeResponse, Error>) -> std::result::Result<hyper::Response<crate::types::BoxBody>, Error>>;
41
42 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
43 self.0.poll_ready(cx)
44 }
45
46 fn call(&mut self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
47 let req = req.map(|b| b.map_err(|e| e.into()).boxed());
48 fn map_response(res: std::result::Result<OxiditeResponse, Error>) -> std::result::Result<hyper::Response<crate::types::BoxBody>, Error> {
49 match res {
50 Ok(response) => Ok(response.into()),
51 Err(error) => Ok(OxiditeResponse::from(error).into()),
54 }
55 }
56 self.0.call(req).map(map_response)
57 }
58}
59
60
61pub struct Server<S> {
62 service: S,
63}
64
65impl<S> Server<S>
66where
67 S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone + Send + Sync + 'static,
68 S::Future: Send + 'static,
69{
70 pub fn new(service: S) -> Self {
71 Self {
72 service,
73 }
74 }
75
76 pub async fn listen(self, addr: SocketAddr) -> Result<()> {
77 let listener = TcpListener::bind(addr).await?;
78 println!("Listening on http://{}", addr);
79
80 loop {
81 let (stream, _) = listener.accept().await?;
82 let io = TokioIo::new(stream);
83 let service = self.service.clone();
84
85 tokio::task::spawn(async move {
86 let service = BodyAdapter::new(service);
87 let hyper_service = TowerToHyperService::new(service);
88
89 if let Err(err) = http1::Builder::new()
90 .serve_connection(io, hyper_service)
91 .await
92 {
93 if let Some(service_err) = err.source().and_then(|e| e.downcast_ref::<Error>()) {
96 if service_err.is_server_error() {
98 eprintln!("Server error: {}", service_err);
99 }
100 } else {
102 let err_msg = err.to_string();
105 if !err_msg.contains("NotFound") && !err_msg.contains("connection closed") {
107 eprintln!("Connection error: {}", err);
108 }
109 }
110 }
111 });
112 }
113 }
114
115 #[cfg(feature = "http3")]
117 pub async fn listen_h3(self, addr: SocketAddr, cert_pem: &str, key_pem: &str) -> Result<()> {
118 use rustls::ServerConfig;
119 use rustls_pemfile::{certs, pkcs8_private_keys};
120 use std::io::Cursor;
121
122 let http1_addr = addr;
124 let http1_service = self.service.clone();
125
126 tokio::spawn(async move {
127 let listener = TcpListener::bind(http1_addr).await.unwrap();
128 println!("HTTP/1.1 server listening on http://{}", http1_addr);
129
130 loop {
131 let (stream, _) = listener.accept().await.unwrap();
132 let io = TokioIo::new(stream);
133 let service = http1_service.clone();
134
135 tokio::task::spawn(async move {
136 let service = BodyAdapter::new(service);
137 let hyper_service = TowerToHyperService::new(service);
138
139 if let Err(err) = http1::Builder::new()
140 .serve_connection(io, hyper_service)
141 .await
142 {
143 if let Some(service_err) = err.source().and_then(|e| e.downcast_ref::<Error>()) {
145 if service_err.is_server_error() {
146 eprintln!("HTTP/1.1 server error: {}", service_err);
147 }
148 } else {
149 let err_msg = err.to_string();
150 if !err_msg.contains("NotFound") && !err_msg.contains("connection closed") {
151 eprintln!("HTTP/1.1 connection error: {}", err);
152 }
153 }
154 }
155 });
156 }
157 });
158
159 let cert_chain = certs(&mut Cursor::new(cert_pem))
161 .collect::<std::result::Result<Vec<_>, _>>()
162 .map_err(|e| crate::error::Error::InternalServerError(e.to_string()))?;
163
164 let mut keys = pkcs8_private_keys(&mut Cursor::new(key_pem))
165 .collect::<std::result::Result<Vec<_>, _>>()?;
166
167 if keys.is_empty() {
168 return Err(crate::error::Error::InternalServerError("No private keys found".to_string()));
169 }
170
171 let tls_config = ServerConfig::builder()
172 .with_no_client_auth()
173 .with_single_cert(cert_chain, rustls::pki_types::PrivateKeyDer::Pkcs8(keys.remove(0)))
174 .map_err(|e| crate::error::Error::InternalServerError(e.to_string()))?;
175
176 let http3_server = Http3Server::new(self.service);
177 http3_server.listen(addr, tls_config).await?;
178
179 Ok(())
180 }
181}