1use std::net::SocketAddr;
2use tokio::net::TcpListener;
3use hyper::server::conn::http1;
4use hyper_util::rt::TokioIo;
5use hyper_util::service::TowerToHyperService;
6use crate::error::{Error, Result};
7use crate::types::{OxiditeRequest, OxiditeResponse};
8use tower_service::Service;
9use std::error::Error as StdError;
10
11use http_body_util::BodyExt;
12
13use std::task::{Context, Poll};
14
15#[cfg(feature = "http3")]
16pub mod http3_server;
17
18#[cfg(feature = "http3")]
19pub use http3_server::Http3Server;
20
21#[derive(Clone)]
23pub struct BodyAdapter<S>(S);
24
25impl<S> BodyAdapter<S> {
26 pub fn new(service: S) -> Self {
27 Self(service)
28 }
29}
30
31
32
33use std::pin::Pin;
34
35impl<S> Service<hyper::Request<hyper::body::Incoming>> for BodyAdapter<S>
36where
37 S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone + Send + 'static,
38 S::Future: Send + 'static,
39{
40 type Response = hyper::Response<crate::types::BoxBody>;
41 type Error = Error;
42 type Future = Pin<Box<dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
43
44 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
45 self.0.poll_ready(cx)
46 }
47
48 fn call(&mut self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
49 let accepts_html = req.headers().get(hyper::header::ACCEPT)
50 .map(|h| h.to_str().unwrap_or("").contains("text/html"))
51 .unwrap_or(false);
52
53 let req = req.map(|b| b.map_err(|e| e.into()).boxed());
54 let fut = self.0.call(req);
55
56 Box::pin(async move {
57 match fut.await {
58 Ok(response) => Ok(response.into()),
59 Err(error) => {
60 let env = std::env::var("OXIDITE_ENV").unwrap_or_else(|_| "development".to_string());
61
62 if env == "development" && accepts_html && error.is_server_error() {
63 use bytes::Bytes;
64 use http_body_util::Full;
65 use hyper::header::{CONTENT_TYPE, SERVER};
66
67 let html = crate::error::render_ignition_error(&error);
68
69 let res = hyper::Response::builder()
70 .status(error.status_code())
71 .header(CONTENT_TYPE, "text/html; charset=utf-8")
72 .header(SERVER, crate::response::SERVER_HEADER_VALUE)
73 .body(Full::new(Bytes::from(html)).map_err(|e| match e {}).boxed())
74 .unwrap();
75
76 Ok(res)
77 } else {
78 Ok(OxiditeResponse::from(error).into())
79 }
80 }
81 }
82 })
83 }
84}
85
86
87
88pub struct Server<S> {
89 service: S,
90}
91
92impl<S> Server<S>
93where
94 S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone + Send + Sync + 'static,
95 S::Future: Send + 'static,
96{
97 pub fn new(service: S) -> Self {
98 Self {
99 service,
100 }
101 }
102
103 pub async fn listen(self, addr: SocketAddr) -> Result<()> {
104 let listener = TcpListener::bind(addr).await?;
105 println!("Listening on http://{}", addr);
106
107 loop {
108 let (stream, _) = listener.accept().await?;
109 let io = TokioIo::new(stream);
110 let service = self.service.clone();
111
112 tokio::task::spawn(async move {
113 let service = BodyAdapter::new(service);
114 let hyper_service = TowerToHyperService::new(service);
115
116 if let Err(err) = http1::Builder::new()
117 .serve_connection(io, hyper_service)
118 .await
119 {
120 if let Some(service_err) = err.source().and_then(|e| e.downcast_ref::<Error>()) {
123 if service_err.is_server_error() {
125 eprintln!("Server error: {}", service_err);
126 }
127 } else {
129 let err_msg = err.to_string();
132 if !err_msg.contains("NotFound") && !err_msg.contains("connection closed") {
134 eprintln!("Connection error: {}", err);
135 }
136 }
137 }
138 });
139 }
140 }
141
142 #[cfg(feature = "http3")]
144 pub async fn listen_h3(self, addr: SocketAddr, cert_pem: &str, key_pem: &str) -> Result<()> {
145 use rustls::ServerConfig;
146 use rustls_pemfile::{certs, pkcs8_private_keys};
147 use std::io::Cursor;
148
149 let http1_addr = addr;
151 let http1_service = self.service.clone();
152
153 tokio::spawn(async move {
154 let listener = TcpListener::bind(http1_addr).await.unwrap();
155 println!("HTTP/1.1 server listening on http://{}", http1_addr);
156
157 loop {
158 let (stream, _) = listener.accept().await.unwrap();
159 let io = TokioIo::new(stream);
160 let service = http1_service.clone();
161
162 tokio::task::spawn(async move {
163 let service = BodyAdapter::new(service);
164 let hyper_service = TowerToHyperService::new(service);
165
166 if let Err(err) = http1::Builder::new()
167 .serve_connection(io, hyper_service)
168 .await
169 {
170 if let Some(service_err) = err.source().and_then(|e| e.downcast_ref::<Error>()) {
172 if service_err.is_server_error() {
173 eprintln!("HTTP/1.1 server error: {}", service_err);
174 }
175 } else {
176 let err_msg = err.to_string();
177 if !err_msg.contains("NotFound") && !err_msg.contains("connection closed") {
178 eprintln!("HTTP/1.1 connection error: {}", err);
179 }
180 }
181 }
182 });
183 }
184 });
185
186 let cert_chain = certs(&mut Cursor::new(cert_pem))
188 .collect::<std::result::Result<Vec<_>, _>>()
189 .map_err(|e| crate::error::Error::InternalServerError(e.to_string()))?;
190
191 let mut keys = pkcs8_private_keys(&mut Cursor::new(key_pem))
192 .collect::<std::result::Result<Vec<_>, _>>()?;
193
194 if keys.is_empty() {
195 return Err(crate::error::Error::InternalServerError("No private keys found".to_string()));
196 }
197
198 let tls_config = ServerConfig::builder()
199 .with_no_client_auth()
200 .with_single_cert(cert_chain, rustls::pki_types::PrivateKeyDer::Pkcs8(keys.remove(0)))
201 .map_err(|e| crate::error::Error::InternalServerError(e.to_string()))?;
202
203 let http3_server = Http3Server::new(self.service);
204 http3_server.listen(addr, tls_config).await?;
205
206 Ok(())
207 }
208}