use crate::error::{ErrorData, Result};
use crate::instance_catalog::Architecture;
use crate::resource::{ResourceDefinition, ResourceOutputsDefinition, ResourceRef};
use crate::ResourceType;
use alien_error::AlienError;
use bon::Builder;
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::fmt::Debug;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase")]
pub struct GpuSpec {
#[serde(rename = "type")]
pub gpu_type: String,
pub count: u32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase")]
pub struct MachineProfile {
pub cpu: String,
pub memory_bytes: u64,
pub ephemeral_storage_bytes: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub architecture: Option<Architecture>,
#[serde(skip_serializing_if = "Option::is_none")]
pub gpu: Option<GpuSpec>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase")]
pub struct ComputeChoiceRange {
pub min: u32,
pub max: u32,
pub default: u32,
}
impl ComputeChoiceRange {
pub fn contains(&self, value: u32) -> bool {
self.min <= value && value <= self.max
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase", tag = "type")]
pub enum CapacityGroupScalePolicy {
Fixed {
machines: ComputeChoiceRange,
},
Autoscale {
min: ComputeChoiceRange,
max: ComputeChoiceRange,
},
}
impl CapacityGroupScalePolicy {
pub fn from_selected_bounds(min_size: u32, max_size: u32) -> Self {
if min_size == max_size {
Self::Fixed {
machines: ComputeChoiceRange {
min: min_size,
max: max_size,
default: min_size,
},
}
} else {
Self::Autoscale {
min: ComputeChoiceRange {
min: min_size,
max: min_size,
default: min_size,
},
max: ComputeChoiceRange {
min: max_size,
max: max_size,
default: max_size,
},
}
}
}
pub fn default_min_size(&self) -> u32 {
match self {
Self::Fixed { machines } => machines.default,
Self::Autoscale { min, .. } => min.default,
}
}
pub fn default_max_size(&self) -> u32 {
match self {
Self::Fixed { machines } => machines.default,
Self::Autoscale { max, .. } => max.default,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase")]
pub struct CapacityGroup {
pub group_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub instance_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub profile: Option<MachineProfile>,
pub min_size: u32,
pub max_size: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub scale_policy: Option<CapacityGroupScalePolicy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nested_virtualization: Option<bool>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Builder)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
#[builder(start_fn = new)]
pub struct ComputeCluster {
#[builder(start_fn)]
pub id: String,
#[builder(field)]
pub capacity_groups: Vec<CapacityGroup>,
#[serde(skip_serializing_if = "Option::is_none")]
pub container_cidr: Option<String>,
}
impl ComputeCluster {
pub const RESOURCE_TYPE: ResourceType = ResourceType::from_static("compute-cluster");
pub fn id(&self) -> &str {
&self.id
}
pub fn container_cidr(&self) -> &str {
self.container_cidr.as_deref().unwrap_or("10.244.0.0/16")
}
}
impl<S: compute_cluster_builder::State> ComputeClusterBuilder<S> {
pub fn capacity_group(mut self, group: CapacityGroup) -> Self {
self.capacity_groups.push(group);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase")]
pub struct CapacityGroupStatus {
pub group_id: String,
pub current_machines: u32,
pub desired_machines: u32,
pub instance_type: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase")]
pub struct ComputeClusterOutputs {
pub cluster_id: String,
pub horizon_ready: bool,
pub capacity_group_statuses: Vec<CapacityGroupStatus>,
pub total_machines: u32,
}
impl ResourceOutputsDefinition for ComputeClusterOutputs {
fn get_resource_type(&self) -> ResourceType {
ComputeCluster::RESOURCE_TYPE.clone()
}
fn as_any(&self) -> &dyn Any {
self
}
fn box_clone(&self) -> Box<dyn ResourceOutputsDefinition> {
Box::new(self.clone())
}
fn outputs_eq(&self, other: &dyn ResourceOutputsDefinition) -> bool {
other.as_any().downcast_ref::<ComputeClusterOutputs>() == Some(self)
}
fn to_json_value(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
}
impl ResourceDefinition for ComputeCluster {
fn get_resource_type(&self) -> ResourceType {
Self::RESOURCE_TYPE
}
fn id(&self) -> &str {
&self.id
}
fn get_dependencies(&self) -> Vec<ResourceRef> {
Vec::new()
}
fn validate_update(&self, new_config: &dyn ResourceDefinition) -> Result<()> {
let new_cluster = new_config
.as_any()
.downcast_ref::<ComputeCluster>()
.ok_or_else(|| {
AlienError::new(ErrorData::UnexpectedResourceType {
resource_id: self.id.clone(),
expected: Self::RESOURCE_TYPE,
actual: new_config.get_resource_type(),
})
})?;
if self.id != new_cluster.id {
return Err(AlienError::new(ErrorData::InvalidResourceUpdate {
resource_id: self.id.clone(),
reason: "the 'id' field is immutable".to_string(),
}));
}
if self.container_cidr.is_some()
&& new_cluster.container_cidr.is_some()
&& self.container_cidr != new_cluster.container_cidr
{
return Err(AlienError::new(ErrorData::InvalidResourceUpdate {
resource_id: self.id.clone(),
reason: "the 'containerCidr' field is immutable once set".to_string(),
}));
}
for new_group in &new_cluster.capacity_groups {
if let Some(existing_group) = self
.capacity_groups
.iter()
.find(|g| g.group_id == new_group.group_id)
{
if existing_group.instance_type.is_some()
&& new_group.instance_type.is_some()
&& existing_group.instance_type != new_group.instance_type
{
return Err(AlienError::new(ErrorData::InvalidResourceUpdate {
resource_id: self.id.clone(),
reason: format!(
"instance type for capacity group '{}' is immutable",
new_group.group_id
),
}));
}
}
}
Ok(())
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn box_clone(&self) -> Box<dyn ResourceDefinition> {
Box::new(self.clone())
}
fn resource_eq(&self, other: &dyn ResourceDefinition) -> bool {
other.as_any().downcast_ref::<ComputeCluster>() == Some(self)
}
fn to_json_value(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_cluster_creation() {
let cluster = ComputeCluster::new("compute".to_string())
.capacity_group(CapacityGroup {
group_id: "general".to_string(),
instance_type: Some("m7g.xlarge".to_string()),
profile: None,
min_size: 1,
max_size: 5,
scale_policy: None,
nested_virtualization: None,
})
.build();
assert_eq!(cluster.id(), "compute");
assert_eq!(cluster.capacity_groups.len(), 1);
assert_eq!(cluster.capacity_groups[0].group_id, "general");
assert_eq!(cluster.container_cidr(), "10.244.0.0/16");
}
#[test]
fn test_compute_cluster_multiple_capacity_groups() {
let cluster = ComputeCluster::new("multi-pool".to_string())
.capacity_group(CapacityGroup {
group_id: "general".to_string(),
instance_type: Some("m7g.xlarge".to_string()),
profile: None,
min_size: 1,
max_size: 3,
scale_policy: None,
nested_virtualization: None,
})
.capacity_group(CapacityGroup {
group_id: "gpu".to_string(),
instance_type: Some("g5.xlarge".to_string()),
profile: Some(MachineProfile {
cpu: "4.0".to_string(),
memory_bytes: 17179869184, ephemeral_storage_bytes: 214748364800, architecture: None,
gpu: Some(GpuSpec {
gpu_type: "nvidia-a10g".to_string(),
count: 1,
}),
}),
min_size: 0,
max_size: 2,
scale_policy: None,
nested_virtualization: None,
})
.build();
assert_eq!(cluster.capacity_groups.len(), 2);
assert_eq!(cluster.capacity_groups[0].group_id, "general");
assert_eq!(cluster.capacity_groups[1].group_id, "gpu");
assert!(cluster.capacity_groups[1]
.profile
.as_ref()
.unwrap()
.gpu
.is_some());
}
#[test]
fn test_compute_cluster_custom_cidr() {
let cluster = ComputeCluster::new("custom-net".to_string())
.container_cidr("172.30.0.0/16".to_string())
.capacity_group(CapacityGroup {
group_id: "general".to_string(),
instance_type: None,
profile: None,
min_size: 1,
max_size: 5,
scale_policy: None,
nested_virtualization: None,
})
.build();
assert_eq!(cluster.container_cidr(), "172.30.0.0/16");
}
#[test]
fn test_compute_cluster_validate_update_immutable_id() {
let cluster1 = ComputeCluster::new("cluster-1".to_string())
.capacity_group(CapacityGroup {
group_id: "general".to_string(),
instance_type: None,
profile: None,
min_size: 1,
max_size: 5,
scale_policy: None,
nested_virtualization: None,
})
.build();
let cluster2 = ComputeCluster::new("cluster-2".to_string())
.capacity_group(CapacityGroup {
group_id: "general".to_string(),
instance_type: None,
profile: None,
min_size: 1,
max_size: 5,
scale_policy: None,
nested_virtualization: None,
})
.build();
let result = cluster1.validate_update(&cluster2);
assert!(result.is_err());
}
#[test]
fn test_compute_cluster_validate_update_scale_change() {
let cluster1 = ComputeCluster::new("compute".to_string())
.capacity_group(CapacityGroup {
group_id: "general".to_string(),
instance_type: Some("m7g.xlarge".to_string()),
profile: None,
min_size: 1,
max_size: 5,
scale_policy: None,
nested_virtualization: None,
})
.build();
let cluster2 = ComputeCluster::new("compute".to_string())
.capacity_group(CapacityGroup {
group_id: "general".to_string(),
instance_type: Some("m7g.xlarge".to_string()),
profile: None,
min_size: 2,
max_size: 10,
scale_policy: None,
nested_virtualization: None,
})
.build();
let result = cluster1.validate_update(&cluster2);
assert!(result.is_ok());
}
#[test]
fn test_compute_cluster_serialization() {
let cluster = ComputeCluster::new("test-cluster".to_string())
.capacity_group(CapacityGroup {
group_id: "general".to_string(),
instance_type: Some("m7g.xlarge".to_string()),
profile: None,
min_size: 1,
max_size: 5,
scale_policy: None,
nested_virtualization: None,
})
.build();
let json = serde_json::to_string(&cluster).unwrap();
let deserialized: ComputeCluster = serde_json::from_str(&json).unwrap();
assert_eq!(cluster, deserialized);
}
}