Skip to main content

mofa_plugins/wasm_runtime/
memory.rs

1//! WASM Memory Management
2//!
3//! Utilities for managing memory between host and WASM guest
4
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use tracing::{debug, warn};
8
9use super::types::{WasmError, WasmResult};
10
11/// Guest pointer type (32-bit address in WASM linear memory)
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct GuestPtr(pub u32);
14
15impl GuestPtr {
16    pub fn new(addr: u32) -> Self {
17        Self(addr)
18    }
19
20    pub fn offset(&self, bytes: u32) -> Self {
21        Self(self.0.saturating_add(bytes))
22    }
23
24    pub fn as_usize(&self) -> usize {
25        self.0 as usize
26    }
27
28    pub fn is_null(&self) -> bool {
29        self.0 == 0
30    }
31}
32
33impl From<u32> for GuestPtr {
34    fn from(addr: u32) -> Self {
35        Self(addr)
36    }
37}
38
39impl From<GuestPtr> for u32 {
40    fn from(ptr: GuestPtr) -> Self {
41        ptr.0
42    }
43}
44
45/// Guest slice (pointer + length)
46#[derive(Debug, Clone, Copy)]
47pub struct GuestSlice {
48    pub ptr: GuestPtr,
49    pub len: u32,
50}
51
52impl GuestSlice {
53    pub fn new(ptr: GuestPtr, len: u32) -> Self {
54        Self { ptr, len }
55    }
56
57    pub fn from_raw(ptr: u32, len: u32) -> Self {
58        Self {
59            ptr: GuestPtr(ptr),
60            len,
61        }
62    }
63
64    pub fn is_empty(&self) -> bool {
65        self.len == 0
66    }
67
68    pub fn end(&self) -> GuestPtr {
69        self.ptr.offset(self.len)
70    }
71}
72
73/// Memory region descriptor
74#[derive(Debug, Clone)]
75pub struct MemoryRegion {
76    /// Start address
77    pub start: GuestPtr,
78    /// Size in bytes
79    pub size: u32,
80    /// Is allocated
81    pub allocated: bool,
82    /// Optional tag for debugging
83    pub tag: Option<String>,
84}
85
86impl MemoryRegion {
87    pub fn new(start: GuestPtr, size: u32) -> Self {
88        Self {
89            start,
90            size,
91            allocated: true,
92            tag: None,
93        }
94    }
95
96    pub fn with_tag(mut self, tag: &str) -> Self {
97        self.tag = Some(tag.to_string());
98        self
99    }
100
101    pub fn contains(&self, addr: GuestPtr) -> bool {
102        addr.0 >= self.start.0 && addr.0 < self.start.0 + self.size
103    }
104
105    pub fn end(&self) -> GuestPtr {
106        self.start.offset(self.size)
107    }
108}
109
110/// WASM memory abstraction
111pub struct WasmMemory {
112    /// Memory data (simulated for now, actual implementation uses wasmtime Memory)
113    data: Vec<u8>,
114    /// Current size in pages (64KB each)
115    pages: u32,
116    /// Maximum pages
117    max_pages: Option<u32>,
118    /// Allocated regions tracking
119    regions: Vec<MemoryRegion>,
120    /// Next allocation address
121    heap_base: u32,
122}
123
124impl WasmMemory {
125    const PAGE_SIZE: u32 = 65536; // 64KB
126
127    pub fn new(initial_pages: u32, max_pages: Option<u32>) -> Self {
128        let size = initial_pages as usize * Self::PAGE_SIZE as usize;
129        Self {
130            data: vec![0u8; size],
131            pages: initial_pages,
132            max_pages,
133            regions: Vec::new(),
134            heap_base: Self::PAGE_SIZE, // Reserve first page for stack/globals
135        }
136    }
137
138    /// Get current size in bytes
139    pub fn size(&self) -> u32 {
140        self.pages * Self::PAGE_SIZE
141    }
142
143    /// Get current size in pages
144    pub fn pages(&self) -> u32 {
145        self.pages
146    }
147
148    /// Grow memory by delta pages
149    pub fn grow(&mut self, delta: u32) -> WasmResult<u32> {
150        let new_pages =
151            self.pages
152                .checked_add(delta)
153                .ok_or_else(|| WasmError::AllocationFailed {
154                    size: delta * Self::PAGE_SIZE,
155                })?;
156
157        if let Some(max) = self.max_pages
158            && new_pages > max
159        {
160            return Err(WasmError::ResourceLimitExceeded(format!(
161                "Memory growth would exceed max pages: {} > {}",
162                new_pages, max
163            )));
164        }
165
166        let old_pages = self.pages;
167        let new_size = new_pages as usize * Self::PAGE_SIZE as usize;
168        self.data.resize(new_size, 0);
169        self.pages = new_pages;
170
171        debug!("Memory grown from {} to {} pages", old_pages, new_pages);
172        Ok(old_pages)
173    }
174
175    /// Read bytes from memory
176    pub fn read(&self, ptr: GuestPtr, len: u32) -> WasmResult<&[u8]> {
177        let start = ptr.as_usize();
178        let end = start
179            .checked_add(len as usize)
180            .ok_or(WasmError::MemoryOutOfBounds {
181                offset: ptr.0,
182                size: len,
183            })?;
184
185        if end > self.data.len() {
186            return Err(WasmError::MemoryOutOfBounds {
187                offset: ptr.0,
188                size: len,
189            });
190        }
191
192        Ok(&self.data[start..end])
193    }
194
195    /// Read bytes as mutable
196    pub fn read_mut(&mut self, ptr: GuestPtr, len: u32) -> WasmResult<&mut [u8]> {
197        let start = ptr.as_usize();
198        let end = start
199            .checked_add(len as usize)
200            .ok_or(WasmError::MemoryOutOfBounds {
201                offset: ptr.0,
202                size: len,
203            })?;
204
205        if end > self.data.len() {
206            return Err(WasmError::MemoryOutOfBounds {
207                offset: ptr.0,
208                size: len,
209            });
210        }
211
212        Ok(&mut self.data[start..end])
213    }
214
215    /// Write bytes to memory
216    pub fn write(&mut self, ptr: GuestPtr, data: &[u8]) -> WasmResult<()> {
217        let start = ptr.as_usize();
218        let end = start
219            .checked_add(data.len())
220            .ok_or(WasmError::MemoryOutOfBounds {
221                offset: ptr.0,
222                size: data.len() as u32,
223            })?;
224
225        if end > self.data.len() {
226            return Err(WasmError::MemoryOutOfBounds {
227                offset: ptr.0,
228                size: data.len() as u32,
229            });
230        }
231
232        self.data[start..end].copy_from_slice(data);
233        Ok(())
234    }
235
236    /// Read a string from memory (null-terminated or with length)
237    pub fn read_string(&self, ptr: GuestPtr, len: u32) -> WasmResult<String> {
238        let bytes = self.read(ptr, len)?;
239        String::from_utf8(bytes.to_vec())
240            .map_err(|e| WasmError::SerializationError(format!("Invalid UTF-8: {}", e)))
241    }
242
243    /// Write a string to memory
244    pub fn write_string(&mut self, ptr: GuestPtr, s: &str) -> WasmResult<()> {
245        self.write(ptr, s.as_bytes())
246    }
247
248    /// Read a value of type T
249    pub fn read_value<T: Copy>(&self, ptr: GuestPtr) -> WasmResult<T> {
250        let size = std::mem::size_of::<T>() as u32;
251        let bytes = self.read(ptr, size)?;
252
253        // Safety: We know the bytes are properly aligned and sized
254        Ok(unsafe { std::ptr::read_unaligned(bytes.as_ptr() as *const T) })
255    }
256
257    /// Write a value of type T
258    pub fn write_value<T: Copy>(&mut self, ptr: GuestPtr, value: T) -> WasmResult<()> {
259        let size = std::mem::size_of::<T>();
260        let bytes = unsafe { std::slice::from_raw_parts(&value as *const T as *const u8, size) };
261        self.write(ptr, bytes)
262    }
263
264    /// Allocate memory from the heap
265    pub fn alloc(&mut self, size: u32) -> WasmResult<GuestPtr> {
266        // Align to 8 bytes
267        let aligned_size = (size + 7) & !7;
268
269        // Check if we have enough space
270        let end = self
271            .heap_base
272            .checked_add(aligned_size)
273            .ok_or(WasmError::AllocationFailed { size })?;
274
275        if end > self.size() {
276            // Try to grow memory
277            let needed_pages = (end / Self::PAGE_SIZE) + 1;
278            let delta = needed_pages.saturating_sub(self.pages);
279            if delta > 0 {
280                self.grow(delta)?;
281            }
282        }
283
284        let ptr = GuestPtr(self.heap_base);
285        self.heap_base = end;
286
287        self.regions.push(MemoryRegion::new(ptr, aligned_size));
288        debug!("Allocated {} bytes at {:?}", aligned_size, ptr);
289
290        Ok(ptr)
291    }
292
293    /// Free allocated memory (basic implementation - doesn't reuse space)
294    pub fn free(&mut self, ptr: GuestPtr) -> WasmResult<()> {
295        if let Some(region) = self.regions.iter_mut().find(|r| r.start == ptr) {
296            region.allocated = false;
297            debug!("Freed {} bytes at {:?}", region.size, ptr);
298            Ok(())
299        } else {
300            warn!("Attempted to free unallocated memory at {:?}", ptr);
301            Ok(()) // Don't error on double-free
302        }
303    }
304
305    /// Allocate and write data
306    pub fn alloc_bytes(&mut self, data: &[u8]) -> WasmResult<GuestSlice> {
307        let ptr = self.alloc(data.len() as u32)?;
308        self.write(ptr, data)?;
309        Ok(GuestSlice::new(ptr, data.len() as u32))
310    }
311
312    /// Allocate and write string
313    pub fn alloc_string(&mut self, s: &str) -> WasmResult<GuestSlice> {
314        self.alloc_bytes(s.as_bytes())
315    }
316}
317
318impl Default for WasmMemory {
319    fn default() -> Self {
320        Self::new(1, Some(256))
321    }
322}
323
324/// Shared memory buffer for inter-module communication
325pub struct SharedMemoryBuffer {
326    /// Buffer data
327    data: Arc<RwLock<Vec<u8>>>,
328    /// Maximum size
329    max_size: usize,
330}
331
332impl SharedMemoryBuffer {
333    pub fn new(max_size: usize) -> Self {
334        Self {
335            data: Arc::new(RwLock::new(Vec::with_capacity(max_size))),
336            max_size,
337        }
338    }
339
340    pub async fn write(&self, data: &[u8]) -> WasmResult<()> {
341        if data.len() > self.max_size {
342            return Err(WasmError::ResourceLimitExceeded(format!(
343                "Data size {} exceeds buffer max {}",
344                data.len(),
345                self.max_size
346            )));
347        }
348
349        let mut buf = self.data.write().await;
350        buf.clear();
351        buf.extend_from_slice(data);
352        Ok(())
353    }
354
355    pub async fn read(&self) -> Vec<u8> {
356        self.data.read().await.clone()
357    }
358
359    pub async fn clear(&self) {
360        self.data.write().await.clear();
361    }
362
363    pub async fn len(&self) -> usize {
364        self.data.read().await.len()
365    }
366
367    pub async fn is_empty(&self) -> bool {
368        self.data.read().await.is_empty()
369    }
370}
371
372/// Memory allocator for managing guest memory
373pub struct MemoryAllocator {
374    /// Free list of blocks
375    free_blocks: Vec<MemoryRegion>,
376    /// Minimum allocation size
377    min_block_size: u32,
378    /// Total allocated bytes
379    allocated_bytes: u64,
380    /// Peak allocated bytes
381    peak_bytes: u64,
382}
383
384impl MemoryAllocator {
385    pub fn new(min_block_size: u32) -> Self {
386        Self {
387            free_blocks: Vec::new(),
388            min_block_size,
389            allocated_bytes: 0,
390            peak_bytes: 0,
391        }
392    }
393
394    /// Find a suitable free block or return None
395    pub fn find_free_block(&mut self, size: u32) -> Option<GuestPtr> {
396        let aligned_size = self.align_size(size);
397
398        // First-fit algorithm
399        for (i, block) in self.free_blocks.iter().enumerate() {
400            if block.size >= aligned_size {
401                let ptr = block.start;
402                let remaining = block.size - aligned_size;
403
404                if remaining >= self.min_block_size {
405                    // Split the block
406                    self.free_blocks[i] = MemoryRegion::new(ptr.offset(aligned_size), remaining);
407                } else {
408                    // Use entire block
409                    self.free_blocks.remove(i);
410                }
411
412                self.allocated_bytes += aligned_size as u64;
413                self.peak_bytes = self.peak_bytes.max(self.allocated_bytes);
414
415                return Some(ptr);
416            }
417        }
418
419        None
420    }
421
422    /// Return a block to the free list
423    pub fn return_block(&mut self, region: MemoryRegion) {
424        self.allocated_bytes = self.allocated_bytes.saturating_sub(region.size as u64);
425
426        // Try to coalesce with adjacent blocks
427        let mut coalesced = region;
428        let mut i = 0;
429
430        while i < self.free_blocks.len() {
431            let block = &self.free_blocks[i];
432
433            // Check if blocks are adjacent
434            if block.end() == coalesced.start {
435                // Block is immediately before
436                coalesced = MemoryRegion::new(block.start, block.size + coalesced.size);
437                self.free_blocks.remove(i);
438            } else if coalesced.end() == block.start {
439                // Block is immediately after
440                coalesced = MemoryRegion::new(coalesced.start, coalesced.size + block.size);
441                self.free_blocks.remove(i);
442            } else {
443                i += 1;
444            }
445        }
446
447        self.free_blocks.push(coalesced);
448    }
449
450    /// Add initial memory region
451    pub fn add_region(&mut self, region: MemoryRegion) {
452        self.free_blocks.push(region);
453    }
454
455    fn align_size(&self, size: u32) -> u32 {
456        (size + 7) & !7
457    }
458
459    pub fn allocated_bytes(&self) -> u64 {
460        self.allocated_bytes
461    }
462
463    pub fn peak_bytes(&self) -> u64 {
464        self.peak_bytes
465    }
466}
467
468impl Default for MemoryAllocator {
469    fn default() -> Self {
470        Self::new(16)
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477
478    #[test]
479    fn test_guest_ptr() {
480        let ptr = GuestPtr::new(100);
481        assert_eq!(ptr.0, 100);
482
483        let offset_ptr = ptr.offset(50);
484        assert_eq!(offset_ptr.0, 150);
485
486        assert!(!ptr.is_null());
487        assert!(GuestPtr::new(0).is_null());
488    }
489
490    #[test]
491    fn test_wasm_memory_read_write() {
492        let mut mem = WasmMemory::new(1, Some(16));
493
494        // Write some data
495        let data = b"Hello, WASM!";
496        let ptr = GuestPtr::new(1024);
497        mem.write(ptr, data).unwrap();
498
499        // Read it back
500        let read = mem.read(ptr, data.len() as u32).unwrap();
501        assert_eq!(read, data);
502    }
503
504    #[test]
505    fn test_wasm_memory_string() {
506        let mut mem = WasmMemory::new(1, Some(16));
507
508        let ptr = GuestPtr::new(1024);
509        let s = "Hello, 世界!";
510        mem.write_string(ptr, s).unwrap();
511
512        let read = mem.read_string(ptr, s.len() as u32).unwrap();
513        assert_eq!(read, s);
514    }
515
516    #[test]
517    fn test_wasm_memory_alloc() {
518        let mut mem = WasmMemory::new(1, Some(16));
519
520        let ptr1 = mem.alloc(100).unwrap();
521        let ptr2 = mem.alloc(200).unwrap();
522
523        assert_ne!(ptr1.0, ptr2.0);
524        assert!(ptr2.0 > ptr1.0);
525    }
526
527    #[test]
528    fn test_wasm_memory_grow() {
529        let mut mem = WasmMemory::new(1, Some(4));
530        assert_eq!(mem.pages(), 1);
531
532        let old = mem.grow(2).unwrap();
533        assert_eq!(old, 1);
534        assert_eq!(mem.pages(), 3);
535    }
536
537    #[test]
538    fn test_memory_bounds() {
539        let mem = WasmMemory::new(1, Some(1));
540
541        // Reading beyond memory should fail
542        let result = mem.read(GuestPtr::new(65536), 100);
543        assert!(result.is_err());
544    }
545
546    #[tokio::test]
547    async fn test_shared_buffer() {
548        let buf = SharedMemoryBuffer::new(1024);
549
550        buf.write(b"test data").await.unwrap();
551        assert_eq!(buf.read().await, b"test data");
552
553        buf.clear().await;
554        assert!(buf.is_empty().await);
555    }
556}