Skip to main content

oxidite_core/
server.rs

1use std::net::SocketAddr;
2use tokio::net::TcpListener;
3use hyper::server::conn::http1;
4use hyper_util::rt::TokioIo;
5use hyper_util::service::TowerToHyperService;
6use crate::error::{Error, Result};
7use crate::types::{OxiditeRequest, OxiditeResponse};
8use crate::router::CorsConfig;
9use tower_service::Service;
10use std::error::Error as StdError;
11use http::HeaderValue;
12
13use http_body_util::BodyExt;
14
15use std::task::{Context, Poll};
16
17#[cfg(feature = "http3")]
18pub mod http3_server;
19
20#[cfg(feature = "http3")]
21pub use http3_server::Http3Server;
22
23/// Adapter to convert hyper::Request<Incoming> to OxiditeRequest
24#[derive(Clone)]
25pub struct BodyAdapter<S> {
26    inner: S,
27    cors_config: Option<CorsConfig>,
28}
29
30impl<S> BodyAdapter<S> {
31    pub fn new(service: S) -> Self {
32        Self {
33            inner: service,
34            cors_config: None,
35        }
36    }
37
38    pub fn with_cors(mut self, cors_config: Option<CorsConfig>) -> Self {
39        self.cors_config = cors_config;
40        self
41    }
42
43    /// Add CORS headers to a hyper response
44    fn add_cors_to_response(&self, res: &mut hyper::Response<crate::types::BoxBody>) {
45        if let Some(cors) = &self.cors_config {
46            let headers = res.headers_mut();
47            
48            if let Some(origin) = cors.allowed_origins.first() {
49                if let Ok(val) = HeaderValue::from_str(origin) {
50                    headers.insert(http::header::ACCESS_CONTROL_ALLOW_ORIGIN, val);
51                }
52            }
53            
54            // Add Access-Control-Allow-Methods (join all methods with ", ")
55            if !cors.allowed_methods.is_empty() {
56                let methods = cors.allowed_methods.join(", ");
57                if let Ok(val) = HeaderValue::from_str(&methods) {
58                    headers.insert(http::header::ACCESS_CONTROL_ALLOW_METHODS, val);
59                }
60            }
61            
62            // Add Access-Control-Allow-Headers (join all headers with ", ")
63            if !cors.allowed_headers.is_empty() {
64                let headers_list = cors.allowed_headers.join(", ");
65                if let Ok(val) = HeaderValue::from_str(&headers_list) {
66                    headers.insert(http::header::ACCESS_CONTROL_ALLOW_HEADERS, val);
67                }
68            }
69            
70            if cors.allow_credentials {
71                headers.insert(http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"));
72            }
73            
74            if let Ok(val) = HeaderValue::from_str(&cors.max_age.to_string()) {
75                headers.insert(http::header::ACCESS_CONTROL_MAX_AGE, val);
76            }
77        }
78    }
79}
80
81
82
83use std::pin::Pin;
84
85impl<S> Service<hyper::Request<hyper::body::Incoming>> for BodyAdapter<S>
86where
87    S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone + Send + 'static,
88    S::Future: Send + 'static,
89{
90    type Response = hyper::Response<crate::types::BoxBody>;
91    type Error = Error;
92    type Future = Pin<Box<dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
93
94    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
95        self.inner.poll_ready(cx)
96    }
97
98    fn call(&mut self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
99        let accepts_html = req.headers().get(hyper::header::ACCEPT)
100            .map(|h| h.to_str().unwrap_or("").contains("text/html"))
101            .unwrap_or(false);
102            
103        let req = req.map(|b| b.map_err(|e| e.into()).boxed());
104        let fut = self.inner.call(req);
105        let cors = self.cors_config.clone();
106        
107        Box::pin(async move {
108            match fut.await {
109                Ok(response) => {
110                    let mut hyper_response: hyper::Response<crate::types::BoxBody> = response.into();
111                    // Add CORS headers to successful responses
112                    let adapter = BodyAdapter { inner: (), cors_config: cors };
113                    adapter.add_cors_to_response(&mut hyper_response);
114                    Ok(hyper_response)
115                },
116                Err(error) => {
117                    let env = std::env::var("OXIDITE_ENV").unwrap_or_else(|_| "development".to_string());
118                    
119                    if env == "development" && accepts_html && error.is_server_error() {
120                        use bytes::Bytes;
121                        use http_body_util::Full;
122                        use hyper::header::{CONTENT_TYPE, SERVER};
123                        
124                        let html = crate::error::render_ignition_error(&error);
125                        
126                        let mut res = hyper::Response::builder()
127                            .status(error.status_code())
128                            .header(CONTENT_TYPE, "text/html; charset=utf-8")
129                            .header(SERVER, crate::response::SERVER_HEADER_VALUE)
130                            .body(Full::new(Bytes::from(html)).map_err(|e| match e {}).boxed())
131                            .unwrap();
132                        
133                        // Add CORS headers to error responses too
134                        let adapter = BodyAdapter { inner: (), cors_config: cors };
135                        adapter.add_cors_to_response(&mut res);
136                        Ok(res)
137                    } else {
138                        let mut error_response: hyper::Response<crate::types::BoxBody> = OxiditeResponse::from(error).into();
139                        // Add CORS headers to error responses
140                        let adapter = BodyAdapter { inner: (), cors_config: cors };
141                        adapter.add_cors_to_response(&mut error_response);
142                        Ok(error_response)
143                    }
144                }
145            }
146        })
147    }
148}
149
150
151
152pub struct Server<S> {
153    service: S,
154    addr: Option<SocketAddr>,
155    cors_config: Option<CorsConfig>,
156}
157
158impl<S> Server<S>
159where
160    S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone + Send + Sync + 'static,
161    S::Future: Send + 'static,
162{
163    pub fn new(service: S) -> Self {
164        Self {
165            service,
166            addr: None,
167            cors_config: None,
168        }
169    }
170
171    pub fn bind(mut self, addr: SocketAddr) -> Self {
172        self.addr = Some(addr);
173        self
174    }
175
176    /// Configure CORS for the server (applies to all responses including errors)
177    pub fn with_cors(mut self, cors_config: CorsConfig) -> Self {
178        self.cors_config = Some(cors_config);
179        self
180    }
181
182    pub async fn run(self) -> Result<()> {
183        let addr = self.addr.unwrap_or_else(|| "127.0.0.1:3000".parse().unwrap());
184        self.listen(addr).await
185    }
186
187    pub async fn listen(self, addr: SocketAddr) -> Result<()> {
188        let listener = TcpListener::bind(addr).await?;
189        println!("Listening on http://{}", addr);
190
191        let cors_config = self.cors_config.clone();
192
193        loop {
194            let (stream, _) = listener.accept().await?;
195            let io = TokioIo::new(stream);
196            let service = self.service.clone();
197            let cors = cors_config.clone();
198
199            tokio::task::spawn(async move {
200                let service = BodyAdapter::new(service).with_cors(cors);
201                let hyper_service = TowerToHyperService::new(service);
202                
203                if let Err(err) = http1::Builder::new()
204                    .serve_connection(io, hyper_service)
205                    .await
206                {
207                    // Only log actual server errors, not client errors like 404
208                    // Check if this is a user service error that we can inspect
209                    if let Some(service_err) = err.source().and_then(|e| e.downcast_ref::<Error>()) {
210                        // Only log if it's a server error (5xx), not a client error (4xx)
211                        if service_err.is_server_error() {
212                            eprintln!("Server error: {}", service_err);
213                        }
214                        // Client errors (404, etc.) are silently handled - they're expected
215                    } else {
216                        // For non-service errors (connection issues, etc.), log them
217                        // but only if they're not common expected errors
218                        let err_msg = err.to_string();
219                        // Don't log if it's just a client disconnecting or similar
220                        if !err_msg.contains("NotFound") && !err_msg.contains("connection closed") {
221                            eprintln!("Connection error: {}", err);
222                        }
223                    }
224                }
225            });
226        }
227    }
228
229    /// Listen with both HTTP/1.1 and HTTP/3 support
230    #[cfg(feature = "http3")]
231    pub async fn listen_h3(self, addr: SocketAddr, cert_pem: &str, key_pem: &str) -> Result<()> {
232        use rustls::ServerConfig;  
233        use rustls_pemfile::{certs, pkcs8_private_keys};
234        use std::io::Cursor;
235        
236        let cors_config = self.cors_config.clone();
237        
238        // Setup HTTP/1.1 server in background
239        let http1_addr = addr;
240        let http1_service = self.service.clone();
241        let http1_cors = cors_config.clone();
242        
243        tokio::spawn(async move {
244            let listener = TcpListener::bind(http1_addr).await.unwrap();
245            println!("HTTP/1.1 server listening on http://{}", http1_addr);
246            
247            loop {
248                let (stream, _) = listener.accept().await.unwrap();
249                let io = TokioIo::new(stream);
250                let service = http1_service.clone();
251                let cors = http1_cors.clone();
252                
253                tokio::task::spawn(async move {
254                    let service = BodyAdapter::new(service).with_cors(cors);
255                    let hyper_service = TowerToHyperService::new(service);
256                    
257                    if let Err(err) = http1::Builder::new()
258                        .serve_connection(io, hyper_service)
259                        .await
260                    {
261                        // Only log server errors, not client errors
262                        if let Some(service_err) = err.source().and_then(|e| e.downcast_ref::<Error>()) {
263                            if service_err.is_server_error() {
264                                eprintln!("HTTP/1.1 server error: {}", service_err);
265                            }
266                        } else {
267                            let err_msg = err.to_string();
268                            if !err_msg.contains("NotFound") && !err_msg.contains("connection closed") {
269                                eprintln!("HTTP/1.1 connection error: {}", err);
270                            }
271                        }
272                    }
273                });
274            }
275        });
276        
277        // Setup HTTP/3 server
278        let cert_chain = certs(&mut Cursor::new(cert_pem))
279            .collect::<std::result::Result<Vec<_>, _>>()
280            .map_err(|e| crate::error::Error::InternalServerError(e.to_string()))?;
281        
282        let mut keys = pkcs8_private_keys(&mut Cursor::new(key_pem))
283            .collect::<std::result::Result<Vec<_>, _>>()?;
284        
285        if keys.is_empty() {
286            return Err(crate::error::Error::InternalServerError("No private keys found".to_string()));
287        }
288        
289        let tls_config = ServerConfig::builder()
290            .with_no_client_auth()
291            .with_single_cert(cert_chain, rustls::pki_types::PrivateKeyDer::Pkcs8(keys.remove(0)))
292            .map_err(|e| crate::error::Error::InternalServerError(e.to_string()))?;
293        
294        let http3_server = Http3Server::new(self.service);
295        http3_server.listen(addr, tls_config).await?;
296        
297        Ok(())
298    }
299}