use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::RwLock;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HealthStatus {
Healthy,
Unhealthy(String),
Starting,
}
impl HealthStatus {
pub fn is_healthy(&self) -> bool {
matches!(self, HealthStatus::Healthy)
}
}
pub struct HealthServer {
addr: SocketAddr,
status: Arc<RwLock<HealthStatus>>,
running: Arc<RwLock<bool>>,
}
impl HealthServer {
pub fn new(addr: SocketAddr) -> Self {
HealthServer {
addr,
status: Arc::new(RwLock::new(HealthStatus::Starting)),
running: Arc::new(RwLock::new(false)),
}
}
pub fn default_port() -> Self {
Self::new("0.0.0.0:8080".parse().unwrap())
}
pub async fn set_status(&self, status: HealthStatus) {
let mut current = self.status.write().await;
*current = status;
}
pub async fn get_status(&self) -> HealthStatus {
self.status.read().await.clone()
}
pub async fn start(&self) -> Result<(), std::io::Error> {
let listener = TcpListener::bind(&self.addr).await?;
{
let mut running = self.running.write().await;
*running = true;
}
let status = self.status.clone();
let running = self.running.clone();
tokio::spawn(async move {
loop {
{
let r = running.read().await;
if !*r {
break;
}
}
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => continue,
};
let status = status.read().await.clone();
let (status_code, body) = match status {
HealthStatus::Healthy => (200, "{\"status\":\"healthy\"}"),
HealthStatus::Unhealthy(ref reason) => (
503,
&format!("{{\"status\":\"unhealthy\",\"reason\":\"{}\"}}", reason)[..],
),
HealthStatus::Starting => (503, "{\"status\":\"starting\"}"),
};
let response = format!(
"HTTP/1.1 {} OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
status_code,
body.len(),
body
);
use tokio::io::AsyncWriteExt;
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.shutdown().await;
}
});
Ok(())
}
pub async fn stop(&self) {
let mut running = self.running.write().await;
*running = false;
}
pub async fn is_running(&self) -> bool {
*self.running.read().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_health_status() {
let healthy = HealthStatus::Healthy;
assert!(healthy.is_healthy());
let unhealthy = HealthStatus::Unhealthy("error".into());
assert!(!unhealthy.is_healthy());
}
#[tokio::test]
async fn test_health_server_creation() {
let server = HealthServer::default_port();
assert!(!server.is_running().await);
}
#[tokio::test]
async fn test_set_status() {
let server = HealthServer::default_port();
server.set_status(HealthStatus::Healthy).await;
assert_eq!(server.get_status().await, HealthStatus::Healthy);
server
.set_status(HealthStatus::Unhealthy("test".into()))
.await;
assert!(!server.get_status().await.is_healthy());
}
}