Skip to main content

prosa_hyper/server/
proc.rs

1use std::{env, sync::Arc, time::Duration};
2
3use hyper::server::conn::{http1, http2};
4use hyper_util::rt::{TokioExecutor, TokioIo};
5use opentelemetry::KeyValue;
6use prosa::{
7    core::{
8        adaptor::Adaptor,
9        error::ProcError,
10        msg::{InternalMsg, Msg, RequestMsg},
11        proc::{Proc, ProcBusParam, ProcConfig as _, proc, proc_settings},
12        service::ServiceError,
13    },
14    event::pending::PendingMsgs,
15    io::{SslConfig, listener::ListenerSetting, url_is_ssl},
16};
17
18use serde::{Deserialize, Serialize};
19use tokio::sync::mpsc;
20use tracing::{debug, info, warn};
21use url::Url;
22
23use crate::{H2, server::service::HyperService};
24
25use super::adaptor::HyperServerAdaptor;
26
27/// Hyper server processor settings
28#[proc_settings]
29#[derive(Debug, Deserialize, Serialize, Clone)]
30pub struct HyperServerSettings {
31    /// Listener settings
32    #[serde(default = "HyperServerSettings::default_listener")]
33    pub listener: ListenerSetting,
34    /// Timeout for internal service requests
35    #[serde(default = "HyperServerSettings::default_service_timeout")]
36    pub service_timeout: Duration,
37}
38
39impl HyperServerSettings {
40    fn default_listener() -> ListenerSetting {
41        let mut url = Url::parse("http://0.0.0.0:8080").unwrap();
42        if let Ok(Ok(port)) = env::var("PORT").map(|p| p.parse::<u16>()) {
43            url.set_port(Some(port)).unwrap();
44        }
45
46        ListenerSetting::new(url, None)
47    }
48
49    fn default_service_timeout() -> Duration {
50        Duration::from_millis(800)
51    }
52
53    /// Create a new Hyper Server settings
54    pub fn new(listener: ListenerSetting, service_timeout: Duration) -> HyperServerSettings {
55        HyperServerSettings {
56            listener,
57            service_timeout,
58            ..Default::default()
59        }
60    }
61}
62
63#[proc_settings]
64impl Default for HyperServerSettings {
65    fn default() -> HyperServerSettings {
66        HyperServerSettings {
67            listener: Self::default_listener(),
68            service_timeout: Self::default_service_timeout(),
69        }
70    }
71}
72
73/// Hyper server processor
74#[proc(settings = HyperServerSettings)]
75pub struct HyperServerProc {}
76
77#[proc]
78impl<M, A> Proc<A> for HyperServerProc
79where
80    M: 'static
81        + std::marker::Send
82        + std::marker::Sync
83        + std::marker::Sized
84        + std::clone::Clone
85        + std::fmt::Debug
86        + prosa::core::msg::Tvf
87        + std::default::Default,
88    A: 'static + Adaptor + HyperServerAdaptor<M> + Clone + std::marker::Send + std::marker::Sync,
89{
90    /// Main loop of the processor
91    async fn internal_run(&mut self) -> Result<(), Box<dyn ProcError + Send + Sync>> {
92        // Initiate an adaptor for the hyper server processor
93        let adaptor = A::new(self)?;
94
95        // Add proc main queue (id: 0)
96        self.proc.add_proc().await?;
97
98        // Declare an internal queue for HTTP requests
99        let (http_tx, mut http_rx) = mpsc::channel::<RequestMsg<M>>(2048);
100
101        // Declare a list for pending HTTP request
102        let mut pending_req = PendingMsgs::<RequestMsg<M>, M>::default();
103
104        // Set default protocol to HTTP2
105        if url_is_ssl(&self.settings.listener.url) {
106            if let Some(ssl) = self.settings.listener.ssl.as_mut() {
107                ssl.set_alpn(vec!["h2".into(), "http/1.1".into()]);
108            } else {
109                let mut ssl = SslConfig::default();
110                ssl.set_alpn(vec!["h2".into(), "http/1.1".into()]);
111                self.settings.listener.ssl = Some(ssl);
112            }
113        }
114
115        // Meter to log HTTP reponses
116        let meter = self.get_proc_param().meter("hyper_server");
117        let observable_http_counter = meter
118            .u64_counter("prosa_hyper_srv_count")
119            .with_description("Hyper HTTP server counter")
120            .build();
121        let observable_http_socket = meter
122            .i64_up_down_counter("prosa_hyper_srv_socket")
123            .with_description("Hyper HTTP server socket counter")
124            .build();
125
126        let listener = Arc::new(self.settings.listener.bind().await?);
127        let service_adaptor = Arc::new(adaptor.clone());
128        info!("Listening on {:?}", listener.local_addr());
129        loop {
130            tokio::select! {
131                Some(msg) = self.internal_rx_queue.recv() => {
132                    match msg {
133                        InternalMsg::Request(msg) => panic!(
134                            "The hyper processor {} receive a request {:?}",
135                            self.get_proc_id(),
136                            msg
137                        ),
138                        InternalMsg::Response(mut msg) => {
139                            if let Some(hyper_msg) = pending_req.pull_msg(msg.get_id())
140                                && let Some(data) = msg.take_data()
141                            {
142                                let _ = hyper_msg.return_to_sender(data);
143                            }
144                        }
145                        InternalMsg::Error(mut err_msg) => {
146                            if let Some(hyper_err_msg) = pending_req.pull_msg(err_msg.get_id()) {
147                                let _ = hyper_err_msg.return_error_to_sender(err_msg.take_data(), err_msg.into_err());
148                            }
149                        }
150                        InternalMsg::Command(_) => todo!(),
151                        InternalMsg::Config => todo!(),
152                        InternalMsg::Service(table) => self.service = table,
153                        InternalMsg::Shutdown => {
154                            adaptor.terminate();
155                            self.proc.remove_proc(None).await?;
156                            warn!("The Hyper server processor will shut down");
157                            return Ok(());
158                        }
159                    }
160                },
161                Some(mut http_msg) = http_rx.recv() => {
162                    if let Some(service) = self.service.get_proc_service(http_msg.get_service())
163                        && let Some(http_msg_data) = http_msg.take_data()
164                    {
165                        let request = RequestMsg::new(http_msg.get_service().clone(), http_msg_data, self.proc.get_service_queue().clone());
166                        let request_id = request.get_id();
167                        service.proc_queue.send(InternalMsg::Request(request)).await?;
168                        pending_req.push_with_id(request_id, http_msg, self.settings.service_timeout);
169                    } else {
170                        warn!(
171                            parent: http_msg.get_span(),
172                            code = "503",
173                            "hyper::server::Msg",
174                        );
175                        let data = http_msg.take_data();
176                        let service_name = http_msg.get_service().clone();
177                        let _ = http_msg.return_error_to_sender(data, ServiceError::UnableToReachService(service_name));
178                    }
179                },
180                accept_result = listener.accept_raw() => {
181                    let (stream, addr) = accept_result?;
182
183                    let listener = listener.clone();
184                    let service_adaptor = service_adaptor.clone();
185                    let http_tx = http_tx.clone();
186                    let http_counter = observable_http_counter.clone();
187                    let http_socket = observable_http_socket.clone();
188                    tokio::task::spawn(async move {
189                        match listener.handshake(stream).await {
190                            Ok(stream) => {
191                                let is_http2 = stream.selected_alpn_check(|alpn| alpn == H2);
192
193                                http_socket.add(1, &[KeyValue::new("version", if is_http2 { "HTTP/2" } else { "HTTP/1.1" })]);
194
195                                let io = TokioIo::new(stream);
196                                let service = HyperService::new(service_adaptor, http_tx, http_counter);
197                                if is_http2 {
198                                    if let Err(err) = http2::Builder::new(TokioExecutor::new())
199                                        .serve_connection(
200                                            io,
201                                            service,
202                                        )
203                                        .await
204                                    {
205                                        warn!("Failed to serve http/2 connection[{addr}]: {err:?}");
206                                    }
207                                } else if let Err(err) = http1::Builder::new()
208                                    .serve_connection(
209                                        io,
210                                        service,
211                                    )
212                                    .await
213                                {
214                                    warn!("Failed to serve http/1 connection[{addr}]: {err:?}");
215                                }
216
217                                http_socket.add(-1, &[KeyValue::new("version", if is_http2 { "HTTP/2" } else { "HTTP/1.1" })]);
218                            }
219                            Err(e) => warn!("Failed to handshake with client[{addr}]: {e:?}"),
220                        }
221
222                        debug!("Connection closed {addr}");
223                    });
224                },
225                Some(mut msg) = pending_req.pull(), if !pending_req.is_empty() => {
226                    warn!(parent: msg.get_span(), "Timeout message {:?}", msg);
227                    let data = msg.take_data();
228                    let service_name = msg.get_service().clone();
229                    let _ = msg.return_error_to_sender(
230                        data,
231                        ServiceError::Timeout(
232                            service_name,
233                            self.settings.service_timeout.as_millis() as u64,
234                        ),
235                    );
236                },
237            }
238        }
239    }
240}