Skip to main content

arcbox_hypervisor/linux/
memory.rs

1//! Guest memory implementation for Linux KVM.
2
3use std::os::unix::io::RawFd;
4use std::sync::RwLock;
5
6use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
7
8use crate::{
9    error::HypervisorError,
10    memory::{GuestAddress, MemoryRegion, PAGE_SIZE},
11    traits::GuestMemory,
12    types::DirtyPageInfo,
13};
14
15use super::ffi;
16
17/// Guest memory implementation for Linux KVM.
18///
19/// This manages the guest physical address space using mmap'd memory
20/// that is registered with KVM via KVM_SET_USER_MEMORY_REGION.
21pub struct KvmMemory {
22    /// Memory regions.
23    regions: RwLock<Vec<MappedRegion>>,
24    /// Total memory size.
25    total_size: u64,
26    /// Base host address (for the primary region).
27    base_host_addr: *mut u8,
28    /// KVM VM fd for dirty logging (set by VM after creation).
29    vm_fd: AtomicI32,
30    /// Memory slots tracked for dirty logging.
31    memory_slots: RwLock<Vec<MemorySlotInfo>>,
32    /// Whether dirty page tracking is enabled.
33    dirty_tracking_enabled: AtomicBool,
34}
35
36/// A mapped memory region with its host backing.
37struct MappedRegion {
38    /// Guest physical address.
39    guest_addr: GuestAddress,
40    /// Size in bytes.
41    size: u64,
42    /// Host virtual address.
43    host_addr: *mut u8,
44    /// Whether this region is read-only.
45    read_only: bool,
46    /// Whether this region was allocated by us (vs provided externally).
47    owned: bool,
48}
49
50/// Memory slot tracking for dirty logging.
51struct MemorySlotInfo {
52    /// Slot ID.
53    slot: u32,
54    /// Guest physical address.
55    guest_phys_addr: u64,
56    /// Size in bytes.
57    size: u64,
58    /// Host virtual address.
59    userspace_addr: u64,
60    /// Base flags for the slot (e.g. read-only).
61    flags: u32,
62}
63
64// SAFETY: The host_addr pointer points to mmap'd memory that is valid
65// for the lifetime of the KvmMemory instance.
66unsafe impl Send for MappedRegion {}
67unsafe impl Sync for MappedRegion {}
68
69// SAFETY: KvmMemory contains only atomic slot counter, a Vec of MappedRegions
70// (which are Send+Sync), and a KVM VM file descriptor. The VM fd is thread-safe
71// as KVM ioctls are designed to be called from multiple threads.
72unsafe impl Send for KvmMemory {}
73unsafe impl Sync for KvmMemory {}
74
75impl KvmMemory {
76    /// Creates a new guest memory region.
77    ///
78    /// # Errors
79    ///
80    /// Returns an error if memory allocation fails.
81    pub fn new(size: u64) -> Result<Self, HypervisorError> {
82        // Allocate the main memory region at guest address 0
83        let host_addr = ffi::allocate_memory(size).map_err(|e| {
84            HypervisorError::MemoryError(format!("Failed to allocate memory: {}", e))
85        })?;
86
87        let region = MappedRegion {
88            guest_addr: GuestAddress::new(0),
89            size,
90            host_addr,
91            read_only: false,
92            owned: true,
93        };
94
95        tracing::debug!("Created guest memory: {}MB", size / (1024 * 1024));
96
97        Ok(Self {
98            regions: RwLock::new(vec![region]),
99            total_size: size,
100            base_host_addr: host_addr,
101            vm_fd: AtomicI32::new(-1),
102            memory_slots: RwLock::new(Vec::new()),
103            dirty_tracking_enabled: AtomicBool::new(false),
104        })
105    }
106
107    /// Returns the host address of the base memory region.
108    ///
109    /// This is used when registering memory with KVM.
110    pub fn host_address(&self) -> *mut u8 {
111        self.base_host_addr
112    }
113
114    /// Adds an additional memory region.
115    ///
116    /// # Errors
117    ///
118    /// Returns an error if the region overlaps with existing regions or
119    /// memory allocation fails.
120    pub fn add_region(
121        &self,
122        guest_addr: GuestAddress,
123        size: u64,
124    ) -> Result<*mut u8, HypervisorError> {
125        let host_addr = ffi::allocate_memory(size).map_err(|e| {
126            HypervisorError::MemoryError(format!("Failed to allocate memory: {}", e))
127        })?;
128
129        let new_region = MappedRegion {
130            guest_addr,
131            size,
132            host_addr,
133            read_only: false,
134            owned: true,
135        };
136
137        let mut regions = self
138            .regions
139            .write()
140            .map_err(|_| HypervisorError::MemoryError("Lock poisoned".to_string()))?;
141
142        // Check for overlaps
143        let new_end = guest_addr.raw() + size;
144        for region in regions.iter() {
145            let existing_end = region.guest_addr.raw() + region.size;
146            if guest_addr.raw() < existing_end && new_end > region.guest_addr.raw() {
147                // Free the allocated memory before returning error
148                ffi::free_memory(host_addr, size);
149                return Err(HypervisorError::MemoryError(
150                    "Region overlaps with existing region".to_string(),
151                ));
152            }
153        }
154
155        let ptr = host_addr;
156        regions.push(new_region);
157
158        tracing::debug!(
159            "Added memory region at {}: {}MB",
160            guest_addr,
161            size / (1024 * 1024)
162        );
163
164        Ok(ptr)
165    }
166
167    /// Adds an externally allocated memory region.
168    ///
169    /// The caller is responsible for ensuring the memory remains valid
170    /// for the lifetime of this object.
171    ///
172    /// # Safety
173    ///
174    /// The host_addr must point to valid memory of at least `size` bytes
175    /// that will remain valid for the lifetime of this KvmMemory.
176    pub unsafe fn add_external_region(
177        &self,
178        guest_addr: GuestAddress,
179        host_addr: *mut u8,
180        size: u64,
181        read_only: bool,
182    ) -> Result<(), HypervisorError> {
183        let new_region = MappedRegion {
184            guest_addr,
185            size,
186            host_addr,
187            read_only,
188            owned: false, // Not owned by us
189        };
190
191        let mut regions = self
192            .regions
193            .write()
194            .map_err(|_| HypervisorError::MemoryError("Lock poisoned".to_string()))?;
195
196        // Check for overlaps
197        let new_end = guest_addr.raw() + size;
198        for region in regions.iter() {
199            let existing_end = region.guest_addr.raw() + region.size;
200            if guest_addr.raw() < existing_end && new_end > region.guest_addr.raw() {
201                return Err(HypervisorError::MemoryError(
202                    "Region overlaps with existing region".to_string(),
203                ));
204            }
205        }
206
207        regions.push(new_region);
208
209        tracing::debug!(
210            "Added external memory region at {}: {}MB, read_only={}",
211            guest_addr,
212            size / (1024 * 1024),
213            read_only
214        );
215
216        Ok(())
217    }
218
219    /// Attaches the KVM VM fd for dirty logging.
220    pub fn attach_vm_fd(&self, vm_fd: RawFd) {
221        self.vm_fd.store(vm_fd, Ordering::SeqCst);
222    }
223
224    /// Registers a memory slot for dirty logging.
225    ///
226    /// # Errors
227    ///
228    /// Returns an error if the slot list cannot be updated.
229    pub fn register_slot(
230        &self,
231        slot: u32,
232        guest_phys_addr: u64,
233        size: u64,
234        userspace_addr: u64,
235        flags: u32,
236    ) -> Result<(), HypervisorError> {
237        let mut slots = self
238            .memory_slots
239            .write()
240            .map_err(|_| HypervisorError::SnapshotError("Lock poisoned".to_string()))?;
241
242        if let Some(existing) = slots.iter_mut().find(|s| s.slot == slot) {
243            existing.guest_phys_addr = guest_phys_addr;
244            existing.size = size;
245            existing.userspace_addr = userspace_addr;
246            existing.flags = flags;
247        } else {
248            slots.push(MemorySlotInfo {
249                slot,
250                guest_phys_addr,
251                size,
252                userspace_addr,
253                flags,
254            });
255        }
256
257        if self.dirty_tracking_enabled.load(Ordering::SeqCst) {
258            let fd = self.vm_fd()?;
259            let slot_info = slots.iter().find(|s| s.slot == slot).unwrap();
260            self.update_dirty_logging(fd, slot_info, true)?;
261        }
262
263        Ok(())
264    }
265
266    /// Unregisters a memory slot.
267    ///
268    /// # Errors
269    ///
270    /// Returns an error if the slot list cannot be updated.
271    pub fn unregister_slot(&self, slot: u32) -> Result<(), HypervisorError> {
272        let mut slots = self
273            .memory_slots
274            .write()
275            .map_err(|_| HypervisorError::SnapshotError("Lock poisoned".to_string()))?;
276        slots.retain(|entry| entry.slot != slot);
277        Ok(())
278    }
279
280    /// Updates the dirty tracking enabled flag from external callers.
281    pub fn set_dirty_tracking_enabled(&self, enabled: bool) {
282        self.dirty_tracking_enabled.store(enabled, Ordering::SeqCst);
283    }
284
285    fn vm_fd(&self) -> Result<RawFd, HypervisorError> {
286        let fd = self.vm_fd.load(Ordering::SeqCst);
287        if fd < 0 {
288            return Err(HypervisorError::SnapshotError(
289                "KVM VM fd not attached".to_string(),
290            ));
291        }
292        Ok(fd)
293    }
294
295    fn update_dirty_logging(
296        &self,
297        fd: RawFd,
298        slot: &MemorySlotInfo,
299        enable: bool,
300    ) -> Result<(), HypervisorError> {
301        let flags = if enable {
302            slot.flags | ffi::KVM_MEM_LOG_DIRTY_PAGES
303        } else {
304            slot.flags
305        };
306
307        let region = ffi::KvmUserspaceMemoryRegion {
308            slot: slot.slot,
309            flags,
310            guest_phys_addr: slot.guest_phys_addr,
311            memory_size: slot.size,
312            userspace_addr: slot.userspace_addr,
313        };
314
315        let ret = unsafe {
316            libc::ioctl(
317                fd,
318                ffi::KVM_SET_USER_MEMORY_REGION,
319                &region as *const _ as libc::c_ulong,
320            )
321        };
322
323        if ret < 0 {
324            return Err(HypervisorError::SnapshotError(format!(
325                "Failed to {} dirty logging for slot {}: {}",
326                if enable { "enable" } else { "disable" },
327                slot.slot,
328                std::io::Error::last_os_error()
329            )));
330        }
331
332        Ok(())
333    }
334
335    fn get_dirty_log(&self, fd: RawFd, slot: &MemorySlotInfo) -> Result<Vec<u64>, HypervisorError> {
336        let num_pages = (slot.size + PAGE_SIZE - 1) / PAGE_SIZE;
337        let bitmap_size = ((num_pages + 63) / 64) as usize;
338        let mut bitmap: Vec<u64> = vec![0; bitmap_size];
339
340        let dirty_log = ffi::KvmDirtyLog {
341            slot: slot.slot,
342            padding: 0,
343            dirty_bitmap: bitmap.as_mut_ptr(),
344        };
345
346        let ret = unsafe {
347            libc::ioctl(
348                fd,
349                ffi::KVM_GET_DIRTY_LOG,
350                &dirty_log as *const _ as libc::c_ulong,
351            )
352        };
353
354        if ret < 0 {
355            return Err(HypervisorError::SnapshotError(format!(
356                "Failed to get dirty log for slot {}: {}",
357                slot.slot,
358                std::io::Error::last_os_error()
359            )));
360        }
361
362        Ok(bitmap)
363    }
364
365    fn parse_dirty_bitmap(bitmap: &[u64], base_addr: u64, size: u64) -> Vec<DirtyPageInfo> {
366        let mut pages = Vec::new();
367        let num_pages = size / PAGE_SIZE;
368
369        for (word_idx, &word) in bitmap.iter().enumerate() {
370            if word == 0 {
371                continue;
372            }
373
374            for bit_idx in 0..64 {
375                if (word >> bit_idx) & 1 != 0 {
376                    let page_num = (word_idx as u64 * 64) + bit_idx as u64;
377                    if page_num < num_pages {
378                        pages.push(DirtyPageInfo {
379                            guest_addr: base_addr + page_num * PAGE_SIZE,
380                            size: PAGE_SIZE,
381                        });
382                    }
383                }
384            }
385        }
386
387        pages
388    }
389
390    /// Finds the region containing the given address.
391    fn find_region(&self, addr: GuestAddress) -> Result<(*mut u8, u64, bool), HypervisorError> {
392        let regions = self
393            .regions
394            .read()
395            .map_err(|_| HypervisorError::MemoryError("Lock poisoned".to_string()))?;
396
397        for region in regions.iter() {
398            if addr.raw() >= region.guest_addr.raw()
399                && addr.raw() < region.guest_addr.raw() + region.size
400            {
401                let offset = addr.raw() - region.guest_addr.raw();
402                let remaining = region.size - offset;
403                let ptr = unsafe { region.host_addr.add(offset as usize) };
404                return Ok((ptr, remaining, region.read_only));
405            }
406        }
407
408        Err(HypervisorError::MemoryError(format!(
409            "Address {} not mapped",
410            addr
411        )))
412    }
413
414    /// Returns an iterator over all memory regions.
415    pub fn regions(&self) -> Result<Vec<MemoryRegion>, HypervisorError> {
416        let regions = self
417            .regions
418            .read()
419            .map_err(|_| HypervisorError::MemoryError("Lock poisoned".to_string()))?;
420
421        Ok(regions
422            .iter()
423            .map(|r| MemoryRegion {
424                guest_addr: r.guest_addr,
425                size: r.size,
426                host_addr: Some(r.host_addr),
427                read_only: r.read_only,
428            })
429            .collect())
430    }
431
432    /// Writes a value to guest memory at the specified address.
433    pub fn write_obj<T: Copy>(&self, addr: GuestAddress, val: &T) -> Result<(), HypervisorError> {
434        let bytes = unsafe {
435            std::slice::from_raw_parts(val as *const T as *const u8, std::mem::size_of::<T>())
436        };
437        self.write(addr, bytes)
438    }
439
440    /// Reads a value from guest memory at the specified address.
441    pub fn read_obj<T: Copy + Default>(&self, addr: GuestAddress) -> Result<T, HypervisorError> {
442        let mut val = T::default();
443        let bytes = unsafe {
444            std::slice::from_raw_parts_mut(&mut val as *mut T as *mut u8, std::mem::size_of::<T>())
445        };
446        self.read(addr, bytes)?;
447        Ok(val)
448    }
449
450    /// Fills a range of guest memory with a byte value.
451    pub fn memset(&self, addr: GuestAddress, val: u8, len: usize) -> Result<(), HypervisorError> {
452        let (ptr, remaining, read_only) = self.find_region(addr)?;
453
454        if read_only {
455            return Err(HypervisorError::MemoryError(
456                "Cannot write to read-only region".to_string(),
457            ));
458        }
459
460        if len as u64 > remaining {
461            return Err(HypervisorError::MemoryError(format!(
462                "Memset of {} bytes at {} exceeds region bounds",
463                len, addr
464            )));
465        }
466
467        unsafe {
468            std::ptr::write_bytes(ptr, val, len);
469        }
470
471        Ok(())
472    }
473}
474
475impl GuestMemory for KvmMemory {
476    fn read(&self, addr: GuestAddress, buf: &mut [u8]) -> Result<(), HypervisorError> {
477        let (ptr, remaining, _) = self.find_region(addr)?;
478
479        if buf.len() as u64 > remaining {
480            return Err(HypervisorError::MemoryError(format!(
481                "Read of {} bytes at {} exceeds region bounds",
482                buf.len(),
483                addr
484            )));
485        }
486
487        unsafe {
488            std::ptr::copy_nonoverlapping(ptr, buf.as_mut_ptr(), buf.len());
489        }
490
491        Ok(())
492    }
493
494    fn write(&self, addr: GuestAddress, buf: &[u8]) -> Result<(), HypervisorError> {
495        let (ptr, remaining, read_only) = self.find_region(addr)?;
496
497        if read_only {
498            return Err(HypervisorError::MemoryError(
499                "Cannot write to read-only region".to_string(),
500            ));
501        }
502
503        if buf.len() as u64 > remaining {
504            return Err(HypervisorError::MemoryError(format!(
505                "Write of {} bytes at {} exceeds region bounds",
506                buf.len(),
507                addr
508            )));
509        }
510
511        unsafe {
512            std::ptr::copy_nonoverlapping(buf.as_ptr(), ptr, buf.len());
513        }
514
515        Ok(())
516    }
517
518    fn get_host_address(&self, addr: GuestAddress) -> Result<*mut u8, HypervisorError> {
519        let (ptr, _, _) = self.find_region(addr)?;
520        Ok(ptr)
521    }
522
523    fn size(&self) -> u64 {
524        self.total_size
525    }
526
527    fn enable_dirty_tracking(&mut self) -> Result<(), HypervisorError> {
528        if self.dirty_tracking_enabled.load(Ordering::SeqCst) {
529            return Ok(());
530        }
531
532        let fd = self.vm_fd()?;
533        let slots = self
534            .memory_slots
535            .read()
536            .map_err(|_| HypervisorError::SnapshotError("Lock poisoned".to_string()))?;
537
538        for slot in slots.iter() {
539            self.update_dirty_logging(fd, slot, true)?;
540        }
541
542        self.dirty_tracking_enabled.store(true, Ordering::SeqCst);
543        tracing::debug!("Dirty page tracking enabled");
544        Ok(())
545    }
546
547    fn disable_dirty_tracking(&mut self) -> Result<(), HypervisorError> {
548        if !self.dirty_tracking_enabled.load(Ordering::SeqCst) {
549            return Ok(());
550        }
551
552        let fd = self.vm_fd()?;
553        let slots = self
554            .memory_slots
555            .read()
556            .map_err(|_| HypervisorError::SnapshotError("Lock poisoned".to_string()))?;
557
558        for slot in slots.iter() {
559            self.update_dirty_logging(fd, slot, false)?;
560        }
561
562        self.dirty_tracking_enabled.store(false, Ordering::SeqCst);
563        tracing::debug!("Dirty page tracking disabled");
564        Ok(())
565    }
566
567    fn get_dirty_pages(&mut self) -> Result<Vec<DirtyPageInfo>, HypervisorError> {
568        if !self.dirty_tracking_enabled.load(Ordering::SeqCst) {
569            return Err(HypervisorError::SnapshotError(
570                "Dirty tracking not enabled".to_string(),
571            ));
572        }
573
574        let fd = self.vm_fd()?;
575        let slots = self
576            .memory_slots
577            .read()
578            .map_err(|_| HypervisorError::SnapshotError("Lock poisoned".to_string()))?;
579
580        let mut dirty_pages = Vec::new();
581        for slot in slots.iter() {
582            let bitmap = self.get_dirty_log(fd, slot)?;
583            let pages = Self::parse_dirty_bitmap(&bitmap, slot.guest_phys_addr, slot.size);
584            dirty_pages.extend(pages);
585        }
586
587        Ok(dirty_pages)
588    }
589
590    fn dump_all(&self, buf: &mut [u8]) -> Result<(), HypervisorError> {
591        if (buf.len() as u64) < self.total_size {
592            return Err(HypervisorError::MemoryError(format!(
593                "Buffer too small: {} bytes, need {} bytes",
594                buf.len(),
595                self.total_size
596            )));
597        }
598
599        let regions = self
600            .regions
601            .read()
602            .map_err(|_| HypervisorError::MemoryError("Lock poisoned".to_string()))?;
603
604        // Copy each region to the appropriate offset in the buffer.
605        for region in regions.iter() {
606            let offset = region.guest_addr.raw() as usize;
607            let end = offset + region.size as usize;
608
609            if end > buf.len() {
610                return Err(HypervisorError::MemoryError(format!(
611                    "Region at {} with size {} exceeds buffer",
612                    region.guest_addr, region.size
613                )));
614            }
615
616            unsafe {
617                std::ptr::copy_nonoverlapping(
618                    region.host_addr,
619                    buf[offset..end].as_mut_ptr(),
620                    region.size as usize,
621                );
622            }
623        }
624
625        tracing::debug!("Dumped {} bytes of guest memory", self.total_size);
626        Ok(())
627    }
628}
629
630impl Drop for KvmMemory {
631    fn drop(&mut self) {
632        if let Ok(regions) = self.regions.write() {
633            for region in regions.iter() {
634                // Only free memory we allocated
635                if region.owned {
636                    ffi::free_memory(region.host_addr, region.size);
637                }
638            }
639        }
640    }
641}
642
643#[cfg(test)]
644mod tests {
645    use super::*;
646
647    #[test]
648    fn test_memory_creation() {
649        let size = 16 * 1024 * 1024; // 16MB
650        let memory = KvmMemory::new(size).unwrap();
651        assert_eq!(memory.size(), size);
652    }
653
654    #[test]
655    fn test_memory_read_write() {
656        let size = 16 * 1024 * 1024;
657        let memory = KvmMemory::new(size).unwrap();
658
659        // Write some data
660        let data = [1u8, 2, 3, 4, 5];
661        memory.write(GuestAddress::new(0x1000), &data).unwrap();
662
663        // Read it back
664        let mut buf = [0u8; 5];
665        memory.read(GuestAddress::new(0x1000), &mut buf).unwrap();
666        assert_eq!(buf, data);
667    }
668
669    #[test]
670    fn test_memory_bounds_check() {
671        let size = 1024; // 1KB
672        let memory = KvmMemory::new(size).unwrap();
673
674        // Try to read beyond bounds
675        let mut buf = [0u8; 16];
676        let result = memory.read(GuestAddress::new(size - 8), &mut buf);
677        assert!(result.is_err());
678
679        // Try to read from unmapped address
680        let result = memory.read(GuestAddress::new(size + 1000), &mut buf);
681        assert!(result.is_err());
682    }
683
684    #[test]
685    fn test_get_host_address() {
686        let size = 16 * 1024 * 1024;
687        let memory = KvmMemory::new(size).unwrap();
688
689        let ptr = memory.get_host_address(GuestAddress::new(0x1000)).unwrap();
690        assert!(!ptr.is_null());
691
692        // Write via pointer
693        unsafe {
694            *ptr = 42;
695        }
696
697        // Read via GuestMemory
698        let mut buf = [0u8; 1];
699        memory.read(GuestAddress::new(0x1000), &mut buf).unwrap();
700        assert_eq!(buf[0], 42);
701    }
702
703    #[test]
704    fn test_write_read_obj() {
705        let size = 16 * 1024 * 1024;
706        let memory = KvmMemory::new(size).unwrap();
707
708        // Write a u64
709        let val: u64 = 0x1234_5678_9abc_def0;
710        memory.write_obj(GuestAddress::new(0x2000), &val).unwrap();
711
712        // Read it back
713        let read_val: u64 = memory.read_obj(GuestAddress::new(0x2000)).unwrap();
714        assert_eq!(read_val, val);
715    }
716
717    #[test]
718    fn test_memset() {
719        let size = 16 * 1024 * 1024;
720        let memory = KvmMemory::new(size).unwrap();
721
722        // Fill a region
723        memory.memset(GuestAddress::new(0x3000), 0xAA, 100).unwrap();
724
725        // Verify
726        let mut buf = [0u8; 100];
727        memory.read(GuestAddress::new(0x3000), &mut buf).unwrap();
728        for &byte in &buf {
729            assert_eq!(byte, 0xAA);
730        }
731    }
732
733    #[test]
734    fn test_add_region() {
735        let size = 16 * 1024 * 1024;
736        let memory = KvmMemory::new(size).unwrap();
737
738        // Add another region at a non-overlapping address
739        let region2_addr = GuestAddress::new(0x1_0000_0000); // 4GB
740        let region2_size = 8 * 1024 * 1024;
741        let ptr = memory.add_region(region2_addr, region2_size).unwrap();
742        assert!(!ptr.is_null());
743
744        // Write to the new region
745        let data = [0xBB; 10];
746        memory.write(region2_addr, &data).unwrap();
747
748        // Read back
749        let mut buf = [0u8; 10];
750        memory.read(region2_addr, &mut buf).unwrap();
751        assert_eq!(buf, data);
752    }
753
754    #[test]
755    fn test_overlapping_region() {
756        let size = 16 * 1024 * 1024;
757        let memory = KvmMemory::new(size).unwrap();
758
759        // Try to add an overlapping region (should fail)
760        let result = memory.add_region(GuestAddress::new(0x1000), 0x1000);
761        assert!(result.is_err());
762    }
763}