1use std::net::SocketAddr;
4use tokio::net::TcpListener;
5use hyper::server::conn::{http1, http2};
6use hyper_util::rt::TokioIo;
7use hyper_util::service::TowerToHyperService;
8use tokio_rustls::rustls::ServerConfig;
9use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
10use tokio_rustls::TlsAcceptor;
11use std::sync::Arc;
12use std::fs::File;
13use std::io::BufReader;
14use crate::error::{Error, Result};
15use crate::types::{OxiditeRequest, OxiditeResponse};
16use tower_service::Service;
17
18
19use crate::server::BodyAdapter;
20
21pub struct TlsConfig {
23 pub cert_path: String,
24 pub key_path: String,
25}
26
27impl TlsConfig {
28 pub fn new(cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
29 Self {
30 cert_path: cert_path.into(),
31 key_path: key_path.into(),
32 }
33 }
34
35 pub fn load_config(&self) -> Result<ServerConfig> {
37 let certs = load_certs(&self.cert_path)?;
38 let key = load_private_key(&self.key_path)?;
39
40 Ok(ServerConfig::builder()
41 .with_no_client_auth()
42 .with_single_cert(certs, key)
43 .map_err(|e| Error::InternalServerError(e.to_string()))?)
44 }
45}
46
47fn load_certs(path: &str) -> Result<Vec<CertificateDer<'static>>> {
48 let file = File::open(path).map_err(|e| Error::InternalServerError(format!("Failed to open cert file: {}", e)))?;
49 let mut reader = BufReader::new(file);
50 rustls_pemfile::certs(&mut reader)
51 .map(|res| res.map_err(|e| Error::InternalServerError(format!("Failed to parse cert: {}", e))))
52 .collect::<Result<Vec<_>>>()
53}
54
55fn load_private_key(path: &str) -> Result<PrivateKeyDer<'static>> {
56 let file = File::open(path).map_err(|e| Error::InternalServerError(format!("Failed to open key file: {}", e)))?;
57 let mut reader = BufReader::new(file);
58
59 loop {
61 match rustls_pemfile::read_one(&mut reader).map_err(|e| Error::InternalServerError(format!("Failed to parse key: {}", e)))? {
62 Some(rustls_pemfile::Item::Pkcs1Key(key)) => return Ok(key.into()),
63 Some(rustls_pemfile::Item::Pkcs8Key(key)) => return Ok(key.into()),
64 Some(rustls_pemfile::Item::Sec1Key(key)) => return Ok(key.into()),
65 None => break,
66 _ => {} }
68 }
69
70 Err(Error::InternalServerError("No supported private key found".to_string()))
71}
72
73#[derive(Debug, Clone, Copy)]
75pub enum HttpVersion {
76 Http1,
77 Http2,
78 Auto, }
80
81pub struct SecureServer<S> {
83 service: S,
84 tls_config: Option<TlsConfig>,
85 http_version: HttpVersion,
86}
87
88impl<S> SecureServer<S>
89where
90 S: Service<OxiditeRequest, Response = OxiditeResponse, Error = Error> + Clone + Send + Sync + 'static,
91 S::Future: Send + 'static,
92{
93 pub fn new(service: S) -> Self {
94 Self {
95 service,
96 tls_config: None,
97 http_version: HttpVersion::Auto,
98 }
99 }
100
101 pub fn with_tls(mut self, tls_config: TlsConfig) -> Self {
103 self.tls_config = Some(tls_config);
104 self
105 }
106
107 pub fn with_http_version(mut self, version: HttpVersion) -> Self {
109 self.http_version = version;
110 self
111 }
112
113 pub async fn listen(self, addr: SocketAddr) -> Result<()> {
115 if let Some(tls_config) = self.tls_config {
116 Self::listen_https(addr, self.service, tls_config, self.http_version).await
117 } else {
118 Self::listen_http(addr, self.service).await
119 }
120 }
121
122 async fn listen_http(addr: SocketAddr, service: S) -> Result<()> {
124 let listener = TcpListener::bind(addr).await?;
125 println!("Listening on http://{}", addr);
126
127 loop {
128 let (stream, _) = listener.accept().await?;
129 let io = TokioIo::new(stream);
130 let service = service.clone();
131
132 tokio::task::spawn(async move {
133 let service = BodyAdapter::new(service);
134 let hyper_service = TowerToHyperService::new(service);
135
136 if let Err(err) = http1::Builder::new()
137 .serve_connection(io, hyper_service)
138 .await
139 {
140 eprintln!("Error serving connection: {:?}", err);
141 }
142 });
143 }
144 }
145
146 async fn listen_https(addr: SocketAddr, service: S, tls_config: TlsConfig, http_version: HttpVersion) -> Result<()> {
148 let server_config = tls_config.load_config()?;
149 let acceptor = TlsAcceptor::from(Arc::new(server_config));
150
151 let listener = TcpListener::bind(addr).await?;
152 println!("Listening on https://{}", addr);
153
154 loop {
155 let (stream, _) = listener.accept().await?;
156 let acceptor = acceptor.clone();
157 let service = service.clone();
158
159 tokio::task::spawn(async move {
160 match acceptor.accept(stream).await {
161 Ok(tls_stream) => {
162 let io = TokioIo::new(tls_stream);
163 let service = BodyAdapter::new(service);
164 let hyper_service = TowerToHyperService::new(service);
165
166 let result = match http_version {
167 HttpVersion::Http1 => {
168 http1::Builder::new()
169 .serve_connection(io, hyper_service)
170 .await
171 }
172 HttpVersion::Http2 => {
173 http2::Builder::new(TokioExecutor)
174 .serve_connection(io, hyper_service)
175 .await
176 }
177 HttpVersion::Auto => {
178 http1::Builder::new()
180 .serve_connection(io, hyper_service)
181 .await
182 }
183 };
184
185 if let Err(err) = result {
186 eprintln!("Error serving TLS connection: {:?}", err);
187 }
188 }
189 Err(err) => {
190 eprintln!("TLS accept error: {:?}", err);
191 }
192 }
193 });
194 }
195 }
196}
197
198#[derive(Clone)]
200struct TokioExecutor;
201
202impl<F> hyper::rt::Executor<F> for TokioExecutor
203where
204 F: std::future::Future + Send + 'static,
205 F::Output: Send + 'static,
206{
207 fn execute(&self, fut: F) {
208 tokio::task::spawn(fut);
209 }
210}