1use std::net::SocketAddr;
2use tokio::net::TcpListener;
3use hyper::server::conn::http1;
4use hyper_util::rt::TokioIo;
5use hyper_util::service::TowerToHyperService;
6use crate::error::{Error, Result};
7use crate::types::{OxiditeRequest, OxiditeResponse};
8use crate::router::CorsConfig;
9use tower_service::Service;
10use std::error::Error as StdError;
11use http::HeaderValue;
12
13use http_body_util::BodyExt;
14
15use std::task::{Context, Poll};
16
17#[cfg(feature = "http3")]
18pub mod http3_server;
19
20#[cfg(feature = "http3")]
21pub use http3_server::Http3Server;
22
23#[derive(Clone)]
25pub struct BodyAdapter<S> {
26 inner: S,
27 cors_config: Option<CorsConfig>,
28}
29
30impl<S> BodyAdapter<S> {
31 pub fn new(service: S) -> Self {
32 Self {
33 inner: service,
34 cors_config: None,
35 }
36 }
37
38 pub fn with_cors(mut self, cors_config: Option<CorsConfig>) -> Self {
39 self.cors_config = cors_config;
40 self
41 }
42
43 fn add_cors_to_response(&self, res: &mut hyper::Response<crate::types::BoxBody>) {
45 if let Some(cors) = &self.cors_config {
46 let headers = res.headers_mut();
47
48 if let Some(origin) = cors.allowed_origins.first() {
49 if let Ok(val) = HeaderValue::from_str(origin) {
50 headers.insert(http::header::ACCESS_CONTROL_ALLOW_ORIGIN, val);
51 }
52 }
53
54 if !cors.allowed_methods.is_empty() {
56 let methods = cors.allowed_methods.join(", ");
57 if let Ok(val) = HeaderValue::from_str(&methods) {
58 headers.insert(http::header::ACCESS_CONTROL_ALLOW_METHODS, val);
59 }
60 }
61
62 if !cors.allowed_headers.is_empty() {
64 let headers_list = cors.allowed_headers.join(", ");
65 if let Ok(val) = HeaderValue::from_str(&headers_list) {
66 headers.insert(http::header::ACCESS_CONTROL_ALLOW_HEADERS, val);
67 }
68 }
69
70 if cors.allow_credentials {
71 headers.insert(http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"));
72 }
73
74 if let Ok(val) = HeaderValue::from_str(&cors.max_age.to_string()) {
75 headers.insert(http::header::ACCESS_CONTROL_MAX_AGE, val);
76 }
77 }
78 }
79}
80
81
82
83use std::pin::Pin;
84
85impl<S> Service<hyper::Request<hyper::body::Incoming>> for BodyAdapter<S>
86where
87 S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone + Send + 'static,
88 S::Future: Send + 'static,
89{
90 type Response = hyper::Response<crate::types::BoxBody>;
91 type Error = Error;
92 type Future = Pin<Box<dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
93
94 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
95 self.inner.poll_ready(cx)
96 }
97
98 fn call(&mut self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
99 let accepts_html = req.headers().get(hyper::header::ACCEPT)
100 .map(|h| h.to_str().unwrap_or("").contains("text/html"))
101 .unwrap_or(false);
102
103 let req = req.map(|b| b.map_err(|e| e.into()).boxed());
104 let fut = self.inner.call(req);
105 let cors = self.cors_config.clone();
106
107 Box::pin(async move {
108 match fut.await {
109 Ok(response) => {
110 let mut hyper_response: hyper::Response<crate::types::BoxBody> = response.into();
111 let adapter = BodyAdapter { inner: (), cors_config: cors };
113 adapter.add_cors_to_response(&mut hyper_response);
114 Ok(hyper_response)
115 },
116 Err(error) => {
117 let env = std::env::var("OXIDITE_ENV").unwrap_or_else(|_| "development".to_string());
118
119 if env == "development" && accepts_html && error.is_server_error() {
120 use bytes::Bytes;
121 use http_body_util::Full;
122 use hyper::header::{CONTENT_TYPE, SERVER};
123
124 let html = crate::error::render_ignition_error(&error);
125
126 let mut res = hyper::Response::builder()
127 .status(error.status_code())
128 .header(CONTENT_TYPE, "text/html; charset=utf-8")
129 .header(SERVER, crate::response::SERVER_HEADER_VALUE)
130 .body(Full::new(Bytes::from(html)).map_err(|e| match e {}).boxed())
131 .unwrap();
132
133 let adapter = BodyAdapter { inner: (), cors_config: cors };
135 adapter.add_cors_to_response(&mut res);
136 Ok(res)
137 } else {
138 let mut error_response: hyper::Response<crate::types::BoxBody> = OxiditeResponse::from(error).into();
139 let adapter = BodyAdapter { inner: (), cors_config: cors };
141 adapter.add_cors_to_response(&mut error_response);
142 Ok(error_response)
143 }
144 }
145 }
146 })
147 }
148}
149
150
151
152pub struct Server<S> {
153 service: S,
154 addr: Option<SocketAddr>,
155 cors_config: Option<CorsConfig>,
156}
157
158impl<S> Server<S>
159where
160 S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone + Send + Sync + 'static,
161 S::Future: Send + 'static,
162{
163 pub fn new(service: S) -> Self {
164 Self {
165 service,
166 addr: None,
167 cors_config: None,
168 }
169 }
170
171 pub fn bind(mut self, addr: SocketAddr) -> Self {
172 self.addr = Some(addr);
173 self
174 }
175
176 pub fn with_cors(mut self, cors_config: CorsConfig) -> Self {
178 self.cors_config = Some(cors_config);
179 self
180 }
181
182 pub async fn run(self) -> Result<()> {
183 let addr = self.addr.unwrap_or_else(|| "127.0.0.1:3000".parse().unwrap());
184 self.listen(addr).await
185 }
186
187 pub async fn listen(self, addr: SocketAddr) -> Result<()> {
188 let listener = TcpListener::bind(addr).await?;
189 println!("Listening on http://{}", addr);
190
191 let cors_config = self.cors_config.clone();
192
193 loop {
194 let (stream, _) = listener.accept().await?;
195 let io = TokioIo::new(stream);
196 let service = self.service.clone();
197 let cors = cors_config.clone();
198
199 tokio::task::spawn(async move {
200 let service = BodyAdapter::new(service).with_cors(cors);
201 let hyper_service = TowerToHyperService::new(service);
202
203 if let Err(err) = http1::Builder::new()
204 .serve_connection(io, hyper_service)
205 .await
206 {
207 if let Some(service_err) = err.source().and_then(|e| e.downcast_ref::<Error>()) {
210 if service_err.is_server_error() {
212 eprintln!("Server error: {}", service_err);
213 }
214 } else {
216 let err_msg = err.to_string();
219 if !err_msg.contains("NotFound") && !err_msg.contains("connection closed") {
221 eprintln!("Connection error: {}", err);
222 }
223 }
224 }
225 });
226 }
227 }
228
229 #[cfg(feature = "http3")]
231 pub async fn listen_h3(self, addr: SocketAddr, cert_pem: &str, key_pem: &str) -> Result<()> {
232 use rustls::ServerConfig;
233 use rustls_pemfile::{certs, pkcs8_private_keys};
234 use std::io::Cursor;
235
236 let cors_config = self.cors_config.clone();
237
238 let http1_addr = addr;
240 let http1_service = self.service.clone();
241 let http1_cors = cors_config.clone();
242
243 tokio::spawn(async move {
244 let listener = TcpListener::bind(http1_addr).await.unwrap();
245 println!("HTTP/1.1 server listening on http://{}", http1_addr);
246
247 loop {
248 let (stream, _) = listener.accept().await.unwrap();
249 let io = TokioIo::new(stream);
250 let service = http1_service.clone();
251 let cors = http1_cors.clone();
252
253 tokio::task::spawn(async move {
254 let service = BodyAdapter::new(service).with_cors(cors);
255 let hyper_service = TowerToHyperService::new(service);
256
257 if let Err(err) = http1::Builder::new()
258 .serve_connection(io, hyper_service)
259 .await
260 {
261 if let Some(service_err) = err.source().and_then(|e| e.downcast_ref::<Error>()) {
263 if service_err.is_server_error() {
264 eprintln!("HTTP/1.1 server error: {}", service_err);
265 }
266 } else {
267 let err_msg = err.to_string();
268 if !err_msg.contains("NotFound") && !err_msg.contains("connection closed") {
269 eprintln!("HTTP/1.1 connection error: {}", err);
270 }
271 }
272 }
273 });
274 }
275 });
276
277 let cert_chain = certs(&mut Cursor::new(cert_pem))
279 .collect::<std::result::Result<Vec<_>, _>>()
280 .map_err(|e| crate::error::Error::InternalServerError(e.to_string()))?;
281
282 let mut keys = pkcs8_private_keys(&mut Cursor::new(key_pem))
283 .collect::<std::result::Result<Vec<_>, _>>()?;
284
285 if keys.is_empty() {
286 return Err(crate::error::Error::InternalServerError("No private keys found".to_string()));
287 }
288
289 let tls_config = ServerConfig::builder()
290 .with_no_client_auth()
291 .with_single_cert(cert_chain, rustls::pki_types::PrivateKeyDer::Pkcs8(keys.remove(0)))
292 .map_err(|e| crate::error::Error::InternalServerError(e.to_string()))?;
293
294 let http3_server = Http3Server::new(self.service);
295 http3_server.listen(addr, tls_config).await?;
296
297 Ok(())
298 }
299}