use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct DeployConfig {
pub project: ProjectConfig,
pub state: StateConfig,
pub pods: Vec<PodConfig>,
#[serde(default)]
pub guardrails: Option<GuardrailsConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ProjectConfig {
pub name: String,
#[serde(default = "default_environment")]
pub environment: String,
#[serde(default)]
pub region: Option<String>,
#[serde(default)]
pub cloud_type: CloudType,
#[serde(default)]
pub compute_type: ComputeType,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct StateConfig {
pub backend: StateBackend,
#[serde(default)]
pub bucket: Option<String>,
#[serde(default)]
pub prefix: Option<String>,
#[serde(default)]
pub region: Option<String>,
#[serde(default)]
pub path: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum StateBackend {
#[default]
Local,
S3,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "UPPERCASE")]
pub enum CloudType {
#[default]
Secure,
Community,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "UPPERCASE")]
pub enum ComputeType {
#[default]
Gpu,
Cpu,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct PodConfig {
pub name: String,
pub gpu: GpuConfig,
#[serde(default)]
pub ports: Vec<PortConfig>,
#[serde(default)]
pub volumes: Vec<VolumeConfig>,
pub runtime: RuntimeConfig,
#[serde(default)]
pub models: Vec<ModelConfig>,
#[serde(default)]
pub health_check: Option<HealthCheckConfig>,
#[serde(default)]
pub tags: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct GpuConfig {
#[serde(rename = "type")]
pub gpu_type: String,
#[serde(default = "default_gpu_count")]
pub count: u32,
#[serde(default)]
pub min_vram_gb: Option<u32>,
#[serde(default)]
pub fallback: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(try_from = "String", into = "String")]
pub struct PortConfig {
pub port: u16,
pub protocol: PortProtocol,
pub name: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum PortProtocol {
Tcp,
#[default]
Http,
Https,
Udp,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct VolumeConfig {
pub name: String,
pub mount: String,
#[serde(default = "default_persistent")]
pub persistent: bool,
#[serde(default)]
pub size_gb: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RuntimeConfig {
pub image: String,
#[serde(default)]
pub env: HashMap<String, String>,
#[serde(default)]
pub command: Option<Vec<String>>,
#[serde(default)]
pub args: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ModelConfig {
pub id: String,
pub provider: ModelProvider,
#[serde(default)]
pub repo: Option<String>,
#[serde(default)]
pub load: Option<LoadConfig>,
#[serde(default)]
pub components: Option<Vec<String>>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum ModelProvider {
#[default]
Huggingface,
Bundle,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct LoadConfig {
pub engine: String,
#[serde(default)]
pub quant: Option<String>,
#[serde(default)]
pub max_seq_len: Option<u32>,
#[serde(default)]
pub options: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct HealthCheckConfig {
pub endpoint: String,
pub port: u16,
#[serde(default = "default_health_interval")]
pub interval_secs: u32,
#[serde(default = "default_health_timeout")]
pub timeout_secs: u32,
#[serde(default = "default_health_threshold")]
pub failure_threshold: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GuardrailsConfig {
#[serde(default)]
pub max_hourly_cost: Option<f64>,
#[serde(default)]
pub max_gpus: Option<u32>,
#[serde(default)]
pub ttl_hours: Option<u32>,
#[serde(default = "default_allow_fallback")]
pub allow_gpu_fallback: bool,
}
const fn default_gpu_count() -> u32 {
1
}
const fn default_persistent() -> bool {
true
}
const fn default_health_interval() -> u32 {
30
}
const fn default_health_timeout() -> u32 {
5
}
const fn default_health_threshold() -> u32 {
3
}
const fn default_allow_fallback() -> bool {
false
}
fn default_environment() -> String {
String::from("dev")
}
impl TryFrom<String> for PortConfig {
type Error = String;
fn try_from(s: String) -> Result<Self, Self::Error> {
Self::parse(&s)
}
}
impl From<PortConfig> for String {
fn from(port: PortConfig) -> Self {
match port.protocol {
PortProtocol::Tcp => format!("{}/tcp", port.port),
PortProtocol::Http => format!("{}/http", port.port),
PortProtocol::Https => format!("{}/https", port.port),
PortProtocol::Udp => format!("{}/udp", port.port),
}
}
}
impl PortConfig {
pub fn parse(s: &str) -> Result<Self, String> {
let parts: Vec<&str> = s.split('/').collect();
if parts.len() != 2 {
return Err(format!("Invalid port format: {s}. Expected format: PORT/PROTOCOL"));
}
let port = parts[0]
.parse::<u16>()
.map_err(|_| format!("Invalid port number: {}", parts[0]))?;
let protocol = match parts[1].to_lowercase().as_str() {
"tcp" => PortProtocol::Tcp,
"http" => PortProtocol::Http,
"https" => PortProtocol::Https,
"udp" => PortProtocol::Udp,
other => return Err(format!("Invalid protocol: {other}. Expected: tcp, http, https, or udp")),
};
Ok(Self {
port,
protocol,
name: None,
})
}
#[must_use]
pub const fn new(port: u16, protocol: PortProtocol) -> Self {
Self {
port,
protocol,
name: None,
}
}
}
impl DeployConfig {
#[must_use]
pub fn qualified_name(&self) -> String {
format!("{}-{}", self.project.name, self.project.environment)
}
#[must_use]
pub fn total_gpus(&self) -> u32 {
self.pods.iter().map(|p| p.gpu.count).sum()
}
#[must_use]
pub fn pod_names(&self) -> Vec<&str> {
self.pods.iter().map(|p| p.name.as_str()).collect()
}
}
impl PodConfig {
#[must_use]
pub fn full_name(&self, project: &ProjectConfig) -> String {
format!("{}-{}-{}", project.name, project.environment, self.name)
}
#[must_use]
pub fn http_ports(&self) -> Vec<u16> {
self.ports
.iter()
.filter(|p| matches!(p.protocol, PortProtocol::Http | PortProtocol::Https))
.map(|p| p.port)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_port_config_parse() {
let port = PortConfig::parse("8000/http");
assert!(port.is_ok());
let port = port.unwrap();
assert_eq!(port.port, 8000);
assert_eq!(port.protocol, PortProtocol::Http);
}
#[test]
fn test_port_config_parse_tcp() {
let port = PortConfig::parse("22/tcp");
assert!(port.is_ok());
let port = port.unwrap();
assert_eq!(port.port, 22);
assert_eq!(port.protocol, PortProtocol::Tcp);
}
#[test]
fn test_port_config_invalid() {
let port = PortConfig::parse("invalid");
assert!(port.is_err());
}
}