prosa_hyper/
server.rs

1//! Module to handle HTTP server
2
3use std::time::{Duration, SystemTime};
4
5use prosa::core::{
6    error::BusError,
7    msg::{InternalMsg, Msg},
8};
9
10use tokio::sync::oneshot;
11
12use tracing::{Level, Span, span};
13
14/// Adaptor for Hyper server processor
15pub mod adaptor;
16/// ProSA Hyper server processor
17pub mod proc;
18/// Hyper service definition
19pub(crate) mod service;
20
21/// Hyper processor
22#[derive(Debug)]
23pub struct HyperProcMsg<M>
24where
25    M: 'static
26        + std::marker::Send
27        + std::marker::Sync
28        + std::marker::Sized
29        + std::clone::Clone
30        + std::fmt::Debug
31        + prosa_utils::msg::tvf::Tvf
32        + std::default::Default,
33{
34    id: u64,
35    span: Span,
36    service: String,
37    data: Option<M>,
38    begin_time: SystemTime,
39    response_queue: Option<oneshot::Sender<InternalMsg<M>>>,
40}
41
42impl<M> HyperProcMsg<M>
43where
44    M: 'static
45        + std::marker::Send
46        + std::marker::Sync
47        + std::marker::Sized
48        + std::clone::Clone
49        + std::fmt::Debug
50        + prosa_utils::msg::tvf::Tvf
51        + std::default::Default,
52{
53    /// Create a new Hyper processor message
54    pub fn new(
55        service: String,
56        data: M,
57        response_queue: oneshot::Sender<InternalMsg<M>>,
58    ) -> HyperProcMsg<M> {
59        let span = span!(Level::INFO, "HyperProcMsg", service = service);
60        HyperProcMsg {
61            id: 0,
62            service,
63            span,
64            data: Some(data),
65            begin_time: SystemTime::now(),
66            response_queue: Some(response_queue),
67        }
68    }
69
70    /// Set the ID of the Hyper processor
71    pub fn set_id(&mut self, id: u64) {
72        self.id = id;
73    }
74
75    /// Get the response queue to respond to the message
76    pub fn get_response_queue(&mut self) -> Result<oneshot::Sender<InternalMsg<M>>, BusError> {
77        self.response_queue
78            .take()
79            .ok_or(BusError::InternalQueue("HyperProcMsg".to_string()))
80    }
81}
82
83impl<M> Msg<M> for HyperProcMsg<M>
84where
85    M: 'static
86        + std::marker::Send
87        + std::marker::Sync
88        + std::marker::Sized
89        + std::clone::Clone
90        + std::fmt::Debug
91        + prosa_utils::msg::tvf::Tvf
92        + std::default::Default,
93{
94    fn get_id(&self) -> u64 {
95        self.id
96    }
97
98    fn get_service(&self) -> &String {
99        &self.service
100    }
101
102    fn get_span(&self) -> &Span {
103        &self.span
104    }
105
106    fn get_span_mut(&mut self) -> &mut Span {
107        &mut self.span
108    }
109
110    fn enter_span(&self) -> span::Entered<'_> {
111        self.span.enter()
112    }
113
114    fn get_data(&self) -> Result<&M, BusError> {
115        self.data.as_ref().ok_or(BusError::NoData)
116    }
117
118    fn get_data_mut(&mut self) -> Result<&mut M, BusError> {
119        self.data.as_mut().ok_or(BusError::NoData)
120    }
121
122    fn elapsed(&self) -> Duration {
123        self.begin_time.elapsed().unwrap_or(Duration::new(0, 0))
124    }
125
126    fn take_data(&mut self) -> Option<M> {
127        self.data.take()
128    }
129
130    fn take_data_if<P>(&mut self, predicate: P) -> Option<M>
131    where
132        P: FnOnce(&mut M) -> bool,
133    {
134        self.data.take_if(predicate)
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use bytes::Bytes;
141    use http_body_util::{Full, combinators::BoxBody};
142    use hyper::{Request, Response, StatusCode};
143    use openssl::{
144        asn1::{Asn1Integer, Asn1Time},
145        bn::{BigNum, MsbOption},
146        ec::{Asn1Flag, EcGroup, EcKey},
147        hash::MessageDigest,
148        nid::Nid,
149        pkey::PKey,
150        symm::Cipher,
151        x509::{X509, X509NameBuilder, extension::SubjectAlternativeName},
152    };
153    use prosa::{
154        core::{
155            adaptor::Adaptor,
156            error::ProcError,
157            main::{MainProc, MainRunnable as _},
158            proc::{Proc, ProcConfig as _},
159            settings::settings,
160        },
161        io::listener::ListenerSetting,
162    };
163    use prosa_utils::{
164        config::{
165            ConfigError, os_country,
166            ssl::{SslConfig, Store},
167        },
168        msg::simple_string_tvf::SimpleStringTvf,
169    };
170    use reqwest::Certificate;
171    use serde::Serialize;
172    use std::{
173        env,
174        fs::{self, File},
175        io::{Read as _, Write as _},
176        time::Duration,
177    };
178    use tokio::time;
179    use url::Url;
180
181    use crate::{
182        HyperResp,
183        server::{
184            adaptor::HyperServerAdaptor,
185            proc::{HyperServerProc, HyperServerSettings},
186        },
187    };
188
189    const WAIT_TIME: time::Duration = time::Duration::from_secs(5);
190
191    /// HTTP settings
192    #[settings]
193    #[derive(Default, Debug, Serialize)]
194    struct HttpTestSettings {
195        server: HyperServerSettings,
196    }
197
198    impl HttpTestSettings {
199        fn new(url: Url, server_ssl: Option<SslConfig>) -> Self {
200            let server = HyperServerSettings::new(
201                ListenerSetting::new(url.clone(), server_ssl),
202                Duration::from_secs(1),
203            );
204            HttpTestSettings {
205                server,
206                ..Default::default()
207            }
208        }
209    }
210
211    #[derive(Adaptor, Clone)]
212    struct ServerTestAdaptor {
213        // Nothing
214    }
215
216    impl<M> HyperServerAdaptor<M> for ServerTestAdaptor
217    where
218        M: 'static
219            + std::marker::Send
220            + std::marker::Sync
221            + std::marker::Sized
222            + std::clone::Clone
223            + std::fmt::Debug
224            + prosa_utils::msg::tvf::Tvf
225            + std::default::Default,
226    {
227        fn new(
228            _proc: &crate::server::proc::HyperServerProc<M>,
229            _prosa_name: &str,
230        ) -> Result<Self, Box<dyn ProcError + Send + Sync>>
231        where
232            Self: Sized,
233        {
234            Ok(ServerTestAdaptor {})
235        }
236
237        async fn process_http_request(&self, req: Request<hyper::body::Incoming>) -> HyperResp<M> {
238            let resp_msg = if req.version() == hyper::Version::HTTP_2 {
239                "Hello, H2 world"
240            } else {
241                "Hello, world"
242            };
243            let response = Response::builder()
244                .header(
245                    hyper::header::SERVER,
246                    <ServerTestAdaptor as HyperServerAdaptor<M>>::SERVER_HEADER,
247                )
248                .status(StatusCode::OK)
249                .body(BoxBody::new(Full::new(Bytes::from(resp_msg))))
250                .unwrap();
251
252            HyperResp::HttpResp(response)
253        }
254
255        fn process_srv_response(
256            &self,
257            _resp: M,
258        ) -> hyper::Response<
259            http_body_util::combinators::BoxBody<bytes::Bytes, std::convert::Infallible>,
260        > {
261            unimplemented!()
262        }
263    }
264
265    async fn run_test(settings: HttpTestSettings, certificate: Option<Certificate>, http2: bool) {
266        let url = settings.server.listener.url.clone();
267
268        // Create bus and main processor
269        let (bus, main) = MainProc::<SimpleStringTvf>::create(&settings);
270
271        // Launch the main task
272        let main_task = main.run();
273
274        // Launch an HTTP server processor
275        let http_server_proc =
276            HyperServerProc::<SimpleStringTvf>::create(1, bus.clone(), settings.server);
277        Proc::<ServerTestAdaptor>::run(http_server_proc, String::from("HTTP_SERVER_PROC"));
278
279        // Wait for processor to start
280        std::thread::sleep(Duration::from_secs(1));
281
282        // Send request to the server with reqwest
283        let mut client_builder = reqwest::ClientBuilder::new()
284            .timeout(Duration::from_secs(WAIT_TIME.as_secs()))
285            .use_rustls_tls();
286        if let Some(cert) = certificate {
287            client_builder = client_builder.add_root_certificate(cert);
288        }
289        if http2 {
290            client_builder = client_builder.http2_prior_knowledge();
291        }
292        let client = client_builder.build().unwrap();
293        for _i in 0..20 {
294            let resp = client
295                .get(url.clone())
296                .send()
297                .await
298                .expect("Failed to send request");
299            assert_eq!(resp.status(), StatusCode::OK);
300            let server_header = resp.headers().get(hyper::header::SERVER).unwrap();
301            assert!(server_header.to_str().unwrap().starts_with("ProSA-Hyper/"));
302        }
303
304        bus.stop("ProSA HTTP client server unit test end".into())
305            .await
306            .unwrap();
307
308        // Wait on main task to end
309        main_task.await;
310    }
311
312    #[tokio::test]
313    async fn http_client_server() {
314        let test_settings =
315            HttpTestSettings::new(Url::parse("http://localhost:48080").unwrap(), None);
316
317        // Run a ProSA to test
318        run_test(test_settings, None, false).await;
319    }
320
321    /// Method to create private key and certificate for a server
322    fn create_server_cert(key_path: String, cert_path: String) -> Result<SslConfig, ConfigError> {
323        const PASSPHRASE: &str = "prosa_test";
324
325        let mut group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?;
326        group.set_asn1_flag(Asn1Flag::NAMED_CURVE);
327        let pkey = PKey::from_ec_key(EcKey::generate(&group)?)?;
328        let mut pkey_file =
329            File::create(key_path.clone()).map_err(|e| ConfigError::IoFile(key_path.clone(), e))?;
330        pkey_file
331            .write_all(&pkey.private_key_to_pem_pkcs8_passphrase(
332                Cipher::aes_256_cbc(),
333                PASSPHRASE.as_bytes(),
334            )?)
335            .map_err(|e| ConfigError::IoFile(key_path.clone(), e))?;
336
337        let mut cert = X509::builder()?;
338        cert.set_version(2)?;
339        cert.set_pubkey(&pkey)?;
340
341        let mut serial_bn = BigNum::new()?;
342        serial_bn.pseudo_rand(64, MsbOption::MAYBE_ZERO, true)?;
343        let serial_number = Asn1Integer::from_bn(&serial_bn)?;
344        cert.set_serial_number(&serial_number)?;
345
346        let begin_valid_time =
347            Asn1Time::from_unix(std::time::UNIX_EPOCH.elapsed().unwrap().as_secs() as i64 - 360)?;
348        cert.set_not_before(&begin_valid_time)?;
349        let end_valid_time = Asn1Time::days_from_now(3)?; // 3 days from now
350        cert.set_not_after(&end_valid_time)?;
351
352        let mut x509_name = X509NameBuilder::new()?;
353        if let Some(cn) = os_country() {
354            x509_name.append_entry_by_text("C", cn.as_str())?;
355        }
356        x509_name.append_entry_by_text("CN", "ProSA-hyper")?;
357        let x509_name = x509_name.build();
358        cert.set_subject_name(&x509_name)?;
359        cert.set_issuer_name(&x509_name)?;
360
361        let mut subject_alternative_name = SubjectAlternativeName::new();
362        let x509_extension = subject_alternative_name
363            .dns("localhost")
364            .build(&cert.x509v3_context(None, None))?;
365        cert.append_extension2(&x509_extension)?;
366
367        cert.sign(&pkey, MessageDigest::sha256())?;
368
369        let mut cert_file = File::create(cert_path.clone())
370            .map_err(|e| ConfigError::IoFile(cert_path.clone(), e))?;
371        cert_file
372            .write_all(&cert.build().to_pem()?)
373            .map_err(|e| ConfigError::IoFile(cert_path.clone(), e))?;
374
375        Ok(SslConfig::new_cert_key(
376            cert_path,
377            key_path,
378            Some(PASSPHRASE.into()),
379        ))
380    }
381
382    #[tokio::test]
383    async fn https_client_server() {
384        const PROSA_HTTPS_TEST_DIR_NAME: &str = "ProSA_HTTPS";
385        let prosa_temp_dir = env::temp_dir().join(PROSA_HTTPS_TEST_DIR_NAME);
386
387        let _ = fs::remove_dir_all(&prosa_temp_dir);
388        fs::create_dir_all(&prosa_temp_dir).unwrap();
389
390        let key_path = prosa_temp_dir.join("prosa_https.key");
391        let cert_path = prosa_temp_dir.join("prosa_https.pem");
392        let server_ssl_config = create_server_cert(
393            key_path.as_os_str().to_str().unwrap().into(),
394            cert_path.as_os_str().to_str().unwrap().into(),
395        )
396        .unwrap();
397
398        let mut buf = Vec::new();
399        File::open(cert_path.as_os_str().to_str().unwrap())
400            .unwrap()
401            .read_to_end(&mut buf)
402            .unwrap();
403        let client_cert = reqwest::Certificate::from_pem(&buf).unwrap();
404
405        let client_ssl_store = Store::File {
406            path: format!("{}/", prosa_temp_dir.as_os_str().to_str().unwrap()),
407        };
408        let mut client_ssl_config = SslConfig::default();
409        client_ssl_config.set_store(client_ssl_store);
410
411        let test_settings = HttpTestSettings::new(
412            Url::parse("https://localhost:48443").unwrap(),
413            Some(server_ssl_config),
414        );
415
416        // Run a ProSA to test
417        run_test(test_settings, Some(client_cert), false).await;
418    }
419
420    #[tokio::test]
421    async fn h2_client_server() {
422        const PROSA_H2_TEST_DIR_NAME: &str = "ProSA_H2";
423        let prosa_temp_dir = env::temp_dir().join(PROSA_H2_TEST_DIR_NAME);
424
425        let _ = fs::remove_dir_all(&prosa_temp_dir);
426        fs::create_dir_all(&prosa_temp_dir).unwrap();
427
428        let key_path = prosa_temp_dir.join("prosa_h2.key");
429        let cert_path = prosa_temp_dir.join("prosa_h2.pem");
430        let mut server_ssl_config = create_server_cert(
431            key_path.as_os_str().to_str().unwrap().into(),
432            cert_path.as_os_str().to_str().unwrap().into(),
433        )
434        .unwrap();
435        // Need to set the ALPN for server because of inline configuration @see TargetSetting::new
436        server_ssl_config.set_alpn(vec!["h2".into()]);
437
438        let mut buf = Vec::new();
439        File::open(cert_path.as_os_str().to_str().unwrap())
440            .unwrap()
441            .read_to_end(&mut buf)
442            .unwrap();
443        let client_cert = reqwest::Certificate::from_pem(&buf).unwrap();
444
445        let test_settings = HttpTestSettings::new(
446            Url::parse("https://localhost:49443").unwrap(),
447            Some(server_ssl_config),
448        );
449
450        // Run a ProSA to test
451        run_test(test_settings, Some(client_cert), true).await;
452    }
453}