poem_grpc/
health.rs

1use std::{collections::HashMap, sync::Mutex};
2
3use futures_util::StreamExt;
4use poem::{IntoEndpoint, endpoint::BoxEndpoint};
5use tokio::sync::watch::{Receiver, Sender};
6
7use crate::{Code, Request, Response, Service, Status, Streaming};
8
9#[allow(unreachable_pub)]
10#[allow(clippy::derive_partial_eq_without_eq)]
11mod proto {
12    include!(concat!(env!("OUT_DIR"), "/grpc.health.v1.rs"));
13}
14
15/// Service health
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ServingStatus {
18    /// The service is currently up and serving requests.
19    Serving,
20    /// The service is currently down and not serving requests.
21    NotServing,
22}
23
24impl ServingStatus {
25    fn to_proto(self) -> proto::health_check_response::ServingStatus {
26        use proto::health_check_response::ServingStatus::*;
27
28        match self {
29            ServingStatus::Serving => Serving,
30            ServingStatus::NotServing => NotServing,
31        }
32    }
33}
34
35type ServiceStatusMap = HashMap<String, ServingStatus>;
36
37struct HealthService {
38    receiver: Receiver<ServiceStatusMap>,
39}
40
41/// A handle providing methods to update the health status of GRPC services
42pub struct HealthReporter {
43    state: Mutex<(ServiceStatusMap, Sender<ServiceStatusMap>)>,
44}
45
46impl HealthReporter {
47    fn set_status<S: Service>(&self, status: ServingStatus) {
48        let mut state = self.state.lock().unwrap();
49        state.0.insert(S::NAME.to_string(), status);
50        let _ = state.1.send(state.0.clone());
51    }
52
53    /// Sets the status of the service implemented by `S` to
54    /// [`ServingStatus::Serving`]
55    pub fn set_serving<S: Service>(&self) {
56        self.set_status::<S>(ServingStatus::Serving);
57    }
58
59    /// Sets the status of the service implemented by `S` to
60    /// [`ServingStatus::NotServing`]
61    pub fn set_not_serving<S: Service>(&self) {
62        self.set_status::<S>(ServingStatus::NotServing);
63    }
64
65    /// Clear the status of the given service.
66    pub fn clear_service_status<S: Service>(&self) {
67        let mut state = self.state.lock().unwrap();
68        state.0.remove(S::NAME);
69        let _ = state.1.send(state.0.clone());
70    }
71}
72
73impl proto::Health for HealthService {
74    async fn check(
75        &self,
76        request: Request<proto::HealthCheckRequest>,
77    ) -> Result<Response<proto::HealthCheckResponse>, Status> {
78        let service_status = self.receiver.borrow();
79        match service_status.get(&request.service) {
80            Some(status) => Ok(Response::new(proto::HealthCheckResponse {
81                status: status.to_proto().into(),
82            })),
83            None => Err(Status::new(Code::NotFound)
84                .with_message(format!("service `{}` not found", request.service))),
85        }
86    }
87
88    async fn watch(
89        &self,
90        request: Request<proto::HealthCheckRequest>,
91    ) -> Result<Response<Streaming<proto::HealthCheckResponse>>, Status> {
92        let mut stream = tokio_stream::wrappers::WatchStream::new(self.receiver.clone());
93        let service_name = request.into_inner().service;
94
95        Ok(Response::new(Streaming::new(async_stream::try_stream! {
96            while let Some(service_status) = stream.next().await {
97                let res = service_status.get(&service_name);
98                let status = res.ok_or_else(|| Status::new(Code::NotFound).with_message(format!("service `{}` not found", service_name)))?
99                    .to_proto()
100                    .into();
101                yield proto::HealthCheckResponse { status };
102            }
103        })))
104    }
105}
106
107/// Create health service and [`HealthReporter`]
108pub fn health_service() -> (
109    impl IntoEndpoint<Endpoint = BoxEndpoint<'static, poem::Response>> + Service,
110    HealthReporter,
111) {
112    let (sender, receiver) = tokio::sync::watch::channel(Default::default());
113
114    (
115        proto::HealthServer::new(HealthService { receiver }),
116        HealthReporter {
117            state: Mutex::new((Default::default(), sender)),
118        },
119    )
120}
121
122#[cfg(test)]
123mod tests {
124    use futures_util::StreamExt;
125
126    use super::*;
127    use crate::health::proto::Health;
128
129    fn create_service() -> (HealthService, HealthReporter) {
130        let (sender, receiver) = tokio::sync::watch::channel(Default::default());
131        (
132            HealthService { receiver },
133            HealthReporter {
134                state: Mutex::new((Default::default(), sender)),
135            },
136        )
137    }
138
139    #[tokio::test]
140    async fn check() {
141        let (service, reporter) = create_service();
142
143        let res = service
144            .check(Request::new(proto::HealthCheckRequest {
145                service: <proto::HealthServer<HealthService>>::NAME.to_string(),
146            }))
147            .await;
148        assert_eq!(res.unwrap_err().code(), Code::NotFound);
149
150        reporter.set_serving::<proto::HealthServer<HealthService>>();
151        let res = service
152            .check(Request::new(proto::HealthCheckRequest {
153                service: <proto::HealthServer<HealthService>>::NAME.to_string(),
154            }))
155            .await;
156        assert_eq!(
157            res.unwrap().into_inner(),
158            proto::HealthCheckResponse {
159                status: proto::health_check_response::ServingStatus::Serving.into()
160            }
161        );
162
163        reporter.set_not_serving::<proto::HealthServer<HealthService>>();
164        let res = service
165            .check(Request::new(proto::HealthCheckRequest {
166                service: <proto::HealthServer<HealthService>>::NAME.to_string(),
167            }))
168            .await;
169        assert_eq!(
170            res.unwrap().into_inner(),
171            proto::HealthCheckResponse {
172                status: proto::health_check_response::ServingStatus::NotServing.into()
173            }
174        );
175
176        reporter.clear_service_status::<proto::HealthServer<HealthService>>();
177        let res = service
178            .check(Request::new(proto::HealthCheckRequest {
179                service: <proto::HealthServer<HealthService>>::NAME.to_string(),
180            }))
181            .await;
182        assert_eq!(res.unwrap_err().code(), Code::NotFound);
183    }
184
185    #[tokio::test]
186    async fn watch() {
187        let (service, reporter) = create_service();
188
189        let mut stream = service
190            .watch(Request::new(proto::HealthCheckRequest {
191                service: <proto::HealthServer<HealthService>>::NAME.to_string(),
192            }))
193            .await
194            .unwrap();
195        assert_eq!(
196            stream.next().await.unwrap().unwrap_err().code(),
197            Code::NotFound
198        );
199
200        reporter.set_serving::<proto::HealthServer<HealthService>>();
201        let mut stream = service
202            .watch(Request::new(proto::HealthCheckRequest {
203                service: <proto::HealthServer<HealthService>>::NAME.to_string(),
204            }))
205            .await
206            .unwrap();
207        assert_eq!(
208            stream.next().await.unwrap().unwrap(),
209            proto::HealthCheckResponse {
210                status: proto::health_check_response::ServingStatus::Serving.into()
211            }
212        );
213
214        reporter.set_not_serving::<proto::HealthServer<HealthService>>();
215        assert_eq!(
216            stream.next().await.unwrap().unwrap(),
217            proto::HealthCheckResponse {
218                status: proto::health_check_response::ServingStatus::NotServing.into()
219            }
220        );
221
222        reporter.clear_service_status::<proto::HealthServer<HealthService>>();
223        assert_eq!(
224            stream.next().await.unwrap().unwrap_err().code(),
225            Code::NotFound
226        );
227    }
228}