Skip to main content

inference_core/
runtime.rs

1//! Runtime / transport / provider taxonomy and per-runtime configuration.
2//!
3//! Doc references: §3.1 (backend taxonomy), §5.4 (`TransportKind` /
4//! `ProviderKind` enums), §10.5 (feature flags).
5
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9
10/// Identifies the runtime *backend* that hosts a model.
11///
12/// Maps 1:1 to the per-runtime crates listed in §10.1. `Custom(String)`
13/// is the escape hatch third-party runtimes use until they're added to
14/// the canonical enum.
15#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum RuntimeKind {
18    Vllm,
19    TensorRt,
20    Ort,
21    Candle,
22    Cudarc,
23    MistralRs,
24    /// Locally-hosted Python runtime without a Rust binding (e.g. XTTS,
25    /// Bark, diffusers). Doc §2.6.
26    Python(String),
27    OpenAi,
28    Anthropic,
29    Gemini,
30    LiteLlm,
31    Custom(String),
32}
33
34impl RuntimeKind {
35    pub fn is_remote(&self) -> bool {
36        matches!(
37            self,
38            RuntimeKind::OpenAi | RuntimeKind::Anthropic | RuntimeKind::Gemini | RuntimeKind::LiteLlm
39        )
40    }
41
42    pub fn is_local(&self) -> bool {
43        !self.is_remote()
44    }
45}
46
47/// Where the runtime executes — local GPU vs remote network. Read by
48/// `PlacementActor` and the worker-spawning logic to decide what kind of
49/// `WorkerActor` to spin up. Doc §5.4.
50#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
51#[serde(tag = "kind", rename_all = "snake_case")]
52pub enum TransportKind {
53    LocalGpu,
54    RemoteNetwork { provider: ProviderKind },
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
58#[serde(rename_all = "snake_case")]
59pub enum ProviderKind {
60    OpenAi,
61    Anthropic,
62    Gemini,
63    LiteLlm,
64    Custom(String),
65}
66
67impl From<&RuntimeKind> for TransportKind {
68    fn from(kind: &RuntimeKind) -> Self {
69        match kind {
70            RuntimeKind::OpenAi => Self::RemoteNetwork {
71                provider: ProviderKind::OpenAi,
72            },
73            RuntimeKind::Anthropic => Self::RemoteNetwork {
74                provider: ProviderKind::Anthropic,
75            },
76            RuntimeKind::Gemini => Self::RemoteNetwork {
77                provider: ProviderKind::Gemini,
78            },
79            RuntimeKind::LiteLlm => Self::RemoteNetwork {
80                provider: ProviderKind::LiteLlm,
81            },
82            _ => Self::LocalGpu,
83        }
84    }
85}
86
87/// Per-deployment runtime configuration. The `runtime` discriminator
88/// drives both the backend selection and the shape of the inner config
89/// blob. Per-runtime crates each contribute one variant or expose their
90/// own `RuntimeConfig`-shaped struct that can be wrapped in `Custom`.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92#[serde(tag = "runtime", rename_all = "snake_case")]
93pub enum RuntimeConfig {
94    /// vLLM (local Python). Body intentionally opaque here — the real
95    /// shape lives in `inference-runtime-vllm` and is parsed lazily.
96    Vllm(serde_json::Value),
97    TensorRt(serde_json::Value),
98    Ort(serde_json::Value),
99    Candle(serde_json::Value),
100    Cudarc(serde_json::Value),
101    MistralRs(serde_json::Value),
102    /// Remote OpenAI / Azure OpenAI. Concrete shape in
103    /// `inference-runtime-openai::OpenAiConfig`.
104    OpenAi(serde_json::Value),
105    Anthropic(serde_json::Value),
106    Gemini(serde_json::Value),
107    LiteLlm(serde_json::Value),
108    /// Custom backend (third-party runtime crate).
109    Custom {
110        kind: String,
111        config: serde_json::Value,
112    },
113}
114
115impl RuntimeConfig {
116    pub fn runtime_kind(&self) -> RuntimeKind {
117        match self {
118            RuntimeConfig::Vllm(_) => RuntimeKind::Vllm,
119            RuntimeConfig::TensorRt(_) => RuntimeKind::TensorRt,
120            RuntimeConfig::Ort(_) => RuntimeKind::Ort,
121            RuntimeConfig::Candle(_) => RuntimeKind::Candle,
122            RuntimeConfig::Cudarc(_) => RuntimeKind::Cudarc,
123            RuntimeConfig::MistralRs(_) => RuntimeKind::MistralRs,
124            RuntimeConfig::OpenAi(_) => RuntimeKind::OpenAi,
125            RuntimeConfig::Anthropic(_) => RuntimeKind::Anthropic,
126            RuntimeConfig::Gemini(_) => RuntimeKind::Gemini,
127            RuntimeConfig::LiteLlm(_) => RuntimeKind::LiteLlm,
128            RuntimeConfig::Custom { kind, .. } => RuntimeKind::Custom(kind.clone()),
129        }
130    }
131
132    pub fn transport_kind(&self) -> TransportKind {
133        TransportKind::from(&self.runtime_kind())
134    }
135}
136
137/// Circuit-breaker config (doc §3.5, §12.2). One per `(provider,
138/// endpoint)`; opens after sustained failures, half-opens after the
139/// configured duration to permit a probe.
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct CircuitBreakerConfig {
142    pub failure_threshold: u32,
143    #[serde(with = "humantime_serde_ms")]
144    pub open_duration: Duration,
145    pub half_open_max_probes: u32,
146}
147
148impl Default for CircuitBreakerConfig {
149    fn default() -> Self {
150        Self {
151            failure_threshold: 10,
152            open_duration: Duration::from_secs(30),
153            half_open_max_probes: 1,
154        }
155    }
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
159#[serde(rename_all = "snake_case")]
160pub enum JitterKind {
161    None,
162    Equal,
163    Full,
164}
165
166/// `Duration` (de)serialization in milliseconds — chosen so the doc's
167/// TOML examples (`open_duration_ms = 30_000`) round-trip naturally.
168pub(crate) mod humantime_serde_ms {
169    use std::time::Duration;
170
171    use serde::{Deserialize, Deserializer, Serialize, Serializer};
172
173    pub fn serialize<S>(d: &Duration, s: S) -> Result<S::Ok, S::Error>
174    where
175        S: Serializer,
176    {
177        (d.as_millis() as u64).serialize(s)
178    }
179
180    pub fn deserialize<'de, D>(d: D) -> Result<Duration, D::Error>
181    where
182        D: Deserializer<'de>,
183    {
184        let ms = u64::deserialize(d)?;
185        Ok(Duration::from_millis(ms))
186    }
187}