sentinel_common/
inference.rs

1//! Inference health check configuration types
2//!
3//! This module provides configuration structures for enhanced model readiness
4//! checks beyond basic HTTP 200 verification. These checks help ensure LLM/AI
5//! backends are truly ready to serve requests.
6//!
7//! # Check Types
8//!
9//! - [`InferenceProbeConfig`]: Send minimal completion request to verify model responds
10//! - [`ModelStatusConfig`]: Check provider-specific model status endpoints
11//! - [`QueueDepthConfig`]: Monitor queue depth to detect overloaded backends
12//! - [`WarmthDetectionConfig`]: Track latency to detect cold models after idle periods
13
14use serde::{Deserialize, Serialize};
15
16/// Configuration for enhanced inference readiness checks
17///
18/// All fields are optional - only enabled checks are performed.
19/// The base inference health check (models endpoint) always runs first.
20#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
21pub struct InferenceReadinessConfig {
22    /// Send minimal inference request to verify model can respond
23    #[serde(default, skip_serializing_if = "Option::is_none")]
24    pub inference_probe: Option<InferenceProbeConfig>,
25
26    /// Check provider-specific model status endpoints
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub model_status: Option<ModelStatusConfig>,
29
30    /// Monitor queue depth from headers or response body
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub queue_depth: Option<QueueDepthConfig>,
33
34    /// Detect cold models after idle periods
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub warmth_detection: Option<WarmthDetectionConfig>,
37}
38
39/// Configuration for inference probe health check
40///
41/// Sends a minimal completion request to verify the model can actually
42/// process requests, not just that the server is running.
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub struct InferenceProbeConfig {
45    /// Endpoint for completion request
46    #[serde(default = "default_probe_endpoint")]
47    pub endpoint: String,
48
49    /// Model to probe (required)
50    pub model: String,
51
52    /// Probe prompt (minimal to reduce cost/latency)
53    #[serde(default = "default_probe_prompt")]
54    pub prompt: String,
55
56    /// Max tokens in response (keep minimal)
57    #[serde(default = "default_probe_max_tokens")]
58    pub max_tokens: u32,
59
60    /// Timeout for probe request in seconds
61    #[serde(default = "default_probe_timeout")]
62    pub timeout_secs: u64,
63
64    /// Mark unhealthy if probe latency exceeds this threshold (ms)
65    #[serde(default, skip_serializing_if = "Option::is_none")]
66    pub max_latency_ms: Option<u64>,
67}
68
69/// Configuration for model status endpoint check
70///
71/// Queries provider-specific status endpoints to verify model readiness.
72/// Useful for providers that expose detailed model state information.
73#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
74pub struct ModelStatusConfig {
75    /// Endpoint pattern with `{model}` placeholder
76    #[serde(default = "default_status_endpoint")]
77    pub endpoint_pattern: String,
78
79    /// Models to check status for
80    pub models: Vec<String>,
81
82    /// Expected status value (e.g., "ready", "loaded")
83    #[serde(default = "default_expected_status")]
84    pub expected_status: String,
85
86    /// JSON path to status field (supports dot notation, e.g., "state.loaded")
87    #[serde(default = "default_status_field")]
88    pub status_field: String,
89
90    /// Timeout for status request in seconds
91    #[serde(default = "default_status_timeout")]
92    pub timeout_secs: u64,
93}
94
95/// Configuration for queue depth monitoring
96///
97/// Monitors queue depth to detect overloaded backends before they
98/// start timing out or returning errors.
99#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
100pub struct QueueDepthConfig {
101    /// Header containing queue depth (e.g., "x-queue-depth")
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub header: Option<String>,
104
105    /// JSON field in response body containing queue depth
106    #[serde(default, skip_serializing_if = "Option::is_none")]
107    pub body_field: Option<String>,
108
109    /// Endpoint to query for queue info (defaults to models endpoint)
110    #[serde(default, skip_serializing_if = "Option::is_none")]
111    pub endpoint: Option<String>,
112
113    /// Mark as degraded if queue exceeds this threshold
114    pub degraded_threshold: u64,
115
116    /// Mark as unhealthy if queue exceeds this threshold
117    pub unhealthy_threshold: u64,
118
119    /// Timeout for queue check in seconds
120    #[serde(default = "default_queue_timeout")]
121    pub timeout_secs: u64,
122}
123
124/// Configuration for cold model detection
125///
126/// Tracks request latency to detect when models have gone cold after
127/// idle periods. This is a passive check that observes actual requests
128/// rather than sending probes.
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct WarmthDetectionConfig {
131    /// Number of requests to sample for baseline latency
132    #[serde(default = "default_warmth_sample_size")]
133    pub sample_size: u32,
134
135    /// Multiplier for cold detection (latency > baseline * multiplier = cold)
136    #[serde(default = "default_cold_threshold_multiplier")]
137    pub cold_threshold_multiplier: f64,
138
139    /// Time after which a model is considered potentially cold (seconds)
140    #[serde(default = "default_idle_cold_timeout")]
141    pub idle_cold_timeout_secs: u64,
142
143    /// Action to take when cold model detected
144    #[serde(default)]
145    pub cold_action: ColdModelAction,
146}
147
148impl PartialEq for WarmthDetectionConfig {
149    fn eq(&self, other: &Self) -> bool {
150        self.sample_size == other.sample_size
151            && self.cold_threshold_multiplier.to_bits() == other.cold_threshold_multiplier.to_bits()
152            && self.idle_cold_timeout_secs == other.idle_cold_timeout_secs
153            && self.cold_action == other.cold_action
154    }
155}
156
157impl Eq for WarmthDetectionConfig {}
158
159/// Action to take when a cold model is detected
160#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
161#[serde(rename_all = "snake_case")]
162pub enum ColdModelAction {
163    /// Log the cold start but continue serving (observability only)
164    #[default]
165    LogOnly,
166    /// Mark as degraded (lower weight in load balancing)
167    MarkDegraded,
168    /// Mark as unhealthy until warmed up
169    MarkUnhealthy,
170}
171
172// Default value functions
173
174fn default_probe_endpoint() -> String {
175    "/v1/completions".to_string()
176}
177
178fn default_probe_prompt() -> String {
179    ".".to_string()
180}
181
182fn default_probe_max_tokens() -> u32 {
183    1
184}
185
186fn default_probe_timeout() -> u64 {
187    30
188}
189
190fn default_status_endpoint() -> String {
191    "/v1/models/{model}/status".to_string()
192}
193
194fn default_expected_status() -> String {
195    "ready".to_string()
196}
197
198fn default_status_field() -> String {
199    "status".to_string()
200}
201
202fn default_status_timeout() -> u64 {
203    5
204}
205
206fn default_queue_timeout() -> u64 {
207    5
208}
209
210fn default_warmth_sample_size() -> u32 {
211    10
212}
213
214fn default_cold_threshold_multiplier() -> f64 {
215    3.0
216}
217
218fn default_idle_cold_timeout() -> u64 {
219    300 // 5 minutes
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_inference_readiness_config_defaults() {
228        let config: InferenceReadinessConfig = serde_json::from_str("{}").unwrap();
229        assert!(config.inference_probe.is_none());
230        assert!(config.model_status.is_none());
231        assert!(config.queue_depth.is_none());
232        assert!(config.warmth_detection.is_none());
233    }
234
235    #[test]
236    fn test_inference_probe_config_defaults() {
237        let json = r#"{"model": "gpt-4"}"#;
238        let config: InferenceProbeConfig = serde_json::from_str(json).unwrap();
239        assert_eq!(config.endpoint, "/v1/completions");
240        assert_eq!(config.model, "gpt-4");
241        assert_eq!(config.prompt, ".");
242        assert_eq!(config.max_tokens, 1);
243        assert_eq!(config.timeout_secs, 30);
244        assert!(config.max_latency_ms.is_none());
245    }
246
247    #[test]
248    fn test_model_status_config_defaults() {
249        let json = r#"{"models": ["gpt-4", "gpt-3.5-turbo"]}"#;
250        let config: ModelStatusConfig = serde_json::from_str(json).unwrap();
251        assert_eq!(config.endpoint_pattern, "/v1/models/{model}/status");
252        assert_eq!(config.models, vec!["gpt-4", "gpt-3.5-turbo"]);
253        assert_eq!(config.expected_status, "ready");
254        assert_eq!(config.status_field, "status");
255        assert_eq!(config.timeout_secs, 5);
256    }
257
258    #[test]
259    fn test_queue_depth_config() {
260        let json = r#"{
261            "header": "x-queue-depth",
262            "degraded_threshold": 50,
263            "unhealthy_threshold": 200
264        }"#;
265        let config: QueueDepthConfig = serde_json::from_str(json).unwrap();
266        assert_eq!(config.header, Some("x-queue-depth".to_string()));
267        assert!(config.body_field.is_none());
268        assert_eq!(config.degraded_threshold, 50);
269        assert_eq!(config.unhealthy_threshold, 200);
270        assert_eq!(config.timeout_secs, 5);
271    }
272
273    #[test]
274    fn test_warmth_detection_defaults() {
275        let json = "{}";
276        let config: WarmthDetectionConfig = serde_json::from_str(json).unwrap();
277        assert_eq!(config.sample_size, 10);
278        assert!((config.cold_threshold_multiplier - 3.0).abs() < f64::EPSILON);
279        assert_eq!(config.idle_cold_timeout_secs, 300);
280        assert_eq!(config.cold_action, ColdModelAction::LogOnly);
281    }
282
283    #[test]
284    fn test_cold_model_action_serialization() {
285        assert_eq!(
286            serde_json::to_string(&ColdModelAction::LogOnly).unwrap(),
287            r#""log_only""#
288        );
289        assert_eq!(
290            serde_json::to_string(&ColdModelAction::MarkDegraded).unwrap(),
291            r#""mark_degraded""#
292        );
293        assert_eq!(
294            serde_json::to_string(&ColdModelAction::MarkUnhealthy).unwrap(),
295            r#""mark_unhealthy""#
296        );
297    }
298
299    #[test]
300    fn test_full_config_roundtrip() {
301        let config = InferenceReadinessConfig {
302            inference_probe: Some(InferenceProbeConfig {
303                endpoint: "/v1/completions".to_string(),
304                model: "gpt-4".to_string(),
305                prompt: ".".to_string(),
306                max_tokens: 1,
307                timeout_secs: 30,
308                max_latency_ms: Some(5000),
309            }),
310            model_status: None,
311            queue_depth: Some(QueueDepthConfig {
312                header: Some("x-queue-depth".to_string()),
313                body_field: None,
314                endpoint: None,
315                degraded_threshold: 50,
316                unhealthy_threshold: 200,
317                timeout_secs: 5,
318            }),
319            warmth_detection: Some(WarmthDetectionConfig {
320                sample_size: 10,
321                cold_threshold_multiplier: 3.0,
322                idle_cold_timeout_secs: 300,
323                cold_action: ColdModelAction::MarkDegraded,
324            }),
325        };
326
327        let json = serde_json::to_string(&config).unwrap();
328        let parsed: InferenceReadinessConfig = serde_json::from_str(&json).unwrap();
329        assert_eq!(config, parsed);
330    }
331}