Skip to main content

ax_memory_set/
set.rs

1use alloc::collections::BTreeMap;
2#[allow(unused_imports)] // this is a weird false alarm
3use alloc::vec::Vec;
4use core::fmt;
5
6use ax_memory_addr::{AddrRange, MemoryAddr};
7
8use crate::{MappingBackend, MappingError, MappingResult, MemoryArea};
9
10/// A container that maintains memory mappings ([`MemoryArea`]).
11pub struct MemorySet<B: MappingBackend> {
12    areas: BTreeMap<B::Addr, MemoryArea<B>>,
13}
14
15impl<B: MappingBackend> MemorySet<B> {
16    /// Creates a new memory set.
17    pub const fn new() -> Self {
18        Self {
19            areas: BTreeMap::new(),
20        }
21    }
22
23    /// Returns the number of memory areas in the memory set.
24    pub fn len(&self) -> usize {
25        self.areas.len()
26    }
27
28    /// Returns `true` if the memory set contains no memory areas.
29    pub fn is_empty(&self) -> bool {
30        self.areas.is_empty()
31    }
32
33    /// Returns the iterator over all memory areas.
34    pub fn iter(&self) -> impl Iterator<Item = &MemoryArea<B>> {
35        self.areas.values()
36    }
37
38    /// Returns whether the given address range overlaps with any existing area.
39    pub fn overlaps(&self, range: AddrRange<B::Addr>) -> bool {
40        if let Some((_, before)) = self.areas.range(..range.start).last()
41            && before.va_range().overlaps(range)
42        {
43            return true;
44        }
45        if let Some((_, after)) = self.areas.range(range.start..).next()
46            && after.va_range().overlaps(range)
47        {
48            return true;
49        }
50        false
51    }
52
53    /// Finds the memory area that contains the given address.
54    pub fn find(&self, addr: B::Addr) -> Option<&MemoryArea<B>> {
55        let candidate = self.areas.range(..=addr).last().map(|(_, a)| a);
56        candidate.filter(|a| a.va_range().contains(addr))
57    }
58
59    /// Finds a free area that can accommodate the given size.
60    ///
61    /// The search starts from the given `hint` address, and the area should be
62    /// within the given `limit` range.
63    ///
64    /// # Notes
65    /// The `align` parameter specifies the alignment of the start address and
66    /// the size of the area. The start address of the resulting area will
67    /// be aligned to this value. Also, the size of the area must be a multiple
68    /// of this value.
69    ///
70    /// # Returns
71    /// Returns the start address of the free area. Returns `None` if no such
72    /// area is found.
73    pub fn find_free_area(
74        &self,
75        hint: B::Addr,
76        size: usize,
77        limit: AddrRange<B::Addr>,
78        align: usize,
79    ) -> Option<B::Addr> {
80        if !size.is_multiple_of(align) {
81            // size must be a multiple of align.
82            return None;
83        }
84        // brute force: try each area's end address as the start.
85        let mut last_end: <B as MappingBackend>::Addr = hint.max(limit.start).align_up(align);
86        if let Some((_, area)) = self.areas.range(..last_end).last() {
87            last_end = last_end.max(area.end()).align_up(align);
88        }
89        for (&addr, area) in self.areas.range(last_end..) {
90            if last_end.checked_add(size).is_some_and(|end| end <= addr) {
91                return Some(last_end);
92            }
93            last_end = area.end().align_up(align);
94        }
95        if last_end
96            .checked_add(size)
97            .is_some_and(|end| end <= limit.end)
98        {
99            Some(last_end)
100        } else {
101            None
102        }
103    }
104
105    /// Grows the area containing `addr` by `additional_size` at its end.
106    pub fn extend_area(
107        &mut self,
108        addr: B::Addr,
109        additional_size: usize,
110        page_table: &mut B::PageTable,
111    ) -> MappingResult {
112        if additional_size == 0 {
113            return Ok(());
114        }
115
116        // Find the area containing addr.
117        let area_start = self
118            .areas
119            .range(..=addr)
120            .last()
121            .filter(|(_, a)| a.va_range().contains(addr))
122            .map(|(&start, _)| start)
123            .ok_or(MappingError::InvalidParam)?;
124
125        // Only the next area can conflict with a rightward extension.
126        let area_end = self.areas[&area_start].end();
127        let new_end = area_end
128            .checked_add(additional_size)
129            .ok_or(MappingError::InvalidParam)?;
130        if let Some((_, next)) = self.areas.range(area_end..).next()
131            && new_end > next.start()
132        {
133            return Err(MappingError::AlreadyExists);
134        }
135
136        self.areas
137            .get_mut(&area_start)
138            .unwrap()
139            .grow_right(additional_size, page_table)?;
140        Ok(())
141    }
142
143    /// Add a new memory mapping.
144    ///
145    /// The mapping is represented by a [`MemoryArea`].
146    ///
147    /// If the new area overlaps with any existing area, the behavior is
148    /// determined by the `unmap_overlap` parameter. If it is `true`, the
149    /// overlapped regions will be unmapped first. Otherwise, it returns an
150    /// error.
151    pub fn map(
152        &mut self,
153        area: MemoryArea<B>,
154        page_table: &mut B::PageTable,
155        unmap_overlap: bool,
156    ) -> MappingResult {
157        if area.va_range().is_empty() {
158            return Err(MappingError::InvalidParam);
159        }
160
161        if self.overlaps(area.va_range()) {
162            if unmap_overlap {
163                self.unmap(area.start(), area.size(), page_table)?;
164            } else {
165                return Err(MappingError::AlreadyExists);
166            }
167        }
168
169        area.map_area(page_table)?;
170        assert!(self.areas.insert(area.start(), area).is_none());
171        Ok(())
172    }
173
174    /// Remove memory mappings within the given address range.
175    ///
176    /// All memory areas that are fully contained in the range will be removed
177    /// directly. If the area intersects with the boundary, it will be shrinked.
178    /// If the unmapped range is in the middle of an existing area, it will be
179    /// split into two areas.
180    pub fn unmap(
181        &mut self,
182        start: B::Addr,
183        size: usize,
184        page_table: &mut B::PageTable,
185    ) -> MappingResult {
186        let range =
187            AddrRange::try_from_start_size(start, size).ok_or(MappingError::InvalidParam)?;
188        if range.is_empty() {
189            return Ok(());
190        }
191
192        let end = range.end;
193
194        // Unmap entire areas that are contained by the range.
195        self.areas.retain(|_, area| {
196            if area.va_range().contained_in(range) {
197                area.unmap_area(page_table).unwrap();
198                false
199            } else {
200                true
201            }
202        });
203
204        // Shrink right if the area intersects with the left boundary.
205        if let Some((&before_start, before)) = self.areas.range_mut(..start).last() {
206            let before_end = before.end();
207            if before_end > start {
208                if before_end <= end {
209                    // the unmapped area is at the end of `before`.
210                    before.shrink_right(start.sub_addr(before_start), page_table)?;
211                } else {
212                    // the unmapped area is in the middle `before`, need to split.
213                    let right_part = before.split(end).unwrap();
214                    before.shrink_right(start.sub_addr(before_start), page_table)?;
215                    assert_eq!(right_part.start().into(), Into::<usize>::into(end));
216                    self.areas.insert(end, right_part);
217                }
218            }
219        }
220
221        // Shrink left if the area intersects with the right boundary.
222        if let Some((&after_start, after)) = self.areas.range_mut(start..).next() {
223            let after_end = after.end();
224            if after_start < end {
225                // the unmapped area is at the start of `after`.
226                let mut new_area = self.areas.remove(&after_start).unwrap();
227                new_area.shrink_left(after_end.sub_addr(end), page_table)?;
228                assert_eq!(new_area.start().into(), Into::<usize>::into(end));
229                self.areas.insert(end, new_area);
230            }
231        }
232
233        Ok(())
234    }
235
236    /// Remove memory area metadata without calling the backend's unmap hook.
237    ///
238    /// This is intended for callers that have already moved or detached the
239    /// affected page-table entries and only need to update VMA bookkeeping.
240    pub fn unmap_metadata(&mut self, start: B::Addr, size: usize) -> MappingResult {
241        let range =
242            AddrRange::try_from_start_size(start, size).ok_or(MappingError::InvalidParam)?;
243        if range.is_empty() {
244            return Ok(());
245        }
246
247        let end = range.end;
248
249        self.areas
250            .retain(|_, area| !area.va_range().contained_in(range));
251
252        if let Some((&before_start, before)) = self.areas.range_mut(..start).last() {
253            let before_end = before.end();
254            if before_end > start {
255                if before_end <= end {
256                    before.shrink_right_metadata(start.sub_addr(before_start));
257                } else {
258                    let right_part = before.split(end).unwrap();
259                    before.shrink_right_metadata(start.sub_addr(before_start));
260                    assert_eq!(right_part.start().into(), Into::<usize>::into(end));
261                    self.areas.insert(end, right_part);
262                }
263            }
264        }
265
266        if let Some((&after_start, _)) = self.areas.range(start..).next()
267            && after_start < end
268        {
269            let mut new_area = self.areas.remove(&after_start).unwrap();
270            let after_end = new_area.end();
271            new_area.shrink_left_metadata(after_end.sub_addr(end));
272            assert_eq!(new_area.start().into(), Into::<usize>::into(end));
273            self.areas.insert(end, new_area);
274        }
275
276        Ok(())
277    }
278
279    /// Replaces area metadata without touching page-table entries.
280    pub fn replace_area_metadata(&mut self, area: MemoryArea<B>) -> MappingResult {
281        if area.va_range().is_empty() {
282            return Err(MappingError::InvalidParam);
283        }
284
285        let start = area.start();
286        let end = area.end();
287
288        let old_start = self
289            .areas
290            .range(..=start)
291            .last()
292            .filter(|(_, old)| old.start() <= start && end <= old.end())
293            .map(|(&old_start, _)| old_start)
294            .ok_or(MappingError::InvalidParam)?;
295
296        let mut old_area = self.areas.remove(&old_start).unwrap();
297        if old_start < start {
298            let right_part = old_area.split(start).unwrap();
299            self.areas.insert(old_start, old_area);
300            old_area = right_part;
301        }
302        if old_area.end() > end {
303            let right_part = old_area.split(end).unwrap();
304            self.areas.insert(right_part.start(), right_part);
305        }
306        assert!(self.areas.insert(start, area).is_none());
307        Ok(())
308    }
309
310    /// Remove all memory areas and the underlying mappings.
311    pub fn clear(&mut self, page_table: &mut B::PageTable) -> MappingResult {
312        for (_, area) in self.areas.iter() {
313            area.unmap_area(page_table)?;
314        }
315        self.areas.clear();
316        Ok(())
317    }
318
319    /// Change the flags of memory mappings within the given address range.
320    ///
321    /// `update_flags` is a function that receives old flags and processes
322    /// new flags (e.g., some flags can not be changed through this interface).
323    /// It returns [`None`] if there is no bit to change.
324    ///
325    /// Memory areas will be skipped according to `update_flags`. Memory areas
326    /// that are fully contained in the range or contains the range or
327    /// intersects with the boundary will be handled similarly to `munmap`.
328    pub fn protect(
329        &mut self,
330        start: B::Addr,
331        size: usize,
332        update_flags: impl Fn(B::Flags) -> Option<B::Flags>,
333        page_table: &mut B::PageTable,
334    ) -> MappingResult {
335        let end = start.checked_add(size).ok_or(MappingError::InvalidParam)?;
336        let mut to_insert = Vec::new();
337        for (&area_start, area) in self.areas.iter_mut() {
338            let area_end = area.end();
339
340            if let Some(new_flags) = update_flags(area.flags()) {
341                if area_start >= end {
342                    // [ prot ]
343                    //          [ area ]
344                    break;
345                } else if area_end <= start {
346                    //          [ prot ]
347                    // [ area ]
348                    // Do nothing
349                } else if area_start >= start && area_end <= end {
350                    // [   prot   ]
351                    //   [ area ]
352                    area.protect_area(new_flags, page_table)?;
353                    area.set_flags(new_flags);
354                } else if area_start < start && area_end > end {
355                    //        [ prot ]
356                    // [ left | area | right ]
357                    let mut middle_part = area.split(start).unwrap();
358                    let right_part = middle_part.split(end).unwrap();
359
360                    middle_part.protect_area(new_flags, page_table)?;
361                    middle_part.set_flags(new_flags);
362
363                    to_insert.push((right_part.start(), right_part));
364                    to_insert.push((middle_part.start(), middle_part));
365                } else if area_end > end {
366                    // [    prot ]
367                    //   [  area | right ]
368                    let right_part = area.split(end).unwrap();
369                    area.protect_area(new_flags, page_table)?;
370                    area.set_flags(new_flags);
371
372                    to_insert.push((right_part.start(), right_part));
373                } else {
374                    //        [ prot    ]
375                    // [ left |  area ]
376                    let mut right_part = area.split(start).unwrap();
377                    right_part.protect_area(new_flags, page_table)?;
378                    right_part.set_flags(new_flags);
379
380                    to_insert.push((right_part.start(), right_part));
381                }
382            }
383        }
384        self.areas.extend(to_insert);
385        Ok(())
386    }
387}
388
389impl<B: MappingBackend> Default for MemorySet<B> {
390    fn default() -> Self {
391        Self::new()
392    }
393}
394
395impl<B: MappingBackend> fmt::Debug for MemorySet<B>
396where
397    B::Addr: fmt::Debug,
398    B::Flags: fmt::Debug,
399{
400    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
401        f.debug_list().entries(self.areas.values()).finish()
402    }
403}