prosa_hyper/
server.rs

1//! Module to handle HTTP server
2
3use std::time::{Duration, SystemTime};
4
5use prosa::core::msg::{InternalMsg, Msg};
6
7use tokio::sync::oneshot;
8
9use tracing::{Level, Span, span};
10
11/// Adaptor for Hyper server processor
12pub mod adaptor;
13/// ProSA Hyper server processor
14pub mod proc;
15/// Hyper service definition
16pub(crate) mod service;
17
18/// Hyper processor
19#[derive(Debug)]
20pub struct HyperProcMsg<M>
21where
22    M: 'static
23        + std::marker::Send
24        + std::marker::Sync
25        + std::marker::Sized
26        + std::clone::Clone
27        + std::fmt::Debug
28        + prosa_utils::msg::tvf::Tvf
29        + std::default::Default,
30{
31    id: u64,
32    span: Span,
33    service: String,
34    data: M,
35    begin_time: SystemTime,
36    response_queue: oneshot::Sender<InternalMsg<M>>,
37}
38
39impl<M> HyperProcMsg<M>
40where
41    M: 'static
42        + std::marker::Send
43        + std::marker::Sync
44        + std::marker::Sized
45        + std::clone::Clone
46        + std::fmt::Debug
47        + prosa_utils::msg::tvf::Tvf
48        + std::default::Default,
49{
50    /// Create a new Hyper processor message
51    pub fn new(
52        service: String,
53        data: M,
54        response_queue: oneshot::Sender<InternalMsg<M>>,
55    ) -> HyperProcMsg<M> {
56        let span = span!(Level::INFO, "HyperProcMsg", service = service);
57        HyperProcMsg {
58            id: 0,
59            service,
60            span,
61            data,
62            begin_time: SystemTime::now(),
63            response_queue,
64        }
65    }
66
67    /// Set the ID of the Hyper processor
68    pub fn set_id(&mut self, id: u64) {
69        self.id = id;
70    }
71}
72
73impl<M> Msg<M> for HyperProcMsg<M>
74where
75    M: 'static
76        + std::marker::Send
77        + std::marker::Sync
78        + std::marker::Sized
79        + std::clone::Clone
80        + std::fmt::Debug
81        + prosa_utils::msg::tvf::Tvf
82        + std::default::Default,
83{
84    fn get_id(&self) -> u64 {
85        self.id
86    }
87
88    fn get_service(&self) -> &String {
89        &self.service
90    }
91
92    fn get_span(&self) -> &Span {
93        &self.span
94    }
95
96    fn get_span_mut(&mut self) -> &mut Span {
97        &mut self.span
98    }
99
100    fn enter_span(&self) -> span::Entered<'_> {
101        self.span.enter()
102    }
103
104    fn get_data(&self) -> &M {
105        &self.data
106    }
107
108    fn get_data_mut(&mut self) -> &mut M {
109        &mut self.data
110    }
111
112    fn elapsed(&self) -> Duration {
113        self.begin_time.elapsed().unwrap_or(Duration::new(0, 0))
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use bytes::Bytes;
120    use http_body_util::{Full, combinators::BoxBody};
121    use hyper::{Request, Response, StatusCode};
122    use openssl::{
123        asn1::{Asn1Integer, Asn1Time},
124        bn::{BigNum, MsbOption},
125        ec::{Asn1Flag, EcGroup, EcKey},
126        hash::MessageDigest,
127        nid::Nid,
128        pkey::PKey,
129        symm::Cipher,
130        x509::{X509, X509NameBuilder, extension::SubjectAlternativeName},
131    };
132    use prosa::{
133        core::{
134            adaptor::Adaptor,
135            error::ProcError,
136            main::{MainProc, MainRunnable as _},
137            proc::{Proc, ProcConfig as _},
138            settings::settings,
139        },
140        io::listener::ListenerSetting,
141    };
142    use prosa_utils::{
143        config::{
144            ConfigError, os_country,
145            ssl::{SslConfig, Store},
146        },
147        msg::simple_string_tvf::SimpleStringTvf,
148    };
149    use reqwest::Certificate;
150    use serde::Serialize;
151    use std::{
152        env,
153        fs::{self, File},
154        io::{Read as _, Write as _},
155        time::Duration,
156    };
157    use tokio::time;
158    use url::Url;
159
160    use crate::{
161        HyperResp,
162        server::{
163            adaptor::HyperServerAdaptor,
164            proc::{HyperServerProc, HyperServerSettings},
165        },
166    };
167
168    const WAIT_TIME: time::Duration = time::Duration::from_secs(5);
169
170    /// HTTP settings
171    #[settings]
172    #[derive(Default, Debug, Serialize)]
173    struct HttpTestSettings {
174        server: HyperServerSettings,
175    }
176
177    impl HttpTestSettings {
178        fn new(url: Url, server_ssl: Option<SslConfig>) -> Self {
179            let server = HyperServerSettings::new(
180                ListenerSetting::new(url.clone(), server_ssl),
181                Duration::from_secs(1),
182            );
183            HttpTestSettings {
184                server,
185                ..Default::default()
186            }
187        }
188    }
189
190    #[derive(Adaptor, Clone)]
191    struct ServerTestAdaptor {
192        // Nothing
193    }
194
195    impl<M> HyperServerAdaptor<M> for ServerTestAdaptor
196    where
197        M: 'static
198            + std::marker::Send
199            + std::marker::Sync
200            + std::marker::Sized
201            + std::clone::Clone
202            + std::fmt::Debug
203            + prosa_utils::msg::tvf::Tvf
204            + std::default::Default,
205    {
206        fn new(
207            _proc: &crate::server::proc::HyperServerProc<M>,
208            _prosa_name: &str,
209        ) -> Result<Self, Box<dyn ProcError + Send + Sync>>
210        where
211            Self: Sized,
212        {
213            Ok(ServerTestAdaptor {})
214        }
215
216        async fn process_http_request(&self, req: Request<hyper::body::Incoming>) -> HyperResp<M> {
217            let resp_msg = if req.version() == hyper::Version::HTTP_2 {
218                "Hello, H2 world"
219            } else {
220                "Hello, world"
221            };
222            let response = Response::builder()
223                .header(
224                    hyper::header::SERVER,
225                    <ServerTestAdaptor as HyperServerAdaptor<M>>::SERVER_HEADER,
226                )
227                .status(StatusCode::OK)
228                .body(BoxBody::new(Full::new(Bytes::from(resp_msg))))
229                .unwrap();
230
231            HyperResp::HttpResp(response)
232        }
233
234        fn process_srv_response(
235            &self,
236            _resp: &M,
237        ) -> hyper::Response<
238            http_body_util::combinators::BoxBody<bytes::Bytes, std::convert::Infallible>,
239        > {
240            unimplemented!()
241        }
242    }
243
244    async fn run_test(settings: HttpTestSettings, certificate: Option<Certificate>, http2: bool) {
245        let url = settings.server.listener.url.clone();
246
247        // Create bus and main processor
248        let (bus, main) = MainProc::<SimpleStringTvf>::create(&settings);
249
250        // Launch the main task
251        let main_task = main.run();
252
253        // Launch an HTTP server processor
254        let http_server_proc =
255            HyperServerProc::<SimpleStringTvf>::create(1, bus.clone(), settings.server);
256        Proc::<ServerTestAdaptor>::run(http_server_proc, String::from("HTTP_SERVER_PROC"));
257
258        // Wait for processor to start
259        std::thread::sleep(Duration::from_secs(1));
260
261        // Send request to the server with reqwest
262        let mut client_builder = reqwest::ClientBuilder::new()
263            .timeout(Duration::from_secs(WAIT_TIME.as_secs()))
264            .use_rustls_tls();
265        if let Some(cert) = certificate {
266            client_builder = client_builder.add_root_certificate(cert);
267        }
268        if http2 {
269            client_builder = client_builder.http2_prior_knowledge();
270        }
271        let client = client_builder.build().unwrap();
272        for _i in 0..20 {
273            let resp = client
274                .get(url.clone())
275                .send()
276                .await
277                .expect("Failed to send request");
278            assert_eq!(resp.status(), StatusCode::OK);
279            let server_header = resp.headers().get(hyper::header::SERVER).unwrap();
280            assert!(server_header.to_str().unwrap().starts_with("ProSA-Hyper/"));
281        }
282
283        bus.stop("ProSA HTTP client server unit test end".into())
284            .await
285            .unwrap();
286
287        // Wait on main task to end
288        main_task.await;
289    }
290
291    #[tokio::test]
292    async fn http_client_server() {
293        let test_settings =
294            HttpTestSettings::new(Url::parse("http://localhost:48080").unwrap(), None);
295
296        // Run a ProSA to test
297        run_test(test_settings, None, false).await;
298    }
299
300    /// Method to create private key and certificate for a server
301    fn create_server_cert(key_path: String, cert_path: String) -> Result<SslConfig, ConfigError> {
302        const PASSPHRASE: &str = "prosa_test";
303
304        let mut group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?;
305        group.set_asn1_flag(Asn1Flag::NAMED_CURVE);
306        let pkey = PKey::from_ec_key(EcKey::generate(&group)?)?;
307        let mut pkey_file =
308            File::create(key_path.clone()).map_err(|e| ConfigError::IoFile(key_path.clone(), e))?;
309        pkey_file
310            .write_all(&pkey.private_key_to_pem_pkcs8_passphrase(
311                Cipher::aes_256_cbc(),
312                PASSPHRASE.as_bytes(),
313            )?)
314            .map_err(|e| ConfigError::IoFile(key_path.clone(), e))?;
315
316        let mut cert = X509::builder()?;
317        cert.set_version(2)?;
318        cert.set_pubkey(&pkey)?;
319
320        let mut serial_bn = BigNum::new()?;
321        serial_bn.pseudo_rand(64, MsbOption::MAYBE_ZERO, true)?;
322        let serial_number = Asn1Integer::from_bn(&serial_bn)?;
323        cert.set_serial_number(&serial_number)?;
324
325        let begin_valid_time =
326            Asn1Time::from_unix(std::time::UNIX_EPOCH.elapsed().unwrap().as_secs() as i64 - 360)?;
327        cert.set_not_before(&begin_valid_time)?;
328        let end_valid_time = Asn1Time::days_from_now(3)?; // 3 days from now
329        cert.set_not_after(&end_valid_time)?;
330
331        let mut x509_name = X509NameBuilder::new()?;
332        if let Some(cn) = os_country() {
333            x509_name.append_entry_by_text("C", cn.as_str())?;
334        }
335        x509_name.append_entry_by_text("CN", "ProSA-hyper")?;
336        let x509_name = x509_name.build();
337        cert.set_subject_name(&x509_name)?;
338        cert.set_issuer_name(&x509_name)?;
339
340        let mut subject_alternative_name = SubjectAlternativeName::new();
341        let x509_extension = subject_alternative_name
342            .dns("localhost")
343            .build(&cert.x509v3_context(None, None))?;
344        cert.append_extension2(&x509_extension)?;
345
346        cert.sign(&pkey, MessageDigest::sha256())?;
347
348        let mut cert_file = File::create(cert_path.clone())
349            .map_err(|e| ConfigError::IoFile(cert_path.clone(), e))?;
350        cert_file
351            .write_all(&cert.build().to_pem()?)
352            .map_err(|e| ConfigError::IoFile(cert_path.clone(), e))?;
353
354        Ok(SslConfig::new_cert_key(
355            cert_path,
356            key_path,
357            Some(PASSPHRASE.into()),
358        ))
359    }
360
361    #[tokio::test]
362    async fn https_client_server() {
363        const PROSA_HTTPS_TEST_DIR_NAME: &str = "ProSA_HTTPS";
364        let prosa_temp_dir = env::temp_dir().join(PROSA_HTTPS_TEST_DIR_NAME);
365
366        let _ = fs::remove_dir_all(&prosa_temp_dir);
367        fs::create_dir_all(&prosa_temp_dir).unwrap();
368
369        let key_path = prosa_temp_dir.join("prosa_https.key");
370        let cert_path = prosa_temp_dir.join("prosa_https.pem");
371        let server_ssl_config = create_server_cert(
372            key_path.as_os_str().to_str().unwrap().into(),
373            cert_path.as_os_str().to_str().unwrap().into(),
374        )
375        .unwrap();
376
377        let mut buf = Vec::new();
378        File::open(cert_path.as_os_str().to_str().unwrap())
379            .unwrap()
380            .read_to_end(&mut buf)
381            .unwrap();
382        let client_cert = reqwest::Certificate::from_pem(&buf).unwrap();
383
384        let client_ssl_store = Store::File {
385            path: format!("{}/", prosa_temp_dir.as_os_str().to_str().unwrap()),
386        };
387        let mut client_ssl_config = SslConfig::default();
388        client_ssl_config.set_store(client_ssl_store);
389
390        let test_settings = HttpTestSettings::new(
391            Url::parse("https://localhost:48443").unwrap(),
392            Some(server_ssl_config),
393        );
394
395        // Run a ProSA to test
396        run_test(test_settings, Some(client_cert), false).await;
397    }
398
399    #[tokio::test]
400    async fn h2_client_server() {
401        const PROSA_H2_TEST_DIR_NAME: &str = "ProSA_H2";
402        let prosa_temp_dir = env::temp_dir().join(PROSA_H2_TEST_DIR_NAME);
403
404        let _ = fs::remove_dir_all(&prosa_temp_dir);
405        fs::create_dir_all(&prosa_temp_dir).unwrap();
406
407        let key_path = prosa_temp_dir.join("prosa_h2.key");
408        let cert_path = prosa_temp_dir.join("prosa_h2.pem");
409        let mut server_ssl_config = create_server_cert(
410            key_path.as_os_str().to_str().unwrap().into(),
411            cert_path.as_os_str().to_str().unwrap().into(),
412        )
413        .unwrap();
414        // Need to set the ALPN for server because of inline configuration @see TargetSetting::new
415        server_ssl_config.set_alpn(vec!["h2".into()]);
416
417        let mut buf = Vec::new();
418        File::open(cert_path.as_os_str().to_str().unwrap())
419            .unwrap()
420            .read_to_end(&mut buf)
421            .unwrap();
422        let client_cert = reqwest::Certificate::from_pem(&buf).unwrap();
423
424        let test_settings = HttpTestSettings::new(
425            Url::parse("https://localhost:49443").unwrap(),
426            Some(server_ssl_config),
427        );
428
429        // Run a ProSA to test
430        run_test(test_settings, Some(client_cert), true).await;
431    }
432}