use std::time::Duration;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum RuntimeKind {
Vllm,
TensorRt,
Ort,
Candle,
Cudarc,
MistralRs,
Python(String),
OpenAi,
Anthropic,
Gemini,
LiteLlm,
Custom(String),
}
impl RuntimeKind {
pub fn is_remote(&self) -> bool {
matches!(
self,
RuntimeKind::OpenAi | RuntimeKind::Anthropic | RuntimeKind::Gemini | RuntimeKind::LiteLlm
)
}
pub fn is_local(&self) -> bool {
!self.is_remote()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
#[non_exhaustive]
pub enum TransportKind {
LocalGpu,
RemoteNetwork { provider: ProviderKind },
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ProviderKind {
OpenAi,
Anthropic,
Gemini,
LiteLlm,
Custom(String),
}
impl From<&RuntimeKind> for TransportKind {
fn from(kind: &RuntimeKind) -> Self {
match kind {
RuntimeKind::OpenAi => Self::RemoteNetwork {
provider: ProviderKind::OpenAi,
},
RuntimeKind::Anthropic => Self::RemoteNetwork {
provider: ProviderKind::Anthropic,
},
RuntimeKind::Gemini => Self::RemoteNetwork {
provider: ProviderKind::Gemini,
},
RuntimeKind::LiteLlm => Self::RemoteNetwork {
provider: ProviderKind::LiteLlm,
},
_ => Self::LocalGpu,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "runtime", rename_all = "snake_case")]
pub enum RuntimeConfig {
Vllm(serde_json::Value),
TensorRt(serde_json::Value),
Ort(serde_json::Value),
Candle(serde_json::Value),
Cudarc(serde_json::Value),
MistralRs(serde_json::Value),
OpenAi(serde_json::Value),
Anthropic(serde_json::Value),
Gemini(serde_json::Value),
LiteLlm(serde_json::Value),
Custom {
kind: String,
config: serde_json::Value,
},
}
impl RuntimeConfig {
pub fn runtime_kind(&self) -> RuntimeKind {
match self {
RuntimeConfig::Vllm(_) => RuntimeKind::Vllm,
RuntimeConfig::TensorRt(_) => RuntimeKind::TensorRt,
RuntimeConfig::Ort(_) => RuntimeKind::Ort,
RuntimeConfig::Candle(_) => RuntimeKind::Candle,
RuntimeConfig::Cudarc(_) => RuntimeKind::Cudarc,
RuntimeConfig::MistralRs(_) => RuntimeKind::MistralRs,
RuntimeConfig::OpenAi(_) => RuntimeKind::OpenAi,
RuntimeConfig::Anthropic(_) => RuntimeKind::Anthropic,
RuntimeConfig::Gemini(_) => RuntimeKind::Gemini,
RuntimeConfig::LiteLlm(_) => RuntimeKind::LiteLlm,
RuntimeConfig::Custom { kind, .. } => RuntimeKind::Custom(kind.clone()),
}
}
pub fn transport_kind(&self) -> TransportKind {
TransportKind::from(&self.runtime_kind())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
#[serde(with = "humantime_serde_ms")]
pub open_duration: Duration,
pub half_open_max_probes: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 10,
open_duration: Duration::from_secs(30),
half_open_max_probes: 1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum JitterKind {
None,
Equal,
Full,
}
pub(crate) mod humantime_serde_ms {
use std::time::Duration;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S>(d: &Duration, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
(d.as_millis() as u64).serialize(s)
}
pub fn deserialize<'de, D>(d: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let ms = u64::deserialize(d)?;
Ok(Duration::from_millis(ms))
}
}