Skip to main content

atomr_infer_core/
runtime.rs

1//! Runtime / transport / provider taxonomy and per-runtime configuration.
2//!
3//! Doc references: §3.1 (backend taxonomy), §5.4 (`TransportKind` /
4//! `ProviderKind` enums), §10.5 (feature flags).
5
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9
10/// Identifies the runtime *backend* that hosts a model.
11///
12/// Maps 1:1 to the per-runtime crates listed in §10.1. `Custom(String)`
13/// is the escape hatch third-party runtimes use until they're added to
14/// the canonical enum.
15#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17#[non_exhaustive]
18pub enum RuntimeKind {
19    Vllm,
20    TensorRt,
21    Ort,
22    Candle,
23    Cudarc,
24    MistralRs,
25    /// Locally-hosted Python runtime without a Rust binding (e.g. XTTS,
26    /// Bark, diffusers). Doc §2.6.
27    Python(String),
28    OpenAi,
29    Anthropic,
30    Gemini,
31    LiteLlm,
32    Custom(String),
33}
34
35impl RuntimeKind {
36    pub fn is_remote(&self) -> bool {
37        matches!(
38            self,
39            RuntimeKind::OpenAi | RuntimeKind::Anthropic | RuntimeKind::Gemini | RuntimeKind::LiteLlm
40        )
41    }
42
43    pub fn is_local(&self) -> bool {
44        !self.is_remote()
45    }
46}
47
48/// Where the runtime executes — local GPU vs remote network. Read by
49/// `PlacementActor` and the worker-spawning logic to decide what kind of
50/// `WorkerActor` to spin up. Doc §5.4.
51#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
52#[serde(tag = "kind", rename_all = "snake_case")]
53#[non_exhaustive]
54pub enum TransportKind {
55    LocalGpu,
56    RemoteNetwork { provider: ProviderKind },
57}
58
59#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
60#[serde(rename_all = "snake_case")]
61#[non_exhaustive]
62pub enum ProviderKind {
63    OpenAi,
64    Anthropic,
65    Gemini,
66    LiteLlm,
67    Custom(String),
68}
69
70impl From<&RuntimeKind> for TransportKind {
71    fn from(kind: &RuntimeKind) -> Self {
72        match kind {
73            RuntimeKind::OpenAi => Self::RemoteNetwork {
74                provider: ProviderKind::OpenAi,
75            },
76            RuntimeKind::Anthropic => Self::RemoteNetwork {
77                provider: ProviderKind::Anthropic,
78            },
79            RuntimeKind::Gemini => Self::RemoteNetwork {
80                provider: ProviderKind::Gemini,
81            },
82            RuntimeKind::LiteLlm => Self::RemoteNetwork {
83                provider: ProviderKind::LiteLlm,
84            },
85            _ => Self::LocalGpu,
86        }
87    }
88}
89
90/// Per-deployment runtime configuration. The `runtime` discriminator
91/// drives both the backend selection and the shape of the inner config
92/// blob. Per-runtime crates each contribute one variant or expose their
93/// own `RuntimeConfig`-shaped struct that can be wrapped in `Custom`.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95#[serde(tag = "runtime", rename_all = "snake_case")]
96pub enum RuntimeConfig {
97    /// vLLM (local Python). Body intentionally opaque here — the real
98    /// shape lives in `inference-runtime-vllm` and is parsed lazily.
99    Vllm(serde_json::Value),
100    TensorRt(serde_json::Value),
101    Ort(serde_json::Value),
102    Candle(serde_json::Value),
103    Cudarc(serde_json::Value),
104    MistralRs(serde_json::Value),
105    /// Remote OpenAI / Azure OpenAI. Concrete shape in
106    /// `inference-runtime-openai::OpenAiConfig`.
107    OpenAi(serde_json::Value),
108    Anthropic(serde_json::Value),
109    Gemini(serde_json::Value),
110    LiteLlm(serde_json::Value),
111    /// Custom backend (third-party runtime crate).
112    Custom {
113        kind: String,
114        config: serde_json::Value,
115    },
116}
117
118impl RuntimeConfig {
119    pub fn runtime_kind(&self) -> RuntimeKind {
120        match self {
121            RuntimeConfig::Vllm(_) => RuntimeKind::Vllm,
122            RuntimeConfig::TensorRt(_) => RuntimeKind::TensorRt,
123            RuntimeConfig::Ort(_) => RuntimeKind::Ort,
124            RuntimeConfig::Candle(_) => RuntimeKind::Candle,
125            RuntimeConfig::Cudarc(_) => RuntimeKind::Cudarc,
126            RuntimeConfig::MistralRs(_) => RuntimeKind::MistralRs,
127            RuntimeConfig::OpenAi(_) => RuntimeKind::OpenAi,
128            RuntimeConfig::Anthropic(_) => RuntimeKind::Anthropic,
129            RuntimeConfig::Gemini(_) => RuntimeKind::Gemini,
130            RuntimeConfig::LiteLlm(_) => RuntimeKind::LiteLlm,
131            RuntimeConfig::Custom { kind, .. } => RuntimeKind::Custom(kind.clone()),
132        }
133    }
134
135    pub fn transport_kind(&self) -> TransportKind {
136        TransportKind::from(&self.runtime_kind())
137    }
138}
139
140/// Circuit-breaker config (doc §3.5, §12.2). One per `(provider,
141/// endpoint)`; opens after sustained failures, half-opens after the
142/// configured duration to permit a probe.
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct CircuitBreakerConfig {
145    pub failure_threshold: u32,
146    #[serde(with = "humantime_serde_ms")]
147    pub open_duration: Duration,
148    pub half_open_max_probes: u32,
149}
150
151impl Default for CircuitBreakerConfig {
152    fn default() -> Self {
153        Self {
154            failure_threshold: 10,
155            open_duration: Duration::from_secs(30),
156            half_open_max_probes: 1,
157        }
158    }
159}
160
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
162#[serde(rename_all = "snake_case")]
163#[non_exhaustive]
164pub enum JitterKind {
165    None,
166    Equal,
167    Full,
168}
169
170/// `Duration` (de)serialization in milliseconds — chosen so the doc's
171/// TOML examples (`open_duration_ms = 30_000`) round-trip naturally.
172pub(crate) mod humantime_serde_ms {
173    use std::time::Duration;
174
175    use serde::{Deserialize, Deserializer, Serialize, Serializer};
176
177    pub fn serialize<S>(d: &Duration, s: S) -> Result<S::Ok, S::Error>
178    where
179        S: Serializer,
180    {
181        (d.as_millis() as u64).serialize(s)
182    }
183
184    pub fn deserialize<'de, D>(d: D) -> Result<Duration, D::Error>
185    where
186        D: Deserializer<'de>,
187    {
188        let ms = u64::deserialize(d)?;
189        Ok(Duration::from_millis(ms))
190    }
191}