rquickjs_core/allocator/
rust.rs

1use std::{
2    alloc::{self, Layout},
3    mem, ptr,
4};
5
6use super::Allocator;
7
8/// The largest value QuickJS will allocate is a u64;
9/// So all allocated memory must have the same alignment is this largest size.
10const ALLOC_ALIGN: usize = mem::align_of::<u64>();
11
12#[derive(Copy, Clone)]
13#[repr(transparent)]
14struct Header {
15    size: usize,
16}
17
18const fn max(a: usize, b: usize) -> usize {
19    if a < b {
20        b
21    } else {
22        a
23    }
24}
25
26/// Head needs to be at least alloc aligned so all that values after the header are aligned.
27const HEADER_SIZE: usize = max(mem::size_of::<Header>(), ALLOC_ALIGN);
28
29#[inline]
30fn round_size(size: usize) -> usize {
31    (size + ALLOC_ALIGN - 1) / ALLOC_ALIGN * ALLOC_ALIGN
32}
33
34/// The allocator which uses Rust global allocator
35pub struct RustAllocator;
36
37unsafe impl Allocator for RustAllocator {
38    fn calloc(&mut self, count: usize, size: usize) -> *mut u8 {
39        if count == 0 || size == 0 {
40            return ptr::null_mut();
41        }
42
43        let total_size = count.checked_mul(size).expect("overflow");
44        let total_size = round_size(total_size);
45
46        let alloc_size = HEADER_SIZE + total_size;
47
48        let layout = if let Ok(layout) = Layout::from_size_align(alloc_size, ALLOC_ALIGN) {
49            layout
50        } else {
51            return ptr::null_mut();
52        };
53
54        let ptr = unsafe { alloc::alloc_zeroed(layout) };
55
56        if ptr.is_null() {
57            return ptr::null_mut();
58        }
59
60        unsafe {
61            ptr.cast::<Header>().write(Header { size: total_size });
62            ptr.add(HEADER_SIZE)
63        }
64    }
65
66    fn alloc(&mut self, size: usize) -> *mut u8 {
67        let size = round_size(size);
68        let alloc_size = size + HEADER_SIZE;
69
70        let layout = if let Ok(layout) = Layout::from_size_align(alloc_size, ALLOC_ALIGN) {
71            layout
72        } else {
73            return ptr::null_mut();
74        };
75
76        let ptr = unsafe { alloc::alloc(layout) };
77
78        if ptr.is_null() {
79            return ptr::null_mut();
80        }
81
82        unsafe {
83            ptr.cast::<Header>().write(Header { size });
84            ptr.add(HEADER_SIZE)
85        }
86    }
87
88    unsafe fn dealloc(&mut self, ptr: *mut u8) {
89        let ptr = ptr.sub(HEADER_SIZE);
90        let alloc_size = ptr.cast::<Header>().read().size + HEADER_SIZE;
91        let layout = Layout::from_size_align_unchecked(alloc_size, ALLOC_ALIGN);
92
93        alloc::dealloc(ptr, layout);
94    }
95
96    unsafe fn realloc(&mut self, ptr: *mut u8, new_size: usize) -> *mut u8 {
97        let new_size = round_size(new_size);
98
99        let ptr = ptr.sub(HEADER_SIZE);
100        let alloc_size = ptr.cast::<Header>().read().size + HEADER_SIZE;
101
102        let layout = Layout::from_size_align_unchecked(alloc_size, ALLOC_ALIGN);
103
104        let new_alloc_size = new_size + HEADER_SIZE;
105
106        let ptr = alloc::realloc(ptr, layout, new_alloc_size);
107
108        if ptr.is_null() {
109            return ptr::null_mut();
110        }
111
112        ptr.cast::<Header>().write(Header { size: new_size });
113        ptr.add(HEADER_SIZE)
114    }
115
116    unsafe fn usable_size(ptr: *mut u8) -> usize {
117        let ptr = ptr.sub(HEADER_SIZE);
118        ptr.cast::<Header>().read().size
119    }
120}
121
122#[cfg(all(test, feature = "rust-alloc"))]
123mod test {
124    use super::RustAllocator;
125    use crate::{allocator::Allocator, Context, Runtime};
126    use std::sync::atomic::{AtomicUsize, Ordering};
127
128    static ALLOC_SIZE: AtomicUsize = AtomicUsize::new(0);
129
130    struct TestAllocator;
131
132    unsafe impl Allocator for TestAllocator {
133        fn alloc(&mut self, size: usize) -> *mut u8 {
134            unsafe {
135                let res = RustAllocator.alloc(size);
136                ALLOC_SIZE.fetch_add(RustAllocator::usable_size(res), Ordering::AcqRel);
137                res
138            }
139        }
140
141        fn calloc(&mut self, count: usize, size: usize) -> *mut u8 {
142            unsafe {
143                let res = RustAllocator.calloc(count, size);
144                ALLOC_SIZE.fetch_add(RustAllocator::usable_size(res), Ordering::AcqRel);
145                res
146            }
147        }
148
149        unsafe fn dealloc(&mut self, ptr: *mut u8) {
150            ALLOC_SIZE.fetch_sub(RustAllocator::usable_size(ptr), Ordering::AcqRel);
151            RustAllocator.dealloc(ptr);
152        }
153
154        unsafe fn realloc(&mut self, ptr: *mut u8, new_size: usize) -> *mut u8 {
155            if !ptr.is_null() {
156                ALLOC_SIZE.fetch_sub(RustAllocator::usable_size(ptr), Ordering::AcqRel);
157            }
158
159            let res = RustAllocator.realloc(ptr, new_size);
160            if !res.is_null() {
161                ALLOC_SIZE.fetch_add(RustAllocator::usable_size(res), Ordering::AcqRel);
162            }
163            res
164        }
165
166        unsafe fn usable_size(ptr: *mut u8) -> usize
167        where
168            Self: Sized,
169        {
170            RustAllocator::usable_size(ptr)
171        }
172    }
173
174    #[test]
175    fn test_gc_working_correctly() {
176        let rt = Runtime::new_with_alloc(TestAllocator).unwrap();
177        let context = Context::full(&rt).unwrap();
178
179        let before = ALLOC_SIZE.load(Ordering::Acquire);
180
181        context.with(|ctx| {
182            ctx.eval::<(), _>(
183                r#"
184                for(let i = 0;i < 100_000;i++){
185                    // create recursive structure.
186                    const a = () => {
187                        if(a){
188                            return true
189                        }
190                        return false
191                    };
192                }
193            "#,
194            )
195            .unwrap();
196        });
197
198        let after = ALLOC_SIZE.load(Ordering::Acquire);
199        // every object takes atleast a single byte.
200        // So the gc must have collected atleast some of the recursive objects if the difference is
201        // smaller then number of objects created.
202        assert!(after.saturating_sub(before) < 100_000)
203    }
204}