prosa_hyper/server/
proc.rs

1use std::{env, sync::Arc, time::Duration};
2
3use hyper::server::conn::{http1, http2};
4use hyper_util::rt::{TokioExecutor, TokioIo};
5use opentelemetry::KeyValue;
6use prosa::{
7    core::{
8        adaptor::Adaptor,
9        error::ProcError,
10        msg::{ErrorMsg, InternalMsg, Msg, RequestMsg},
11        proc::{Proc, ProcBusParam, ProcConfig as _, proc, proc_settings},
12        service::ServiceError,
13    },
14    event::pending::PendingMsgs,
15    io::{listener::ListenerSetting, stream::Stream, url_is_ssl},
16};
17
18use prosa_utils::config::ssl::SslConfig;
19use serde::{Deserialize, Serialize};
20use tokio::sync::mpsc;
21use tracing::{Level, debug, info, span, warn};
22use url::Url;
23
24use crate::{H2, server::service::HyperService};
25
26use super::{HyperProcMsg, adaptor::HyperServerAdaptor};
27
28/// Hyper server processor settings
29#[proc_settings]
30#[derive(Debug, Deserialize, Serialize, Clone)]
31pub struct HyperServerSettings {
32    /// Listener settings
33    #[serde(default = "HyperServerSettings::default_listener")]
34    pub listener: ListenerSetting,
35    /// Timeout for internal service requests
36    #[serde(default = "HyperServerSettings::default_service_timeout")]
37    pub service_timeout: Duration,
38}
39
40impl HyperServerSettings {
41    fn default_listener() -> ListenerSetting {
42        let mut url = Url::parse("http://0.0.0.0:8080").unwrap();
43        if let Ok(Ok(port)) = env::var("PORT").map(|p| p.parse::<u16>()) {
44            url.set_port(Some(port)).unwrap();
45        }
46
47        ListenerSetting::new(url, None)
48    }
49
50    fn default_service_timeout() -> Duration {
51        Duration::from_millis(800)
52    }
53
54    /// Create a new Hyper Server settings
55    pub fn new(listener: ListenerSetting, service_timeout: Duration) -> HyperServerSettings {
56        HyperServerSettings {
57            listener,
58            service_timeout,
59            ..Default::default()
60        }
61    }
62}
63
64#[proc_settings]
65impl Default for HyperServerSettings {
66    fn default() -> HyperServerSettings {
67        HyperServerSettings {
68            listener: Self::default_listener(),
69            service_timeout: Self::default_service_timeout(),
70        }
71    }
72}
73
74/// Hyper server processor
75#[proc(settings = HyperServerSettings)]
76pub struct HyperServerProc {}
77
78#[proc]
79impl<M, A> Proc<A> for HyperServerProc
80where
81    M: 'static
82        + std::marker::Send
83        + std::marker::Sync
84        + std::marker::Sized
85        + std::clone::Clone
86        + std::fmt::Debug
87        + prosa_utils::msg::tvf::Tvf
88        + std::default::Default,
89    A: 'static + Adaptor + HyperServerAdaptor<M> + Clone + std::marker::Send + std::marker::Sync,
90{
91    /// Main loop of the processor
92    async fn internal_run(&mut self, name: String) -> Result<(), Box<dyn ProcError + Send + Sync>> {
93        // Initiate an adaptor for the stub processor
94        let mut adaptor = A::new(self, &name)?;
95
96        // Add proc main queue (id: 0)
97        self.proc.add_proc().await?;
98
99        // Declare an internal queue for HTTP requests
100        let (http_tx, mut http_rx) = mpsc::channel::<HyperProcMsg<M>>(2048);
101
102        // Declare a list for pending HTTP request
103        let mut pending_req = PendingMsgs::<HyperProcMsg<M>, M>::default();
104        let mut message_ref_request = 0;
105
106        // Set default protocol to HTTP2
107        if url_is_ssl(&self.settings.listener.url) {
108            if let Some(ssl) = self.settings.listener.ssl.as_mut() {
109                ssl.set_alpn(vec!["h2".into(), "http/1.1".into()]);
110            } else {
111                let mut ssl = SslConfig::default();
112                ssl.set_alpn(vec!["h2".into(), "http/1.1".into()]);
113                self.settings.listener.ssl = Some(ssl);
114            }
115        }
116
117        // Meter to log HTTP reponses
118        let meter = self.get_proc_param().meter("hyper_server");
119        let observable_http_counter = meter
120            .u64_counter("prosa_hyper_srv_count")
121            .with_description("Hyper HTTP counter")
122            .init();
123        let observable_http_socket = meter
124            .i64_up_down_counter("prosa_hyper_srv_socket")
125            .with_description("Hyper HTTP socket counter")
126            .init();
127
128        let listener = Arc::new(self.settings.listener.bind().await?);
129        let service_adaptor = Arc::new(adaptor.clone());
130        info!("Listening on {:?}", listener.local_addr());
131        loop {
132            tokio::select! {
133                Some(msg) = self.internal_rx_queue.recv() => {
134                    match msg {
135                        InternalMsg::Request(msg) => panic!(
136                            "The hyper processor {} receive a request {:?}",
137                            self.get_proc_id(),
138                            msg
139                        ),
140                        InternalMsg::Response(msg) => {
141                            if let Some(hyper_msg) = pending_req.pull_msg(msg.get_id()) {
142                                let _ = hyper_msg.response_queue.send(InternalMsg::Response(msg));
143                            }
144                        }
145                        InternalMsg::Error(err_msg) => {
146                            if let Some(hyper_err_msg) = pending_req.pull_msg(err_msg.get_id()) {
147                                let _ = hyper_err_msg
148                                    .response_queue
149                                    .send(InternalMsg::Error(err_msg));
150                            }
151                        }
152                        InternalMsg::Command(_) => todo!(),
153                        InternalMsg::Config => todo!(),
154                        InternalMsg::Service(table) => self.service = table,
155                        InternalMsg::Shutdown => {
156                            adaptor.terminate();
157                            self.proc.remove_proc(None).await?;
158                            warn!("The Hyper server processor will shut down");
159                            return Ok(());
160                        }
161                    }
162                },
163                Some(http_msg) = http_rx.recv() => {
164                    let service_name = http_msg.get_service().clone();
165                    if let Some(service) = self.service.get_proc_service(&service_name, message_ref_request) {
166                        debug!("The service is find: {service:?}, send to the internal service");
167                        service.proc_queue.send(InternalMsg::Request(RequestMsg::new(message_ref_request, service_name, http_msg.get_data().clone(), self.proc.get_service_queue().clone()))).await.unwrap();
168                        pending_req.push_with_id(message_ref_request, http_msg, self.settings.service_timeout);
169
170                        message_ref_request += 1;
171                    } else {
172                        let origin_data = http_msg.get_data().clone();
173                        let _ = http_msg.response_queue.send(InternalMsg::Error(ErrorMsg::new(0, service_name.clone(), span!(Level::WARN, "hyper::server::Msg", code = "503"), origin_data, ServiceError::UnableToReachService(service_name))));
174                    }
175                },
176                accept_result = listener.accept_raw() => {
177                    let (stream, addr) = accept_result?;
178
179                    let listener = listener.clone();
180                    let service_adaptor = service_adaptor.clone();
181                    let http_tx = http_tx.clone();
182                    let http_counter = observable_http_counter.clone();
183                    let http_socket = observable_http_socket.clone();
184                    tokio::task::spawn(async move {
185                        match listener.handshake(stream).await {
186                            Ok(stream) => {
187                                let is_http2 = if let Stream::Ssl(ssl) = &stream {
188                                    if let Some(alpn) = ssl.ssl().selected_alpn_protocol() {
189                                        alpn == H2
190                                    } else {
191                                        false
192                                    }
193                                } else {
194                                    false
195                                };
196
197                                http_socket.add(1, &[KeyValue::new("version", if is_http2 { "HTTP/2" } else { "HTTP/1.1" })]);
198
199                                let io = TokioIo::new(stream);
200                                let service = HyperService::new(service_adaptor, http_tx, http_counter);
201                                if is_http2 {
202                                    if let Err(err) = http2::Builder::new(TokioExecutor::new())
203                                        .serve_connection(
204                                            io,
205                                            service,
206                                        )
207                                        .await
208                                    {
209                                        warn!("Failed to serve http/2 connection[{addr}]: {err:?}");
210                                    }
211                                } else if let Err(err) = http1::Builder::new()
212                                    .serve_connection(
213                                        io,
214                                        service,
215                                    )
216                                    .await
217                                {
218                                    warn!("Failed to serve http/1 connection[{addr}]: {err:?}");
219                                }
220
221                                http_socket.add(-1, &[KeyValue::new("version", if is_http2 { "HTTP/2" } else { "HTTP/1.1" })]);
222                            }
223                            Err(e) => warn!("Failed to handshake with client[{addr}]: {e:?}"),
224                        }
225
226                        debug!("Connection closed {addr}");
227                    });
228                },
229                Some(msg) = pending_req.pull(), if !pending_req.is_empty() => {
230                    warn!(parent: msg.get_span(), "Timeout message {:?}", msg);
231
232                    let service_name = msg.get_service().clone();
233                    let span_msg = msg.get_span().clone();
234                    let origin_data = msg.get_data().clone();
235                    let _ = msg.response_queue.send(InternalMsg::Error(ErrorMsg::new(0, service_name.clone(), span_msg, origin_data, ServiceError::Timeout(service_name, self.settings.service_timeout.as_millis() as u64))));
236                },
237            }
238        }
239    }
240}