1use std::time::{Duration, SystemTime};
4
5use prosa::core::{
6 error::BusError,
7 msg::{InternalMsg, Msg},
8};
9
10use tokio::sync::oneshot;
11
12use tracing::{Level, Span, span};
13
14pub mod adaptor;
16pub mod proc;
18pub(crate) mod service;
20
21#[derive(Debug)]
23pub struct HyperProcMsg<M>
24where
25 M: 'static
26 + std::marker::Send
27 + std::marker::Sync
28 + std::marker::Sized
29 + std::clone::Clone
30 + std::fmt::Debug
31 + prosa_utils::msg::tvf::Tvf
32 + std::default::Default,
33{
34 id: u64,
35 span: Span,
36 service: String,
37 data: Option<M>,
38 begin_time: SystemTime,
39 response_queue: Option<oneshot::Sender<InternalMsg<M>>>,
40}
41
42impl<M> HyperProcMsg<M>
43where
44 M: 'static
45 + std::marker::Send
46 + std::marker::Sync
47 + std::marker::Sized
48 + std::clone::Clone
49 + std::fmt::Debug
50 + prosa_utils::msg::tvf::Tvf
51 + std::default::Default,
52{
53 pub fn new(
55 service: String,
56 data: M,
57 response_queue: oneshot::Sender<InternalMsg<M>>,
58 ) -> HyperProcMsg<M> {
59 let span = span!(Level::INFO, "HyperProcMsg", service = service);
60 HyperProcMsg {
61 id: 0,
62 service,
63 span,
64 data: Some(data),
65 begin_time: SystemTime::now(),
66 response_queue: Some(response_queue),
67 }
68 }
69
70 pub fn set_id(&mut self, id: u64) {
72 self.id = id;
73 }
74
75 pub fn get_response_queue(&mut self) -> Result<oneshot::Sender<InternalMsg<M>>, BusError> {
77 self.response_queue
78 .take()
79 .ok_or(BusError::InternalQueue("HyperProcMsg".to_string()))
80 }
81}
82
83impl<M> Msg<M> for HyperProcMsg<M>
84where
85 M: 'static
86 + std::marker::Send
87 + std::marker::Sync
88 + std::marker::Sized
89 + std::clone::Clone
90 + std::fmt::Debug
91 + prosa_utils::msg::tvf::Tvf
92 + std::default::Default,
93{
94 fn get_id(&self) -> u64 {
95 self.id
96 }
97
98 fn get_service(&self) -> &String {
99 &self.service
100 }
101
102 fn get_span(&self) -> &Span {
103 &self.span
104 }
105
106 fn get_span_mut(&mut self) -> &mut Span {
107 &mut self.span
108 }
109
110 fn enter_span(&self) -> span::Entered<'_> {
111 self.span.enter()
112 }
113
114 fn get_data(&self) -> Result<&M, BusError> {
115 self.data.as_ref().ok_or(BusError::NoData)
116 }
117
118 fn get_data_mut(&mut self) -> Result<&mut M, BusError> {
119 self.data.as_mut().ok_or(BusError::NoData)
120 }
121
122 fn elapsed(&self) -> Duration {
123 self.begin_time.elapsed().unwrap_or(Duration::new(0, 0))
124 }
125
126 fn take_data(&mut self) -> Option<M> {
127 self.data.take()
128 }
129
130 fn take_data_if<P>(&mut self, predicate: P) -> Option<M>
131 where
132 P: FnOnce(&mut M) -> bool,
133 {
134 self.data.take_if(predicate)
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use bytes::Bytes;
141 use http_body_util::{Full, combinators::BoxBody};
142 use hyper::{Request, Response, StatusCode};
143 use openssl::{
144 asn1::{Asn1Integer, Asn1Time},
145 bn::{BigNum, MsbOption},
146 ec::{Asn1Flag, EcGroup, EcKey},
147 hash::MessageDigest,
148 nid::Nid,
149 pkey::PKey,
150 symm::Cipher,
151 x509::{X509, X509NameBuilder, extension::SubjectAlternativeName},
152 };
153 use prosa::{
154 core::{
155 adaptor::Adaptor,
156 error::ProcError,
157 main::{MainProc, MainRunnable as _},
158 proc::{Proc, ProcConfig as _},
159 settings::settings,
160 },
161 io::listener::ListenerSetting,
162 };
163 use prosa_utils::{
164 config::{
165 ConfigError, os_country,
166 ssl::{SslConfig, Store},
167 },
168 msg::simple_string_tvf::SimpleStringTvf,
169 };
170 use reqwest::Certificate;
171 use serde::Serialize;
172 use std::{
173 env,
174 fs::{self, File},
175 io::{Read as _, Write as _},
176 time::Duration,
177 };
178 use tokio::time;
179 use url::Url;
180
181 use crate::{
182 HyperResp,
183 server::{
184 adaptor::HyperServerAdaptor,
185 proc::{HyperServerProc, HyperServerSettings},
186 },
187 };
188
189 const WAIT_TIME: time::Duration = time::Duration::from_secs(5);
190
191 #[settings]
193 #[derive(Default, Debug, Serialize)]
194 struct HttpTestSettings {
195 server: HyperServerSettings,
196 }
197
198 impl HttpTestSettings {
199 fn new(url: Url, server_ssl: Option<SslConfig>) -> Self {
200 let server = HyperServerSettings::new(
201 ListenerSetting::new(url.clone(), server_ssl),
202 Duration::from_secs(1),
203 );
204 HttpTestSettings {
205 server,
206 ..Default::default()
207 }
208 }
209 }
210
211 #[derive(Adaptor, Clone)]
212 struct ServerTestAdaptor {
213 }
215
216 impl<M> HyperServerAdaptor<M> for ServerTestAdaptor
217 where
218 M: 'static
219 + std::marker::Send
220 + std::marker::Sync
221 + std::marker::Sized
222 + std::clone::Clone
223 + std::fmt::Debug
224 + prosa_utils::msg::tvf::Tvf
225 + std::default::Default,
226 {
227 fn new(
228 _proc: &crate::server::proc::HyperServerProc<M>,
229 _prosa_name: &str,
230 ) -> Result<Self, Box<dyn ProcError + Send + Sync>>
231 where
232 Self: Sized,
233 {
234 Ok(ServerTestAdaptor {})
235 }
236
237 async fn process_http_request(&self, req: Request<hyper::body::Incoming>) -> HyperResp<M> {
238 let resp_msg = if req.version() == hyper::Version::HTTP_2 {
239 "Hello, H2 world"
240 } else {
241 "Hello, world"
242 };
243 let response = Response::builder()
244 .header(
245 hyper::header::SERVER,
246 <ServerTestAdaptor as HyperServerAdaptor<M>>::SERVER_HEADER,
247 )
248 .status(StatusCode::OK)
249 .body(BoxBody::new(Full::new(Bytes::from(resp_msg))))
250 .unwrap();
251
252 HyperResp::HttpResp(response)
253 }
254
255 fn process_srv_response(
256 &self,
257 _resp: M,
258 ) -> hyper::Response<
259 http_body_util::combinators::BoxBody<bytes::Bytes, std::convert::Infallible>,
260 > {
261 unimplemented!()
262 }
263 }
264
265 async fn run_test(settings: HttpTestSettings, certificate: Option<Certificate>, http2: bool) {
266 let url = settings.server.listener.url.clone();
267
268 let (bus, main) = MainProc::<SimpleStringTvf>::create(&settings);
270
271 let main_task = main.run();
273
274 let http_server_proc =
276 HyperServerProc::<SimpleStringTvf>::create(1, bus.clone(), settings.server);
277 Proc::<ServerTestAdaptor>::run(http_server_proc, String::from("HTTP_SERVER_PROC"));
278
279 std::thread::sleep(Duration::from_secs(1));
281
282 let mut client_builder = reqwest::ClientBuilder::new()
284 .timeout(Duration::from_secs(WAIT_TIME.as_secs()))
285 .use_rustls_tls();
286 if let Some(cert) = certificate {
287 client_builder = client_builder.add_root_certificate(cert);
288 }
289 if http2 {
290 client_builder = client_builder.http2_prior_knowledge();
291 }
292 let client = client_builder.build().unwrap();
293 for _i in 0..20 {
294 let resp = client
295 .get(url.clone())
296 .send()
297 .await
298 .expect("Failed to send request");
299 assert_eq!(resp.status(), StatusCode::OK);
300 let server_header = resp.headers().get(hyper::header::SERVER).unwrap();
301 assert!(server_header.to_str().unwrap().starts_with("ProSA-Hyper/"));
302 }
303
304 bus.stop("ProSA HTTP client server unit test end".into())
305 .await
306 .unwrap();
307
308 main_task.await;
310 }
311
312 #[tokio::test]
313 async fn http_client_server() {
314 let test_settings =
315 HttpTestSettings::new(Url::parse("http://localhost:48080").unwrap(), None);
316
317 run_test(test_settings, None, false).await;
319 }
320
321 fn create_server_cert(key_path: String, cert_path: String) -> Result<SslConfig, ConfigError> {
323 const PASSPHRASE: &str = "prosa_test";
324
325 let mut group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?;
326 group.set_asn1_flag(Asn1Flag::NAMED_CURVE);
327 let pkey = PKey::from_ec_key(EcKey::generate(&group)?)?;
328 let mut pkey_file =
329 File::create(key_path.clone()).map_err(|e| ConfigError::IoFile(key_path.clone(), e))?;
330 pkey_file
331 .write_all(&pkey.private_key_to_pem_pkcs8_passphrase(
332 Cipher::aes_256_cbc(),
333 PASSPHRASE.as_bytes(),
334 )?)
335 .map_err(|e| ConfigError::IoFile(key_path.clone(), e))?;
336
337 let mut cert = X509::builder()?;
338 cert.set_version(2)?;
339 cert.set_pubkey(&pkey)?;
340
341 let mut serial_bn = BigNum::new()?;
342 serial_bn.pseudo_rand(64, MsbOption::MAYBE_ZERO, true)?;
343 let serial_number = Asn1Integer::from_bn(&serial_bn)?;
344 cert.set_serial_number(&serial_number)?;
345
346 let begin_valid_time =
347 Asn1Time::from_unix(std::time::UNIX_EPOCH.elapsed().unwrap().as_secs() as i64 - 360)?;
348 cert.set_not_before(&begin_valid_time)?;
349 let end_valid_time = Asn1Time::days_from_now(3)?; cert.set_not_after(&end_valid_time)?;
351
352 let mut x509_name = X509NameBuilder::new()?;
353 if let Some(cn) = os_country() {
354 x509_name.append_entry_by_text("C", cn.as_str())?;
355 }
356 x509_name.append_entry_by_text("CN", "ProSA-hyper")?;
357 let x509_name = x509_name.build();
358 cert.set_subject_name(&x509_name)?;
359 cert.set_issuer_name(&x509_name)?;
360
361 let mut subject_alternative_name = SubjectAlternativeName::new();
362 let x509_extension = subject_alternative_name
363 .dns("localhost")
364 .build(&cert.x509v3_context(None, None))?;
365 cert.append_extension2(&x509_extension)?;
366
367 cert.sign(&pkey, MessageDigest::sha256())?;
368
369 let mut cert_file = File::create(cert_path.clone())
370 .map_err(|e| ConfigError::IoFile(cert_path.clone(), e))?;
371 cert_file
372 .write_all(&cert.build().to_pem()?)
373 .map_err(|e| ConfigError::IoFile(cert_path.clone(), e))?;
374
375 Ok(SslConfig::new_cert_key(
376 cert_path,
377 key_path,
378 Some(PASSPHRASE.into()),
379 ))
380 }
381
382 #[tokio::test]
383 async fn https_client_server() {
384 const PROSA_HTTPS_TEST_DIR_NAME: &str = "ProSA_HTTPS";
385 let prosa_temp_dir = env::temp_dir().join(PROSA_HTTPS_TEST_DIR_NAME);
386
387 let _ = fs::remove_dir_all(&prosa_temp_dir);
388 fs::create_dir_all(&prosa_temp_dir).unwrap();
389
390 let key_path = prosa_temp_dir.join("prosa_https.key");
391 let cert_path = prosa_temp_dir.join("prosa_https.pem");
392 let server_ssl_config = create_server_cert(
393 key_path.as_os_str().to_str().unwrap().into(),
394 cert_path.as_os_str().to_str().unwrap().into(),
395 )
396 .unwrap();
397
398 let mut buf = Vec::new();
399 File::open(cert_path.as_os_str().to_str().unwrap())
400 .unwrap()
401 .read_to_end(&mut buf)
402 .unwrap();
403 let client_cert = reqwest::Certificate::from_pem(&buf).unwrap();
404
405 let client_ssl_store = Store::File {
406 path: format!("{}/", prosa_temp_dir.as_os_str().to_str().unwrap()),
407 };
408 let mut client_ssl_config = SslConfig::default();
409 client_ssl_config.set_store(client_ssl_store);
410
411 let test_settings = HttpTestSettings::new(
412 Url::parse("https://localhost:48443").unwrap(),
413 Some(server_ssl_config),
414 );
415
416 run_test(test_settings, Some(client_cert), false).await;
418 }
419
420 #[tokio::test]
421 async fn h2_client_server() {
422 const PROSA_H2_TEST_DIR_NAME: &str = "ProSA_H2";
423 let prosa_temp_dir = env::temp_dir().join(PROSA_H2_TEST_DIR_NAME);
424
425 let _ = fs::remove_dir_all(&prosa_temp_dir);
426 fs::create_dir_all(&prosa_temp_dir).unwrap();
427
428 let key_path = prosa_temp_dir.join("prosa_h2.key");
429 let cert_path = prosa_temp_dir.join("prosa_h2.pem");
430 let mut server_ssl_config = create_server_cert(
431 key_path.as_os_str().to_str().unwrap().into(),
432 cert_path.as_os_str().to_str().unwrap().into(),
433 )
434 .unwrap();
435 server_ssl_config.set_alpn(vec!["h2".into()]);
437
438 let mut buf = Vec::new();
439 File::open(cert_path.as_os_str().to_str().unwrap())
440 .unwrap()
441 .read_to_end(&mut buf)
442 .unwrap();
443 let client_cert = reqwest::Certificate::from_pem(&buf).unwrap();
444
445 let test_settings = HttpTestSettings::new(
446 Url::parse("https://localhost:49443").unwrap(),
447 Some(server_ssl_config),
448 );
449
450 run_test(test_settings, Some(client_cert), true).await;
452 }
453}