1use std::{collections::HashMap, sync::Mutex};
2
3use futures_util::StreamExt;
4use poem::{IntoEndpoint, endpoint::BoxEndpoint};
5use tokio::sync::watch::{Receiver, Sender};
6
7use crate::{Code, Request, Response, Service, Status, Streaming};
8
9#[allow(unreachable_pub)]
10#[allow(clippy::derive_partial_eq_without_eq)]
11mod proto {
12 include!(concat!(env!("OUT_DIR"), "/grpc.health.v1.rs"));
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ServingStatus {
18 Serving,
20 NotServing,
22}
23
24impl ServingStatus {
25 fn to_proto(self) -> proto::health_check_response::ServingStatus {
26 use proto::health_check_response::ServingStatus::*;
27
28 match self {
29 ServingStatus::Serving => Serving,
30 ServingStatus::NotServing => NotServing,
31 }
32 }
33}
34
35type ServiceStatusMap = HashMap<String, ServingStatus>;
36
37struct HealthService {
38 receiver: Receiver<ServiceStatusMap>,
39}
40
41pub struct HealthReporter {
43 state: Mutex<(ServiceStatusMap, Sender<ServiceStatusMap>)>,
44}
45
46impl HealthReporter {
47 fn set_status<S: Service>(&self, status: ServingStatus) {
48 let mut state = self.state.lock().unwrap();
49 state.0.insert(S::NAME.to_string(), status);
50 let _ = state.1.send(state.0.clone());
51 }
52
53 pub fn set_serving<S: Service>(&self) {
56 self.set_status::<S>(ServingStatus::Serving);
57 }
58
59 pub fn set_not_serving<S: Service>(&self) {
62 self.set_status::<S>(ServingStatus::NotServing);
63 }
64
65 pub fn clear_service_status<S: Service>(&self) {
67 let mut state = self.state.lock().unwrap();
68 state.0.remove(S::NAME);
69 let _ = state.1.send(state.0.clone());
70 }
71}
72
73impl proto::Health for HealthService {
74 async fn check(
75 &self,
76 request: Request<proto::HealthCheckRequest>,
77 ) -> Result<Response<proto::HealthCheckResponse>, Status> {
78 let service_status = self.receiver.borrow();
79 match service_status.get(&request.service) {
80 Some(status) => Ok(Response::new(proto::HealthCheckResponse {
81 status: status.to_proto().into(),
82 })),
83 None => Err(Status::new(Code::NotFound)
84 .with_message(format!("service `{}` not found", request.service))),
85 }
86 }
87
88 async fn watch(
89 &self,
90 request: Request<proto::HealthCheckRequest>,
91 ) -> Result<Response<Streaming<proto::HealthCheckResponse>>, Status> {
92 let mut stream = tokio_stream::wrappers::WatchStream::new(self.receiver.clone());
93 let service_name = request.into_inner().service;
94
95 Ok(Response::new(Streaming::new(async_stream::try_stream! {
96 while let Some(service_status) = stream.next().await {
97 let res = service_status.get(&service_name);
98 let status = res.ok_or_else(|| Status::new(Code::NotFound).with_message(format!("service `{}` not found", service_name)))?
99 .to_proto()
100 .into();
101 yield proto::HealthCheckResponse { status };
102 }
103 })))
104 }
105}
106
107pub fn health_service() -> (
109 impl IntoEndpoint<Endpoint = BoxEndpoint<'static, poem::Response>> + Service,
110 HealthReporter,
111) {
112 let (sender, receiver) = tokio::sync::watch::channel(Default::default());
113
114 (
115 proto::HealthServer::new(HealthService { receiver }),
116 HealthReporter {
117 state: Mutex::new((Default::default(), sender)),
118 },
119 )
120}
121
122#[cfg(test)]
123mod tests {
124 use futures_util::StreamExt;
125
126 use super::*;
127 use crate::health::proto::Health;
128
129 fn create_service() -> (HealthService, HealthReporter) {
130 let (sender, receiver) = tokio::sync::watch::channel(Default::default());
131 (
132 HealthService { receiver },
133 HealthReporter {
134 state: Mutex::new((Default::default(), sender)),
135 },
136 )
137 }
138
139 #[tokio::test]
140 async fn check() {
141 let (service, reporter) = create_service();
142
143 let res = service
144 .check(Request::new(proto::HealthCheckRequest {
145 service: <proto::HealthServer<HealthService>>::NAME.to_string(),
146 }))
147 .await;
148 assert_eq!(res.unwrap_err().code(), Code::NotFound);
149
150 reporter.set_serving::<proto::HealthServer<HealthService>>();
151 let res = service
152 .check(Request::new(proto::HealthCheckRequest {
153 service: <proto::HealthServer<HealthService>>::NAME.to_string(),
154 }))
155 .await;
156 assert_eq!(
157 res.unwrap().into_inner(),
158 proto::HealthCheckResponse {
159 status: proto::health_check_response::ServingStatus::Serving.into()
160 }
161 );
162
163 reporter.set_not_serving::<proto::HealthServer<HealthService>>();
164 let res = service
165 .check(Request::new(proto::HealthCheckRequest {
166 service: <proto::HealthServer<HealthService>>::NAME.to_string(),
167 }))
168 .await;
169 assert_eq!(
170 res.unwrap().into_inner(),
171 proto::HealthCheckResponse {
172 status: proto::health_check_response::ServingStatus::NotServing.into()
173 }
174 );
175
176 reporter.clear_service_status::<proto::HealthServer<HealthService>>();
177 let res = service
178 .check(Request::new(proto::HealthCheckRequest {
179 service: <proto::HealthServer<HealthService>>::NAME.to_string(),
180 }))
181 .await;
182 assert_eq!(res.unwrap_err().code(), Code::NotFound);
183 }
184
185 #[tokio::test]
186 async fn watch() {
187 let (service, reporter) = create_service();
188
189 let mut stream = service
190 .watch(Request::new(proto::HealthCheckRequest {
191 service: <proto::HealthServer<HealthService>>::NAME.to_string(),
192 }))
193 .await
194 .unwrap();
195 assert_eq!(
196 stream.next().await.unwrap().unwrap_err().code(),
197 Code::NotFound
198 );
199
200 reporter.set_serving::<proto::HealthServer<HealthService>>();
201 let mut stream = service
202 .watch(Request::new(proto::HealthCheckRequest {
203 service: <proto::HealthServer<HealthService>>::NAME.to_string(),
204 }))
205 .await
206 .unwrap();
207 assert_eq!(
208 stream.next().await.unwrap().unwrap(),
209 proto::HealthCheckResponse {
210 status: proto::health_check_response::ServingStatus::Serving.into()
211 }
212 );
213
214 reporter.set_not_serving::<proto::HealthServer<HealthService>>();
215 assert_eq!(
216 stream.next().await.unwrap().unwrap(),
217 proto::HealthCheckResponse {
218 status: proto::health_check_response::ServingStatus::NotServing.into()
219 }
220 );
221
222 reporter.clear_service_status::<proto::HealthServer<HealthService>>();
223 assert_eq!(
224 stream.next().await.unwrap().unwrap_err().code(),
225 Code::NotFound
226 );
227 }
228}