Skip to main content

consortium_nix/
config.rs

1//! Configuration types for NixOS deployment.
2//!
3//! These types map to the JSON produced by the Nix library's `mkFleet` function.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8
9/// Profile type for a deployment target.
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(rename_all = "kebab-case")]
12pub enum ProfileType {
13    Nixos,
14    NixDarwin,
15}
16
17/// A single deployment target node.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(rename_all = "camelCase")]
20pub struct DeploymentNode {
21    /// Node name (matches nixosConfigurations key).
22    pub name: String,
23    /// Target host (hostname or IP).
24    pub target_host: String,
25    /// SSH user for deployment.
26    pub target_user: String,
27    /// SSH port (None = default 22).
28    pub target_port: Option<u16>,
29    /// System architecture (e.g. "x86_64-linux").
30    pub system: String,
31    /// Profile type.
32    pub profile_type: ProfileType,
33    /// Whether to build on the target itself.
34    pub build_on_target: bool,
35    /// Tags for group-based selection.
36    pub tags: Vec<String>,
37    /// Derivation path for the system toplevel.
38    #[serde(default)]
39    pub drv_path: Option<String>,
40    /// Store path for the system toplevel (after build).
41    #[serde(default)]
42    pub toplevel: Option<String>,
43}
44
45/// A remote builder machine.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47#[serde(rename_all = "camelCase")]
48pub struct Builder {
49    /// Builder hostname or IP.
50    pub host: String,
51    /// SSH user for builder access.
52    pub user: String,
53    /// Maximum concurrent build jobs.
54    pub max_jobs: u32,
55    /// Speed factor (higher = preferred).
56    pub speed_factor: u32,
57    /// Supported system types.
58    pub systems: Vec<String>,
59    /// Supported build features.
60    pub features: Vec<String>,
61    /// Path to SSH identity file for builder access.
62    pub ssh_key: Option<String>,
63    /// SSH protocol (e.g. "ssh-ng").
64    #[serde(default = "default_protocol")]
65    pub protocol: String,
66}
67
68fn default_protocol() -> String {
69    "ssh-ng".to_string()
70}
71
72/// Complete fleet configuration produced by Nix evaluation.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct FleetConfig {
75    /// Deployment nodes keyed by hostname.
76    pub nodes: HashMap<String, DeploymentNode>,
77    /// Available remote builders keyed by hostname.
78    #[serde(default)]
79    pub builders: HashMap<String, Builder>,
80    /// Flake URI for builds (e.g. "." or "github:user/repo").
81    #[serde(default = "default_flake_uri")]
82    pub flake_uri: String,
83    /// Ansible configuration (if ansible integration is enabled).
84    #[serde(default, rename = "ansibleConfig")]
85    pub ansible_config: Option<AnsibleFleetConfig>,
86    /// Slurm configuration (if slurm integration is enabled).
87    #[serde(default, rename = "slurmConfig")]
88    pub slurm_config: Option<SlurmFleetConfig>,
89    /// Ray configuration (if ray integration is enabled).
90    #[serde(default, rename = "rayConfig")]
91    pub ray_config: Option<RayFleetConfig>,
92    /// SkyPilot configuration (if skypilot integration is enabled).
93    #[serde(default, rename = "skypilotConfig")]
94    pub skypilot_config: Option<SkypilotFleetConfig>,
95}
96
97/// Ansible fleet configuration.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99#[serde(rename_all = "camelCase")]
100pub struct AnsibleFleetConfig {
101    /// Host that runs ansible-playbook (the control node).
102    pub control_node: String,
103    /// Pinned ansible version (e.g. "2.16").
104    #[serde(default)]
105    pub ansible_version: Option<String>,
106    /// Ansible collections to include in the environment.
107    #[serde(default)]
108    pub collections: Vec<String>,
109    /// Path to playbooks directory.
110    #[serde(default)]
111    pub playbook_dir: Option<String>,
112    /// Additional host group assignments beyond tags.
113    #[serde(default)]
114    pub host_groups: HashMap<String, Vec<String>>,
115}
116
117/// Slurm fleet configuration.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119#[serde(rename_all = "camelCase")]
120pub struct SlurmFleetConfig {
121    /// Host for sbatch submission.
122    pub submit_node: String,
123    /// SSH user for submission.
124    pub submit_user: String,
125    /// Host running slurmctld.
126    pub control_node: String,
127    /// Partition definitions.
128    #[serde(default)]
129    pub partitions: HashMap<String, SlurmPartition>,
130}
131
132/// Slurm partition definition.
133#[derive(Debug, Clone, Serialize, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct SlurmPartition {
136    /// Nodes in this partition.
137    pub nodes: Vec<String>,
138    /// Whether this is the default partition.
139    #[serde(default)]
140    pub default: bool,
141    /// Maximum job time (e.g. "7-00:00:00").
142    #[serde(default)]
143    pub max_time: Option<String>,
144}
145
146/// Ray fleet configuration.
147#[derive(Debug, Clone, Serialize, Deserialize)]
148#[serde(rename_all = "camelCase")]
149pub struct RayFleetConfig {
150    /// Ray head node (or dashboard address for K8s).
151    pub head_address: String,
152    /// Ray dashboard port.
153    #[serde(default = "default_ray_port")]
154    pub dashboard_port: u16,
155    /// Whether ray is running on K8s (KubeRay) vs bare metal.
156    #[serde(default)]
157    pub kubernetes: bool,
158    /// Worker groups with their node assignments.
159    #[serde(default)]
160    pub worker_groups: HashMap<String, RayWorkerGroup>,
161}
162
163fn default_ray_port() -> u16 {
164    8265
165}
166
167/// Ray worker group definition.
168#[derive(Debug, Clone, Serialize, Deserialize)]
169#[serde(rename_all = "camelCase")]
170pub struct RayWorkerGroup {
171    /// Nodes in this group.
172    pub nodes: Vec<String>,
173    /// CPUs per worker.
174    #[serde(default)]
175    pub cpus: Option<u32>,
176    /// GPUs per worker.
177    #[serde(default)]
178    pub gpus: Option<u32>,
179    /// Memory per worker (MB).
180    #[serde(default)]
181    pub memory_mb: Option<u32>,
182}
183
184/// SkyPilot fleet configuration.
185#[derive(Debug, Clone, Serialize, Deserialize)]
186#[serde(rename_all = "camelCase")]
187pub struct SkypilotFleetConfig {
188    /// Default cloud provider.
189    pub cloud: String,
190    /// Default region.
191    #[serde(default)]
192    pub region: Option<String>,
193    /// Default instance type.
194    #[serde(default)]
195    pub instance_type: Option<String>,
196}
197
198fn default_flake_uri() -> String {
199    ".".to_string()
200}
201
202impl FleetConfig {
203    /// Load fleet configuration from a JSON file.
204    pub fn from_file(path: &Path) -> Result<Self, ConfigError> {
205        let content =
206            std::fs::read_to_string(path).map_err(|e| ConfigError::Io(path.to_path_buf(), e))?;
207        serde_json::from_str(&content).map_err(|e| ConfigError::Parse(path.to_path_buf(), e))
208    }
209
210    /// Load fleet configuration from a JSON string.
211    pub fn from_json(json: &str) -> Result<Self, ConfigError> {
212        serde_json::from_str(json).map_err(|e| ConfigError::Parse(PathBuf::from("<string>"), e))
213    }
214
215    /// Get nodes matching a set of tags (any match).
216    pub fn nodes_by_tags(&self, tags: &[String]) -> Vec<&DeploymentNode> {
217        self.nodes
218            .values()
219            .filter(|n| n.tags.iter().any(|t| tags.contains(t)))
220            .collect()
221    }
222
223    /// Get nodes matching a list of names (supports consortium NodeSet patterns).
224    pub fn nodes_by_names(&self, names: &[String]) -> Vec<&DeploymentNode> {
225        self.nodes
226            .values()
227            .filter(|n| names.contains(&n.name))
228            .collect()
229    }
230
231    /// Get all node names as a sorted vector.
232    pub fn node_names(&self) -> Vec<String> {
233        let mut names: Vec<_> = self.nodes.keys().cloned().collect();
234        names.sort();
235        names
236    }
237
238    /// Get all builder names as a sorted vector.
239    pub fn builder_names(&self) -> Vec<String> {
240        let mut names: Vec<_> = self.builders.keys().cloned().collect();
241        names.sort();
242        names
243    }
244
245    /// Generate a Nix machines file string from the builder pool.
246    pub fn machines_file(&self) -> String {
247        self.builders
248            .values()
249            .map(|b| {
250                let key = b.ssh_key.as_deref().unwrap_or("-");
251                let features = b.features.join(",");
252                let systems = b.systems.join(",");
253                format!(
254                    "{}://{}@{} {} {} {} {} {}",
255                    b.protocol, b.user, b.host, systems, key, b.max_jobs, b.speed_factor, features
256                )
257            })
258            .collect::<Vec<_>>()
259            .join("\n")
260    }
261}
262
263/// Deployment action to perform on targets.
264#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
265#[serde(rename_all = "kebab-case")]
266pub enum DeployAction {
267    /// Activate and set as boot default.
268    Switch,
269    /// Set as boot default without activating.
270    Boot,
271    /// Activate without setting as boot default.
272    Test,
273    /// Check what would change without activating.
274    DryActivate,
275    /// Only build, don't deploy.
276    Build,
277}
278
279impl std::fmt::Display for DeployAction {
280    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        match self {
282            DeployAction::Switch => write!(f, "switch"),
283            DeployAction::Boot => write!(f, "boot"),
284            DeployAction::Test => write!(f, "test"),
285            DeployAction::DryActivate => write!(f, "dry-activate"),
286            DeployAction::Build => write!(f, "build"),
287        }
288    }
289}
290
291impl std::str::FromStr for DeployAction {
292    type Err = ConfigError;
293
294    fn from_str(s: &str) -> Result<Self, Self::Err> {
295        match s {
296            "switch" => Ok(DeployAction::Switch),
297            "boot" => Ok(DeployAction::Boot),
298            "test" => Ok(DeployAction::Test),
299            "dry-activate" => Ok(DeployAction::DryActivate),
300            "build" => Ok(DeployAction::Build),
301            _ => Err(ConfigError::InvalidAction(s.to_string())),
302        }
303    }
304}
305
306/// A single target in a deployment plan.
307#[derive(Debug, Clone)]
308pub struct DeploymentTarget {
309    /// The node to deploy to.
310    pub node: DeploymentNode,
311    /// Built toplevel store path.
312    pub toplevel_path: String,
313    /// Current system path on the target (if known).
314    pub current_system: Option<String>,
315    /// Whether the closure needs building.
316    pub needs_build: bool,
317    /// Whether the closure needs copying to the target.
318    pub needs_copy: bool,
319}
320
321/// A deployment plan describing what to do.
322#[derive(Debug, Clone)]
323pub struct DeploymentPlan {
324    /// Targets to deploy.
325    pub targets: Vec<DeploymentTarget>,
326    /// Action to perform.
327    pub action: DeployAction,
328    /// Maximum parallel operations (fanout).
329    pub max_parallel: usize,
330}
331
332impl DeploymentPlan {
333    /// Create a new empty deployment plan.
334    pub fn new(action: DeployAction, max_parallel: usize) -> Self {
335        Self {
336            targets: Vec::new(),
337            action,
338            max_parallel,
339        }
340    }
341
342    /// Number of targets that need building.
343    pub fn build_count(&self) -> usize {
344        self.targets.iter().filter(|t| t.needs_build).count()
345    }
346
347    /// Number of targets that need closure copying.
348    pub fn copy_count(&self) -> usize {
349        self.targets.iter().filter(|t| t.needs_copy).count()
350    }
351
352    /// Total number of targets.
353    pub fn target_count(&self) -> usize {
354        self.targets.len()
355    }
356}
357
358/// Errors from configuration loading.
359#[derive(Debug, thiserror::Error)]
360pub enum ConfigError {
361    #[error("failed to read config file {0}: {1}")]
362    Io(PathBuf, std::io::Error),
363    #[error("failed to parse config file {0}: {1}")]
364    Parse(PathBuf, serde_json::Error),
365    #[error("invalid deploy action: {0}")]
366    InvalidAction(String),
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    fn sample_config_json() -> &'static str {
374        r#"{
375            "nodes": {
376                "hp01": {
377                    "name": "hp01",
378                    "targetHost": "192.168.1.121",
379                    "targetUser": "root",
380                    "targetPort": null,
381                    "system": "x86_64-linux",
382                    "profileType": "nixos",
383                    "buildOnTarget": false,
384                    "tags": ["build-host", "hpe"]
385                },
386                "mm01": {
387                    "name": "mm01",
388                    "targetHost": "192.168.1.111",
389                    "targetUser": "root",
390                    "targetPort": null,
391                    "system": "x86_64-linux",
392                    "profileType": "nixos",
393                    "buildOnTarget": false,
394                    "tags": ["ray"]
395                }
396            },
397            "builders": {
398                "hp01": {
399                    "host": "192.168.1.121",
400                    "user": "root",
401                    "maxJobs": 16,
402                    "speedFactor": 2,
403                    "systems": ["x86_64-linux"],
404                    "features": ["big-parallel", "kvm"],
405                    "sshKey": null,
406                    "protocol": "ssh-ng"
407                }
408            },
409            "flakeUri": "."
410        }"#
411    }
412
413    #[test]
414    fn test_parse_fleet_config() {
415        let config = FleetConfig::from_json(sample_config_json()).unwrap();
416        assert_eq!(config.nodes.len(), 2);
417        assert_eq!(config.builders.len(), 1);
418        assert_eq!(config.flake_uri, ".");
419    }
420
421    #[test]
422    fn test_node_fields() {
423        let config = FleetConfig::from_json(sample_config_json()).unwrap();
424        let hp01 = &config.nodes["hp01"];
425        assert_eq!(hp01.target_host, "192.168.1.121");
426        assert_eq!(hp01.target_user, "root");
427        assert_eq!(hp01.profile_type, ProfileType::Nixos);
428        assert!(!hp01.build_on_target);
429        assert_eq!(hp01.tags, vec!["build-host", "hpe"]);
430    }
431
432    #[test]
433    fn test_nodes_by_tags() {
434        let config = FleetConfig::from_json(sample_config_json()).unwrap();
435        let build_hosts = config.nodes_by_tags(&["build-host".to_string()]);
436        assert_eq!(build_hosts.len(), 1);
437        assert_eq!(build_hosts[0].name, "hp01");
438    }
439
440    #[test]
441    fn test_node_names_sorted() {
442        let config = FleetConfig::from_json(sample_config_json()).unwrap();
443        let names = config.node_names();
444        assert_eq!(names, vec!["hp01", "mm01"]);
445    }
446
447    #[test]
448    fn test_machines_file() {
449        let config = FleetConfig::from_json(sample_config_json()).unwrap();
450        let machines = config.machines_file();
451        assert!(machines.contains("ssh-ng://root@192.168.1.121"));
452        assert!(machines.contains("x86_64-linux"));
453        assert!(machines.contains("16"));
454        assert!(machines.contains("big-parallel,kvm"));
455    }
456
457    #[test]
458    fn test_deploy_action_display() {
459        assert_eq!(DeployAction::Switch.to_string(), "switch");
460        assert_eq!(DeployAction::DryActivate.to_string(), "dry-activate");
461    }
462
463    #[test]
464    fn test_deploy_action_parse() {
465        assert_eq!(
466            "switch".parse::<DeployAction>().unwrap(),
467            DeployAction::Switch
468        );
469        assert_eq!(
470            "dry-activate".parse::<DeployAction>().unwrap(),
471            DeployAction::DryActivate
472        );
473        assert!("invalid".parse::<DeployAction>().is_err());
474    }
475
476    #[test]
477    fn test_deployment_plan() {
478        let mut plan = DeploymentPlan::new(DeployAction::Switch, 4);
479        plan.targets.push(DeploymentTarget {
480            node: DeploymentNode {
481                name: "hp01".to_string(),
482                target_host: "192.168.1.121".to_string(),
483                target_user: "root".to_string(),
484                target_port: None,
485                system: "x86_64-linux".to_string(),
486                profile_type: ProfileType::Nixos,
487                build_on_target: false,
488                tags: vec![],
489                drv_path: None,
490                toplevel: None,
491            },
492            toplevel_path: "/nix/store/abc-nixos-system".to_string(),
493            current_system: Some("/nix/store/old-nixos-system".to_string()),
494            needs_build: true,
495            needs_copy: true,
496        });
497        assert_eq!(plan.build_count(), 1);
498        assert_eq!(plan.copy_count(), 1);
499        assert_eq!(plan.target_count(), 1);
500    }
501}