1#![deny(
31 bad_style,
32 missing_debug_implementations,
33 missing_docs,
34 overflowing_literals,
35 patterns_in_fns_without_body,
36 trivial_casts,
37 trivial_numeric_casts,
38 unsafe_code,
39 unused,
40 unused_extern_crates,
41 unused_import_braces,
42 unused_qualifications,
43 unused_results
44)]
45
46use std::{convert::Infallible, net::IpAddr, time::Duration};
47
48use http_body_util::Full;
49use hyper::{Method, Request, Response, StatusCode, server::conn::http1, service::service_fn};
50use hyper_util::rt::TokioIo;
51use log::{debug, error, info};
52pub use service_probe_common::ServiceState;
53use snafu::{ResultExt as _, Snafu};
54use tokio::{
55 net::{TcpListener, TcpStream},
56 sync::{RwLock, oneshot},
57 task::JoinHandle,
58};
59
60struct ProbeTaskHandle {
61 shutdown_sender: oneshot::Sender<()>,
62 join_handle: JoinHandle<()>,
63}
64
65static SERVICE_STATE: std::sync::RwLock<ServiceState> = std::sync::RwLock::new(ServiceState::Up);
66static PROBE_TASK_HANDLE: RwLock<Option<ProbeTaskHandle>> = RwLock::const_new(None);
67
68pub const SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_millis(500);
70
71#[derive(Debug, Snafu)]
73pub enum ProbeStartError {
74 AlreadyStarted,
76
77 SocketUnavailable {
79 source: std::io::Error,
81 },
82}
83
84pub fn set_service_state(state: ServiceState) {
88 let mut state_lock = SERVICE_STATE
89 .write()
90 .expect("rwlock poisoning should be impossible with the implemented control flow");
91 if state != *state_lock {
92 debug!("Service state change: {} to {}.", *state_lock, state);
93 *state_lock = state;
94 }
95}
96
97pub fn get_service_state() -> ServiceState {
101 *SERVICE_STATE
102 .read()
103 .expect("rwlock poisoning should be impossible with the implemented control flow")
104}
105
106pub async fn start_probe<A>(
110 address: A,
111 port: u16,
112 initial_state: ServiceState,
113) -> Result<(), ProbeStartError>
114where
115 A: Into<IpAddr>,
116{
117 let mut probe_task_handle = PROBE_TASK_HANDLE.write().await;
118
119 if probe_task_handle.is_some() {
120 return Err(ProbeStartError::AlreadyStarted);
121 }
122
123 let (shutdown_sender, shutdown_receiver) = oneshot::channel();
124
125 let ip_address: IpAddr = address.into();
126
127 let listener = TcpListener::bind((ip_address, port))
128 .await
129 .context(SocketUnavailableSnafu)?;
130 info!(
131 "Service readiness probe listening on http://{ip_address}:{port}/ with initial state {initial_state}"
132 );
133
134 set_service_state(initial_state);
136 let join_handle = tokio::task::spawn(run_probe_server(listener, shutdown_receiver));
137
138 *probe_task_handle = Some(ProbeTaskHandle {
139 shutdown_sender,
140 join_handle,
141 });
142
143 Ok(())
144}
145
146pub async fn stop_probe() {
152 let Some(ProbeTaskHandle {
153 shutdown_sender,
154 join_handle,
155 }) = PROBE_TASK_HANDLE.write().await.take()
156 else {
157 return;
158 };
159
160 let _ = shutdown_sender.send(());
161
162 debug!("Shutting down service readiness probe");
163
164 if let Err(_elapsed) = tokio::time::timeout(SHUTDOWN_GRACE_PERIOD, join_handle).await {
165 error!("Error shutting down the service readiness probe");
166 }
167}
168
169async fn run_probe_server(listener: TcpListener, mut shutdown_receiver: oneshot::Receiver<()>) {
170 loop {
171 tokio::select! {
172 accept = listener.accept() => {
173 match accept {
174 Ok((stream, _addr)) => {
175 _ = tokio::spawn(handle_accept(stream));
176 }
177 Err(e) => {
178 error!("Error accepting connection for service readiness probe: {e:?}");
179 }
180 }
181 }
182 _ = &mut shutdown_receiver => {
183 return;
184 }
185 }
186 }
187}
188
189async fn handle_accept(stream: TcpStream) {
190 let io = TokioIo::new(stream);
191 let service = service_fn(handle_request);
192
193 if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
194 error!("Error serving connection for service readiness probe: {e:?}");
195 }
196}
197
198async fn handle_request(
199 req: Request<hyper::body::Incoming>,
200) -> Result<Response<Full<&'static [u8]>>, Infallible> {
201 let (status_code, body) = match *req.method() {
202 Method::GET => {
203 let path = req.uri().path();
204 let state = get_service_state();
205 if ["", "/", "/health", "/health/"].contains(&path) {
206 (StatusCode::OK, state.as_str())
207 } else if ["/ready", "/ready/"].contains(&path) {
208 if state == ServiceState::Ready {
209 (StatusCode::OK, state.as_str())
210 } else {
211 (StatusCode::SERVICE_UNAVAILABLE, state.as_str())
212 }
213 } else {
214 (StatusCode::NOT_FOUND, "Not found")
215 }
216 }
217 Method::HEAD => (StatusCode::OK, ""),
218 _ => (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed"),
219 };
220 let mut response = Response::new(Full::new(body.as_bytes()));
221 *response.status_mut() = status_code;
222 Ok(response)
223}