use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tonic::{Request, Response, Status};
use crate::proto::{health_server::Health, HealthCheckRequest, HealthCheckResponse, ServingStatus};
#[derive(Debug)]
pub struct HealthState {
query_service_healthy: AtomicBool,
blob_service_healthy: AtomicBool,
vector_service_healthy: AtomicBool,
is_draining: AtomicBool,
}
impl Default for HealthState {
fn default() -> Self {
Self {
query_service_healthy: AtomicBool::new(true),
blob_service_healthy: AtomicBool::new(true),
vector_service_healthy: AtomicBool::new(true),
is_draining: AtomicBool::new(false),
}
}
}
impl HealthState {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set_query_service_healthy(&self, healthy: bool) {
self.query_service_healthy.store(healthy, Ordering::SeqCst);
}
pub fn set_blob_service_healthy(&self, healthy: bool) {
self.blob_service_healthy.store(healthy, Ordering::SeqCst);
}
#[must_use]
pub fn is_query_service_healthy(&self) -> bool {
self.query_service_healthy.load(Ordering::SeqCst)
}
#[must_use]
pub fn is_blob_service_healthy(&self) -> bool {
self.blob_service_healthy.load(Ordering::SeqCst)
}
pub fn set_vector_service_healthy(&self, healthy: bool) {
self.vector_service_healthy.store(healthy, Ordering::SeqCst);
}
#[must_use]
pub fn is_vector_service_healthy(&self) -> bool {
self.vector_service_healthy.load(Ordering::SeqCst)
}
#[must_use]
pub fn is_all_healthy(&self) -> bool {
self.is_query_service_healthy()
&& self.is_blob_service_healthy()
&& self.is_vector_service_healthy()
&& !self.is_draining()
}
pub fn set_draining(&self, draining: bool) {
self.is_draining.store(draining, Ordering::SeqCst);
}
#[must_use]
pub fn is_draining(&self) -> bool {
self.is_draining.load(Ordering::SeqCst)
}
}
#[derive(Debug, Clone)]
pub struct HealthServiceImpl {
state: Arc<HealthState>,
}
impl Default for HealthServiceImpl {
fn default() -> Self {
Self::new()
}
}
impl HealthServiceImpl {
#[must_use]
pub fn new() -> Self {
Self {
state: Arc::new(HealthState::new()),
}
}
#[must_use]
pub const fn with_state(state: Arc<HealthState>) -> Self {
Self { state }
}
#[must_use]
pub const fn state(&self) -> &Arc<HealthState> {
&self.state
}
}
#[tonic::async_trait]
impl Health for HealthServiceImpl {
async fn check(
&self,
request: Request<HealthCheckRequest>,
) -> Result<Response<HealthCheckResponse>, Status> {
let service = request.into_inner().service;
if self.state.is_draining() {
return Ok(Response::new(HealthCheckResponse {
status: ServingStatus::NotServing.into(),
}));
}
let status = match service.as_deref() {
Some("neumann.v1.QueryService") => {
if self.state.is_query_service_healthy() {
ServingStatus::Serving
} else {
ServingStatus::NotServing
}
},
Some("neumann.v1.BlobService") => {
if self.state.is_blob_service_healthy() {
ServingStatus::Serving
} else {
ServingStatus::NotServing
}
},
Some("neumann.vector.v1.PointsService" | "neumann.vector.v1.CollectionsService") => {
if self.state.is_vector_service_healthy() {
ServingStatus::Serving
} else {
ServingStatus::NotServing
}
},
Some("") | None => {
if self.state.is_all_healthy() {
ServingStatus::Serving
} else {
ServingStatus::NotServing
}
},
Some(unknown) => {
tracing::warn!("Health check for unknown service: {}", unknown);
ServingStatus::Unspecified
},
};
Ok(Response::new(HealthCheckResponse {
status: status.into(),
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_health_check_no_service() {
let service = HealthServiceImpl::new();
let request = Request::new(HealthCheckRequest { service: None });
let response = service.check(request).await.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Serving)
);
}
#[tokio::test]
async fn test_health_check_empty_service() {
let service = HealthServiceImpl::new();
let request = Request::new(HealthCheckRequest {
service: Some(String::new()),
});
let response = service.check(request).await.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Serving)
);
}
#[tokio::test]
async fn test_health_check_query_service() {
let service = HealthServiceImpl::new();
let request = Request::new(HealthCheckRequest {
service: Some("neumann.v1.QueryService".to_string()),
});
let response = service.check(request).await.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Serving)
);
}
#[tokio::test]
async fn test_health_check_blob_service() {
let service = HealthServiceImpl::new();
let request = Request::new(HealthCheckRequest {
service: Some("neumann.v1.BlobService".to_string()),
});
let response = service.check(request).await.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Serving)
);
}
#[tokio::test]
async fn test_health_check_unknown_service() {
let service = HealthServiceImpl::new();
let request = Request::new(HealthCheckRequest {
service: Some("unknown.Service".to_string()),
});
let response = service.check(request).await.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Unspecified)
);
}
#[tokio::test]
async fn test_health_state_set_unhealthy() {
let state = Arc::new(HealthState::new());
let service = HealthServiceImpl::with_state(Arc::clone(&state));
assert!(state.is_query_service_healthy());
assert!(state.is_blob_service_healthy());
state.set_query_service_healthy(false);
let request = Request::new(HealthCheckRequest {
service: Some("neumann.v1.QueryService".to_string()),
});
let response = service
.check(request)
.await
.expect("should return response");
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::NotServing)
);
let request = Request::new(HealthCheckRequest { service: None });
let response = service
.check(request)
.await
.expect("should return response");
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::NotServing)
);
let request = Request::new(HealthCheckRequest {
service: Some("neumann.v1.BlobService".to_string()),
});
let response = service
.check(request)
.await
.expect("should return response");
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Serving)
);
}
#[test]
fn test_health_state_operations() {
let state = HealthState::new();
assert!(state.is_query_service_healthy());
assert!(state.is_blob_service_healthy());
assert!(state.is_vector_service_healthy());
assert!(state.is_all_healthy());
state.set_query_service_healthy(false);
assert!(!state.is_query_service_healthy());
assert!(state.is_blob_service_healthy());
assert!(state.is_vector_service_healthy());
assert!(!state.is_all_healthy());
state.set_query_service_healthy(true);
state.set_blob_service_healthy(false);
assert!(state.is_query_service_healthy());
assert!(!state.is_blob_service_healthy());
assert!(state.is_vector_service_healthy());
assert!(!state.is_all_healthy());
state.set_blob_service_healthy(true);
state.set_vector_service_healthy(false);
assert!(state.is_query_service_healthy());
assert!(state.is_blob_service_healthy());
assert!(!state.is_vector_service_healthy());
assert!(!state.is_all_healthy());
state.set_vector_service_healthy(true);
assert!(state.is_all_healthy());
}
#[test]
fn test_health_state_draining() {
let state = HealthState::new();
assert!(!state.is_draining());
assert!(state.is_all_healthy());
state.set_draining(true);
assert!(state.is_draining());
assert!(!state.is_all_healthy());
state.set_draining(false);
assert!(!state.is_draining());
assert!(state.is_all_healthy());
}
#[tokio::test]
async fn test_health_check_vector_service() {
let service = HealthServiceImpl::new();
let request = Request::new(HealthCheckRequest {
service: Some("neumann.vector.v1.PointsService".to_string()),
});
let response = service.check(request).await.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Serving)
);
let request = Request::new(HealthCheckRequest {
service: Some("neumann.vector.v1.CollectionsService".to_string()),
});
let response = service.check(request).await.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Serving)
);
}
#[tokio::test]
async fn test_health_check_vector_service_unhealthy() {
let state = Arc::new(HealthState::new());
let service = HealthServiceImpl::with_state(Arc::clone(&state));
state.set_vector_service_healthy(false);
let request = Request::new(HealthCheckRequest {
service: Some("neumann.vector.v1.PointsService".to_string()),
});
let response = service.check(request).await.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::NotServing)
);
let request = Request::new(HealthCheckRequest { service: None });
let response = service.check(request).await.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::NotServing)
);
}
#[tokio::test]
async fn test_health_check_when_draining() {
let state = Arc::new(HealthState::new());
let service = HealthServiceImpl::with_state(Arc::clone(&state));
let request = Request::new(HealthCheckRequest { service: None });
let response = service.check(request).await.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::Serving)
);
state.set_draining(true);
let request = Request::new(HealthCheckRequest { service: None });
let response = service.check(request).await.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::NotServing)
);
let request = Request::new(HealthCheckRequest {
service: Some("neumann.v1.QueryService".to_string()),
});
let response = service.check(request).await.unwrap();
assert_eq!(
response.into_inner().status,
i32::from(ServingStatus::NotServing)
);
}
#[test]
fn test_health_service_default() {
let service = HealthServiceImpl::default();
let state = service.state();
assert!(state.is_all_healthy());
}
#[test]
fn test_health_service_state_accessor() {
let state = Arc::new(HealthState::new());
let service = HealthServiceImpl::with_state(Arc::clone(&state));
assert!(Arc::ptr_eq(service.state(), &state));
}
}