Skip to main content

microvm_runtime/adapters/
in_memory.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, RwLock},
4};
5
6use crate::{
7    error::{VmRuntimeError, VmRuntimeResult},
8    model::{VmStatus, VmView},
9    provider::{VmProvider, VmQuery},
10};
11
12/// In-memory VM provider for development and testing.
13///
14/// Replace with a hypervisor-backed adapter (e.g. Firecracker, Cloud Hypervisor)
15/// for production use.
16#[derive(Debug, Clone, Default)]
17pub struct InMemoryVmProvider {
18    state: Arc<RwLock<HashMap<String, VmRecord>>>,
19}
20
21#[derive(Debug, Clone)]
22struct VmRecord {
23    status: VmStatus,
24    snapshots: Vec<String>,
25}
26
27impl VmRecord {
28    fn view(&self, vm_id: &str) -> VmView {
29        VmView {
30            vm_id: vm_id.to_owned(),
31            status: self.status,
32            snapshots: self.snapshots.clone(),
33        }
34    }
35}
36
37impl VmProvider for InMemoryVmProvider {
38    fn create_vm(&self, vm_id: &str) -> VmRuntimeResult<()> {
39        let mut state = self
40            .state
41            .write()
42            .map_err(|_| VmRuntimeError::StatePoisoned)?;
43
44        if state.contains_key(vm_id) {
45            return Err(VmRuntimeError::VmAlreadyExists(vm_id.to_owned()));
46        }
47
48        state.insert(
49            vm_id.to_owned(),
50            VmRecord {
51                status: VmStatus::Created,
52                snapshots: Vec::new(),
53            },
54        );
55
56        Ok(())
57    }
58
59    fn start_vm(&self, vm_id: &str) -> VmRuntimeResult<()> {
60        let mut state = self
61            .state
62            .write()
63            .map_err(|_| VmRuntimeError::StatePoisoned)?;
64        let record = state
65            .get_mut(vm_id)
66            .ok_or_else(|| VmRuntimeError::VmNotFound(vm_id.to_owned()))?;
67
68        match record.status {
69            VmStatus::Created | VmStatus::Stopped => {
70                record.status = VmStatus::Running;
71                Ok(())
72            }
73            other => Err(VmRuntimeError::InvalidTransition {
74                vm_id: vm_id.to_owned(),
75                from: other.to_string(),
76                to: "running",
77            }),
78        }
79    }
80
81    fn stop_vm(&self, vm_id: &str) -> VmRuntimeResult<()> {
82        let mut state = self
83            .state
84            .write()
85            .map_err(|_| VmRuntimeError::StatePoisoned)?;
86        let record = state
87            .get_mut(vm_id)
88            .ok_or_else(|| VmRuntimeError::VmNotFound(vm_id.to_owned()))?;
89
90        match record.status {
91            VmStatus::Running => {
92                record.status = VmStatus::Stopped;
93                Ok(())
94            }
95            other => Err(VmRuntimeError::InvalidTransition {
96                vm_id: vm_id.to_owned(),
97                from: other.to_string(),
98                to: "stopped",
99            }),
100        }
101    }
102
103    fn snapshot_vm(&self, vm_id: &str, snapshot_id: &str) -> VmRuntimeResult<()> {
104        let mut state = self
105            .state
106            .write()
107            .map_err(|_| VmRuntimeError::StatePoisoned)?;
108        let record = state
109            .get_mut(vm_id)
110            .ok_or_else(|| VmRuntimeError::VmNotFound(vm_id.to_owned()))?;
111
112        if record.status == VmStatus::Destroyed {
113            return Err(VmRuntimeError::InvalidTransition {
114                vm_id: vm_id.to_owned(),
115                from: VmStatus::Destroyed.to_string(),
116                to: "snapshot",
117            });
118        }
119
120        if record
121            .snapshots
122            .iter()
123            .any(|existing| existing == snapshot_id)
124        {
125            return Err(VmRuntimeError::SnapshotAlreadyExists {
126                vm_id: vm_id.to_owned(),
127                snapshot_id: snapshot_id.to_owned(),
128            });
129        }
130
131        record.snapshots.push(snapshot_id.to_owned());
132        Ok(())
133    }
134
135    fn destroy_vm(&self, vm_id: &str) -> VmRuntimeResult<()> {
136        let mut state = self
137            .state
138            .write()
139            .map_err(|_| VmRuntimeError::StatePoisoned)?;
140        let record = state
141            .get_mut(vm_id)
142            .ok_or_else(|| VmRuntimeError::VmNotFound(vm_id.to_owned()))?;
143
144        if record.status == VmStatus::Destroyed {
145            return Err(VmRuntimeError::InvalidTransition {
146                vm_id: vm_id.to_owned(),
147                from: VmStatus::Destroyed.to_string(),
148                to: "destroyed",
149            });
150        }
151
152        record.status = VmStatus::Destroyed;
153        Ok(())
154    }
155}
156
157impl VmQuery for InMemoryVmProvider {
158    fn list_vms(&self) -> VmRuntimeResult<Vec<VmView>> {
159        let state = self
160            .state
161            .read()
162            .map_err(|_| VmRuntimeError::StatePoisoned)?;
163        let mut views = state
164            .iter()
165            .map(|(vm_id, record)| record.view(vm_id))
166            .collect::<Vec<_>>();
167
168        views.sort_by(|a, b| a.vm_id.cmp(&b.vm_id));
169        Ok(views)
170    }
171
172    fn get_vm(&self, vm_id: &str) -> VmRuntimeResult<Option<VmView>> {
173        let state = self
174            .state
175            .read()
176            .map_err(|_| VmRuntimeError::StatePoisoned)?;
177        Ok(state.get(vm_id).map(|record| record.view(vm_id)))
178    }
179
180    fn list_snapshots(&self, vm_id: &str) -> VmRuntimeResult<Option<Vec<String>>> {
181        let state = self
182            .state
183            .read()
184            .map_err(|_| VmRuntimeError::StatePoisoned)?;
185        Ok(state.get(vm_id).map(|record| record.snapshots.clone()))
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn full_lifecycle() {
195        let provider = InMemoryVmProvider::default();
196
197        provider.create_vm("vm-a").expect("create");
198        provider.start_vm("vm-a").expect("start");
199        provider.snapshot_vm("vm-a", "snap-1").expect("snapshot");
200        provider.stop_vm("vm-a").expect("stop");
201        provider.destroy_vm("vm-a").expect("destroy");
202
203        let vm = provider.get_vm("vm-a").expect("query").expect("vm exists");
204
205        assert_eq!(vm.status, VmStatus::Destroyed);
206        assert_eq!(vm.snapshots, vec!["snap-1".to_owned()]);
207    }
208}