1use std::{env, sync::Arc, time::Duration};
2
3use hyper::server::conn::{http1, http2};
4use hyper_util::rt::{TokioExecutor, TokioIo};
5use opentelemetry::KeyValue;
6use prosa::{
7 core::{
8 adaptor::Adaptor,
9 error::ProcError,
10 msg::{ErrorMsg, InternalMsg, Msg, RequestMsg},
11 proc::{Proc, ProcBusParam, ProcConfig as _, proc, proc_settings},
12 service::ServiceError,
13 },
14 event::pending::PendingMsgs,
15 io::{listener::ListenerSetting, url_is_ssl},
16};
17
18use prosa_utils::config::ssl::SslConfig;
19use serde::{Deserialize, Serialize};
20use tokio::sync::mpsc;
21use tracing::{Level, debug, info, span, warn};
22use url::Url;
23
24use crate::{H2, server::service::HyperService};
25
26use super::{HyperProcMsg, adaptor::HyperServerAdaptor};
27
28#[proc_settings]
30#[derive(Debug, Deserialize, Serialize, Clone)]
31pub struct HyperServerSettings {
32 #[serde(default = "HyperServerSettings::default_listener")]
34 pub listener: ListenerSetting,
35 #[serde(default = "HyperServerSettings::default_service_timeout")]
37 pub service_timeout: Duration,
38}
39
40impl HyperServerSettings {
41 fn default_listener() -> ListenerSetting {
42 let mut url = Url::parse("http://0.0.0.0:8080").unwrap();
43 if let Ok(Ok(port)) = env::var("PORT").map(|p| p.parse::<u16>()) {
44 url.set_port(Some(port)).unwrap();
45 }
46
47 ListenerSetting::new(url, None)
48 }
49
50 fn default_service_timeout() -> Duration {
51 Duration::from_millis(800)
52 }
53
54 pub fn new(listener: ListenerSetting, service_timeout: Duration) -> HyperServerSettings {
56 HyperServerSettings {
57 listener,
58 service_timeout,
59 ..Default::default()
60 }
61 }
62}
63
64#[proc_settings]
65impl Default for HyperServerSettings {
66 fn default() -> HyperServerSettings {
67 HyperServerSettings {
68 listener: Self::default_listener(),
69 service_timeout: Self::default_service_timeout(),
70 }
71 }
72}
73
74#[proc(settings = HyperServerSettings)]
76pub struct HyperServerProc {}
77
78#[proc]
79impl<M, A> Proc<A> for HyperServerProc
80where
81 M: 'static
82 + std::marker::Send
83 + std::marker::Sync
84 + std::marker::Sized
85 + std::clone::Clone
86 + std::fmt::Debug
87 + prosa_utils::msg::tvf::Tvf
88 + std::default::Default,
89 A: 'static + Adaptor + HyperServerAdaptor<M> + Clone + std::marker::Send + std::marker::Sync,
90{
91 async fn internal_run(&mut self, name: String) -> Result<(), Box<dyn ProcError + Send + Sync>> {
93 let adaptor = A::new(self, &name)?;
95
96 self.proc.add_proc().await?;
98
99 let (http_tx, mut http_rx) = mpsc::channel::<HyperProcMsg<M>>(2048);
101
102 let mut pending_req = PendingMsgs::<HyperProcMsg<M>, M>::default();
104
105 if url_is_ssl(&self.settings.listener.url) {
107 if let Some(ssl) = self.settings.listener.ssl.as_mut() {
108 ssl.set_alpn(vec!["h2".into(), "http/1.1".into()]);
109 } else {
110 let mut ssl = SslConfig::default();
111 ssl.set_alpn(vec!["h2".into(), "http/1.1".into()]);
112 self.settings.listener.ssl = Some(ssl);
113 }
114 }
115
116 let meter = self.get_proc_param().meter("hyper_server");
118 let observable_http_counter = meter
119 .u64_counter("prosa_hyper_srv_count")
120 .with_description("Hyper HTTP counter")
121 .build();
122 let observable_http_socket = meter
123 .i64_up_down_counter("prosa_hyper_srv_socket")
124 .with_description("Hyper HTTP socket counter")
125 .build();
126
127 let listener = Arc::new(self.settings.listener.bind().await?);
128 let service_adaptor = Arc::new(adaptor.clone());
129 info!("Listening on {:?}", listener.local_addr());
130 loop {
131 tokio::select! {
132 Some(msg) = self.internal_rx_queue.recv() => {
133 match msg {
134 InternalMsg::Request(msg) => panic!(
135 "The hyper processor {} receive a request {:?}",
136 self.get_proc_id(),
137 msg
138 ),
139 InternalMsg::Response(msg) => {
140 if let Some(mut hyper_msg) = pending_req.pull_msg(msg.get_id()) {
141 let response_queue = hyper_msg.get_response_queue()?;
142 let _ = response_queue.send(InternalMsg::Response(msg));
143 }
144 }
145 InternalMsg::Error(err_msg) => {
146 if let Some(mut hyper_err_msg) = pending_req.pull_msg(err_msg.get_id()) {
147 let response_queue = hyper_err_msg.get_response_queue()?;
148 let _ = response_queue
149 .send(InternalMsg::Error(err_msg));
150 }
151 }
152 InternalMsg::Command(_) => todo!(),
153 InternalMsg::Config => todo!(),
154 InternalMsg::Service(table) => self.service = table,
155 InternalMsg::Shutdown => {
156 adaptor.terminate();
157 self.proc.remove_proc(None).await?;
158 warn!("The Hyper server processor will shut down");
159 return Ok(());
160 }
161 }
162 },
163 Some(mut http_msg) = http_rx.recv() => {
164 if let Some(service) = self.service.get_proc_service(http_msg.get_service())
165 && let Some(http_msg_data) = http_msg.take_data()
166 {
167 let request = RequestMsg::new(http_msg.get_service().clone(), http_msg_data, self.proc.get_service_queue().clone());
168 let request_id = request.get_id();
169 service.proc_queue.send(InternalMsg::Request(request)).await.unwrap();
170 pending_req.push_with_id(request_id, http_msg, self.settings.service_timeout);
171 } else {
172 let service_name = http_msg.get_service().clone();
173 let response_queue = http_msg.get_response_queue()?;
174 let data = http_msg.take_data();
175 let _ = response_queue.send(InternalMsg::Error(ErrorMsg::new(http_msg, service_name.clone(), span!(Level::WARN, "hyper::server::Msg", code = "503"), data, ServiceError::UnableToReachService(service_name))));
176 }
177 },
178 accept_result = listener.accept_raw() => {
179 let (stream, addr) = accept_result?;
180
181 let listener = listener.clone();
182 let service_adaptor = service_adaptor.clone();
183 let http_tx = http_tx.clone();
184 let http_counter = observable_http_counter.clone();
185 let http_socket = observable_http_socket.clone();
186 tokio::task::spawn(async move {
187 match listener.handshake(stream).await {
188 Ok(stream) => {
189 let is_http2 = stream.selected_alpn_check(|alpn| alpn == H2);
190
191 http_socket.add(1, &[KeyValue::new("version", if is_http2 { "HTTP/2" } else { "HTTP/1.1" })]);
192
193 let io = TokioIo::new(stream);
194 let service = HyperService::new(service_adaptor, http_tx, http_counter);
195 if is_http2 {
196 if let Err(err) = http2::Builder::new(TokioExecutor::new())
197 .serve_connection(
198 io,
199 service,
200 )
201 .await
202 {
203 warn!("Failed to serve http/2 connection[{addr}]: {err:?}");
204 }
205 } else if let Err(err) = http1::Builder::new()
206 .serve_connection(
207 io,
208 service,
209 )
210 .await
211 {
212 warn!("Failed to serve http/1 connection[{addr}]: {err:?}");
213 }
214
215 http_socket.add(-1, &[KeyValue::new("version", if is_http2 { "HTTP/2" } else { "HTTP/1.1" })]);
216 }
217 Err(e) => warn!("Failed to handshake with client[{addr}]: {e:?}"),
218 }
219
220 debug!("Connection closed {addr}");
221 });
222 },
223 Some(mut msg) = pending_req.pull(), if !pending_req.is_empty() => {
224 warn!(parent: msg.get_span(), "Timeout message {:?}", msg);
225
226 let service_name = msg.get_service().clone();
227 let span_msg = msg.get_span().clone();
228 let response_queue = msg.get_response_queue()?;
229 let data = msg.take_data();
230 let _ = response_queue.send(InternalMsg::Error(ErrorMsg::new(msg, service_name.clone(), span_msg, data, ServiceError::Timeout(service_name, self.settings.service_timeout.as_millis() as u64))));
231 },
232 }
233 }
234 }
235}