use async_trait::async_trait;
use pingora_core::{Error, ErrorType::CustomCode, Result};
use pingora_load_balancing::health_check::HealthCheck as PingoraHealthCheck;
use pingora_load_balancing::Backend;
use serde::Deserialize;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, trace, warn};
pub struct InferenceHealthCheck {
endpoint: String,
expected_models: Vec<String>,
timeout: Duration,
pub consecutive_success: usize,
pub consecutive_failure: usize,
}
#[derive(Debug, Deserialize)]
struct ModelsResponse {
data: Vec<ModelInfo>,
}
#[derive(Debug, Deserialize)]
struct ModelInfo {
id: String,
#[serde(default)]
object: String,
}
impl InferenceHealthCheck {
pub fn new(endpoint: String, expected_models: Vec<String>, timeout: Duration) -> Self {
Self {
endpoint,
expected_models,
timeout,
consecutive_success: 1,
consecutive_failure: 1,
}
}
async fn check_backend(&self, addr: &str) -> Result<(), String> {
let socket_addr: std::net::SocketAddr = addr
.parse()
.map_err(|e| format!("Invalid address '{}': {}", addr, e))?;
let stream = tokio::time::timeout(self.timeout, TcpStream::connect(socket_addr))
.await
.map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
.map_err(|e| format!("Connection failed: {}", e))?;
let request = format!(
"GET {} HTTP/1.1\r\n\
Host: {}\r\n\
User-Agent: Zentinel-HealthCheck/1.0\r\n\
Accept: application/json\r\n\
Connection: close\r\n\r\n",
self.endpoint, addr
);
let mut stream = stream;
stream
.write_all(request.as_bytes())
.await
.map_err(|e| format!("Failed to send request: {}", e))?;
let mut response = vec![0u8; 65536]; let n = tokio::time::timeout(self.timeout, stream.read(&mut response))
.await
.map_err(|_| "Response timeout".to_string())?
.map_err(|e| format!("Failed to read response: {}", e))?;
if n == 0 {
return Err("Empty response".to_string());
}
let response_str = String::from_utf8_lossy(&response[..n]);
let status_code = self.parse_status_code(&response_str)?;
if status_code != 200 {
return Err(format!("HTTP {} (expected 200)", status_code));
}
if self.expected_models.is_empty() {
trace!(
addr = %addr,
endpoint = %self.endpoint,
"Inference health check passed (no model verification)"
);
return Ok(());
}
let body = self.extract_body(&response_str)?;
let models = self.parse_models_response(body)?;
self.verify_models(&models)?;
trace!(
addr = %addr,
endpoint = %self.endpoint,
model_count = models.len(),
expected_models = ?self.expected_models,
"Inference health check passed"
);
Ok(())
}
fn parse_status_code(&self, response: &str) -> Result<u16, String> {
response
.lines()
.next()
.and_then(|line| line.split_whitespace().nth(1))
.and_then(|code| code.parse().ok())
.ok_or_else(|| "Failed to parse HTTP status".to_string())
}
fn extract_body<'a>(&self, response: &'a str) -> Result<&'a str, String> {
response
.find("\r\n\r\n")
.map(|pos| &response[pos + 4..])
.ok_or_else(|| "Could not find response body".to_string())
}
fn parse_models_response(&self, body: &str) -> Result<Vec<String>, String> {
let json_body = if body.starts_with(|c: char| c.is_ascii_hexdigit()) {
body.lines()
.skip(1)
.take_while(|line| !line.is_empty() && *line != "0")
.collect::<Vec<_>>()
.join("\n")
} else {
body.to_string()
};
if let Ok(response) = serde_json::from_str::<ModelsResponse>(&json_body) {
return Ok(response.data.into_iter().map(|m| m.id).collect());
}
if let Ok(models) = serde_json::from_str::<Vec<ModelInfo>>(&json_body) {
return Ok(models.into_iter().map(|m| m.id).collect());
}
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&json_body) {
if let Some(data) = json.get("data").and_then(|d| d.as_array()) {
let models: Vec<String> = data
.iter()
.filter_map(|m| m.get("id").and_then(|id| id.as_str()))
.map(String::from)
.collect();
if !models.is_empty() {
return Ok(models);
}
}
if let Some(models_arr) = json.get("models").and_then(|m| m.as_array()) {
let models: Vec<String> = models_arr
.iter()
.filter_map(|m| {
m.get("id")
.or_else(|| m.get("name"))
.and_then(|id| id.as_str())
})
.map(String::from)
.collect();
if !models.is_empty() {
return Ok(models);
}
}
}
Err(format!(
"Failed to parse models response. Body preview: {}",
&json_body[..json_body.len().min(200)]
))
}
fn verify_models(&self, available_models: &[String]) -> Result<(), String> {
let mut missing = Vec::new();
for expected in &self.expected_models {
let found = available_models
.iter()
.any(|m| m == expected || m.starts_with(expected) || expected.starts_with(m));
if !found {
missing.push(expected.as_str());
}
}
if missing.is_empty() {
Ok(())
} else {
Err(format!(
"Missing models: {}. Available: {:?}",
missing.join(", "),
available_models
))
}
}
}
#[async_trait]
impl PingoraHealthCheck for InferenceHealthCheck {
async fn check(&self, backend: &Backend) -> Result<()> {
let addr = backend.addr.to_string();
match self.check_backend(&addr).await {
Ok(()) => {
trace!(
addr = %addr,
endpoint = %self.endpoint,
expected_models = ?self.expected_models,
"Inference health check passed"
);
Ok(())
}
Err(error) => {
debug!(
addr = %addr,
endpoint = %self.endpoint,
error = %error,
"Inference health check failed"
);
Err(Error::explain(
CustomCode("inference health check", 1),
error,
))
}
}
}
fn health_threshold(&self, success: bool) -> usize {
if success {
self.consecutive_success
} else {
self.consecutive_failure
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_openai_models_response() {
let check = InferenceHealthCheck::new(
"/v1/models".to_string(),
vec!["gpt-4".to_string()],
Duration::from_secs(5),
);
let body = r#"{"object":"list","data":[{"id":"gpt-4","object":"model"},{"id":"gpt-3.5-turbo","object":"model"}]}"#;
let models = check.parse_models_response(body).unwrap();
assert_eq!(models.len(), 2);
assert!(models.contains(&"gpt-4".to_string()));
assert!(models.contains(&"gpt-3.5-turbo".to_string()));
}
#[test]
fn test_parse_ollama_models_response() {
let check = InferenceHealthCheck::new(
"/api/tags".to_string(),
vec!["llama3".to_string()],
Duration::from_secs(5),
);
let body = r#"{"models":[{"name":"llama3:latest"},{"name":"codellama:7b"}]}"#;
let models = check.parse_models_response(body).unwrap();
assert_eq!(models.len(), 2);
assert!(models.contains(&"llama3:latest".to_string()));
}
#[test]
fn test_verify_models_exact_match() {
let check = InferenceHealthCheck::new(
"/v1/models".to_string(),
vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()],
Duration::from_secs(5),
);
let available = vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()];
assert!(check.verify_models(&available).is_ok());
}
#[test]
fn test_verify_models_prefix_match() {
let check = InferenceHealthCheck::new(
"/v1/models".to_string(),
vec!["gpt-4".to_string()],
Duration::from_secs(5),
);
let available = vec!["gpt-4-turbo".to_string(), "gpt-3.5-turbo".to_string()];
assert!(check.verify_models(&available).is_ok());
}
#[test]
fn test_verify_models_missing() {
let check = InferenceHealthCheck::new(
"/v1/models".to_string(),
vec!["gpt-4".to_string(), "claude-3".to_string()],
Duration::from_secs(5),
);
let available = vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()];
let result = check.verify_models(&available);
assert!(result.is_err());
assert!(result.unwrap_err().contains("claude-3"));
}
#[test]
fn test_parse_status_code() {
let check =
InferenceHealthCheck::new("/v1/models".to_string(), vec![], Duration::from_secs(5));
assert_eq!(check.parse_status_code("HTTP/1.1 200 OK\r\n"), Ok(200));
assert_eq!(
check.parse_status_code("HTTP/1.1 404 Not Found\r\n"),
Ok(404)
);
}
#[test]
fn test_extract_body() {
let check =
InferenceHealthCheck::new("/v1/models".to_string(), vec![], Duration::from_secs(5));
let response = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n{\"data\":[]}";
let body = check.extract_body(response).unwrap();
assert_eq!(body, "{\"data\":[]}");
}
}