Skip to main content

hyperlight_common/arch/amd64/
vmem.rs

1/*
2Copyright 2025  The Hyperlight Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15 */
16
17//! x86-64 4-level page table manipulation code.
18//!
19//! This module implements page table setup for x86-64 long mode using 4-level paging:
20//! - PML4 (Page Map Level 4) - bits 47:39 - 512 entries, each covering 512GB
21//! - PDPT (Page Directory Pointer Table) - bits 38:30 - 512 entries, each covering 1GB
22//! - PD (Page Directory) - bits 29:21 - 512 entries, each covering 2MB
23//! - PT (Page Table) - bits 20:12 - 512 entries, each covering 4KB pages
24//!
25//! The code uses an iterator-based approach to walk the page table hierarchy,
26//! allocating intermediate tables as needed and setting appropriate flags on leaf PTEs
27
28use crate::vmem::{
29    BasicMapping, CowMapping, Mapping, MappingKind, TableMovabilityBase, TableOps, TableReadOps,
30    Void,
31};
32
33// Paging Flags
34//
35// See the following links explaining paging:
36//
37// * Intel® 64 and IA-32 Architectures Software Developer’s Manual, Volume 3A: System Programming Guide, Part 1
38//  - Chapter 5 "Paging"
39//
40// https://cdrdv2.intel.com/v1/dl/getContent/671200
41//
42// * AMD64 Architecture Programmer’s Manual, Volume 2: System Programming, Section 5.3: Long-Mode Page Translation
43//
44// https://docs.amd.com/v/u/en-US/24593_3.43
45//
46// Or if you prefer something less formal:
47//
48// * Very basic description: https://stackoverflow.com/a/26945892
49// * More in-depth descriptions: https://wiki.osdev.org/Paging
50//
51
52/// Page is Present
53const PAGE_PRESENT: u64 = 1;
54/// Page is Read/Write (if not set page is read only so long as the WP bit in CR0 is set to 1 - which it is in Hyperlight)
55const PAGE_RW: u64 = 1 << 1;
56/// Execute Disable (if this bit is set then data in the page cannot be executed)`
57const PAGE_NX: u64 = 1 << 63;
58/// Mask to extract the physical address from a PTE (bits 51:12)
59/// This masks out the lower 12 flag bits AND the upper bits including NX (bit 63)
60const PTE_ADDR_MASK: u64 = 0x000F_FFFF_FFFF_F000;
61const PAGE_USER_ACCESS_DISABLED: u64 = 0 << 2; // U/S bit not set - supervisor mode only (no code runs in user mode for now)
62const PAGE_DIRTY_SET: u64 = 1 << 6; // D - dirty bit
63const PAGE_ACCESSED_SET: u64 = 1 << 5; // A - accessed bit
64const PAGE_CACHE_ENABLED: u64 = 0 << 4; // PCD - page cache disable bit not set (caching enabled)
65const PAGE_WRITE_BACK: u64 = 0 << 3; // PWT - page write-through bit not set (write-back caching)
66const PAGE_PAT_WB: u64 = 0 << 7; // PAT - page attribute table index bit (0 for write-back memory when PCD=0, PWT=0)
67
68// We use various patterns of the available-for-software-use bits to
69// represent certain special mappings.
70const PTE_AVL_MASK: u64 = 0x0000_0000_0000_0E00;
71const PAGE_AVL_COW: u64 = 1 << 9;
72
73/// Returns PAGE_RW if writable is true, 0 otherwise
74#[inline(always)]
75const fn page_rw_flag(writable: bool) -> u64 {
76    if writable { PAGE_RW } else { 0 }
77}
78
79/// Returns PAGE_NX if executable is false (NX = No Execute), 0 otherwise
80#[inline(always)]
81const fn page_nx_flag(executable: bool) -> u64 {
82    if executable { 0 } else { PAGE_NX }
83}
84
85/// Read a page table entry and return it if the present bit is set
86/// # Safety
87/// The caller must ensure that `entry_ptr` points to a valid page table entry.
88#[inline(always)]
89unsafe fn read_pte_if_present<Op: TableReadOps>(op: &Op, entry_ptr: Op::TableAddr) -> Option<u64> {
90    let pte = unsafe { op.read_entry(entry_ptr) };
91    if (pte & PAGE_PRESENT) != 0 {
92        Some(pte)
93    } else {
94        None
95    }
96}
97
98/// Utility function to extract an (inclusive on both ends) bit range
99/// from a quadword.
100#[inline(always)]
101fn bits<const HIGH_BIT: u8, const LOW_BIT: u8>(x: u64) -> u64 {
102    (x & ((1 << (HIGH_BIT + 1)) - 1)) >> LOW_BIT
103}
104
105/// Helper function to generate a page table entry that points to another table
106#[allow(clippy::identity_op)]
107#[allow(clippy::precedence)]
108fn pte_for_table<Op: TableOps>(table_addr: Op::TableAddr) -> u64 {
109    Op::to_phys(table_addr) |
110        PAGE_ACCESSED_SET | // prevent the CPU writing to the access flag
111        PAGE_CACHE_ENABLED | // leave caching enabled
112        PAGE_WRITE_BACK | // use write-back caching
113        PAGE_USER_ACCESS_DISABLED |// dont allow user access (no code runs in user mode for now)
114        PAGE_RW | // R/W - we don't use block-level permissions
115        PAGE_PRESENT // P   - this entry is present
116}
117
118/// This trait is used to select appropriate implementations of
119/// [`UpdateParent`] to be used, depending on whether a particular
120/// implementation needs the ability to move tables.
121pub trait TableMovability<Op: TableReadOps + ?Sized, TableMoveInfo> {
122    type RootUpdateParent: UpdateParent<Op, TableMoveInfo = TableMoveInfo>;
123    fn root_update_parent() -> Self::RootUpdateParent;
124}
125impl<Op: TableOps<TableMovability = crate::vmem::MayMoveTable>> TableMovability<Op, Op::TableAddr>
126    for crate::vmem::MayMoveTable
127{
128    type RootUpdateParent = UpdateParentRoot;
129    fn root_update_parent() -> Self::RootUpdateParent {
130        UpdateParentRoot {}
131    }
132}
133impl<Op: TableReadOps> TableMovability<Op, Void> for crate::vmem::MayNotMoveTable {
134    type RootUpdateParent = UpdateParentNone;
135    fn root_update_parent() -> Self::RootUpdateParent {
136        UpdateParentNone {}
137    }
138}
139
140/// Helper function to write a page table entry, updating the whole
141/// chain of tables back to the root if necessary
142unsafe fn write_entry_updating<
143    Op: TableOps,
144    P: UpdateParent<
145            Op,
146            TableMoveInfo = <Op::TableMovability as TableMovabilityBase<Op>>::TableMoveInfo,
147        >,
148>(
149    op: &Op,
150    parent: P,
151    addr: Op::TableAddr,
152    entry: u64,
153) {
154    if let Some(again) = unsafe { op.write_entry(addr, entry) } {
155        parent.update_parent(op, again);
156    }
157}
158
159/// A helper trait that allows us to move a page table (e.g. from the
160/// snapshot to the scratch region), keeping track of the context that
161/// needs to be updated when that is moved (and potentially
162/// recursively updating, if necessary)
163///
164/// This is done via a trait so that the selected impl knows the exact
165/// nesting depth of tables, in order to assist
166/// inlining/specialisation in generating efficient code.
167///
168/// The trait definition only bounds its parameter by
169/// [`TableReadOps`], since [`UpdateParentNone`] does not need to be
170/// able to actually write to the tables.
171pub trait UpdateParent<Op: TableReadOps + ?Sized>: Copy {
172    /// The type of the information about a moved table which is
173    /// needed in order to update its parent.
174    type TableMoveInfo;
175    /// The [`UpdateParent`] type that should be used when going down
176    /// another level in the table, in order to add the current level
177    /// to the chain of ancestors to be updated.
178    type ChildType: UpdateParent<Op, TableMoveInfo = Self::TableMoveInfo>;
179    fn update_parent(self, op: &Op, new_ptr: Self::TableMoveInfo);
180    fn for_child_at_entry(self, entry_ptr: Op::TableAddr) -> Self::ChildType;
181}
182
183/// A struct implementing [`UpdateParent`] that keeps track of the
184/// fact that the parent table is itself another table, whose own
185/// ancestors may need to be recursively updated
186pub struct UpdateParentTable<Op: TableOps, P: UpdateParent<Op>> {
187    parent: P,
188    entry_ptr: Op::TableAddr,
189}
190impl<Op: TableOps, P: UpdateParent<Op>> Clone for UpdateParentTable<Op, P> {
191    fn clone(&self) -> Self {
192        *self
193    }
194}
195impl<Op: TableOps, P: UpdateParent<Op>> Copy for UpdateParentTable<Op, P> {}
196impl<Op: TableOps, P: UpdateParent<Op>> UpdateParentTable<Op, P> {
197    fn new(parent: P, entry_ptr: Op::TableAddr) -> Self {
198        UpdateParentTable { parent, entry_ptr }
199    }
200}
201impl<
202    Op: TableOps<TableMovability = crate::vmem::MayMoveTable>,
203    P: UpdateParent<Op, TableMoveInfo = Op::TableAddr>,
204> UpdateParent<Op> for UpdateParentTable<Op, P>
205{
206    type TableMoveInfo = Op::TableAddr;
207    type ChildType = UpdateParentTable<Op, Self>;
208    fn update_parent(self, op: &Op, new_ptr: Op::TableAddr) {
209        let pte = pte_for_table::<Op>(new_ptr);
210        unsafe {
211            write_entry_updating(op, self.parent, self.entry_ptr, pte);
212        }
213    }
214    fn for_child_at_entry(self, entry_ptr: Op::TableAddr) -> Self::ChildType {
215        Self::ChildType::new(self, entry_ptr)
216    }
217}
218
219/// A struct implementing [`UpdateParent`] that keeps track of the
220/// fact that the parent "table" is actually the root (e.g. the value
221/// of CR3 in the guest)
222#[derive(Copy, Clone)]
223pub struct UpdateParentRoot {}
224impl<Op: TableOps<TableMovability = crate::vmem::MayMoveTable>> UpdateParent<Op>
225    for UpdateParentRoot
226{
227    type TableMoveInfo = Op::TableAddr;
228    type ChildType = UpdateParentTable<Op, Self>;
229    fn update_parent(self, op: &Op, new_ptr: Op::TableAddr) {
230        unsafe {
231            op.update_root(new_ptr);
232        }
233    }
234    fn for_child_at_entry(self, entry_ptr: Op::TableAddr) -> Self::ChildType {
235        Self::ChildType::new(self, entry_ptr)
236    }
237}
238
239/// A struct implementing [`UpdateParent`] that is impossible to use
240/// (since its [`update_parent`] method takes `Void`), used when it is
241/// statically known that a table operation cannot result in a need to
242/// update ancestors.
243#[derive(Copy, Clone)]
244pub struct UpdateParentNone {}
245impl<Op: TableReadOps> UpdateParent<Op> for UpdateParentNone {
246    type TableMoveInfo = Void;
247    type ChildType = Self;
248    fn update_parent(self, _op: &Op, impossible: Void) {
249        match impossible {}
250    }
251    fn for_child_at_entry(self, _entry_ptr: Op::TableAddr) -> Self {
252        self
253    }
254}
255
256/// A helper structure indicating a mapping operation that needs to be
257/// performed
258struct MapRequest<Op: TableReadOps, P: UpdateParent<Op>> {
259    table_base: Op::TableAddr,
260    vmin: VirtAddr,
261    len: u64,
262    update_parent: P,
263}
264
265/// A helper structure indicating that a particular PTE needs to be
266/// modified
267struct MapResponse<Op: TableReadOps, P: UpdateParent<Op>> {
268    entry_ptr: Op::TableAddr,
269    vmin: VirtAddr,
270    len: u64,
271    update_parent: P,
272}
273
274/// Iterator that walks through page table entries at a specific level.
275///
276/// Given a virtual address range and a table base, this iterator yields
277/// `MapResponse` items for each page table entry that needs to be modified.
278/// The const generics `HIGH_BIT` and `LOW_BIT` specify which bits of the
279/// virtual address are used to index into this level's table.
280///
281/// For example:
282/// - PML4: HIGH_BIT=47, LOW_BIT=39 (9 bits = 512 entries, each covering 512GB)
283/// - PDPT: HIGH_BIT=38, LOW_BIT=30 (9 bits = 512 entries, each covering 1GB)
284/// - PD:   HIGH_BIT=29, LOW_BIT=21 (9 bits = 512 entries, each covering 2MB)
285/// - PT:   HIGH_BIT=20, LOW_BIT=12 (9 bits = 512 entries, each covering 4KB)
286struct ModifyPteIterator<
287    const HIGH_BIT: u8,
288    const LOW_BIT: u8,
289    Op: TableReadOps,
290    P: UpdateParent<Op>,
291> {
292    request: MapRequest<Op, P>,
293    n: u64,
294}
295impl<const HIGH_BIT: u8, const LOW_BIT: u8, Op: TableReadOps, P: UpdateParent<Op>> Iterator
296    for ModifyPteIterator<HIGH_BIT, LOW_BIT, Op, P>
297{
298    type Item = MapResponse<Op, P>;
299    fn next(&mut self) -> Option<Self::Item> {
300        // Each page table entry at this level covers a region of size (1 << LOW_BIT) bytes.
301        // For example, at the PT level (LOW_BIT=12), each entry covers 4KB (0x1000 bytes).
302        // At the PD level (LOW_BIT=21), each entry covers 2MB (0x200000 bytes).
303        //
304        // This mask isolates the bits below this level's index bits, used for alignment.
305        let lower_bits_mask = (1 << LOW_BIT) - 1;
306
307        // Calculate the virtual address for this iteration.
308        // On the first iteration (n=0), start at the requested vmin.
309        // On subsequent iterations, advance to the next aligned boundary.
310        // This handles the case where vmin isn't aligned to this level's entry size.
311        let next_vmin = if self.n == 0 {
312            self.request.vmin
313        } else {
314            // Align to the next boundary by adding one entry's worth
315            // and masking off lower bits. Masking off before adding
316            // is safe, since n << LOW_BIT must always have zeros in
317            // these positions.
318            let aligned_min = self.request.vmin & !lower_bits_mask;
319            // Use checked_add here because going past the end of the
320            // address space counts as "the next one would be out of
321            // range"
322            aligned_min.checked_add(self.n << LOW_BIT)?
323        };
324
325        // Check if we've processed the entire requested range
326        if next_vmin >= self.request.vmin + self.request.len {
327            return None;
328        }
329
330        // Calculate the pointer to this level's page table entry.
331        // bits::<HIGH_BIT, LOW_BIT> extracts the relevant index bits from the virtual address.
332        // Shift left by 3 (multiply by 8) because each entry is 8 bytes (u64).
333        let entry_ptr = Op::entry_addr(
334            self.request.table_base,
335            bits::<HIGH_BIT, LOW_BIT>(next_vmin) << 3,
336        );
337
338        // Calculate how many bytes remain to be mapped from this point
339        let len_from_here = self.request.len - (next_vmin - self.request.vmin);
340
341        // Calculate the maximum bytes this single entry can cover.
342        // If next_vmin is aligned, this is the full entry size (1 << LOW_BIT).
343        // If not aligned (only possible on first iteration), it's the remaining
344        // space until the next boundary.
345        let max_len = (1 << LOW_BIT) - (next_vmin & lower_bits_mask);
346
347        // The actual length for this entry is the smaller of what's needed vs what fits
348        let next_len = core::cmp::min(len_from_here, max_len);
349
350        // Advance iteration counter for next call
351        self.n += 1;
352
353        Some(MapResponse {
354            entry_ptr,
355            vmin: next_vmin,
356            len: next_len,
357            update_parent: self.request.update_parent,
358        })
359    }
360}
361fn modify_ptes<const HIGH_BIT: u8, const LOW_BIT: u8, Op: TableReadOps, P: UpdateParent<Op>>(
362    r: MapRequest<Op, P>,
363) -> ModifyPteIterator<HIGH_BIT, LOW_BIT, Op, P> {
364    ModifyPteIterator { request: r, n: 0 }
365}
366
367/// Page-mapping callback to allocate a next-level page table if necessary.
368/// # Safety
369/// This function modifies page table data structures, and should not be called concurrently
370/// with any other operations that modify the page tables.
371unsafe fn alloc_pte_if_needed<
372    Op: TableOps,
373    P: UpdateParent<
374            Op,
375            TableMoveInfo = <Op::TableMovability as TableMovabilityBase<Op>>::TableMoveInfo,
376        >,
377>(
378    op: &Op,
379    x: MapResponse<Op, P>,
380) -> MapRequest<Op, P::ChildType>
381where
382    P::ChildType: UpdateParent<Op>,
383{
384    let new_update_parent = x.update_parent.for_child_at_entry(x.entry_ptr);
385    if let Some(pte) = unsafe { read_pte_if_present(op, x.entry_ptr) } {
386        return MapRequest {
387            table_base: Op::from_phys(pte & PTE_ADDR_MASK),
388            vmin: x.vmin,
389            len: x.len,
390            update_parent: new_update_parent,
391        };
392    }
393
394    let page_addr = unsafe { op.alloc_table() };
395
396    let pte = pte_for_table::<Op>(page_addr);
397    unsafe {
398        write_entry_updating(op, x.update_parent, x.entry_ptr, pte);
399    };
400    MapRequest {
401        table_base: page_addr,
402        vmin: x.vmin,
403        len: x.len,
404        update_parent: new_update_parent,
405    }
406}
407
408/// Map a normal memory page
409/// # Safety
410/// This function modifies page table data structures, and should not be called concurrently
411/// with any other operations that modify the page tables.
412#[allow(clippy::identity_op)]
413#[allow(clippy::precedence)]
414unsafe fn map_page<
415    Op: TableOps,
416    P: UpdateParent<
417            Op,
418            TableMoveInfo = <Op::TableMovability as TableMovabilityBase<Op>>::TableMoveInfo,
419        >,
420>(
421    op: &Op,
422    mapping: &Mapping,
423    r: MapResponse<Op, P>,
424) {
425    let pte = match &mapping.kind {
426        MappingKind::Basic(bm) =>
427        // TODO: Support not readable
428        // NOTE: On x86-64, there is no separate "readable" bit in the page table entry.
429        // This means that pages cannot be made write-only or execute-only without also being readable.
430        // All pages that are mapped as writable or executable are also implicitly readable.
431        // If support for "not readable" mappings is required in the future, it would need to be
432        // implemented using additional mechanisms (e.g., page-fault handling or memory protection keys),
433        // but for now, this architectural limitation is accepted.
434        {
435            (mapping.phys_base + (r.vmin - mapping.virt_base)) |
436                page_nx_flag(bm.executable) | // NX - no execute unless allowed
437                PAGE_PAT_WB | // PAT index bit for write-back memory
438                PAGE_DIRTY_SET | // prevent the CPU writing to the dirty bit
439                PAGE_ACCESSED_SET | // prevent the CPU writing to the access flag
440                PAGE_CACHE_ENABLED | // leave caching enabled
441                PAGE_WRITE_BACK | // use write-back caching
442                PAGE_USER_ACCESS_DISABLED | // dont allow user access (no code runs in user mode for now)
443                page_rw_flag(bm.writable) | // R/W - set if writable
444                PAGE_PRESENT // P   - this entry is present
445        }
446        MappingKind::Cow(cm) => {
447            (mapping.phys_base + (r.vmin - mapping.virt_base)) |
448                page_nx_flag(cm.executable) | // NX - no execute unless allowed
449                PAGE_AVL_COW |
450                PAGE_PAT_WB | // PAT index bit for write-back memory
451                PAGE_DIRTY_SET | // prevent the CPU writing to the dirty bit
452                PAGE_ACCESSED_SET | // prevent the CPU writing to the access flag
453                PAGE_CACHE_ENABLED | // leave caching enabled
454                PAGE_WRITE_BACK | // use write-back caching
455                PAGE_USER_ACCESS_DISABLED | // dont allow user access (no code runs in user mode for now)
456                0 | // R/W - Cow page is never writable
457                PAGE_PRESENT // P   - this entry is present
458        }
459    };
460    unsafe {
461        write_entry_updating(op, r.update_parent, r.entry_ptr, pte);
462    }
463}
464
465// There are no notable architecture-specific safety considerations
466// here, and the general conditions are documented in the
467// architecture-independent re-export in vmem.rs
468
469/// Maps a contiguous virtual address range to physical memory.
470///
471/// This function walks the 4-level page table hierarchy (PML4 → PDPT → PD → PT),
472/// allocating intermediate tables as needed via `alloc_pte_if_needed`, and finally
473/// writing the leaf page table entries with the requested permissions via `map_page`.
474///
475/// The iterator chain processes each level:
476/// 1. PML4 (47:39) - allocate PDPT if needed
477/// 2. PDPT (38:30) - allocate PD if needed
478/// 3. PD (29:21) - allocate PT if needed
479/// 4. PT (20:12) - write final PTE with physical address and flags
480#[allow(clippy::missing_safety_doc)]
481pub unsafe fn map<Op: TableOps>(op: &Op, mapping: Mapping) {
482    modify_ptes::<47, 39, Op, _>(MapRequest {
483        table_base: op.root_table(),
484        vmin: mapping.virt_base,
485        len: mapping.len,
486        update_parent: Op::TableMovability::root_update_parent(),
487    })
488    .map(|r| unsafe { alloc_pte_if_needed(op, r) })
489    .flat_map(modify_ptes::<38, 30, Op, _>)
490    .map(|r| unsafe { alloc_pte_if_needed(op, r) })
491    .flat_map(modify_ptes::<29, 21, Op, _>)
492    .map(|r| unsafe { alloc_pte_if_needed(op, r) })
493    .flat_map(modify_ptes::<20, 12, Op, _>)
494    .map(|r| unsafe { map_page(op, &mapping, r) })
495    .for_each(drop);
496}
497
498/// # Safety
499/// This function traverses page table data structures, and should not
500/// be called concurrently with any other operations that modify the
501/// page table.
502unsafe fn require_pte_exist<Op: TableReadOps, P: UpdateParent<Op>>(
503    op: &Op,
504    x: MapResponse<Op, P>,
505) -> Option<MapRequest<Op, P::ChildType>>
506where
507    P::ChildType: UpdateParent<Op>,
508{
509    unsafe { read_pte_if_present(op, x.entry_ptr) }.map(|pte| MapRequest {
510        table_base: Op::from_phys(pte & PTE_ADDR_MASK),
511        vmin: x.vmin,
512        len: x.len,
513        update_parent: x.update_parent.for_child_at_entry(x.entry_ptr),
514    })
515}
516
517// There are no notable architecture-specific safety considerations
518// here, and the general conditions are documented in the
519// architecture-independent re-export in vmem.rs
520
521/// Translates a virtual address range to the physical address pages
522/// that back it by walking the page tables.
523///
524/// Returns an iterator with an entry for each mapped page that
525/// intersects the given range.
526///
527/// This takes AsRef<Op> + Copy so that on targets where the
528/// operations have little state (e.g. the guest) the operations state
529/// can be copied into the closure(s) in the iterator, allowing for a
530/// nicer result lifetime.  On targets like the
531/// building-an-original-snapshot portion of the host, where the
532/// operations structure owns a large buffer, a reference can instead
533/// be passed.
534#[allow(clippy::missing_safety_doc)]
535pub unsafe fn virt_to_phys<'a, Op: TableReadOps + 'a>(
536    op: impl core::convert::AsRef<Op> + Copy + 'a,
537    address: u64,
538    len: u64,
539) -> impl Iterator<Item = Mapping> + 'a {
540    // Undo sign-extension
541    let addr = address & ((1u64 << VA_BITS) - 1);
542    // Mask off any sub-page bits
543    let vmin = addr & !(PAGE_SIZE as u64 - 1);
544    // Calculate the maximum virtual address we need to look at based on the starting
545    // address and length ensuring we don't go past the end of the address space
546    let vmax = core::cmp::min(addr + len, 1u64 << VA_BITS);
547    modify_ptes::<47, 39, Op, _>(MapRequest {
548        table_base: op.as_ref().root_table(),
549        vmin,
550        len: vmax - vmin,
551        update_parent: UpdateParentNone {},
552    })
553    .filter_map(move |r| unsafe { require_pte_exist(op.as_ref(), r) })
554    .flat_map(modify_ptes::<38, 30, Op, _>)
555    .filter_map(move |r| unsafe { require_pte_exist(op.as_ref(), r) })
556    .flat_map(modify_ptes::<29, 21, Op, _>)
557    .filter_map(move |r| unsafe { require_pte_exist(op.as_ref(), r) })
558    .flat_map(modify_ptes::<20, 12, Op, _>)
559    .filter_map(move |r| {
560        let pte = unsafe { read_pte_if_present(op.as_ref(), r.entry_ptr) }?;
561        let phys_addr = pte & PTE_ADDR_MASK;
562        // Re-do the sign extension
563        let sgn_bit = r.vmin >> (VA_BITS - 1);
564        let sgn_bits = 0u64.wrapping_sub(sgn_bit) << VA_BITS;
565        let virt_addr = sgn_bits | r.vmin;
566
567        let executable = (pte & PAGE_NX) == 0;
568        let avl = pte & PTE_AVL_MASK;
569        let kind = if avl == PAGE_AVL_COW {
570            MappingKind::Cow(CowMapping {
571                readable: true,
572                executable,
573            })
574        } else {
575            MappingKind::Basic(BasicMapping {
576                readable: true,
577                writable: (pte & PAGE_RW) != 0,
578                executable,
579            })
580        };
581        Some(Mapping {
582            phys_base: phys_addr,
583            virt_base: virt_addr,
584            len: PAGE_SIZE as u64,
585            kind,
586        })
587    })
588}
589
590const VA_BITS: usize = 48; // We use 48-bit virtual addresses at the moment.
591
592pub const PAGE_SIZE: usize = 4096;
593pub const PAGE_TABLE_SIZE: usize = 4096;
594pub type PageTableEntry = u64;
595pub type VirtAddr = u64;
596pub type PhysAddr = u64;
597
598#[cfg(test)]
599mod tests {
600    use alloc::vec;
601    use alloc::vec::Vec;
602    use core::cell::RefCell;
603
604    use super::*;
605    use crate::vmem::{
606        BasicMapping, Mapping, MappingKind, MayNotMoveTable, PAGE_TABLE_ENTRIES_PER_TABLE,
607        TableOps, TableReadOps, Void,
608    };
609
610    /// A mock TableOps implementation for testing that stores page tables in memory
611    /// needed because the `GuestPageTableBuffer` is in hyperlight_host which would cause a circular dependency
612    struct MockTableOps {
613        tables: RefCell<Vec<[u64; PAGE_TABLE_ENTRIES_PER_TABLE]>>,
614    }
615
616    // for virt_to_phys
617    impl core::convert::AsRef<MockTableOps> for MockTableOps {
618        fn as_ref(&self) -> &Self {
619            self
620        }
621    }
622
623    impl MockTableOps {
624        fn new() -> Self {
625            // Start with one table (the root/PML4)
626            Self {
627                tables: RefCell::new(vec![[0u64; PAGE_TABLE_ENTRIES_PER_TABLE]]),
628            }
629        }
630
631        fn table_count(&self) -> usize {
632            self.tables.borrow().len()
633        }
634
635        fn get_entry(&self, table_idx: usize, entry_idx: usize) -> u64 {
636            self.tables.borrow()[table_idx][entry_idx]
637        }
638    }
639
640    impl TableReadOps for MockTableOps {
641        type TableAddr = (usize, usize); // (table_index, entry_index)
642
643        fn entry_addr(addr: Self::TableAddr, entry_offset: u64) -> Self::TableAddr {
644            // Convert to physical address, add offset, convert back
645            let phys = Self::to_phys(addr) + entry_offset;
646            Self::from_phys(phys)
647        }
648
649        unsafe fn read_entry(&self, addr: Self::TableAddr) -> u64 {
650            self.tables.borrow()[addr.0][addr.1]
651        }
652
653        fn to_phys(addr: Self::TableAddr) -> PhysAddr {
654            // Each table is 4KB, entries are 8 bytes
655            (addr.0 as u64 * PAGE_TABLE_SIZE as u64) + (addr.1 as u64 * 8)
656        }
657
658        fn from_phys(addr: PhysAddr) -> Self::TableAddr {
659            let table_idx = (addr / PAGE_TABLE_SIZE as u64) as usize;
660            let entry_idx = ((addr % PAGE_TABLE_SIZE as u64) / 8) as usize;
661            (table_idx, entry_idx)
662        }
663
664        fn root_table(&self) -> Self::TableAddr {
665            (0, 0)
666        }
667    }
668
669    impl TableOps for MockTableOps {
670        type TableMovability = MayNotMoveTable;
671
672        unsafe fn alloc_table(&self) -> Self::TableAddr {
673            let mut tables = self.tables.borrow_mut();
674            let idx = tables.len();
675            tables.push([0u64; PAGE_TABLE_ENTRIES_PER_TABLE]);
676            (idx, 0)
677        }
678
679        unsafe fn write_entry(&self, addr: Self::TableAddr, entry: u64) -> Option<Void> {
680            self.tables.borrow_mut()[addr.0][addr.1] = entry;
681            None
682        }
683
684        unsafe fn update_root(&self, impossible: Void) {
685            match impossible {}
686        }
687    }
688
689    // ==================== bits() function tests ====================
690
691    #[test]
692    fn test_bits_extracts_pml4_index() {
693        // PML4 uses bits 47:39
694        // Address 0x0000_0080_0000_0000 should have PML4 index 1
695        let addr: u64 = 0x0000_0080_0000_0000;
696        assert_eq!(bits::<47, 39>(addr), 1);
697    }
698
699    #[test]
700    fn test_bits_extracts_pdpt_index() {
701        // PDPT uses bits 38:30
702        // Address with PDPT index 1: bit 30 set = 0x4000_0000 (1GB)
703        let addr: u64 = 0x4000_0000;
704        assert_eq!(bits::<38, 30>(addr), 1);
705    }
706
707    #[test]
708    fn test_bits_extracts_pd_index() {
709        // PD uses bits 29:21
710        // Address 0x0000_0000_0020_0000 (2MB) should have PD index 1
711        let addr: u64 = 0x0000_0000_0020_0000;
712        assert_eq!(bits::<29, 21>(addr), 1);
713    }
714
715    #[test]
716    fn test_bits_extracts_pt_index() {
717        // PT uses bits 20:12
718        // Address 0x0000_0000_0000_1000 (4KB) should have PT index 1
719        let addr: u64 = 0x0000_0000_0000_1000;
720        assert_eq!(bits::<20, 12>(addr), 1);
721    }
722
723    #[test]
724    fn test_bits_max_index() {
725        // Maximum 9-bit index is 511
726        // PML4 index 511 = bits 47:39 all set = 0x0000_FF80_0000_0000
727        let addr: u64 = 0x0000_FF80_0000_0000;
728        assert_eq!(bits::<47, 39>(addr), 511);
729    }
730
731    // ==================== PTE flag tests ====================
732
733    #[test]
734    fn test_page_rw_flag_writable() {
735        assert_eq!(page_rw_flag(true), PAGE_RW);
736    }
737
738    #[test]
739    fn test_page_rw_flag_readonly() {
740        assert_eq!(page_rw_flag(false), 0);
741    }
742
743    #[test]
744    fn test_page_nx_flag_executable() {
745        assert_eq!(page_nx_flag(true), 0); // Executable = no NX bit
746    }
747
748    #[test]
749    fn test_page_nx_flag_not_executable() {
750        assert_eq!(page_nx_flag(false), PAGE_NX);
751    }
752
753    // ==================== map() function tests ====================
754
755    #[test]
756    fn test_map_single_page() {
757        let ops = MockTableOps::new();
758        let mapping = Mapping {
759            phys_base: 0x1000,
760            virt_base: 0x1000,
761            len: PAGE_SIZE as u64,
762            kind: MappingKind::Basic(BasicMapping {
763                readable: true,
764                writable: true,
765                executable: false,
766            }),
767        };
768
769        unsafe { map(&ops, mapping) };
770
771        // Should have allocated: PML4(exists) + PDPT + PD + PT = 4 tables
772        assert_eq!(ops.table_count(), 4);
773
774        // Check PML4 entry 0 points to PDPT (table 1) with correct flags
775        let pml4_entry = ops.get_entry(0, 0);
776        assert_ne!(pml4_entry & PAGE_PRESENT, 0, "PML4 entry should be present");
777        assert_ne!(pml4_entry & PAGE_RW, 0, "PML4 entry should be writable");
778
779        // Check the leaf PTE has correct flags
780        // PT is table 3, entry 1 (for virt_base 0x1000)
781        let pte = ops.get_entry(3, 1);
782        assert_ne!(pte & PAGE_PRESENT, 0, "PTE should be present");
783        assert_ne!(pte & PAGE_RW, 0, "PTE should be writable");
784        assert_ne!(pte & PAGE_NX, 0, "PTE should have NX set (not executable)");
785        assert_eq!(pte & PTE_ADDR_MASK, 0x1000, "PTE should map to phys 0x1000");
786    }
787
788    #[test]
789    fn test_map_executable_page() {
790        let ops = MockTableOps::new();
791        let mapping = Mapping {
792            phys_base: 0x2000,
793            virt_base: 0x2000,
794            len: PAGE_SIZE as u64,
795            kind: MappingKind::Basic(BasicMapping {
796                readable: true,
797                writable: false,
798                executable: true,
799            }),
800        };
801
802        unsafe { map(&ops, mapping) };
803
804        // PT is table 3, entry 2 (for virt_base 0x2000)
805        let pte = ops.get_entry(3, 2);
806        assert_ne!(pte & PAGE_PRESENT, 0, "PTE should be present");
807        assert_eq!(pte & PAGE_RW, 0, "PTE should be read-only");
808        assert_eq!(pte & PAGE_NX, 0, "PTE should NOT have NX set (executable)");
809    }
810
811    #[test]
812    fn test_map_multiple_pages() {
813        let ops = MockTableOps::new();
814        let mapping = Mapping {
815            phys_base: 0x10000,
816            virt_base: 0x10000,
817            len: 4 * PAGE_SIZE as u64, // 4 pages = 16KB
818            kind: MappingKind::Basic(BasicMapping {
819                readable: true,
820                writable: true,
821                executable: false,
822            }),
823        };
824
825        unsafe { map(&ops, mapping) };
826
827        // Check all 4 PTEs are present
828        for i in 0..4 {
829            let entry_idx = 16 + i; // 0x10000 / 0x1000 = 16
830            let pte = ops.get_entry(3, entry_idx);
831            assert_ne!(pte & PAGE_PRESENT, 0, "PTE {} should be present", i);
832            let expected_phys = 0x10000 + (i as u64 * PAGE_SIZE as u64);
833            assert_eq!(
834                pte & PTE_ADDR_MASK,
835                expected_phys,
836                "PTE {} should map to correct phys addr",
837                i
838            );
839        }
840    }
841
842    #[test]
843    fn test_map_reuses_existing_tables() {
844        let ops = MockTableOps::new();
845
846        // Map first region
847        let mapping1 = Mapping {
848            phys_base: 0x1000,
849            virt_base: 0x1000,
850            len: PAGE_SIZE as u64,
851            kind: MappingKind::Basic(BasicMapping {
852                readable: true,
853                writable: true,
854                executable: false,
855            }),
856        };
857        unsafe { map(&ops, mapping1) };
858        let tables_after_first = ops.table_count();
859
860        // Map second region in same PT (different page)
861        let mapping2 = Mapping {
862            phys_base: 0x5000,
863            virt_base: 0x5000,
864            len: PAGE_SIZE as u64,
865            kind: MappingKind::Basic(BasicMapping {
866                readable: true,
867                writable: true,
868                executable: false,
869            }),
870        };
871        unsafe { map(&ops, mapping2) };
872
873        // Should NOT allocate new tables (reuses existing hierarchy)
874        assert_eq!(
875            ops.table_count(),
876            tables_after_first,
877            "Should reuse existing page tables"
878        );
879    }
880
881    // ==================== virt_to_phys() tests ====================
882
883    #[test]
884    fn test_virt_to_phys_mapped_address() {
885        let ops = MockTableOps::new();
886        let mapping = Mapping {
887            phys_base: 0x1000,
888            virt_base: 0x1000,
889            len: PAGE_SIZE as u64,
890            kind: MappingKind::Basic(BasicMapping {
891                readable: true,
892                writable: true,
893                executable: false,
894            }),
895        };
896
897        unsafe { map(&ops, mapping) };
898
899        let result = unsafe { virt_to_phys(&ops, 0x1000, 1).next() };
900        assert!(result.is_some(), "Should find mapped address");
901        let mapping = result.unwrap();
902        assert_eq!(mapping.phys_base, 0x1000);
903    }
904
905    #[test]
906    fn test_virt_to_phys_unaligned_virt() {
907        let ops = MockTableOps::new();
908        let mapping = Mapping {
909            phys_base: 0x1000,
910            virt_base: 0x1000,
911            len: PAGE_SIZE as u64,
912            kind: MappingKind::Basic(BasicMapping {
913                readable: true,
914                writable: true,
915                executable: false,
916            }),
917        };
918
919        unsafe { map(&ops, mapping) };
920
921        let result = unsafe { virt_to_phys(&ops, 0x1234, 1).next() };
922        assert!(result.is_some(), "Should find mapped address");
923        let mapping = result.unwrap();
924        assert_eq!(mapping.phys_base, 0x1000);
925    }
926
927    #[test]
928    fn test_virt_to_phys_unaligned_virt_and_across_pages_len() {
929        let ops = MockTableOps::new();
930        let mapping = Mapping {
931            phys_base: 0x1000,
932            virt_base: 0x1000,
933            len: 2 * PAGE_SIZE as u64, // 2 page
934            kind: MappingKind::Basic(BasicMapping {
935                readable: true,
936                writable: true,
937                executable: false,
938            }),
939        };
940
941        unsafe { map(&ops, mapping) };
942
943        let mappings = unsafe { virt_to_phys(&ops, 0x1F00, 0x300).collect::<Vec<_>>() };
944        assert_eq!(mappings.len(), 2, "Should return 2 mappings for 2 pages");
945        assert_eq!(mappings[0].phys_base, 0x1000);
946        assert_eq!(mappings[1].phys_base, 0x2000);
947    }
948
949    #[test]
950    fn test_virt_to_phys_unaligned_virt_and_multiple_page_len() {
951        let ops = MockTableOps::new();
952        let mapping = Mapping {
953            phys_base: 0x1000,
954            virt_base: 0x1000,
955            len: PAGE_SIZE as u64 * 2 + 0x200, // 2 page + 512 bytes
956            kind: MappingKind::Basic(BasicMapping {
957                readable: true,
958                writable: true,
959                executable: false,
960            }),
961        };
962
963        unsafe { map(&ops, mapping) };
964
965        let mappings =
966            unsafe { virt_to_phys(&ops, 0x1234, PAGE_SIZE as u64 * 2 + 0x10).collect::<Vec<_>>() };
967        assert_eq!(mappings.len(), 3, "Should return 3 mappings for 3 pages");
968        assert_eq!(mappings[0].phys_base, 0x1000);
969        assert_eq!(mappings[1].phys_base, 0x2000);
970        assert_eq!(mappings[2].phys_base, 0x3000);
971    }
972
973    #[test]
974    fn test_virt_to_phys_perms() {
975        let test = |kind| {
976            let ops = MockTableOps::new();
977            let mapping = Mapping {
978                phys_base: 0x1000,
979                virt_base: 0x1000,
980                len: PAGE_SIZE as u64,
981                kind,
982            };
983            unsafe { map(&ops, mapping) };
984            let result = unsafe { virt_to_phys(&ops, 0x1000, 1).next() };
985            let mapping = result.unwrap();
986            assert_eq!(mapping.kind, kind);
987        };
988        test(MappingKind::Basic(BasicMapping {
989            readable: true,
990            writable: false,
991            executable: false,
992        }));
993        test(MappingKind::Basic(BasicMapping {
994            readable: true,
995            writable: false,
996            executable: true,
997        }));
998        test(MappingKind::Basic(BasicMapping {
999            readable: true,
1000            writable: true,
1001            executable: false,
1002        }));
1003        test(MappingKind::Basic(BasicMapping {
1004            readable: true,
1005            writable: true,
1006            executable: true,
1007        }));
1008        test(MappingKind::Cow(CowMapping {
1009            readable: true,
1010            executable: false,
1011        }));
1012        test(MappingKind::Cow(CowMapping {
1013            readable: true,
1014            executable: true,
1015        }));
1016    }
1017
1018    #[test]
1019    fn test_virt_to_phys_unmapped_address() {
1020        let ops = MockTableOps::new();
1021        // Don't map anything
1022
1023        let result = unsafe { virt_to_phys(&ops, 0x1000, 1).next() };
1024        assert!(result.is_none(), "Should return None for unmapped address");
1025    }
1026
1027    #[test]
1028    fn test_virt_to_phys_partially_mapped() {
1029        let ops = MockTableOps::new();
1030        let mapping = Mapping {
1031            phys_base: 0x1000,
1032            virt_base: 0x1000,
1033            len: PAGE_SIZE as u64,
1034            kind: MappingKind::Basic(BasicMapping {
1035                readable: true,
1036                writable: true,
1037                executable: false,
1038            }),
1039        };
1040
1041        unsafe { map(&ops, mapping) };
1042
1043        // Query an address in a different PT entry (unmapped)
1044        let result = unsafe { virt_to_phys(&ops, 0x5000, 1).next() };
1045        assert!(
1046            result.is_none(),
1047            "Should return None for unmapped address in same PT"
1048        );
1049    }
1050
1051    // ==================== ModifyPteIterator tests ====================
1052
1053    #[test]
1054    fn test_modify_pte_iterator_single_page() {
1055        let ops = MockTableOps::new();
1056        let request = MapRequest {
1057            table_base: ops.root_table(),
1058            vmin: 0x1000,
1059            len: PAGE_SIZE as u64,
1060            update_parent: UpdateParentNone {},
1061        };
1062
1063        let responses: Vec<_> = modify_ptes::<20, 12, MockTableOps, _>(request).collect();
1064        assert_eq!(responses.len(), 1, "Single page should yield one response");
1065        assert_eq!(responses[0].vmin, 0x1000);
1066        assert_eq!(responses[0].len, PAGE_SIZE as u64);
1067    }
1068
1069    #[test]
1070    fn test_modify_pte_iterator_multiple_pages() {
1071        let ops = MockTableOps::new();
1072        let request = MapRequest {
1073            table_base: ops.root_table(),
1074            vmin: 0x1000,
1075            len: 3 * PAGE_SIZE as u64,
1076            update_parent: UpdateParentNone {},
1077        };
1078
1079        let responses: Vec<_> = modify_ptes::<20, 12, MockTableOps, _>(request).collect();
1080        assert_eq!(responses.len(), 3, "3 pages should yield 3 responses");
1081    }
1082
1083    #[test]
1084    fn test_modify_pte_iterator_zero_length() {
1085        let ops = MockTableOps::new();
1086        let request = MapRequest {
1087            table_base: ops.root_table(),
1088            vmin: 0x1000,
1089            len: 0,
1090            update_parent: UpdateParentNone {},
1091        };
1092
1093        let responses: Vec<_> = modify_ptes::<20, 12, MockTableOps, _>(request).collect();
1094        assert_eq!(responses.len(), 0, "Zero length should yield no responses");
1095    }
1096
1097    #[test]
1098    fn test_modify_pte_iterator_unaligned_start() {
1099        let ops = MockTableOps::new();
1100        // Start at 0x1800 (mid-page), map 0x1000 bytes
1101        // Should cover 0x1800-0x1FFF (first page) and 0x2000-0x27FF (second page)
1102        let request = MapRequest {
1103            table_base: ops.root_table(),
1104            vmin: 0x1800,
1105            len: 0x1000,
1106            update_parent: UpdateParentNone {},
1107        };
1108
1109        let responses: Vec<_> = modify_ptes::<20, 12, MockTableOps, _>(request).collect();
1110        assert_eq!(
1111            responses.len(),
1112            2,
1113            "Unaligned mapping spanning 2 pages should yield 2 responses"
1114        );
1115        assert_eq!(responses[0].vmin, 0x1800);
1116        assert_eq!(responses[0].len, 0x800); // Remaining in first page
1117        assert_eq!(responses[1].vmin, 0x2000);
1118        assert_eq!(responses[1].len, 0x800); // Continuing in second page
1119    }
1120
1121    // ==================== TableOps entry_addr tests ====================
1122
1123    #[test]
1124    fn test_entry_addr_from_table_base() {
1125        // entry_addr is called with a table base (entry_index = 0) and a byte offset
1126        // offset = entry_index * 8, so offset 40 means entry 5
1127        let result = MockTableOps::entry_addr((2, 0), 40);
1128        assert_eq!(result, (2, 5), "Should return (table 2, entry 5)");
1129    }
1130
1131    #[test]
1132    fn test_entry_addr_with_nonzero_base_entry() {
1133        // Even though entry_addr is typically called with entry_index=0,
1134        // it should handle non-zero base correctly by adding the offset
1135        // Base: table 1, entry 10 (phys = 1*4096 + 10*8 = 4176)
1136        // Offset: 16 bytes (2 entries)
1137        // Result phys: 4176 + 16 = 4192 = 1*4096 + 12*8 → (1, 12)
1138        let result = MockTableOps::entry_addr((1, 10), 16);
1139        assert_eq!(result, (1, 12), "Should add offset to base entry");
1140    }
1141
1142    #[test]
1143    fn test_to_phys_from_phys_roundtrip() {
1144        // Verify to_phys and from_phys are inverses
1145        let addr = (3, 42);
1146        let phys = MockTableOps::to_phys(addr);
1147        let back = MockTableOps::from_phys(phys);
1148        assert_eq!(back, addr, "to_phys/from_phys should roundtrip");
1149    }
1150}