use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
pub type AllocId = Uuid;
pub type NodeId = String; pub type TenantId = String;
pub type VClusterId = String;
pub type GroupId = u32; pub type UserId = String; pub type LaunchId = Uuid;
pub type SessionId = Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: SessionId,
pub allocation_id: AllocId,
pub user: UserId,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Allocation {
pub id: AllocId,
pub tenant: TenantId,
pub project: String,
pub vcluster: VClusterId,
pub user: UserId,
pub tags: HashMap<String, String>,
pub allocation_type: AllocationType,
pub environment: Environment,
pub entrypoint: String,
pub resources: ResourceRequest,
pub lifecycle: Lifecycle,
pub requeue_policy: RequeuePolicy,
pub max_requeue: u32,
pub data: DataRequirements,
pub connectivity: Connectivity,
pub depends_on: Vec<Dependency>,
pub checkpoint: CheckpointStrategy,
pub telemetry_mode: TelemetryMode,
pub liveness_probe: Option<LivenessProbe>,
pub state: AllocationState,
pub created_at: DateTime<Utc>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
pub assigned_nodes: Vec<NodeId>,
pub dag_id: Option<String>,
pub exit_code: Option<i32>,
pub message: Option<String>,
pub requeue_count: u32,
pub preempted_count: u32,
pub resume_from_checkpoint: bool,
#[serde(default)]
pub sensitive: bool,
#[serde(default)]
pub state_version: u64,
#[serde(default)]
pub dispatch_retry_count: u32,
#[serde(default)]
pub last_completion_report_at: Option<DateTime<Utc>>,
#[serde(default)]
pub assigned_at: Option<DateTime<Utc>>,
#[serde(default)]
pub per_node_phase: HashMap<NodeId, CompletionPhase>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum AllocationState {
Pending,
Staging,
Running,
Checkpointing,
Suspended,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CompletionPhase {
Staging,
Running,
Completed,
Failed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionReport {
pub allocation_id: AllocId,
pub phase: CompletionPhase,
pub pid: Option<u32>,
pub exit_code: Option<i32>,
pub reason: Option<String>,
}
impl CompletionReport {
pub fn to_per_node_allocation_state(&self) -> AllocationState {
match self.phase {
CompletionPhase::Staging => AllocationState::Staging,
CompletionPhase::Running => AllocationState::Running,
CompletionPhase::Completed => AllocationState::Completed,
CompletionPhase::Failed => AllocationState::Failed,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RefusalReason {
Busy,
UnsupportedCapability,
MalformedRequest,
AlreadyRunning,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RuntimeVariant {
BareProcess,
Uenv,
Podman,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AllocationType {
Single,
TaskGroup {
range_start: u32,
range_end: u32,
step: u32,
max_concurrent: u32,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RequeuePolicy {
Never,
OnNodeFailure,
Always,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum ImageType {
#[default]
Uenv,
Oci,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ImageRef {
pub spec: String,
pub image_type: ImageType,
pub registry: String,
pub name: String,
pub version: String,
pub original_tag: String,
pub sha256: String,
pub size_bytes: u64,
pub mount_point: String,
pub resolve_on_schedule: bool,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum EnvOp {
Prepend,
Append,
#[default]
Set,
Unset,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnvPatch {
pub variable: String,
pub op: EnvOp,
pub value: String,
pub separator: String,
}
impl Default for EnvPatch {
fn default() -> Self {
Self {
variable: String::new(),
op: EnvOp::Set,
value: String::new(),
separator: ":".to_string(),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ViewDef {
pub name: String,
pub description: String,
pub patches: Vec<EnvPatch>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ImageMetadata {
pub name: String,
pub description: String,
pub mount_point: String,
pub views: Vec<ViewDef>,
pub default_view: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ContainerSpec {
pub base_environments: Vec<String>,
pub mounts: Vec<MountSpec>,
pub devices: Vec<String>,
pub workdir: String,
pub writable: bool,
pub env: Vec<(String, String)>,
pub annotations: Vec<(String, String)>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MountSpec {
pub source: String,
pub target: String,
pub options: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Environment {
pub images: Vec<ImageRef>,
pub env_patches: Vec<EnvPatch>,
pub devices: Vec<String>,
pub mounts: Vec<MountSpec>,
pub container: Option<ContainerSpec>,
pub writable: bool,
pub sign_required: bool,
pub scan_required: bool,
pub approved_bases_only: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceRequest {
pub nodes: NodeCount,
pub constraints: ResourceConstraints,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NodeCount {
Exact(u32),
Range {
min: u32,
max: u32,
},
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ResourceConstraints {
pub gpu_type: Option<String>,
pub features: Vec<String>,
pub topology: Option<TopologyHint>,
pub feature_counts: HashMap<String, u32>,
#[serde(default)]
pub require_unified_memory: bool,
#[serde(default)]
pub prefer_same_numa: bool,
#[serde(default)]
pub allow_cxl_memory: bool,
pub memory_policy: Option<MemoryPolicy>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TopologyHint {
Tight, Spread, Any, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Lifecycle {
pub lifecycle_type: LifecycleType,
pub preemption_class: u8, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LifecycleType {
Bounded { walltime: chrono::Duration },
Unbounded,
Reactive {
min_nodes: u32,
max_nodes: u32,
metric: String,
target: String,
},
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DataRequirements {
pub mounts: Vec<DataMount>,
pub use_defaults: bool,
pub scratch_per_node: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataMount {
pub source: String, pub target: String, pub access: DataAccess,
pub tier_hint: Option<StorageTier>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DataAccess {
ReadOnly,
ReadWrite,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum StorageTier {
Hot,
Warm,
Cold,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Connectivity {
pub network_domain: Option<String>,
pub expose: Vec<ServiceEndpoint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceEndpoint {
pub name: String,
pub port: u16,
pub protocol: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegisteredEndpoint {
pub allocation_id: AllocId,
pub tenant: TenantId,
pub nodes: Vec<NodeId>,
pub port: u16,
pub protocol: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ServiceRegistryEntry {
pub endpoints: Vec<RegisteredEndpoint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LivenessProbe {
pub probe_type: ProbeType,
pub period_secs: u32,
pub initial_delay_secs: u32,
pub failure_threshold: u32,
pub timeout_secs: u32,
}
impl Default for LivenessProbe {
fn default() -> Self {
Self {
probe_type: ProbeType::Tcp { port: 8080 },
period_secs: 30,
initial_delay_secs: 10,
failure_threshold: 3,
timeout_secs: 5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ProbeType {
Tcp { port: u16 },
Http { port: u16, path: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkDomain {
pub name: String,
pub tenant: TenantId,
pub vni: u32,
pub state: NetworkDomainState,
pub member_allocations: Vec<AllocId>,
pub created_at: DateTime<Utc>,
pub grace_deadline: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum NetworkDomainState {
Active,
Draining,
Released,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Dependency {
pub ref_id: String,
pub condition: DependencyCondition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DependencyCondition {
Success,
Failure,
Any,
Corresponding,
Mutex,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CheckpointStrategy {
Auto,
Manual,
None,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub enum TelemetryMode {
#[default]
Prod,
Debug { duration_seconds: u64 },
Audit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttachSession {
pub session_id: Uuid,
pub allocation_id: AllocId,
pub node_id: NodeId,
pub user: UserId,
pub command: String,
pub started_at: DateTime<Utc>,
pub ended_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogEntry {
pub node_id: NodeId,
pub stream: LogStream,
pub data: Vec<u8>,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LogStream {
Stdout,
Stderr,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogConfig {
pub ring_buffer_size: u64,
pub s3_persistence: bool,
pub retention: Option<chrono::Duration>,
}
impl Default for LogConfig {
fn default() -> Self {
Self {
ring_buffer_size: 64 * 1024 * 1024, s3_persistence: true,
retention: None, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeMetricsSnapshot {
pub node_id: NodeId,
pub timestamp: DateTime<Utc>,
pub cpu_utilization: f64,
pub memory_used_bytes: u64,
pub memory_total_bytes: u64,
pub network_tx_bytes_per_sec: f64,
pub network_rx_bytes_per_sec: f64,
pub io_read_bytes_per_sec: f64,
pub io_write_bytes_per_sec: f64,
pub io_latency_p99_us: f64,
pub gpus: Vec<GpuMetricsSnapshot>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuMetricsSnapshot {
pub index: u32,
pub utilization: f64,
pub memory_used_bytes: u64,
pub memory_total_bytes: u64,
pub power_draw_watts: f64,
pub temperature_celsius: f64,
pub ecc_errors: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AllocationMetricsSummary {
pub allocation_id: AllocId,
pub timestamp: DateTime<Utc>,
pub gpu_utilization_mean: f64,
pub cpu_utilization_mean: f64,
pub memory_used_bytes: u64,
pub memory_total_bytes: u64,
pub gpu_memory_used_bytes: u64,
pub gpu_memory_total_bytes: u64,
pub network_tx_bytes_per_sec: f64,
pub network_rx_bytes_per_sec: f64,
pub io_read_bytes_per_sec: f64,
pub io_write_bytes_per_sec: f64,
pub io_latency_p99_us: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkDiagnostics {
pub group_span: u32,
pub groups: Vec<GroupId>,
pub csig_congestion_avg: f64,
pub inter_node_bandwidth_gbps: f64,
pub target_bandwidth_gbps: f64,
pub node_pairs: Vec<NodePairBandwidth>,
pub nvlink_throughput_gbps: f64,
pub network_errors: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodePairBandwidth {
pub source_node: NodeId,
pub target_node: NodeId,
pub bandwidth_gbps: f64,
pub latency_us: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageDiagnostics {
pub mounts: Vec<MountDiagnostics>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MountDiagnostics {
pub mount_path: String,
pub mount_type: String,
pub read_throughput_gbps: f64,
pub write_throughput_gbps: f64,
pub qos_floor_gbps: f64,
pub latency_p50_us: f64,
pub latency_p95_us: f64,
pub latency_p99_us: f64,
pub iops_read: f64,
pub iops_write: f64,
pub health: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricAlert {
pub node_id: NodeId,
pub metric_name: String,
pub current_value: f64,
pub threshold: f64,
pub severity: AlertSeverity,
pub message: String,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AlertSeverity {
Info,
Warning,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Node {
pub id: NodeId,
pub group: GroupId,
pub capabilities: NodeCapabilities,
pub state: NodeState,
pub owner: Option<NodeOwnership>,
pub conformance_fingerprint: Option<String>,
pub last_heartbeat: Option<DateTime<Utc>>,
#[serde(default)]
pub owner_version: u64,
#[serde(default)]
pub agent_address: String,
#[serde(default)]
pub consecutive_dispatch_failures: u32,
#[serde(default)]
pub degraded_at: Option<DateTime<Utc>>,
#[serde(default)]
pub reattach_in_progress: bool,
#[serde(default)]
pub reattach_first_set_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeCapabilities {
pub gpu_type: Option<String>,
pub gpu_count: u32,
pub cpu_cores: u32,
pub memory_gb: u64,
pub features: Vec<String>,
pub gpu_topology: Option<GpuTopology>,
pub memory_topology: Option<MemoryTopology>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeState {
Unknown,
Booting,
Ready,
Degraded { reason: String },
Down { reason: String },
Draining,
Drained,
Failed { reason: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeOwnership {
pub tenant: TenantId,
pub vcluster: VClusterId,
pub allocation: AllocId,
pub claimed_by: Option<UserId>,
pub is_borrowed: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tenant {
pub id: TenantId,
pub name: String,
pub quota: TenantQuota,
pub isolation_level: IsolationLevel,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TenantQuota {
pub max_nodes: u32,
pub fair_share_target: f64,
pub gpu_hours_budget: Option<f64>,
#[serde(default)]
pub node_hours_budget: Option<f64>,
pub max_concurrent_allocations: Option<u32>,
#[serde(default)]
pub burst_allowance: Option<f64>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum IsolationLevel {
Standard,
Strict,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VCluster {
pub id: VClusterId,
pub name: String,
pub tenant: TenantId,
pub scheduler_type: SchedulerType,
pub cost_weights: CostWeights,
pub dedicated_nodes: Vec<NodeId>,
pub allow_borrowing: bool,
pub allow_lending: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum SchedulerType {
HpcBackfill,
ServiceBinPack,
SensitiveReservation,
InteractiveFifo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostWeights {
pub priority: f64,
pub wait_time: f64,
pub fair_share: f64,
pub topology: f64,
pub data_readiness: f64,
pub backlog: f64,
pub energy: f64,
pub checkpoint_efficiency: f64,
pub conformance: f64,
}
impl Default for CostWeights {
fn default() -> Self {
Self {
priority: 0.20,
wait_time: 0.20,
fair_share: 0.20,
topology: 0.15,
data_readiness: 0.10,
backlog: 0.05,
energy: 0.00,
checkpoint_efficiency: 0.00,
conformance: 0.10,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuTopology {
pub devices: Vec<GpuDevice>,
pub nic_affinity: HashMap<u32, u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuDevice {
pub index: u32,
pub vendor: GpuVendor,
pub model: String,
pub memory_bytes: u64,
pub links: Vec<GpuLink>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuLink {
pub peer_index: u32,
pub link_type: GpuLinkType,
pub bandwidth_gbps: f64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum GpuLinkType {
NVLink,
NVSwitch,
InfinityFabric,
PCIe,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GpuVendor {
Nvidia,
Amd,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryTopology {
pub domains: Vec<MemoryDomain>,
pub interconnects: Vec<MemoryInterconnect>,
pub total_capacity_bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryDomain {
pub id: u32,
pub domain_type: MemoryDomainType,
pub capacity_bytes: u64,
pub numa_node: Option<u32>,
pub attached_cpus: Vec<u32>,
pub attached_gpus: Vec<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryInterconnect {
pub domain_a: u32,
pub domain_b: u32,
pub link_type: MemoryLinkType,
pub bandwidth_gbps: f64,
pub latency_ns: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryDomainType {
Dram,
Hbm,
CxlAttached,
Unified,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryLinkType {
NumaLink,
CxlSwitch,
CoherentFabric,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryPolicy {
Local,
Interleave,
Preferred,
Bind,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopologyModel {
pub groups: Vec<TopologyGroup>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopologyGroup {
pub id: GroupId,
pub nodes: Vec<NodeId>,
pub adjacent_groups: Vec<GroupId>,
}
impl AllocationState {
pub fn can_transition_to(&self, target: &AllocationState) -> bool {
matches!(
(self, target),
(AllocationState::Pending, AllocationState::Staging)
| (AllocationState::Pending, AllocationState::Running)
| (AllocationState::Pending, AllocationState::Cancelled)
| (AllocationState::Pending, AllocationState::Failed)
| (AllocationState::Staging, AllocationState::Running)
| (AllocationState::Staging, AllocationState::Failed)
| (AllocationState::Staging, AllocationState::Cancelled)
| (AllocationState::Running, AllocationState::Checkpointing)
| (AllocationState::Running, AllocationState::Completed)
| (AllocationState::Running, AllocationState::Failed)
| (AllocationState::Running, AllocationState::Cancelled)
| (AllocationState::Checkpointing, AllocationState::Suspended)
| (AllocationState::Checkpointing, AllocationState::Failed)
| (AllocationState::Checkpointing, AllocationState::Cancelled)
| (AllocationState::Suspended, AllocationState::Pending)
| (AllocationState::Suspended, AllocationState::Cancelled)
| (AllocationState::Suspended, AllocationState::Failed)
| (AllocationState::Failed, AllocationState::Pending)
)
}
pub fn is_terminal(&self) -> bool {
matches!(
self,
AllocationState::Completed | AllocationState::Failed | AllocationState::Cancelled
)
}
}
impl NodeState {
pub fn can_transition_to(&self, target: &NodeState) -> bool {
matches!(
(self, target),
(NodeState::Unknown, NodeState::Booting)
| (NodeState::Unknown, NodeState::Failed { .. })
| (NodeState::Booting, NodeState::Ready)
| (NodeState::Booting, NodeState::Failed { .. })
| (NodeState::Ready, NodeState::Degraded { .. })
| (NodeState::Ready, NodeState::Draining)
| (NodeState::Ready, NodeState::Down { .. })
| (NodeState::Degraded { .. }, NodeState::Ready)
| (NodeState::Degraded { .. }, NodeState::Down { .. })
| (NodeState::Degraded { .. }, NodeState::Draining)
| (NodeState::Down { .. }, NodeState::Ready)
| (NodeState::Down { .. }, NodeState::Booting)
| (NodeState::Down { .. }, NodeState::Failed { .. })
| (NodeState::Draining, NodeState::Drained)
| (NodeState::Drained, NodeState::Ready)
| (NodeState::Drained, NodeState::Booting)
| (NodeState::Failed { .. }, NodeState::Booting)
)
}
pub fn is_operational(&self) -> bool {
matches!(self, NodeState::Ready | NodeState::Degraded { .. })
}
}
impl NetworkDomainState {
pub fn can_transition_to(&self, target: &NetworkDomainState) -> bool {
matches!(
(self, target),
(NetworkDomainState::Active, NetworkDomainState::Draining)
| (NetworkDomainState::Draining, NetworkDomainState::Active)
| (NetworkDomainState::Draining, NetworkDomainState::Released)
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum PmiMode {
#[default]
Pmi2,
Pmix,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CxiCredentials {
pub vni: u32,
pub auth_key: Vec<u8>,
pub svc_id: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PeerInfo {
pub node_id: NodeId,
pub grpc_address: String,
pub first_rank: u32,
pub num_ranks: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeRankAssignment {
pub node_id: NodeId,
pub first_rank: u32,
pub num_ranks: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RankLayout {
pub total_ranks: u32,
pub tasks_per_node: u32,
pub node_assignments: Vec<NodeRankAssignment>,
}
impl RankLayout {
pub fn compute(nodes: &[NodeId], tasks_per_node: u32) -> Self {
let mut assignments = Vec::with_capacity(nodes.len());
let mut rank = 0u32;
for node_id in nodes {
assignments.push(NodeRankAssignment {
node_id: node_id.clone(),
first_rank: rank,
num_ranks: tasks_per_node,
});
rank += tasks_per_node;
}
RankLayout {
total_ranks: rank,
tasks_per_node,
node_assignments: assignments,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RankExitStatus {
pub rank: u32,
pub exit_code: Option<i32>,
pub signal: Option<i32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LaunchSpec {
pub launch_id: LaunchId,
pub allocation_id: AllocId,
pub entrypoint: String,
pub args: Vec<String>,
pub env: HashMap<String, String>,
pub rank_layout: RankLayout,
pub pmi_mode: PmiMode,
pub cxi_credentials: Option<CxiCredentials>,
}
pub fn apply_env_patches(patches: &[EnvPatch], env: &mut HashMap<String, String>) {
for patch in patches {
match patch.op {
EnvOp::Set => {
env.insert(patch.variable.clone(), patch.value.clone());
}
EnvOp::Unset => {
env.remove(&patch.variable);
}
EnvOp::Prepend => {
let existing = env.get(&patch.variable).cloned().unwrap_or_default();
let new_val = if existing.is_empty() {
patch.value.clone()
} else {
format!("{}{}{}", patch.value, patch.separator, existing)
};
env.insert(patch.variable.clone(), new_val);
}
EnvOp::Append => {
let existing = env.get(&patch.variable).cloned().unwrap_or_default();
let new_val = if existing.is_empty() {
patch.value.clone()
} else {
format!("{}{}{}", existing, patch.separator, patch.value)
};
env.insert(patch.variable.clone(), new_val);
}
}
}
}
pub fn check_mount_overlap(targets: &[&str]) -> Result<(), String> {
let mut sorted: Vec<&str> = targets.to_vec();
sorted.sort_by_key(|t| t.len());
for i in 0..sorted.len() {
for j in (i + 1)..sorted.len() {
let short = sorted[i];
let long = sorted[j];
if short == long {
return Err(format!("duplicate mount target: {short}"));
}
if let Some(rest) = long.strip_prefix(short) {
if rest.starts_with('/') || short.ends_with('/') {
return Err(format!("mount target overlap: {short} shadows {long}"));
}
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn environment_default_has_empty_vecs_and_false_bools() {
let env = Environment::default();
assert!(env.images.is_empty());
assert!(env.env_patches.is_empty());
assert!(env.devices.is_empty());
assert!(env.mounts.is_empty());
assert!(env.container.is_none());
assert!(!env.writable);
assert!(!env.sign_required);
assert!(!env.scan_required);
assert!(!env.approved_bases_only);
}
#[test]
fn image_ref_default() {
let img = ImageRef::default();
assert!(img.spec.is_empty());
assert_eq!(img.image_type, ImageType::Uenv);
assert!(img.sha256.is_empty());
assert_eq!(img.size_bytes, 0);
assert!(!img.resolve_on_schedule);
}
#[test]
fn image_ref_deferred_has_empty_sha256() {
let img = ImageRef {
spec: "prgenv-gnu/24.11:v1".into(),
resolve_on_schedule: true,
..ImageRef::default()
};
assert!(img.sha256.is_empty());
assert!(img.resolve_on_schedule);
}
#[test]
fn env_patch_default_separator_is_colon() {
let patch = EnvPatch::default();
assert_eq!(patch.separator, ":");
assert_eq!(patch.op, EnvOp::Set);
}
#[test]
fn container_spec_default() {
let cs = ContainerSpec::default();
assert!(cs.base_environments.is_empty());
assert!(cs.mounts.is_empty());
assert!(cs.devices.is_empty());
assert!(cs.workdir.is_empty());
assert!(!cs.writable);
}
#[test]
fn mount_spec_default() {
let ms = MountSpec::default();
assert!(ms.source.is_empty());
assert!(ms.target.is_empty());
assert!(ms.options.is_empty());
}
#[test]
fn view_def_default() {
let vd = ViewDef::default();
assert!(vd.name.is_empty());
assert!(vd.patches.is_empty());
}
#[test]
fn image_metadata_default() {
let im = ImageMetadata::default();
assert!(im.views.is_empty());
assert!(im.default_view.is_none());
}
#[test]
fn apply_env_patches_set() {
let patches = vec![EnvPatch {
variable: "FOO".into(),
op: EnvOp::Set,
value: "bar".into(),
separator: ":".into(),
}];
let mut env = HashMap::new();
apply_env_patches(&patches, &mut env);
assert_eq!(env.get("FOO").unwrap(), "bar");
}
#[test]
fn apply_env_patches_unset() {
let patches = vec![EnvPatch {
variable: "FOO".into(),
op: EnvOp::Unset,
value: String::new(),
separator: ":".into(),
}];
let mut env = HashMap::new();
env.insert("FOO".into(), "bar".into());
apply_env_patches(&patches, &mut env);
assert!(!env.contains_key("FOO"));
}
#[test]
fn apply_env_patches_prepend() {
let patches = vec![EnvPatch {
variable: "PATH".into(),
op: EnvOp::Prepend,
value: "/new/bin".into(),
separator: ":".into(),
}];
let mut env = HashMap::new();
env.insert("PATH".into(), "/usr/bin".into());
apply_env_patches(&patches, &mut env);
assert_eq!(env.get("PATH").unwrap(), "/new/bin:/usr/bin");
}
#[test]
fn apply_env_patches_append() {
let patches = vec![EnvPatch {
variable: "PATH".into(),
op: EnvOp::Append,
value: "/extra/bin".into(),
separator: ":".into(),
}];
let mut env = HashMap::new();
env.insert("PATH".into(), "/usr/bin".into());
apply_env_patches(&patches, &mut env);
assert_eq!(env.get("PATH").unwrap(), "/usr/bin:/extra/bin");
}
#[test]
fn apply_env_patches_prepend_to_empty() {
let patches = vec![EnvPatch {
variable: "NEW_VAR".into(),
op: EnvOp::Prepend,
value: "/opt/lib".into(),
separator: ":".into(),
}];
let mut env = HashMap::new();
apply_env_patches(&patches, &mut env);
assert_eq!(env.get("NEW_VAR").unwrap(), "/opt/lib");
}
#[test]
fn apply_env_patches_declaration_order() {
let patches = vec![
EnvPatch {
variable: "X".into(),
op: EnvOp::Set,
value: "first".into(),
separator: ":".into(),
},
EnvPatch {
variable: "X".into(),
op: EnvOp::Set,
value: "second".into(),
separator: ":".into(),
},
];
let mut env = HashMap::new();
apply_env_patches(&patches, &mut env);
assert_eq!(env.get("X").unwrap(), "second");
}
#[test]
fn check_mount_overlap_no_overlap() {
assert!(check_mount_overlap(&["/opt", "/usr", "/mnt/data"]).is_ok());
}
#[test]
fn check_mount_overlap_prefix_shadow() {
let result = check_mount_overlap(&["/opt", "/opt/env"]);
assert!(result.is_err());
assert!(result.unwrap_err().contains("shadows"));
}
#[test]
fn check_mount_overlap_duplicate() {
let result = check_mount_overlap(&["/opt", "/opt"]);
assert!(result.is_err());
assert!(result.unwrap_err().contains("duplicate"));
}
#[test]
fn check_mount_overlap_no_false_positive_on_shared_prefix() {
assert!(check_mount_overlap(&["/opt", "/optional"]).is_ok());
}
#[test]
fn check_mount_overlap_empty_list() {
assert!(check_mount_overlap(&[]).is_ok());
}
#[test]
fn pending_can_transition_to_staging() {
assert!(AllocationState::Pending.can_transition_to(&AllocationState::Staging));
}
#[test]
fn pending_can_transition_to_running() {
assert!(AllocationState::Pending.can_transition_to(&AllocationState::Running));
}
#[test]
fn pending_can_transition_to_cancelled() {
assert!(AllocationState::Pending.can_transition_to(&AllocationState::Cancelled));
}
#[test]
fn pending_can_transition_to_failed() {
assert!(AllocationState::Pending.can_transition_to(&AllocationState::Failed));
}
#[test]
fn pending_cannot_transition_to_completed() {
assert!(!AllocationState::Pending.can_transition_to(&AllocationState::Completed));
}
#[test]
fn pending_cannot_transition_to_checkpointing() {
assert!(!AllocationState::Pending.can_transition_to(&AllocationState::Checkpointing));
}
#[test]
fn staging_can_transition_to_running() {
assert!(AllocationState::Staging.can_transition_to(&AllocationState::Running));
}
#[test]
fn staging_can_transition_to_failed() {
assert!(AllocationState::Staging.can_transition_to(&AllocationState::Failed));
}
#[test]
fn staging_can_transition_to_cancelled() {
assert!(AllocationState::Staging.can_transition_to(&AllocationState::Cancelled));
}
#[test]
fn running_can_transition_to_checkpointing() {
assert!(AllocationState::Running.can_transition_to(&AllocationState::Checkpointing));
}
#[test]
fn running_can_transition_to_completed() {
assert!(AllocationState::Running.can_transition_to(&AllocationState::Completed));
}
#[test]
fn running_can_transition_to_failed() {
assert!(AllocationState::Running.can_transition_to(&AllocationState::Failed));
}
#[test]
fn running_can_transition_to_cancelled() {
assert!(AllocationState::Running.can_transition_to(&AllocationState::Cancelled));
}
#[test]
fn running_cannot_transition_to_pending() {
assert!(!AllocationState::Running.can_transition_to(&AllocationState::Pending));
}
#[test]
fn checkpointing_can_transition_to_suspended() {
assert!(AllocationState::Checkpointing.can_transition_to(&AllocationState::Suspended));
}
#[test]
fn checkpointing_can_transition_to_failed() {
assert!(AllocationState::Checkpointing.can_transition_to(&AllocationState::Failed));
}
#[test]
fn checkpointing_can_transition_to_cancelled() {
assert!(AllocationState::Checkpointing.can_transition_to(&AllocationState::Cancelled));
}
#[test]
fn suspended_can_transition_to_pending() {
assert!(AllocationState::Suspended.can_transition_to(&AllocationState::Pending));
}
#[test]
fn suspended_can_transition_to_cancelled() {
assert!(AllocationState::Suspended.can_transition_to(&AllocationState::Cancelled));
}
#[test]
fn suspended_can_transition_to_failed() {
assert!(AllocationState::Suspended.can_transition_to(&AllocationState::Failed));
}
#[test]
fn completed_is_terminal() {
assert!(AllocationState::Completed.is_terminal());
}
#[test]
fn failed_is_terminal() {
assert!(AllocationState::Failed.is_terminal());
}
#[test]
fn cancelled_is_terminal() {
assert!(AllocationState::Cancelled.is_terminal());
}
#[test]
fn running_is_not_terminal() {
assert!(!AllocationState::Running.is_terminal());
}
#[test]
fn terminal_states_cannot_transition_to_anything_except_failed_to_pending() {
let all_states = [
AllocationState::Pending,
AllocationState::Staging,
AllocationState::Running,
AllocationState::Checkpointing,
AllocationState::Suspended,
AllocationState::Completed,
AllocationState::Failed,
AllocationState::Cancelled,
];
for terminal in &[AllocationState::Completed, AllocationState::Cancelled] {
for target in &all_states {
assert!(
!terminal.can_transition_to(target),
"{terminal:?} should not transition to {target:?}"
);
}
}
assert!(AllocationState::Failed.can_transition_to(&AllocationState::Pending));
for target in &all_states {
if *target != AllocationState::Pending {
assert!(
!AllocationState::Failed.can_transition_to(target),
"Failed should not transition to {target:?}"
);
}
}
}
#[test]
fn unknown_can_transition_to_booting() {
assert!(NodeState::Unknown.can_transition_to(&NodeState::Booting));
}
#[test]
fn unknown_can_transition_to_failed() {
assert!(NodeState::Unknown.can_transition_to(&NodeState::Failed {
reason: "hw error".into()
}));
}
#[test]
fn booting_can_transition_to_ready() {
assert!(NodeState::Booting.can_transition_to(&NodeState::Ready));
}
#[test]
fn booting_can_transition_to_failed() {
assert!(NodeState::Booting.can_transition_to(&NodeState::Failed {
reason: "boot failed".into()
}));
}
#[test]
fn ready_can_transition_to_degraded() {
assert!(NodeState::Ready.can_transition_to(&NodeState::Degraded {
reason: "heartbeat missed".into()
}));
}
#[test]
fn ready_can_transition_to_draining() {
assert!(NodeState::Ready.can_transition_to(&NodeState::Draining));
}
#[test]
fn ready_can_transition_to_down() {
assert!(NodeState::Ready.can_transition_to(&NodeState::Down {
reason: "operator".into()
}));
}
#[test]
fn degraded_can_transition_to_ready() {
assert!(NodeState::Degraded {
reason: "fixed".into()
}
.can_transition_to(&NodeState::Ready));
}
#[test]
fn degraded_can_transition_to_down() {
assert!(NodeState::Degraded {
reason: "worsened".into()
}
.can_transition_to(&NodeState::Down {
reason: "confirmed".into()
}));
}
#[test]
fn degraded_can_transition_to_draining() {
assert!(NodeState::Degraded {
reason: "draining".into()
}
.can_transition_to(&NodeState::Draining));
}
#[test]
fn down_can_transition_to_ready() {
assert!(NodeState::Down {
reason: "operator disabled".into()
}
.can_transition_to(&NodeState::Ready));
}
#[test]
fn down_can_transition_to_booting() {
assert!(NodeState::Down {
reason: "reimage".into()
}
.can_transition_to(&NodeState::Booting));
}
#[test]
fn down_can_transition_to_failed() {
assert!(NodeState::Down {
reason: "unrecoverable".into()
}
.can_transition_to(&NodeState::Failed {
reason: "hw".into()
}));
}
#[test]
fn draining_can_transition_to_drained() {
assert!(NodeState::Draining.can_transition_to(&NodeState::Drained));
}
#[test]
fn draining_cannot_transition_to_ready() {
assert!(!NodeState::Draining.can_transition_to(&NodeState::Ready));
}
#[test]
fn drained_can_transition_to_ready() {
assert!(NodeState::Drained.can_transition_to(&NodeState::Ready));
}
#[test]
fn drained_can_transition_to_booting() {
assert!(NodeState::Drained.can_transition_to(&NodeState::Booting));
}
#[test]
fn failed_node_can_transition_to_booting() {
assert!(NodeState::Failed {
reason: "reimage".into()
}
.can_transition_to(&NodeState::Booting));
}
#[test]
fn ready_is_operational() {
assert!(NodeState::Ready.is_operational());
}
#[test]
fn degraded_is_operational() {
assert!(NodeState::Degraded {
reason: "minor".into()
}
.is_operational());
}
#[test]
fn down_is_not_operational() {
assert!(!NodeState::Down {
reason: "down".into()
}
.is_operational());
}
#[test]
fn draining_is_not_operational() {
assert!(!NodeState::Draining.is_operational());
}
#[test]
fn active_can_transition_to_draining() {
assert!(NetworkDomainState::Active.can_transition_to(&NetworkDomainState::Draining));
}
#[test]
fn active_cannot_transition_to_released() {
assert!(!NetworkDomainState::Active.can_transition_to(&NetworkDomainState::Released));
}
#[test]
fn draining_can_transition_to_active() {
assert!(NetworkDomainState::Draining.can_transition_to(&NetworkDomainState::Active));
}
#[test]
fn draining_can_transition_to_released() {
assert!(NetworkDomainState::Draining.can_transition_to(&NetworkDomainState::Released));
}
#[test]
fn released_cannot_transition_to_anything() {
assert!(!NetworkDomainState::Released.can_transition_to(&NetworkDomainState::Active));
assert!(!NetworkDomainState::Released.can_transition_to(&NetworkDomainState::Draining));
assert!(!NetworkDomainState::Released.can_transition_to(&NetworkDomainState::Released));
}
#[test]
fn allocation_state_serde_roundtrip() {
let states = [
AllocationState::Pending,
AllocationState::Staging,
AllocationState::Running,
AllocationState::Checkpointing,
AllocationState::Suspended,
AllocationState::Completed,
AllocationState::Failed,
AllocationState::Cancelled,
];
for state in &states {
let json = serde_json::to_string(state).unwrap();
let deser: AllocationState = serde_json::from_str(&json).unwrap();
assert_eq!(*state, deser, "roundtrip failed for {state:?}");
}
}
#[test]
fn node_state_serde_roundtrip() {
let states = [
NodeState::Unknown,
NodeState::Booting,
NodeState::Ready,
NodeState::Degraded {
reason: "test".into(),
},
NodeState::Down {
reason: "test".into(),
},
NodeState::Draining,
NodeState::Drained,
NodeState::Failed {
reason: "test".into(),
},
];
for state in &states {
let json = serde_json::to_string(state).unwrap();
let deser: NodeState = serde_json::from_str(&json).unwrap();
assert_eq!(*state, deser, "roundtrip failed for {state:?}");
}
}
#[test]
fn cost_weights_default_sums_to_one() {
let w = CostWeights::default();
let sum = w.priority
+ w.wait_time
+ w.fair_share
+ w.topology
+ w.data_readiness
+ w.backlog
+ w.energy
+ w.checkpoint_efficiency
+ w.conformance;
assert!(
(sum - 1.0).abs() < 1e-10,
"default weights sum to {sum}, expected 1.0"
);
}
#[test]
fn cost_weights_serde_roundtrip() {
let w = CostWeights::default();
let json = serde_json::to_string(&w).unwrap();
let deser: CostWeights = serde_json::from_str(&json).unwrap();
assert!((deser.priority - w.priority).abs() < f64::EPSILON);
assert!((deser.topology - w.topology).abs() < f64::EPSILON);
}
#[test]
fn memory_domain_type_serde_roundtrip() {
let types = [
MemoryDomainType::Dram,
MemoryDomainType::Hbm,
MemoryDomainType::CxlAttached,
MemoryDomainType::Unified,
];
for t in &types {
let json = serde_json::to_string(t).unwrap();
let deser: MemoryDomainType = serde_json::from_str(&json).unwrap();
assert_eq!(*t, deser, "roundtrip failed for {t:?}");
}
}
#[test]
fn memory_link_type_serde_roundtrip() {
let types = [
MemoryLinkType::NumaLink,
MemoryLinkType::CxlSwitch,
MemoryLinkType::CoherentFabric,
];
for t in &types {
let json = serde_json::to_string(t).unwrap();
let deser: MemoryLinkType = serde_json::from_str(&json).unwrap();
assert_eq!(*t, deser, "roundtrip failed for {t:?}");
}
}
#[test]
fn memory_policy_serde_roundtrip() {
let policies = [
MemoryPolicy::Local,
MemoryPolicy::Interleave,
MemoryPolicy::Preferred,
MemoryPolicy::Bind,
];
for p in &policies {
let json = serde_json::to_string(p).unwrap();
let deser: MemoryPolicy = serde_json::from_str(&json).unwrap();
assert_eq!(*p, deser, "roundtrip failed for {p:?}");
}
}
#[test]
fn memory_topology_serde_roundtrip() {
let topo = MemoryTopology {
domains: vec![
MemoryDomain {
id: 0,
domain_type: MemoryDomainType::Dram,
capacity_bytes: 128 * 1024 * 1024 * 1024,
numa_node: Some(0),
attached_cpus: vec![0, 1, 2, 3],
attached_gpus: vec![0, 1],
},
MemoryDomain {
id: 1,
domain_type: MemoryDomainType::CxlAttached,
capacity_bytes: 256 * 1024 * 1024 * 1024,
numa_node: None,
attached_cpus: vec![],
attached_gpus: vec![],
},
],
interconnects: vec![MemoryInterconnect {
domain_a: 0,
domain_b: 1,
link_type: MemoryLinkType::CxlSwitch,
bandwidth_gbps: 64.0,
latency_ns: 200,
}],
total_capacity_bytes: 384 * 1024 * 1024 * 1024,
};
let json = serde_json::to_string(&topo).unwrap();
let deser: MemoryTopology = serde_json::from_str(&json).unwrap();
assert_eq!(deser.domains.len(), 2);
assert_eq!(deser.interconnects.len(), 1);
assert_eq!(deser.total_capacity_bytes, topo.total_capacity_bytes);
}
#[test]
fn resource_constraints_default_has_memory_fields() {
let rc = ResourceConstraints::default();
assert!(!rc.require_unified_memory);
assert!(!rc.prefer_same_numa);
assert!(!rc.allow_cxl_memory);
assert!(rc.memory_policy.is_none());
}
#[test]
fn resource_constraints_memory_serde_roundtrip() {
let rc = ResourceConstraints {
require_unified_memory: true,
prefer_same_numa: true,
allow_cxl_memory: false,
memory_policy: Some(MemoryPolicy::Interleave),
..Default::default()
};
let json = serde_json::to_string(&rc).unwrap();
let deser: ResourceConstraints = serde_json::from_str(&json).unwrap();
assert!(deser.require_unified_memory);
assert!(deser.prefer_same_numa);
assert!(!deser.allow_cxl_memory);
assert_eq!(deser.memory_policy, Some(MemoryPolicy::Interleave));
}
mod proptests {
use super::*;
use proptest::prelude::*;
fn arb_allocation_state() -> impl Strategy<Value = AllocationState> {
prop_oneof![
Just(AllocationState::Pending),
Just(AllocationState::Staging),
Just(AllocationState::Running),
Just(AllocationState::Checkpointing),
Just(AllocationState::Suspended),
Just(AllocationState::Completed),
Just(AllocationState::Failed),
Just(AllocationState::Cancelled),
]
}
fn arb_memory_domain_type() -> impl Strategy<Value = MemoryDomainType> {
prop_oneof![
Just(MemoryDomainType::Dram),
Just(MemoryDomainType::Hbm),
Just(MemoryDomainType::CxlAttached),
Just(MemoryDomainType::Unified),
]
}
fn arb_memory_policy() -> impl Strategy<Value = MemoryPolicy> {
prop_oneof![
Just(MemoryPolicy::Local),
Just(MemoryPolicy::Interleave),
Just(MemoryPolicy::Preferred),
Just(MemoryPolicy::Bind),
]
}
proptest! {
#[test]
fn memory_domain_type_roundtrip(dt in arb_memory_domain_type()) {
let json = serde_json::to_string(&dt).unwrap();
let back: MemoryDomainType = serde_json::from_str(&json).unwrap();
prop_assert_eq!(dt, back);
}
#[test]
fn memory_policy_roundtrip(p in arb_memory_policy()) {
let json = serde_json::to_string(&p).unwrap();
let back: MemoryPolicy = serde_json::from_str(&json).unwrap();
prop_assert_eq!(p, back);
}
#[test]
fn terminal_states_block_all_transitions(target in arb_allocation_state()) {
for terminal in &[AllocationState::Completed, AllocationState::Cancelled] {
prop_assert!(!terminal.can_transition_to(&target));
}
if target != AllocationState::Pending {
prop_assert!(!AllocationState::Failed.can_transition_to(&target));
}
}
#[test]
fn no_self_transitions(state in arb_allocation_state()) {
prop_assert!(!state.can_transition_to(&state));
}
#[test]
fn node_count_range_min_le_max(min in 1u32..1000, max in 1u32..1000) {
prop_assume!(min <= max);
let _nc = NodeCount::Range { min, max };
}
}
}
}