Skip to main content

hyperlight_common/arch/amd64/
vmem.rs

1/*
2Copyright 2025  The Hyperlight Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15 */
16
17//! x86-64 4-level page table manipulation code.
18//!
19//! This module implements page table setup for x86-64 long mode using 4-level paging:
20//! - PML4 (Page Map Level 4) - bits 47:39 - 512 entries, each covering 512GB
21//! - PDPT (Page Directory Pointer Table) - bits 38:30 - 512 entries, each covering 1GB
22//! - PD (Page Directory) - bits 29:21 - 512 entries, each covering 2MB
23//! - PT (Page Table) - bits 20:12 - 512 entries, each covering 4KB pages
24//!
25//! The code uses an iterator-based approach to walk the page table hierarchy,
26//! allocating intermediate tables as needed and setting appropriate flags on leaf PTEs
27
28use crate::vmem::{
29    BasicMapping, CowMapping, Mapping, MappingKind, TableMovabilityBase, TableOps, TableReadOps,
30    Void,
31};
32
33// Paging Flags
34//
35// See the following links explaining paging:
36//
37// * Intel® 64 and IA-32 Architectures Software Developer’s Manual, Volume 3A: System Programming Guide, Part 1
38//  - Chapter 5 "Paging"
39//
40// https://cdrdv2.intel.com/v1/dl/getContent/671200
41//
42// * AMD64 Architecture Programmer’s Manual, Volume 2: System Programming, Section 5.3: Long-Mode Page Translation
43//
44// https://docs.amd.com/v/u/en-US/24593_3.43
45//
46// Or if you prefer something less formal:
47//
48// * Very basic description: https://stackoverflow.com/a/26945892
49// * More in-depth descriptions: https://wiki.osdev.org/Paging
50//
51
52/// Page is Present
53const PAGE_PRESENT: u64 = 1;
54/// Page is Read/Write (if not set page is read only so long as the WP bit in CR0 is set to 1 - which it is in Hyperlight)
55const PAGE_RW: u64 = 1 << 1;
56/// Execute Disable (if this bit is set then data in the page cannot be executed)`
57const PAGE_NX: u64 = 1 << 63;
58/// Mask to extract the physical address from a PTE (bits 51:12)
59/// This masks out the lower 12 flag bits AND the upper bits including NX (bit 63)
60const PTE_ADDR_MASK: u64 = 0x000F_FFFF_FFFF_F000;
61const PAGE_USER_ACCESS_DISABLED: u64 = 0 << 2; // U/S bit not set - supervisor mode only (no code runs in user mode for now)
62const PAGE_DIRTY_SET: u64 = 1 << 6; // D - dirty bit
63const PAGE_ACCESSED_SET: u64 = 1 << 5; // A - accessed bit
64const PAGE_CACHE_ENABLED: u64 = 0 << 4; // PCD - page cache disable bit not set (caching enabled)
65const PAGE_WRITE_BACK: u64 = 0 << 3; // PWT - page write-through bit not set (write-back caching)
66const PAGE_PAT_WB: u64 = 0 << 7; // PAT - page attribute table index bit (0 for write-back memory when PCD=0, PWT=0)
67
68// We use various patterns of the available-for-software-use bits to
69// represent certain special mappings.
70const PTE_AVL_MASK: u64 = 0x0000_0000_0000_0E00;
71const PAGE_AVL_COW: u64 = 1 << 9;
72
73/// Returns PAGE_RW if writable is true, 0 otherwise
74#[inline(always)]
75const fn page_rw_flag(writable: bool) -> u64 {
76    if writable { PAGE_RW } else { 0 }
77}
78
79/// Returns PAGE_NX if executable is false (NX = No Execute), 0 otherwise
80#[inline(always)]
81const fn page_nx_flag(executable: bool) -> u64 {
82    if executable { 0 } else { PAGE_NX }
83}
84
85/// Read a page table entry and return it if the present bit is set
86/// # Safety
87/// The caller must ensure that `entry_ptr` points to a valid page table entry.
88#[inline(always)]
89unsafe fn read_pte_if_present<Op: TableReadOps>(op: &Op, entry_ptr: Op::TableAddr) -> Option<u64> {
90    let pte = unsafe { op.read_entry(entry_ptr) };
91    if (pte & PAGE_PRESENT) != 0 {
92        Some(pte)
93    } else {
94        None
95    }
96}
97
98/// Utility function to extract an (inclusive on both ends) bit range
99/// from a quadword.
100#[inline(always)]
101fn bits<const HIGH_BIT: u8, const LOW_BIT: u8>(x: u64) -> u64 {
102    (x & ((1 << (HIGH_BIT + 1)) - 1)) >> LOW_BIT
103}
104
105/// Helper function to generate a page table entry that points to another table
106#[allow(clippy::identity_op)]
107#[allow(clippy::precedence)]
108fn pte_for_table<Op: TableOps>(table_addr: Op::TableAddr) -> u64 {
109    Op::to_phys(table_addr) |
110        PAGE_ACCESSED_SET | // prevent the CPU writing to the access flag
111        PAGE_CACHE_ENABLED | // leave caching enabled
112        PAGE_WRITE_BACK | // use write-back caching
113        PAGE_USER_ACCESS_DISABLED |// dont allow user access (no code runs in user mode for now)
114        PAGE_RW | // R/W - we don't use block-level permissions
115        PAGE_PRESENT // P   - this entry is present
116}
117
118/// This trait is used to select appropriate implementations of
119/// [`UpdateParent`] to be used, depending on whether a particular
120/// implementation needs the ability to move tables.
121pub trait TableMovability<Op: TableReadOps + ?Sized, TableMoveInfo> {
122    type RootUpdateParent: UpdateParent<Op, TableMoveInfo = TableMoveInfo>;
123    fn root_update_parent() -> Self::RootUpdateParent;
124}
125impl<Op: TableOps<TableMovability = crate::vmem::MayMoveTable>> TableMovability<Op, Op::TableAddr>
126    for crate::vmem::MayMoveTable
127{
128    type RootUpdateParent = UpdateParentRoot;
129    fn root_update_parent() -> Self::RootUpdateParent {
130        UpdateParentRoot {}
131    }
132}
133impl<Op: TableReadOps> TableMovability<Op, Void> for crate::vmem::MayNotMoveTable {
134    type RootUpdateParent = UpdateParentNone;
135    fn root_update_parent() -> Self::RootUpdateParent {
136        UpdateParentNone {}
137    }
138}
139
140/// Helper function to write a page table entry, updating the whole
141/// chain of tables back to the root if necessary
142unsafe fn write_entry_updating<
143    Op: TableOps,
144    P: UpdateParent<
145            Op,
146            TableMoveInfo = <Op::TableMovability as TableMovabilityBase<Op>>::TableMoveInfo,
147        >,
148>(
149    op: &Op,
150    parent: P,
151    addr: Op::TableAddr,
152    entry: u64,
153) {
154    if let Some(again) = unsafe { op.write_entry(addr, entry) } {
155        parent.update_parent(op, again);
156    }
157}
158
159/// A helper trait that allows us to move a page table (e.g. from the
160/// snapshot to the scratch region), keeping track of the context that
161/// needs to be updated when that is moved (and potentially
162/// recursively updating, if necessary)
163///
164/// This is done via a trait so that the selected impl knows the exact
165/// nesting depth of tables, in order to assist
166/// inlining/specialisation in generating efficient code.
167///
168/// The trait definition only bounds its parameter by
169/// [`TableReadOps`], since [`UpdateParentNone`] does not need to be
170/// able to actually write to the tables.
171pub trait UpdateParent<Op: TableReadOps + ?Sized>: Copy {
172    /// The type of the information about a moved table which is
173    /// needed in order to update its parent.
174    type TableMoveInfo;
175    /// The [`UpdateParent`] type that should be used when going down
176    /// another level in the table, in order to add the current level
177    /// to the chain of ancestors to be updated.
178    type ChildType: UpdateParent<Op, TableMoveInfo = Self::TableMoveInfo>;
179    fn update_parent(self, op: &Op, new_ptr: Self::TableMoveInfo);
180    fn for_child_at_entry(self, entry_ptr: Op::TableAddr) -> Self::ChildType;
181}
182
183/// A struct implementing [`UpdateParent`] that keeps track of the
184/// fact that the parent table is itself another table, whose own
185/// ancestors may need to be recursively updated
186pub struct UpdateParentTable<Op: TableOps, P: UpdateParent<Op>> {
187    parent: P,
188    entry_ptr: Op::TableAddr,
189}
190impl<Op: TableOps, P: UpdateParent<Op>> Clone for UpdateParentTable<Op, P> {
191    fn clone(&self) -> Self {
192        *self
193    }
194}
195impl<Op: TableOps, P: UpdateParent<Op>> Copy for UpdateParentTable<Op, P> {}
196impl<Op: TableOps, P: UpdateParent<Op>> UpdateParentTable<Op, P> {
197    fn new(parent: P, entry_ptr: Op::TableAddr) -> Self {
198        UpdateParentTable { parent, entry_ptr }
199    }
200}
201impl<
202    Op: TableOps<TableMovability = crate::vmem::MayMoveTable>,
203    P: UpdateParent<Op, TableMoveInfo = Op::TableAddr>,
204> UpdateParent<Op> for UpdateParentTable<Op, P>
205{
206    type TableMoveInfo = Op::TableAddr;
207    type ChildType = UpdateParentTable<Op, Self>;
208    fn update_parent(self, op: &Op, new_ptr: Op::TableAddr) {
209        let pte = pte_for_table::<Op>(new_ptr);
210        unsafe {
211            write_entry_updating(op, self.parent, self.entry_ptr, pte);
212        }
213    }
214    fn for_child_at_entry(self, entry_ptr: Op::TableAddr) -> Self::ChildType {
215        Self::ChildType::new(self, entry_ptr)
216    }
217}
218
219/// A struct implementing [`UpdateParent`] that keeps track of the
220/// fact that the parent "table" is actually the root (e.g. the value
221/// of CR3 in the guest)
222#[derive(Copy, Clone)]
223pub struct UpdateParentRoot {}
224impl<Op: TableOps<TableMovability = crate::vmem::MayMoveTable>> UpdateParent<Op>
225    for UpdateParentRoot
226{
227    type TableMoveInfo = Op::TableAddr;
228    type ChildType = UpdateParentTable<Op, Self>;
229    fn update_parent(self, op: &Op, new_ptr: Op::TableAddr) {
230        unsafe {
231            op.update_root(new_ptr);
232        }
233    }
234    fn for_child_at_entry(self, entry_ptr: Op::TableAddr) -> Self::ChildType {
235        Self::ChildType::new(self, entry_ptr)
236    }
237}
238
239/// A struct implementing [`UpdateParent`] that is impossible to use
240/// (since its [`update_parent`] method takes `Void`), used when it is
241/// statically known that a table operation cannot result in a need to
242/// update ancestors.
243#[derive(Copy, Clone)]
244pub struct UpdateParentNone {}
245impl<Op: TableReadOps> UpdateParent<Op> for UpdateParentNone {
246    type TableMoveInfo = Void;
247    type ChildType = Self;
248    fn update_parent(self, _op: &Op, impossible: Void) {
249        match impossible {}
250    }
251    fn for_child_at_entry(self, _entry_ptr: Op::TableAddr) -> Self {
252        self
253    }
254}
255
256/// A helper structure indicating a mapping operation that needs to be
257/// performed
258struct MapRequest<Op: TableReadOps, P: UpdateParent<Op>> {
259    table_base: Op::TableAddr,
260    vmin: VirtAddr,
261    len: u64,
262    update_parent: P,
263}
264
265/// A helper structure indicating that a particular PTE needs to be
266/// modified
267struct MapResponse<Op: TableReadOps, P: UpdateParent<Op>> {
268    entry_ptr: Op::TableAddr,
269    vmin: VirtAddr,
270    len: u64,
271    update_parent: P,
272}
273
274/// Iterator that walks through page table entries at a specific level.
275///
276/// Given a virtual address range and a table base, this iterator yields
277/// `MapResponse` items for each page table entry that needs to be modified.
278/// The const generics `HIGH_BIT` and `LOW_BIT` specify which bits of the
279/// virtual address are used to index into this level's table.
280///
281/// For example:
282/// - PML4: HIGH_BIT=47, LOW_BIT=39 (9 bits = 512 entries, each covering 512GB)
283/// - PDPT: HIGH_BIT=38, LOW_BIT=30 (9 bits = 512 entries, each covering 1GB)
284/// - PD:   HIGH_BIT=29, LOW_BIT=21 (9 bits = 512 entries, each covering 2MB)
285/// - PT:   HIGH_BIT=20, LOW_BIT=12 (9 bits = 512 entries, each covering 4KB)
286struct ModifyPteIterator<
287    const HIGH_BIT: u8,
288    const LOW_BIT: u8,
289    Op: TableReadOps,
290    P: UpdateParent<Op>,
291> {
292    request: MapRequest<Op, P>,
293    n: u64,
294}
295impl<const HIGH_BIT: u8, const LOW_BIT: u8, Op: TableReadOps, P: UpdateParent<Op>> Iterator
296    for ModifyPteIterator<HIGH_BIT, LOW_BIT, Op, P>
297{
298    type Item = MapResponse<Op, P>;
299    fn next(&mut self) -> Option<Self::Item> {
300        // Each page table entry at this level covers a region of size (1 << LOW_BIT) bytes.
301        // For example, at the PT level (LOW_BIT=12), each entry covers 4KB (0x1000 bytes).
302        // At the PD level (LOW_BIT=21), each entry covers 2MB (0x200000 bytes).
303        //
304        // This mask isolates the bits below this level's index bits, used for alignment.
305        let lower_bits_mask = (1 << LOW_BIT) - 1;
306
307        // Calculate the virtual address for this iteration.
308        // On the first iteration (n=0), start at the requested vmin.
309        // On subsequent iterations, advance to the next aligned boundary.
310        // This handles the case where vmin isn't aligned to this level's entry size.
311        let next_vmin = if self.n == 0 {
312            self.request.vmin
313        } else {
314            // Align to the next boundary by adding one entry's worth
315            // and masking off lower bits. Masking off before adding
316            // is safe, since n << LOW_BIT must always have zeros in
317            // these positions.
318            let aligned_min = self.request.vmin & !lower_bits_mask;
319            // Use checked_add here because going past the end of the
320            // address space counts as "the next one would be out of
321            // range"
322            aligned_min.checked_add(self.n << LOW_BIT)?
323        };
324
325        // Check if we've processed the entire requested range
326        if next_vmin >= self.request.vmin + self.request.len {
327            return None;
328        }
329
330        // Calculate the pointer to this level's page table entry.
331        // bits::<HIGH_BIT, LOW_BIT> extracts the relevant index bits from the virtual address.
332        // Shift left by 3 (multiply by 8) because each entry is 8 bytes (u64).
333        let entry_ptr = Op::entry_addr(
334            self.request.table_base,
335            bits::<HIGH_BIT, LOW_BIT>(next_vmin) << 3,
336        );
337
338        // Calculate how many bytes remain to be mapped from this point
339        let len_from_here = self.request.len - (next_vmin - self.request.vmin);
340
341        // Calculate the maximum bytes this single entry can cover.
342        // If next_vmin is aligned, this is the full entry size (1 << LOW_BIT).
343        // If not aligned (only possible on first iteration), it's the remaining
344        // space until the next boundary.
345        let max_len = (1 << LOW_BIT) - (next_vmin & lower_bits_mask);
346
347        // The actual length for this entry is the smaller of what's needed vs what fits
348        let next_len = core::cmp::min(len_from_here, max_len);
349
350        // Advance iteration counter for next call
351        self.n += 1;
352
353        Some(MapResponse {
354            entry_ptr,
355            vmin: next_vmin,
356            len: next_len,
357            update_parent: self.request.update_parent,
358        })
359    }
360}
361fn modify_ptes<const HIGH_BIT: u8, const LOW_BIT: u8, Op: TableReadOps, P: UpdateParent<Op>>(
362    r: MapRequest<Op, P>,
363) -> ModifyPteIterator<HIGH_BIT, LOW_BIT, Op, P> {
364    ModifyPteIterator { request: r, n: 0 }
365}
366
367/// Page-mapping callback to allocate a next-level page table if necessary.
368/// # Safety
369/// This function modifies page table data structures, and should not be called concurrently
370/// with any other operations that modify the page tables.
371unsafe fn alloc_pte_if_needed<
372    Op: TableOps,
373    P: UpdateParent<
374            Op,
375            TableMoveInfo = <Op::TableMovability as TableMovabilityBase<Op>>::TableMoveInfo,
376        >,
377>(
378    op: &Op,
379    x: MapResponse<Op, P>,
380) -> MapRequest<Op, P::ChildType>
381where
382    P::ChildType: UpdateParent<Op>,
383{
384    let new_update_parent = x.update_parent.for_child_at_entry(x.entry_ptr);
385    if let Some(pte) = unsafe { read_pte_if_present(op, x.entry_ptr) } {
386        return MapRequest {
387            table_base: Op::from_phys(pte & PTE_ADDR_MASK),
388            vmin: x.vmin,
389            len: x.len,
390            update_parent: new_update_parent,
391        };
392    }
393
394    let page_addr = unsafe { op.alloc_table() };
395
396    let pte = pte_for_table::<Op>(page_addr);
397    unsafe {
398        write_entry_updating(op, x.update_parent, x.entry_ptr, pte);
399    };
400    MapRequest {
401        table_base: page_addr,
402        vmin: x.vmin,
403        len: x.len,
404        update_parent: new_update_parent,
405    }
406}
407
408/// Map a normal memory page
409/// # Safety
410/// This function modifies page table data structures, and should not be called concurrently
411/// with any other operations that modify the page tables.
412#[allow(clippy::identity_op)]
413#[allow(clippy::precedence)]
414unsafe fn map_page<
415    Op: TableOps,
416    P: UpdateParent<
417            Op,
418            TableMoveInfo = <Op::TableMovability as TableMovabilityBase<Op>>::TableMoveInfo,
419        >,
420>(
421    op: &Op,
422    mapping: &Mapping,
423    r: MapResponse<Op, P>,
424) {
425    let pte = match &mapping.kind {
426        MappingKind::Basic(bm) =>
427        // TODO: Support not readable
428        // NOTE: On x86-64, there is no separate "readable" bit in the page table entry.
429        // This means that pages cannot be made write-only or execute-only without also being readable.
430        // All pages that are mapped as writable or executable are also implicitly readable.
431        // If support for "not readable" mappings is required in the future, it would need to be
432        // implemented using additional mechanisms (e.g., page-fault handling or memory protection keys),
433        // but for now, this architectural limitation is accepted.
434        {
435            (mapping.phys_base + (r.vmin - mapping.virt_base)) |
436                page_nx_flag(bm.executable) | // NX - no execute unless allowed
437                PAGE_PAT_WB | // PAT index bit for write-back memory
438                PAGE_DIRTY_SET | // prevent the CPU writing to the dirty bit
439                PAGE_ACCESSED_SET | // prevent the CPU writing to the access flag
440                PAGE_CACHE_ENABLED | // leave caching enabled
441                PAGE_WRITE_BACK | // use write-back caching
442                PAGE_USER_ACCESS_DISABLED | // dont allow user access (no code runs in user mode for now)
443                page_rw_flag(bm.writable) | // R/W - set if writable
444                PAGE_PRESENT // P   - this entry is present
445        }
446        MappingKind::Cow(cm) => {
447            (mapping.phys_base + (r.vmin - mapping.virt_base)) |
448                page_nx_flag(cm.executable) | // NX - no execute unless allowed
449                PAGE_AVL_COW |
450                PAGE_PAT_WB | // PAT index bit for write-back memory
451                PAGE_DIRTY_SET | // prevent the CPU writing to the dirty bit
452                PAGE_ACCESSED_SET | // prevent the CPU writing to the access flag
453                PAGE_CACHE_ENABLED | // leave caching enabled
454                PAGE_WRITE_BACK | // use write-back caching
455                PAGE_USER_ACCESS_DISABLED | // dont allow user access (no code runs in user mode for now)
456                0 | // R/W - Cow page is never writable
457                PAGE_PRESENT // P   - this entry is present
458        }
459        MappingKind::Unmapped => 0,
460    };
461    unsafe {
462        write_entry_updating(op, r.update_parent, r.entry_ptr, pte);
463    }
464}
465
466// There are no notable architecture-specific safety considerations
467// here, and the general conditions are documented in the
468// architecture-independent re-export in vmem.rs
469
470/// Maps a contiguous virtual address range to physical memory.
471///
472/// This function walks the 4-level page table hierarchy (PML4 → PDPT → PD → PT),
473/// allocating intermediate tables as needed via `alloc_pte_if_needed`, and finally
474/// writing the leaf page table entries with the requested permissions via `map_page`.
475///
476/// The iterator chain processes each level:
477/// 1. PML4 (47:39) - allocate PDPT if needed
478/// 2. PDPT (38:30) - allocate PD if needed
479/// 3. PD (29:21) - allocate PT if needed
480/// 4. PT (20:12) - write final PTE with physical address and flags
481#[allow(clippy::missing_safety_doc)]
482pub unsafe fn map<Op: TableOps>(op: &Op, mapping: Mapping) {
483    modify_ptes::<47, 39, Op, _>(MapRequest {
484        table_base: op.root_table(),
485        vmin: mapping.virt_base,
486        len: mapping.len,
487        update_parent: Op::TableMovability::root_update_parent(),
488    })
489    .map(|r| unsafe { alloc_pte_if_needed(op, r) })
490    .flat_map(modify_ptes::<38, 30, Op, _>)
491    .map(|r| unsafe { alloc_pte_if_needed(op, r) })
492    .flat_map(modify_ptes::<29, 21, Op, _>)
493    .map(|r| unsafe { alloc_pte_if_needed(op, r) })
494    .flat_map(modify_ptes::<20, 12, Op, _>)
495    .map(|r| unsafe { map_page(op, &mapping, r) })
496    .for_each(drop);
497}
498
499/// # Safety
500/// This function traverses page table data structures, and should not
501/// be called concurrently with any other operations that modify the
502/// page table.
503unsafe fn require_pte_exist<Op: TableReadOps, P: UpdateParent<Op>>(
504    op: &Op,
505    x: MapResponse<Op, P>,
506) -> Option<MapRequest<Op, P::ChildType>>
507where
508    P::ChildType: UpdateParent<Op>,
509{
510    unsafe { read_pte_if_present(op, x.entry_ptr) }.map(|pte| MapRequest {
511        table_base: Op::from_phys(pte & PTE_ADDR_MASK),
512        vmin: x.vmin,
513        len: x.len,
514        update_parent: x.update_parent.for_child_at_entry(x.entry_ptr),
515    })
516}
517
518// There are no notable architecture-specific safety considerations
519// here, and the general conditions are documented in the
520// architecture-independent re-export in vmem.rs
521
522/// Translates a virtual address range to the physical address pages
523/// that back it by walking the page tables.
524///
525/// Returns an iterator with an entry for each mapped page that
526/// intersects the given range.
527///
528/// This takes AsRef<Op> + Copy so that on targets where the
529/// operations have little state (e.g. the guest) the operations state
530/// can be copied into the closure(s) in the iterator, allowing for a
531/// nicer result lifetime.  On targets like the
532/// building-an-original-snapshot portion of the host, where the
533/// operations structure owns a large buffer, a reference can instead
534/// be passed.
535#[allow(clippy::missing_safety_doc)]
536pub unsafe fn virt_to_phys<'a, Op: TableReadOps + 'a>(
537    op: impl core::convert::AsRef<Op> + Copy + 'a,
538    address: u64,
539    len: u64,
540) -> impl Iterator<Item = Mapping> + 'a {
541    // Undo sign-extension
542    let addr = address & ((1u64 << VA_BITS) - 1);
543    // Mask off any sub-page bits
544    let vmin = addr & !(PAGE_SIZE as u64 - 1);
545    // Calculate the maximum virtual address we need to look at based on the starting
546    // address and length ensuring we don't go past the end of the address space
547    let vmax = core::cmp::min(addr + len, 1u64 << VA_BITS);
548    modify_ptes::<47, 39, Op, _>(MapRequest {
549        table_base: op.as_ref().root_table(),
550        vmin,
551        len: vmax - vmin,
552        update_parent: UpdateParentNone {},
553    })
554    .filter_map(move |r| unsafe { require_pte_exist(op.as_ref(), r) })
555    .flat_map(modify_ptes::<38, 30, Op, _>)
556    .filter_map(move |r| unsafe { require_pte_exist(op.as_ref(), r) })
557    .flat_map(modify_ptes::<29, 21, Op, _>)
558    .filter_map(move |r| unsafe { require_pte_exist(op.as_ref(), r) })
559    .flat_map(modify_ptes::<20, 12, Op, _>)
560    .filter_map(move |r| {
561        let pte = unsafe { read_pte_if_present(op.as_ref(), r.entry_ptr) }?;
562        let phys_addr = pte & PTE_ADDR_MASK;
563        // Re-do the sign extension
564        let sgn_bit = r.vmin >> (VA_BITS - 1);
565        let sgn_bits = 0u64.wrapping_sub(sgn_bit) << VA_BITS;
566        let virt_addr = sgn_bits | r.vmin;
567
568        let executable = (pte & PAGE_NX) == 0;
569        let avl = pte & PTE_AVL_MASK;
570        let kind = if avl == PAGE_AVL_COW {
571            MappingKind::Cow(CowMapping {
572                readable: true,
573                executable,
574            })
575        } else {
576            MappingKind::Basic(BasicMapping {
577                readable: true,
578                writable: (pte & PAGE_RW) != 0,
579                executable,
580            })
581        };
582        Some(Mapping {
583            phys_base: phys_addr,
584            virt_base: virt_addr,
585            len: PAGE_SIZE as u64,
586            kind,
587        })
588    })
589}
590
591const VA_BITS: usize = 48; // We use 48-bit virtual addresses at the moment.
592
593pub const PAGE_SIZE: usize = 4096;
594pub const PAGE_TABLE_SIZE: usize = 4096;
595pub type PageTableEntry = u64;
596pub type VirtAddr = u64;
597pub type PhysAddr = u64;
598
599#[cfg(test)]
600mod tests {
601    use alloc::vec;
602    use alloc::vec::Vec;
603    use core::cell::RefCell;
604
605    use super::*;
606    use crate::vmem::{
607        BasicMapping, Mapping, MappingKind, MayNotMoveTable, PAGE_TABLE_ENTRIES_PER_TABLE,
608        TableOps, TableReadOps, Void,
609    };
610
611    /// A mock TableOps implementation for testing that stores page tables in memory
612    /// needed because the `GuestPageTableBuffer` is in hyperlight_host which would cause a circular dependency
613    struct MockTableOps {
614        tables: RefCell<Vec<[u64; PAGE_TABLE_ENTRIES_PER_TABLE]>>,
615    }
616
617    // for virt_to_phys
618    impl core::convert::AsRef<MockTableOps> for MockTableOps {
619        fn as_ref(&self) -> &Self {
620            self
621        }
622    }
623
624    impl MockTableOps {
625        fn new() -> Self {
626            // Start with one table (the root/PML4)
627            Self {
628                tables: RefCell::new(vec![[0u64; PAGE_TABLE_ENTRIES_PER_TABLE]]),
629            }
630        }
631
632        fn table_count(&self) -> usize {
633            self.tables.borrow().len()
634        }
635
636        fn get_entry(&self, table_idx: usize, entry_idx: usize) -> u64 {
637            self.tables.borrow()[table_idx][entry_idx]
638        }
639    }
640
641    impl TableReadOps for MockTableOps {
642        type TableAddr = (usize, usize); // (table_index, entry_index)
643
644        fn entry_addr(addr: Self::TableAddr, entry_offset: u64) -> Self::TableAddr {
645            // Convert to physical address, add offset, convert back
646            let phys = Self::to_phys(addr) + entry_offset;
647            Self::from_phys(phys)
648        }
649
650        unsafe fn read_entry(&self, addr: Self::TableAddr) -> u64 {
651            self.tables.borrow()[addr.0][addr.1]
652        }
653
654        fn to_phys(addr: Self::TableAddr) -> PhysAddr {
655            // Each table is 4KB, entries are 8 bytes
656            (addr.0 as u64 * PAGE_TABLE_SIZE as u64) + (addr.1 as u64 * 8)
657        }
658
659        fn from_phys(addr: PhysAddr) -> Self::TableAddr {
660            let table_idx = (addr / PAGE_TABLE_SIZE as u64) as usize;
661            let entry_idx = ((addr % PAGE_TABLE_SIZE as u64) / 8) as usize;
662            (table_idx, entry_idx)
663        }
664
665        fn root_table(&self) -> Self::TableAddr {
666            (0, 0)
667        }
668    }
669
670    impl TableOps for MockTableOps {
671        type TableMovability = MayNotMoveTable;
672
673        unsafe fn alloc_table(&self) -> Self::TableAddr {
674            let mut tables = self.tables.borrow_mut();
675            let idx = tables.len();
676            tables.push([0u64; PAGE_TABLE_ENTRIES_PER_TABLE]);
677            (idx, 0)
678        }
679
680        unsafe fn write_entry(&self, addr: Self::TableAddr, entry: u64) -> Option<Void> {
681            self.tables.borrow_mut()[addr.0][addr.1] = entry;
682            None
683        }
684
685        unsafe fn update_root(&self, impossible: Void) {
686            match impossible {}
687        }
688    }
689
690    // ==================== bits() function tests ====================
691
692    #[test]
693    fn test_bits_extracts_pml4_index() {
694        // PML4 uses bits 47:39
695        // Address 0x0000_0080_0000_0000 should have PML4 index 1
696        let addr: u64 = 0x0000_0080_0000_0000;
697        assert_eq!(bits::<47, 39>(addr), 1);
698    }
699
700    #[test]
701    fn test_bits_extracts_pdpt_index() {
702        // PDPT uses bits 38:30
703        // Address with PDPT index 1: bit 30 set = 0x4000_0000 (1GB)
704        let addr: u64 = 0x4000_0000;
705        assert_eq!(bits::<38, 30>(addr), 1);
706    }
707
708    #[test]
709    fn test_bits_extracts_pd_index() {
710        // PD uses bits 29:21
711        // Address 0x0000_0000_0020_0000 (2MB) should have PD index 1
712        let addr: u64 = 0x0000_0000_0020_0000;
713        assert_eq!(bits::<29, 21>(addr), 1);
714    }
715
716    #[test]
717    fn test_bits_extracts_pt_index() {
718        // PT uses bits 20:12
719        // Address 0x0000_0000_0000_1000 (4KB) should have PT index 1
720        let addr: u64 = 0x0000_0000_0000_1000;
721        assert_eq!(bits::<20, 12>(addr), 1);
722    }
723
724    #[test]
725    fn test_bits_max_index() {
726        // Maximum 9-bit index is 511
727        // PML4 index 511 = bits 47:39 all set = 0x0000_FF80_0000_0000
728        let addr: u64 = 0x0000_FF80_0000_0000;
729        assert_eq!(bits::<47, 39>(addr), 511);
730    }
731
732    // ==================== PTE flag tests ====================
733
734    #[test]
735    fn test_page_rw_flag_writable() {
736        assert_eq!(page_rw_flag(true), PAGE_RW);
737    }
738
739    #[test]
740    fn test_page_rw_flag_readonly() {
741        assert_eq!(page_rw_flag(false), 0);
742    }
743
744    #[test]
745    fn test_page_nx_flag_executable() {
746        assert_eq!(page_nx_flag(true), 0); // Executable = no NX bit
747    }
748
749    #[test]
750    fn test_page_nx_flag_not_executable() {
751        assert_eq!(page_nx_flag(false), PAGE_NX);
752    }
753
754    // ==================== map() function tests ====================
755
756    #[test]
757    fn test_map_single_page() {
758        let ops = MockTableOps::new();
759        let mapping = Mapping {
760            phys_base: 0x1000,
761            virt_base: 0x1000,
762            len: PAGE_SIZE as u64,
763            kind: MappingKind::Basic(BasicMapping {
764                readable: true,
765                writable: true,
766                executable: false,
767            }),
768        };
769
770        unsafe { map(&ops, mapping) };
771
772        // Should have allocated: PML4(exists) + PDPT + PD + PT = 4 tables
773        assert_eq!(ops.table_count(), 4);
774
775        // Check PML4 entry 0 points to PDPT (table 1) with correct flags
776        let pml4_entry = ops.get_entry(0, 0);
777        assert_ne!(pml4_entry & PAGE_PRESENT, 0, "PML4 entry should be present");
778        assert_ne!(pml4_entry & PAGE_RW, 0, "PML4 entry should be writable");
779
780        // Check the leaf PTE has correct flags
781        // PT is table 3, entry 1 (for virt_base 0x1000)
782        let pte = ops.get_entry(3, 1);
783        assert_ne!(pte & PAGE_PRESENT, 0, "PTE should be present");
784        assert_ne!(pte & PAGE_RW, 0, "PTE should be writable");
785        assert_ne!(pte & PAGE_NX, 0, "PTE should have NX set (not executable)");
786        assert_eq!(pte & PTE_ADDR_MASK, 0x1000, "PTE should map to phys 0x1000");
787    }
788
789    #[test]
790    fn test_map_executable_page() {
791        let ops = MockTableOps::new();
792        let mapping = Mapping {
793            phys_base: 0x2000,
794            virt_base: 0x2000,
795            len: PAGE_SIZE as u64,
796            kind: MappingKind::Basic(BasicMapping {
797                readable: true,
798                writable: false,
799                executable: true,
800            }),
801        };
802
803        unsafe { map(&ops, mapping) };
804
805        // PT is table 3, entry 2 (for virt_base 0x2000)
806        let pte = ops.get_entry(3, 2);
807        assert_ne!(pte & PAGE_PRESENT, 0, "PTE should be present");
808        assert_eq!(pte & PAGE_RW, 0, "PTE should be read-only");
809        assert_eq!(pte & PAGE_NX, 0, "PTE should NOT have NX set (executable)");
810    }
811
812    #[test]
813    fn test_map_multiple_pages() {
814        let ops = MockTableOps::new();
815        let mapping = Mapping {
816            phys_base: 0x10000,
817            virt_base: 0x10000,
818            len: 4 * PAGE_SIZE as u64, // 4 pages = 16KB
819            kind: MappingKind::Basic(BasicMapping {
820                readable: true,
821                writable: true,
822                executable: false,
823            }),
824        };
825
826        unsafe { map(&ops, mapping) };
827
828        // Check all 4 PTEs are present
829        for i in 0..4 {
830            let entry_idx = 16 + i; // 0x10000 / 0x1000 = 16
831            let pte = ops.get_entry(3, entry_idx);
832            assert_ne!(pte & PAGE_PRESENT, 0, "PTE {} should be present", i);
833            let expected_phys = 0x10000 + (i as u64 * PAGE_SIZE as u64);
834            assert_eq!(
835                pte & PTE_ADDR_MASK,
836                expected_phys,
837                "PTE {} should map to correct phys addr",
838                i
839            );
840        }
841    }
842
843    #[test]
844    fn test_map_reuses_existing_tables() {
845        let ops = MockTableOps::new();
846
847        // Map first region
848        let mapping1 = Mapping {
849            phys_base: 0x1000,
850            virt_base: 0x1000,
851            len: PAGE_SIZE as u64,
852            kind: MappingKind::Basic(BasicMapping {
853                readable: true,
854                writable: true,
855                executable: false,
856            }),
857        };
858        unsafe { map(&ops, mapping1) };
859        let tables_after_first = ops.table_count();
860
861        // Map second region in same PT (different page)
862        let mapping2 = Mapping {
863            phys_base: 0x5000,
864            virt_base: 0x5000,
865            len: PAGE_SIZE as u64,
866            kind: MappingKind::Basic(BasicMapping {
867                readable: true,
868                writable: true,
869                executable: false,
870            }),
871        };
872        unsafe { map(&ops, mapping2) };
873
874        // Should NOT allocate new tables (reuses existing hierarchy)
875        assert_eq!(
876            ops.table_count(),
877            tables_after_first,
878            "Should reuse existing page tables"
879        );
880    }
881
882    // ==================== virt_to_phys() tests ====================
883
884    #[test]
885    fn test_virt_to_phys_mapped_address() {
886        let ops = MockTableOps::new();
887        let mapping = Mapping {
888            phys_base: 0x1000,
889            virt_base: 0x1000,
890            len: PAGE_SIZE as u64,
891            kind: MappingKind::Basic(BasicMapping {
892                readable: true,
893                writable: true,
894                executable: false,
895            }),
896        };
897
898        unsafe { map(&ops, mapping) };
899
900        let result = unsafe { virt_to_phys(&ops, 0x1000, 1).next() };
901        assert!(result.is_some(), "Should find mapped address");
902        let mapping = result.unwrap();
903        assert_eq!(mapping.phys_base, 0x1000);
904    }
905
906    #[test]
907    fn test_virt_to_phys_unaligned_virt() {
908        let ops = MockTableOps::new();
909        let mapping = Mapping {
910            phys_base: 0x1000,
911            virt_base: 0x1000,
912            len: PAGE_SIZE as u64,
913            kind: MappingKind::Basic(BasicMapping {
914                readable: true,
915                writable: true,
916                executable: false,
917            }),
918        };
919
920        unsafe { map(&ops, mapping) };
921
922        let result = unsafe { virt_to_phys(&ops, 0x1234, 1).next() };
923        assert!(result.is_some(), "Should find mapped address");
924        let mapping = result.unwrap();
925        assert_eq!(mapping.phys_base, 0x1000);
926    }
927
928    #[test]
929    fn test_virt_to_phys_unaligned_virt_and_across_pages_len() {
930        let ops = MockTableOps::new();
931        let mapping = Mapping {
932            phys_base: 0x1000,
933            virt_base: 0x1000,
934            len: 2 * PAGE_SIZE as u64, // 2 page
935            kind: MappingKind::Basic(BasicMapping {
936                readable: true,
937                writable: true,
938                executable: false,
939            }),
940        };
941
942        unsafe { map(&ops, mapping) };
943
944        let mappings = unsafe { virt_to_phys(&ops, 0x1F00, 0x300).collect::<Vec<_>>() };
945        assert_eq!(mappings.len(), 2, "Should return 2 mappings for 2 pages");
946        assert_eq!(mappings[0].phys_base, 0x1000);
947        assert_eq!(mappings[1].phys_base, 0x2000);
948    }
949
950    #[test]
951    fn test_virt_to_phys_unaligned_virt_and_multiple_page_len() {
952        let ops = MockTableOps::new();
953        let mapping = Mapping {
954            phys_base: 0x1000,
955            virt_base: 0x1000,
956            len: PAGE_SIZE as u64 * 2 + 0x200, // 2 page + 512 bytes
957            kind: MappingKind::Basic(BasicMapping {
958                readable: true,
959                writable: true,
960                executable: false,
961            }),
962        };
963
964        unsafe { map(&ops, mapping) };
965
966        let mappings =
967            unsafe { virt_to_phys(&ops, 0x1234, PAGE_SIZE as u64 * 2 + 0x10).collect::<Vec<_>>() };
968        assert_eq!(mappings.len(), 3, "Should return 3 mappings for 3 pages");
969        assert_eq!(mappings[0].phys_base, 0x1000);
970        assert_eq!(mappings[1].phys_base, 0x2000);
971        assert_eq!(mappings[2].phys_base, 0x3000);
972    }
973
974    #[test]
975    fn test_virt_to_phys_perms() {
976        let test = |kind| {
977            let ops = MockTableOps::new();
978            let mapping = Mapping {
979                phys_base: 0x1000,
980                virt_base: 0x1000,
981                len: PAGE_SIZE as u64,
982                kind,
983            };
984            unsafe { map(&ops, mapping) };
985            let result = unsafe { virt_to_phys(&ops, 0x1000, 1).next() };
986            let mapping = result.unwrap();
987            assert_eq!(mapping.kind, kind);
988        };
989        test(MappingKind::Basic(BasicMapping {
990            readable: true,
991            writable: false,
992            executable: false,
993        }));
994        test(MappingKind::Basic(BasicMapping {
995            readable: true,
996            writable: false,
997            executable: true,
998        }));
999        test(MappingKind::Basic(BasicMapping {
1000            readable: true,
1001            writable: true,
1002            executable: false,
1003        }));
1004        test(MappingKind::Basic(BasicMapping {
1005            readable: true,
1006            writable: true,
1007            executable: true,
1008        }));
1009        test(MappingKind::Cow(CowMapping {
1010            readable: true,
1011            executable: false,
1012        }));
1013        test(MappingKind::Cow(CowMapping {
1014            readable: true,
1015            executable: true,
1016        }));
1017    }
1018
1019    #[test]
1020    fn test_virt_to_phys_unmapped_address() {
1021        let ops = MockTableOps::new();
1022        // Don't map anything
1023
1024        let result = unsafe { virt_to_phys(&ops, 0x1000, 1).next() };
1025        assert!(result.is_none(), "Should return None for unmapped address");
1026    }
1027
1028    #[test]
1029    fn test_virt_to_phys_partially_mapped() {
1030        let ops = MockTableOps::new();
1031        let mapping = Mapping {
1032            phys_base: 0x1000,
1033            virt_base: 0x1000,
1034            len: PAGE_SIZE as u64,
1035            kind: MappingKind::Basic(BasicMapping {
1036                readable: true,
1037                writable: true,
1038                executable: false,
1039            }),
1040        };
1041
1042        unsafe { map(&ops, mapping) };
1043
1044        // Query an address in a different PT entry (unmapped)
1045        let result = unsafe { virt_to_phys(&ops, 0x5000, 1).next() };
1046        assert!(
1047            result.is_none(),
1048            "Should return None for unmapped address in same PT"
1049        );
1050    }
1051
1052    // ==================== ModifyPteIterator tests ====================
1053
1054    #[test]
1055    fn test_modify_pte_iterator_single_page() {
1056        let ops = MockTableOps::new();
1057        let request = MapRequest {
1058            table_base: ops.root_table(),
1059            vmin: 0x1000,
1060            len: PAGE_SIZE as u64,
1061            update_parent: UpdateParentNone {},
1062        };
1063
1064        let responses: Vec<_> = modify_ptes::<20, 12, MockTableOps, _>(request).collect();
1065        assert_eq!(responses.len(), 1, "Single page should yield one response");
1066        assert_eq!(responses[0].vmin, 0x1000);
1067        assert_eq!(responses[0].len, PAGE_SIZE as u64);
1068    }
1069
1070    #[test]
1071    fn test_modify_pte_iterator_multiple_pages() {
1072        let ops = MockTableOps::new();
1073        let request = MapRequest {
1074            table_base: ops.root_table(),
1075            vmin: 0x1000,
1076            len: 3 * PAGE_SIZE as u64,
1077            update_parent: UpdateParentNone {},
1078        };
1079
1080        let responses: Vec<_> = modify_ptes::<20, 12, MockTableOps, _>(request).collect();
1081        assert_eq!(responses.len(), 3, "3 pages should yield 3 responses");
1082    }
1083
1084    #[test]
1085    fn test_modify_pte_iterator_zero_length() {
1086        let ops = MockTableOps::new();
1087        let request = MapRequest {
1088            table_base: ops.root_table(),
1089            vmin: 0x1000,
1090            len: 0,
1091            update_parent: UpdateParentNone {},
1092        };
1093
1094        let responses: Vec<_> = modify_ptes::<20, 12, MockTableOps, _>(request).collect();
1095        assert_eq!(responses.len(), 0, "Zero length should yield no responses");
1096    }
1097
1098    #[test]
1099    fn test_modify_pte_iterator_unaligned_start() {
1100        let ops = MockTableOps::new();
1101        // Start at 0x1800 (mid-page), map 0x1000 bytes
1102        // Should cover 0x1800-0x1FFF (first page) and 0x2000-0x27FF (second page)
1103        let request = MapRequest {
1104            table_base: ops.root_table(),
1105            vmin: 0x1800,
1106            len: 0x1000,
1107            update_parent: UpdateParentNone {},
1108        };
1109
1110        let responses: Vec<_> = modify_ptes::<20, 12, MockTableOps, _>(request).collect();
1111        assert_eq!(
1112            responses.len(),
1113            2,
1114            "Unaligned mapping spanning 2 pages should yield 2 responses"
1115        );
1116        assert_eq!(responses[0].vmin, 0x1800);
1117        assert_eq!(responses[0].len, 0x800); // Remaining in first page
1118        assert_eq!(responses[1].vmin, 0x2000);
1119        assert_eq!(responses[1].len, 0x800); // Continuing in second page
1120    }
1121
1122    // ==================== TableOps entry_addr tests ====================
1123
1124    #[test]
1125    fn test_entry_addr_from_table_base() {
1126        // entry_addr is called with a table base (entry_index = 0) and a byte offset
1127        // offset = entry_index * 8, so offset 40 means entry 5
1128        let result = MockTableOps::entry_addr((2, 0), 40);
1129        assert_eq!(result, (2, 5), "Should return (table 2, entry 5)");
1130    }
1131
1132    #[test]
1133    fn test_entry_addr_with_nonzero_base_entry() {
1134        // Even though entry_addr is typically called with entry_index=0,
1135        // it should handle non-zero base correctly by adding the offset
1136        // Base: table 1, entry 10 (phys = 1*4096 + 10*8 = 4176)
1137        // Offset: 16 bytes (2 entries)
1138        // Result phys: 4176 + 16 = 4192 = 1*4096 + 12*8 → (1, 12)
1139        let result = MockTableOps::entry_addr((1, 10), 16);
1140        assert_eq!(result, (1, 12), "Should add offset to base entry");
1141    }
1142
1143    #[test]
1144    fn test_to_phys_from_phys_roundtrip() {
1145        // Verify to_phys and from_phys are inverses
1146        let addr = (3, 42);
1147        let phys = MockTableOps::to_phys(addr);
1148        let back = MockTableOps::from_phys(phys);
1149        assert_eq!(back, addr, "to_phys/from_phys should roundtrip");
1150    }
1151}