Skip to main content

vm_rs/vm/
manager.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::RwLock;
4
5use super::disk;
6use super::platform::create_platform_driver;
7use crate::config::{VmConfig, VmHandle, VmState};
8use crate::driver::{VmDriver, VmError};
9
10/// Manages all VMs on this node. Auto-selects the correct hypervisor driver.
11///
12/// Thread-safe: uses RwLock for concurrent access to VM state.
13pub struct VmManager {
14    driver: Box<dyn VmDriver>,
15    vms: RwLock<HashMap<String, VmHandle>>,
16    base_dir: PathBuf,
17}
18
19impl VmManager {
20    /// Create a new VmManager with auto-detected driver.
21    ///
22    /// `base_dir` is the root directory for VM data (disks, logs, configs).
23    /// Typically `~/.vm-rs/vms/` for development or `/var/lib/vm-rs/` for production.
24    pub fn new(base_dir: PathBuf) -> Result<Self, VmError> {
25        let driver = create_platform_driver()?;
26        std::fs::create_dir_all(&base_dir).map_err(VmError::Io)?;
27        Ok(Self {
28            driver,
29            vms: RwLock::new(HashMap::new()),
30            base_dir,
31        })
32    }
33
34    /// Create a VmManager with a specific driver (useful for testing).
35    pub fn with_driver(driver: Box<dyn VmDriver>, base_dir: PathBuf) -> Result<Self, VmError> {
36        std::fs::create_dir_all(&base_dir).map_err(VmError::Io)?;
37        Ok(Self {
38            driver,
39            vms: RwLock::new(HashMap::new()),
40            base_dir,
41        })
42    }
43
44    /// Base directory for all VM data.
45    pub fn base_dir(&self) -> &Path {
46        &self.base_dir
47    }
48
49    /// Directory for a specific VM's data (logs, disks, seed ISOs).
50    pub fn vm_dir(&self, name: &str) -> PathBuf {
51        self.base_dir.join(name)
52    }
53
54    /// Boot a VM. Creates the VM directory and delegates to the driver.
55    pub fn start(&self, config: &VmConfig) -> Result<VmHandle, VmError> {
56        config.validate()?;
57
58        {
59            let mut vms = self
60                .vms
61                .write()
62                .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
63            if let Some(existing) = vms.get(&config.name) {
64                if !matches!(existing.state, VmState::Stopped | VmState::Failed { .. }) {
65                    return Err(VmError::BootFailed {
66                        name: config.name.clone(),
67                        detail: format!("VM already exists in state: {}", existing.state),
68                    });
69                }
70            }
71            vms.insert(
72                config.name.clone(),
73                VmHandle {
74                    name: config.name.clone(),
75                    namespace: config.namespace.clone(),
76                    state: VmState::Starting,
77                    process: None,
78                    serial_log: config.serial_log.clone(),
79                    machine_id: None,
80                },
81            );
82        }
83
84        let vm_dir = self.vm_dir(&config.name);
85        if let Err(e) = std::fs::create_dir_all(&vm_dir) {
86            let mut vms = self
87                .vms
88                .write()
89                .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
90            vms.remove(&config.name);
91            return Err(VmError::Io(e));
92        }
93
94        tracing::info!(
95            vm = %config.name,
96            cpus = config.cpus,
97            memory_mb = config.memory_mb,
98            "booting VM"
99        );
100
101        let handle = match self.driver.boot(config) {
102            Ok(handle) => handle,
103            Err(err) => {
104                tracing::warn!(vm = %config.name, error = %err, "VM boot failed, cleaning up reservation");
105                let mut vms = self
106                    .vms
107                    .write()
108                    .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
109                vms.remove(&config.name);
110                return Err(err);
111            }
112        };
113
114        let mut vms = self
115            .vms
116            .write()
117            .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
118        vms.insert(config.name.clone(), handle.clone());
119
120        Ok(handle)
121    }
122
123    /// Stop a VM gracefully.
124    pub fn stop(&self, name: &str) -> Result<(), VmError> {
125        tracing::info!(vm = %name, "stopping VM");
126        let handle = self.get_handle(name)?;
127        self.driver.stop(&handle)?;
128        self.update_cached_state(name, VmState::Stopped)
129    }
130
131    /// Force-kill a VM.
132    pub fn kill(&self, name: &str) -> Result<(), VmError> {
133        tracing::info!(vm = %name, "force-killing VM");
134        let handle = self.get_handle(name)?;
135        self.driver.kill(&handle)?;
136        self.update_cached_state(name, VmState::Stopped)
137    }
138
139    /// Stop a VM using a pre-built handle (e.g. restored from persisted metadata).
140    pub fn stop_by_handle(&self, handle: &VmHandle) -> Result<(), VmError> {
141        tracing::info!(vm = %handle.name, "stopping VM by handle");
142        self.driver.stop(handle)?;
143        self.update_cached_state(&handle.name, VmState::Stopped)
144    }
145
146    /// Force-kill a VM using a pre-built handle (e.g. restored from persisted metadata).
147    pub fn kill_by_handle(&self, handle: &VmHandle) -> Result<(), VmError> {
148        tracing::info!(vm = %handle.name, "force-killing VM by handle");
149        self.driver.kill(handle)?;
150        self.update_cached_state(&handle.name, VmState::Stopped)
151    }
152
153    /// Pause a running VM.
154    pub fn pause(&self, name: &str) -> Result<(), VmError> {
155        let handle = self.get_handle(name)?;
156        self.driver.pause(&handle)?;
157        self.update_cached_state(name, VmState::Paused)
158    }
159
160    /// Resume a paused VM.
161    pub fn resume(&self, name: &str) -> Result<(), VmError> {
162        let handle = self.get_handle(name)?;
163        self.driver.resume(&handle)?;
164        let resumed_state = self.driver.state(&handle)?;
165        self.update_cached_state(name, resumed_state)
166    }
167
168    /// Query current state of a VM.
169    pub fn state(&self, name: &str) -> Result<VmState, VmError> {
170        let handle = self.get_handle(name)?;
171        let state = self.driver.state(&handle)?;
172        self.update_cached_state(name, state.clone())?;
173        Ok(state)
174    }
175
176    /// Get the IP address of a ready VM.
177    pub fn get_ip(&self, name: &str) -> Result<Option<String>, VmError> {
178        Ok(self.state(name)?.ip().map(ToOwned::to_owned))
179    }
180
181    /// List all tracked VMs.
182    pub fn list(&self) -> Result<Vec<VmHandle>, VmError> {
183        let vms = self
184            .vms
185            .read()
186            .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
187        Ok(vms.values().cloned().collect())
188    }
189
190    /// Wait for all VMs to emit the readiness marker within the timeout.
191    pub fn wait_all_ready(&self, timeout_secs: u64) -> Result<(), VmError> {
192        let start = std::time::Instant::now();
193        let timeout = std::time::Duration::from_secs(timeout_secs);
194
195        loop {
196            if start.elapsed() > timeout {
197                let pending = self.pending_names(|state| state.is_ready())?;
198                return Err(VmError::Hypervisor(format!(
199                    "timeout waiting for VMs: {}",
200                    pending.join(", ")
201                )));
202            }
203
204            let mut all_ready = true;
205            let names = self.vm_names()?;
206
207            for name in &names {
208                match self.state(name)? {
209                    state if state.is_ready() => {}
210                    VmState::Failed { reason } => {
211                        return Err(VmError::BootFailed {
212                            name: name.clone(),
213                            detail: reason,
214                        });
215                    }
216                    _ => {
217                        all_ready = false;
218                    }
219                }
220            }
221
222            if all_ready {
223                return Ok(());
224            }
225
226            let elapsed = start.elapsed().as_secs();
227            if elapsed > 0 && elapsed.is_multiple_of(10) {
228                let pending = self.pending_names(|state| state.is_ready())?;
229                tracing::info!(
230                    pending = ?pending,
231                    elapsed_secs = elapsed,
232                    "waiting for VMs to become ready"
233                );
234            }
235
236            std::thread::sleep(std::time::Duration::from_secs(1));
237        }
238    }
239
240    /// Create a disk image as a CoW clone of a base image.
241    pub fn clone_disk(base: &Path, target: &Path) -> Result<(), VmError> {
242        disk::clone_disk(base, target)
243    }
244
245    fn get_handle(&self, name: &str) -> Result<VmHandle, VmError> {
246        let vms = self
247            .vms
248            .read()
249            .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
250        vms.get(name).cloned().ok_or_else(|| VmError::NotFound {
251            name: name.to_string(),
252        })
253    }
254
255    fn update_cached_state(&self, name: &str, state: VmState) -> Result<(), VmError> {
256        let mut vms = self
257            .vms
258            .write()
259            .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
260        if let Some(handle) = vms.get_mut(name) {
261            handle.state = state;
262        }
263        Ok(())
264    }
265
266    fn vm_names(&self) -> Result<Vec<String>, VmError> {
267        let vms = self
268            .vms
269            .read()
270            .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
271        Ok(vms.keys().cloned().collect())
272    }
273
274    fn pending_names(&self, predicate: impl Fn(&VmState) -> bool) -> Result<Vec<String>, VmError> {
275        let vms = self
276            .vms
277            .read()
278            .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
279        Ok(vms
280            .iter()
281            .filter(|(_, handle)| !predicate(&handle.state))
282            .map(|(name, _)| name.clone())
283            .collect())
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use std::path::{Path, PathBuf};
290    use std::sync::atomic::{AtomicUsize, Ordering};
291    use std::sync::{mpsc, Arc, Mutex};
292
293    use super::*;
294
295    struct BlockingDriver {
296        boot_calls: AtomicUsize,
297        release_rx: Mutex<Option<mpsc::Receiver<()>>>,
298    }
299
300    struct FailedStateDriver;
301    struct ReadyAfterTwoPollsDriver {
302        polls: AtomicUsize,
303    }
304
305    impl VmDriver for Arc<BlockingDriver> {
306        fn boot(&self, config: &VmConfig) -> Result<VmHandle, VmError> {
307            self.boot_calls.fetch_add(1, Ordering::SeqCst);
308            if let Some(rx) = self.release_rx.lock().expect("release lock").take() {
309                rx.recv().expect("release boot");
310            }
311            Ok(VmHandle {
312                name: config.name.clone(),
313                namespace: config.namespace.clone(),
314                state: VmState::Starting,
315                process: None,
316                serial_log: config.serial_log.clone(),
317                machine_id: None,
318            })
319        }
320
321        fn stop(&self, _handle: &VmHandle) -> Result<(), VmError> {
322            Ok(())
323        }
324
325        fn kill(&self, _handle: &VmHandle) -> Result<(), VmError> {
326            Ok(())
327        }
328
329        fn state(&self, _handle: &VmHandle) -> Result<VmState, VmError> {
330            Ok(VmState::Stopped)
331        }
332    }
333
334    impl VmDriver for FailedStateDriver {
335        fn boot(&self, config: &VmConfig) -> Result<VmHandle, VmError> {
336            Ok(VmHandle {
337                name: config.name.clone(),
338                namespace: config.namespace.clone(),
339                state: VmState::Starting,
340                process: None,
341                serial_log: config.serial_log.clone(),
342                machine_id: None,
343            })
344        }
345
346        fn stop(&self, _handle: &VmHandle) -> Result<(), VmError> {
347            Ok(())
348        }
349
350        fn kill(&self, _handle: &VmHandle) -> Result<(), VmError> {
351            Ok(())
352        }
353
354        fn state(&self, _handle: &VmHandle) -> Result<VmState, VmError> {
355            Ok(VmState::Failed {
356                reason: "crashed".into(),
357            })
358        }
359    }
360
361    impl VmDriver for ReadyAfterTwoPollsDriver {
362        fn boot(&self, config: &VmConfig) -> Result<VmHandle, VmError> {
363            Ok(VmHandle {
364                name: config.name.clone(),
365                namespace: config.namespace.clone(),
366                state: VmState::Starting,
367                process: None,
368                serial_log: config.serial_log.clone(),
369                machine_id: None,
370            })
371        }
372
373        fn stop(&self, _handle: &VmHandle) -> Result<(), VmError> {
374            Ok(())
375        }
376
377        fn kill(&self, _handle: &VmHandle) -> Result<(), VmError> {
378            Ok(())
379        }
380
381        fn state(&self, _handle: &VmHandle) -> Result<VmState, VmError> {
382            let poll = self.polls.fetch_add(1, Ordering::SeqCst);
383            if poll == 0 {
384                Ok(VmState::Running)
385            } else {
386                Ok(VmState::Ready {
387                    ip: "10.0.0.2".into(),
388                })
389            }
390        }
391    }
392
393    fn test_config(base_dir: &Path) -> VmConfig {
394        VmConfig {
395            name: "test-vm".into(),
396            namespace: "tests".into(),
397            kernel: PathBuf::from("/tmp/kernel"),
398            initramfs: None,
399            root_disk: None,
400            data_disk: None,
401            seed_iso: None,
402            cpus: 1,
403            memory_mb: 256,
404            networks: vec![],
405            shared_dirs: vec![],
406            serial_log: base_dir.join("serial.log"),
407            cmdline: None,
408            netns: None,
409            vsock: false,
410            machine_id: None,
411            efi_variable_store: None,
412            rosetta: false,
413        }
414    }
415
416    #[test]
417    fn start_reserves_name_before_driver_boot() {
418        let tmp = tempfile::tempdir().expect("tempdir");
419        let (release_tx, release_rx) = mpsc::channel();
420        let driver = Arc::new(BlockingDriver {
421            boot_calls: AtomicUsize::new(0),
422            release_rx: Mutex::new(Some(release_rx)),
423        });
424        let manager = Arc::new(
425            VmManager::with_driver(Box::new(driver.clone()), tmp.path().to_path_buf())
426                .expect("manager"),
427        );
428        let config = test_config(tmp.path());
429
430        let manager_clone = Arc::clone(&manager);
431        let config_clone = config.clone();
432        let boot_thread = std::thread::spawn(move || manager_clone.start(&config_clone));
433
434        while driver.boot_calls.load(Ordering::SeqCst) == 0 {
435            std::thread::sleep(std::time::Duration::from_millis(10));
436        }
437
438        let err = manager
439            .start(&config)
440            .expect_err("second concurrent start should fail");
441        assert!(err.to_string().contains("already exists"));
442
443        release_tx.send(()).expect("release first boot");
444        boot_thread
445            .join()
446            .expect("join")
447            .expect("first boot should succeed");
448
449        assert_eq!(driver.boot_calls.load(Ordering::SeqCst), 1);
450    }
451
452    #[test]
453    fn restart_is_allowed_after_failed_state_is_observed() {
454        let tmp = tempfile::tempdir().expect("tempdir");
455        let manager = VmManager::with_driver(Box::new(FailedStateDriver), tmp.path().to_path_buf())
456            .expect("manager");
457        let config = test_config(tmp.path());
458
459        manager.start(&config).expect("first boot");
460        let state = manager.state(&config.name).expect("state query");
461        assert!(matches!(state, VmState::Failed { .. }));
462
463        manager
464            .start(&config)
465            .expect("restart after failed state should be allowed");
466    }
467
468    #[test]
469    fn wait_all_ready_waits_for_ready_not_just_running() {
470        let tmp = tempfile::tempdir().expect("tempdir");
471        let manager = VmManager::with_driver(
472            Box::new(ReadyAfterTwoPollsDriver {
473                polls: AtomicUsize::new(0),
474            }),
475            tmp.path().to_path_buf(),
476        )
477        .expect("manager");
478        let config = test_config(tmp.path());
479
480        manager.start(&config).expect("boot");
481        manager
482            .wait_all_ready(2)
483            .expect("wait_all_ready should wait until ready");
484        assert_eq!(
485            manager.get_ip(&config.name).expect("ip query"),
486            Some("10.0.0.2".into())
487        );
488    }
489}