miraland_bpf_loader_program/syscalls/
mem_ops.rs

1use {
2    super::*,
3    solana_rbpf::{error::EbpfError, memory_region::MemoryRegion},
4    std::slice,
5};
6
7fn mem_op_consume(invoke_context: &mut InvokeContext, n: u64) -> Result<(), Error> {
8    let compute_budget = invoke_context.get_compute_budget();
9    let cost = compute_budget.mem_op_base_cost.max(
10        n.checked_div(compute_budget.cpi_bytes_per_unit)
11            .unwrap_or(u64::MAX),
12    );
13    consume_compute_meter(invoke_context, cost)
14}
15
16declare_builtin_function!(
17    /// memcpy
18    SyscallMemcpy,
19    fn rust(
20        invoke_context: &mut InvokeContext,
21        dst_addr: u64,
22        src_addr: u64,
23        n: u64,
24        _arg4: u64,
25        _arg5: u64,
26        memory_mapping: &mut MemoryMapping,
27    ) -> Result<u64, Error> {
28        mem_op_consume(invoke_context, n)?;
29
30        if !is_nonoverlapping(src_addr, n, dst_addr, n) {
31            return Err(SyscallError::CopyOverlapping.into());
32        }
33
34        // host addresses can overlap so we always invoke memmove
35        memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
36    }
37);
38
39declare_builtin_function!(
40    /// memmove
41    SyscallMemmove,
42    fn rust(
43        invoke_context: &mut InvokeContext,
44        dst_addr: u64,
45        src_addr: u64,
46        n: u64,
47        _arg4: u64,
48        _arg5: u64,
49        memory_mapping: &mut MemoryMapping,
50    ) -> Result<u64, Error> {
51        mem_op_consume(invoke_context, n)?;
52
53        memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
54    }
55);
56
57declare_builtin_function!(
58    /// memcmp
59    SyscallMemcmp,
60    fn rust(
61        invoke_context: &mut InvokeContext,
62        s1_addr: u64,
63        s2_addr: u64,
64        n: u64,
65        cmp_result_addr: u64,
66        _arg5: u64,
67        memory_mapping: &mut MemoryMapping,
68    ) -> Result<u64, Error> {
69        mem_op_consume(invoke_context, n)?;
70
71        if invoke_context
72            .feature_set
73            .is_active(&feature_set::bpf_account_data_direct_mapping::id())
74        {
75            let cmp_result = translate_type_mut::<i32>(
76                memory_mapping,
77                cmp_result_addr,
78                invoke_context.get_check_aligned(),
79            )?;
80            *cmp_result = memcmp_non_contiguous(s1_addr, s2_addr, n, memory_mapping)?;
81        } else {
82            let s1 = translate_slice::<u8>(
83                memory_mapping,
84                s1_addr,
85                n,
86                invoke_context.get_check_aligned(),
87            )?;
88            let s2 = translate_slice::<u8>(
89                memory_mapping,
90                s2_addr,
91                n,
92                invoke_context.get_check_aligned(),
93            )?;
94            let cmp_result = translate_type_mut::<i32>(
95                memory_mapping,
96                cmp_result_addr,
97                invoke_context.get_check_aligned(),
98            )?;
99
100            debug_assert_eq!(s1.len(), n as usize);
101            debug_assert_eq!(s2.len(), n as usize);
102            // Safety:
103            // memcmp is marked unsafe since it assumes that the inputs are at least
104            // `n` bytes long. `s1` and `s2` are guaranteed to be exactly `n` bytes
105            // long because `translate_slice` would have failed otherwise.
106            *cmp_result = unsafe { memcmp(s1, s2, n as usize) };
107        }
108
109        Ok(0)
110    }
111);
112
113declare_builtin_function!(
114    /// memset
115    SyscallMemset,
116    fn rust(
117        invoke_context: &mut InvokeContext,
118        dst_addr: u64,
119        c: u64,
120        n: u64,
121        _arg4: u64,
122        _arg5: u64,
123        memory_mapping: &mut MemoryMapping,
124    ) -> Result<u64, Error> {
125        mem_op_consume(invoke_context, n)?;
126
127        if invoke_context
128            .feature_set
129            .is_active(&feature_set::bpf_account_data_direct_mapping::id())
130        {
131            memset_non_contiguous(dst_addr, c as u8, n, memory_mapping)
132        } else {
133            let s = translate_slice_mut::<u8>(
134                memory_mapping,
135                dst_addr,
136                n,
137                invoke_context.get_check_aligned(),
138            )?;
139            s.fill(c as u8);
140            Ok(0)
141        }
142    }
143);
144
145fn memmove(
146    invoke_context: &mut InvokeContext,
147    dst_addr: u64,
148    src_addr: u64,
149    n: u64,
150    memory_mapping: &MemoryMapping,
151) -> Result<u64, Error> {
152    if invoke_context
153        .feature_set
154        .is_active(&feature_set::bpf_account_data_direct_mapping::id())
155    {
156        memmove_non_contiguous(dst_addr, src_addr, n, memory_mapping)
157    } else {
158        let dst_ptr = translate_slice_mut::<u8>(
159            memory_mapping,
160            dst_addr,
161            n,
162            invoke_context.get_check_aligned(),
163        )?
164        .as_mut_ptr();
165        let src_ptr = translate_slice::<u8>(
166            memory_mapping,
167            src_addr,
168            n,
169            invoke_context.get_check_aligned(),
170        )?
171        .as_ptr();
172
173        unsafe { std::ptr::copy(src_ptr, dst_ptr, n as usize) };
174        Ok(0)
175    }
176}
177
178fn memmove_non_contiguous(
179    dst_addr: u64,
180    src_addr: u64,
181    n: u64,
182    memory_mapping: &MemoryMapping,
183) -> Result<u64, Error> {
184    let reverse = dst_addr.wrapping_sub(src_addr) < n;
185    iter_memory_pair_chunks(
186        AccessType::Load,
187        src_addr,
188        AccessType::Store,
189        dst_addr,
190        n,
191        memory_mapping,
192        reverse,
193        |src_host_addr, dst_host_addr, chunk_len| {
194            unsafe { std::ptr::copy(src_host_addr, dst_host_addr as *mut u8, chunk_len) };
195            Ok(0)
196        },
197    )
198}
199
200// Marked unsafe since it assumes that the slices are at least `n` bytes long.
201unsafe fn memcmp(s1: &[u8], s2: &[u8], n: usize) -> i32 {
202    for i in 0..n {
203        let a = *s1.get_unchecked(i);
204        let b = *s2.get_unchecked(i);
205        if a != b {
206            return (a as i32).saturating_sub(b as i32);
207        };
208    }
209
210    0
211}
212
213fn memcmp_non_contiguous(
214    src_addr: u64,
215    dst_addr: u64,
216    n: u64,
217    memory_mapping: &MemoryMapping,
218) -> Result<i32, Error> {
219    let memcmp_chunk = |s1_addr, s2_addr, chunk_len| {
220        let res = unsafe {
221            let s1 = slice::from_raw_parts(s1_addr, chunk_len);
222            let s2 = slice::from_raw_parts(s2_addr, chunk_len);
223            // Safety:
224            // memcmp is marked unsafe since it assumes that s1 and s2 are exactly chunk_len
225            // long. The whole point of iter_memory_pair_chunks is to find same length chunks
226            // across two memory regions.
227            memcmp(s1, s2, chunk_len)
228        };
229        if res != 0 {
230            return Err(MemcmpError::Diff(res).into());
231        }
232        Ok(0)
233    };
234    match iter_memory_pair_chunks(
235        AccessType::Load,
236        src_addr,
237        AccessType::Load,
238        dst_addr,
239        n,
240        memory_mapping,
241        false,
242        memcmp_chunk,
243    ) {
244        Ok(res) => Ok(res),
245        Err(error) => match error.downcast_ref() {
246            Some(MemcmpError::Diff(diff)) => Ok(*diff),
247            _ => Err(error),
248        },
249    }
250}
251
252#[derive(Debug)]
253enum MemcmpError {
254    Diff(i32),
255}
256
257impl std::fmt::Display for MemcmpError {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        match self {
260            MemcmpError::Diff(diff) => write!(f, "memcmp diff: {diff}"),
261        }
262    }
263}
264
265impl std::error::Error for MemcmpError {
266    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
267        match self {
268            MemcmpError::Diff(_) => None,
269        }
270    }
271}
272
273fn memset_non_contiguous(
274    dst_addr: u64,
275    c: u8,
276    n: u64,
277    memory_mapping: &MemoryMapping,
278) -> Result<u64, Error> {
279    let dst_chunk_iter = MemoryChunkIterator::new(memory_mapping, AccessType::Store, dst_addr, n)?;
280    for item in dst_chunk_iter {
281        let (dst_region, dst_vm_addr, dst_len) = item?;
282        let dst_host_addr = Result::from(dst_region.vm_to_host(dst_vm_addr, dst_len as u64))?;
283        unsafe { slice::from_raw_parts_mut(dst_host_addr as *mut u8, dst_len).fill(c) }
284    }
285
286    Ok(0)
287}
288
289fn iter_memory_pair_chunks<T, F>(
290    src_access: AccessType,
291    src_addr: u64,
292    dst_access: AccessType,
293    dst_addr: u64,
294    n_bytes: u64,
295    memory_mapping: &MemoryMapping,
296    reverse: bool,
297    mut fun: F,
298) -> Result<T, Error>
299where
300    T: Default,
301    F: FnMut(*const u8, *const u8, usize) -> Result<T, Error>,
302{
303    let mut src_chunk_iter =
304        MemoryChunkIterator::new(memory_mapping, src_access, src_addr, n_bytes)
305            .map_err(EbpfError::from)?;
306    let mut dst_chunk_iter =
307        MemoryChunkIterator::new(memory_mapping, dst_access, dst_addr, n_bytes)
308            .map_err(EbpfError::from)?;
309
310    let mut src_chunk = None;
311    let mut dst_chunk = None;
312
313    macro_rules! memory_chunk {
314        ($chunk_iter:ident, $chunk:ident) => {
315            if let Some($chunk) = &mut $chunk {
316                // Keep processing the current chunk
317                $chunk
318            } else {
319                // This is either the first call or we've processed all the bytes in the current
320                // chunk. Move to the next one.
321                let chunk = match if reverse {
322                    $chunk_iter.next_back()
323                } else {
324                    $chunk_iter.next()
325                } {
326                    Some(item) => item?,
327                    None => break,
328                };
329                $chunk.insert(chunk)
330            }
331        };
332    }
333
334    loop {
335        let (src_region, src_chunk_addr, src_remaining) = memory_chunk!(src_chunk_iter, src_chunk);
336        let (dst_region, dst_chunk_addr, dst_remaining) = memory_chunk!(dst_chunk_iter, dst_chunk);
337
338        // We always process same-length pairs
339        let chunk_len = *src_remaining.min(dst_remaining);
340
341        let (src_host_addr, dst_host_addr) = {
342            let (src_addr, dst_addr) = if reverse {
343                // When scanning backwards not only we want to scan regions from the end,
344                // we want to process the memory within regions backwards as well.
345                (
346                    src_chunk_addr
347                        .saturating_add(*src_remaining as u64)
348                        .saturating_sub(chunk_len as u64),
349                    dst_chunk_addr
350                        .saturating_add(*dst_remaining as u64)
351                        .saturating_sub(chunk_len as u64),
352                )
353            } else {
354                (*src_chunk_addr, *dst_chunk_addr)
355            };
356
357            (
358                Result::from(src_region.vm_to_host(src_addr, chunk_len as u64))?,
359                Result::from(dst_region.vm_to_host(dst_addr, chunk_len as u64))?,
360            )
361        };
362
363        fun(
364            src_host_addr as *const u8,
365            dst_host_addr as *const u8,
366            chunk_len,
367        )?;
368
369        // Update how many bytes we have left to scan in each chunk
370        *src_remaining = src_remaining.saturating_sub(chunk_len);
371        *dst_remaining = dst_remaining.saturating_sub(chunk_len);
372
373        if !reverse {
374            // We've scanned `chunk_len` bytes so we move the vm address forward. In reverse
375            // mode we don't do this since we make progress by decreasing src_len and
376            // dst_len.
377            *src_chunk_addr = src_chunk_addr.saturating_add(chunk_len as u64);
378            *dst_chunk_addr = dst_chunk_addr.saturating_add(chunk_len as u64);
379        }
380
381        if *src_remaining == 0 {
382            src_chunk = None;
383        }
384
385        if *dst_remaining == 0 {
386            dst_chunk = None;
387        }
388    }
389
390    Ok(T::default())
391}
392
393struct MemoryChunkIterator<'a> {
394    memory_mapping: &'a MemoryMapping<'a>,
395    access_type: AccessType,
396    initial_vm_addr: u64,
397    vm_addr_start: u64,
398    // exclusive end index (start + len, so one past the last valid address)
399    vm_addr_end: u64,
400    len: u64,
401}
402
403impl<'a> MemoryChunkIterator<'a> {
404    fn new(
405        memory_mapping: &'a MemoryMapping,
406        access_type: AccessType,
407        vm_addr: u64,
408        len: u64,
409    ) -> Result<MemoryChunkIterator<'a>, EbpfError> {
410        let vm_addr_end = vm_addr.checked_add(len).ok_or(EbpfError::AccessViolation(
411            access_type,
412            vm_addr,
413            len,
414            "unknown",
415        ))?;
416        Ok(MemoryChunkIterator {
417            memory_mapping,
418            access_type,
419            initial_vm_addr: vm_addr,
420            len,
421            vm_addr_start: vm_addr,
422            vm_addr_end,
423        })
424    }
425
426    fn region(&mut self, vm_addr: u64) -> Result<&'a MemoryRegion, Error> {
427        match self.memory_mapping.region(self.access_type, vm_addr) {
428            Ok(region) => Ok(region),
429            Err(error) => match error {
430                EbpfError::AccessViolation(access_type, _vm_addr, _len, name) => Err(Box::new(
431                    EbpfError::AccessViolation(access_type, self.initial_vm_addr, self.len, name),
432                )),
433                EbpfError::StackAccessViolation(access_type, _vm_addr, _len, frame) => {
434                    Err(Box::new(EbpfError::StackAccessViolation(
435                        access_type,
436                        self.initial_vm_addr,
437                        self.len,
438                        frame,
439                    )))
440                }
441                _ => Err(error.into()),
442            },
443        }
444    }
445}
446
447impl<'a> Iterator for MemoryChunkIterator<'a> {
448    type Item = Result<(&'a MemoryRegion, u64, usize), Error>;
449
450    fn next(&mut self) -> Option<Self::Item> {
451        if self.vm_addr_start == self.vm_addr_end {
452            return None;
453        }
454
455        let region = match self.region(self.vm_addr_start) {
456            Ok(region) => region,
457            Err(e) => {
458                self.vm_addr_start = self.vm_addr_end;
459                return Some(Err(e));
460            }
461        };
462
463        let vm_addr = self.vm_addr_start;
464
465        let chunk_len = if region.vm_addr_end <= self.vm_addr_end {
466            // consume the whole region
467            let len = region.vm_addr_end.saturating_sub(self.vm_addr_start);
468            self.vm_addr_start = region.vm_addr_end;
469            len
470        } else {
471            // consume part of the region
472            let len = self.vm_addr_end.saturating_sub(self.vm_addr_start);
473            self.vm_addr_start = self.vm_addr_end;
474            len
475        };
476
477        Some(Ok((region, vm_addr, chunk_len as usize)))
478    }
479}
480
481impl<'a> DoubleEndedIterator for MemoryChunkIterator<'a> {
482    fn next_back(&mut self) -> Option<Self::Item> {
483        if self.vm_addr_start == self.vm_addr_end {
484            return None;
485        }
486
487        let region = match self.region(self.vm_addr_end.saturating_sub(1)) {
488            Ok(region) => region,
489            Err(e) => {
490                self.vm_addr_start = self.vm_addr_end;
491                return Some(Err(e));
492            }
493        };
494
495        let chunk_len = if region.vm_addr >= self.vm_addr_start {
496            // consume the whole region
497            let len = self.vm_addr_end.saturating_sub(region.vm_addr);
498            self.vm_addr_end = region.vm_addr;
499            len
500        } else {
501            // consume part of the region
502            let len = self.vm_addr_end.saturating_sub(self.vm_addr_start);
503            self.vm_addr_end = self.vm_addr_start;
504            len
505        };
506
507        Some(Ok((region, self.vm_addr_end, chunk_len as usize)))
508    }
509}
510
511#[cfg(test)]
512#[allow(clippy::indexing_slicing)]
513#[allow(clippy::arithmetic_side_effects)]
514mod tests {
515    use {
516        super::*,
517        assert_matches::assert_matches,
518        solana_rbpf::{ebpf::MM_PROGRAM_START, program::SBPFVersion},
519        test_case::test_case,
520    };
521
522    fn to_chunk_vec<'a>(
523        iter: impl Iterator<Item = Result<(&'a MemoryRegion, u64, usize), Error>>,
524    ) -> Vec<(u64, usize)> {
525        iter.flat_map(|res| res.map(|(_, vm_addr, len)| (vm_addr, len)))
526            .collect::<Vec<_>>()
527    }
528
529    #[test]
530    #[should_panic(expected = "AccessViolation")]
531    fn test_memory_chunk_iterator_no_regions() {
532        let config = Config {
533            aligned_memory_mapping: false,
534            ..Config::default()
535        };
536        let memory_mapping = MemoryMapping::new(vec![], &config, &SBPFVersion::V2).unwrap();
537
538        let mut src_chunk_iter =
539            MemoryChunkIterator::new(&memory_mapping, AccessType::Load, 0, 1).unwrap();
540        src_chunk_iter.next().unwrap().unwrap();
541    }
542
543    #[test]
544    #[should_panic(expected = "AccessViolation")]
545    fn test_memory_chunk_iterator_new_out_of_bounds_upper() {
546        let config = Config {
547            aligned_memory_mapping: false,
548            ..Config::default()
549        };
550        let memory_mapping = MemoryMapping::new(vec![], &config, &SBPFVersion::V2).unwrap();
551
552        let mut src_chunk_iter =
553            MemoryChunkIterator::new(&memory_mapping, AccessType::Load, u64::MAX, 1).unwrap();
554        src_chunk_iter.next().unwrap().unwrap();
555    }
556
557    #[test]
558    fn test_memory_chunk_iterator_out_of_bounds() {
559        let config = Config {
560            aligned_memory_mapping: false,
561            ..Config::default()
562        };
563        let mem1 = vec![0xFF; 42];
564        let memory_mapping = MemoryMapping::new(
565            vec![MemoryRegion::new_readonly(&mem1, MM_PROGRAM_START)],
566            &config,
567            &SBPFVersion::V2,
568        )
569        .unwrap();
570
571        // check oob at the lower bound on the first next()
572        let mut src_chunk_iter =
573            MemoryChunkIterator::new(&memory_mapping, AccessType::Load, MM_PROGRAM_START - 1, 42)
574                .unwrap();
575        assert_matches!(
576            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
577            EbpfError::AccessViolation(AccessType::Load, addr, 42, "unknown") if *addr == MM_PROGRAM_START - 1
578        );
579
580        // check oob at the upper bound. Since the memory mapping isn't empty,
581        // this always happens on the second next().
582        let mut src_chunk_iter =
583            MemoryChunkIterator::new(&memory_mapping, AccessType::Load, MM_PROGRAM_START, 43)
584                .unwrap();
585        assert!(src_chunk_iter.next().unwrap().is_ok());
586        assert_matches!(
587            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
588            EbpfError::AccessViolation(AccessType::Load, addr, 43, "program") if *addr == MM_PROGRAM_START
589        );
590
591        // check oob at the upper bound on the first next_back()
592        let mut src_chunk_iter =
593            MemoryChunkIterator::new(&memory_mapping, AccessType::Load, MM_PROGRAM_START, 43)
594                .unwrap()
595                .rev();
596        assert_matches!(
597            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
598            EbpfError::AccessViolation(AccessType::Load, addr, 43, "program") if *addr == MM_PROGRAM_START
599        );
600
601        // check oob at the upper bound on the 2nd next_back()
602        let mut src_chunk_iter =
603            MemoryChunkIterator::new(&memory_mapping, AccessType::Load, MM_PROGRAM_START - 1, 43)
604                .unwrap()
605                .rev();
606        assert!(src_chunk_iter.next().unwrap().is_ok());
607        assert_matches!(
608            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
609            EbpfError::AccessViolation(AccessType::Load, addr, 43, "unknown") if *addr == MM_PROGRAM_START - 1
610        );
611    }
612
613    #[test]
614    fn test_memory_chunk_iterator_one() {
615        let config = Config {
616            aligned_memory_mapping: false,
617            ..Config::default()
618        };
619        let mem1 = vec![0xFF; 42];
620        let memory_mapping = MemoryMapping::new(
621            vec![MemoryRegion::new_readonly(&mem1, MM_PROGRAM_START)],
622            &config,
623            &SBPFVersion::V2,
624        )
625        .unwrap();
626
627        // check lower bound
628        let mut src_chunk_iter =
629            MemoryChunkIterator::new(&memory_mapping, AccessType::Load, MM_PROGRAM_START - 1, 1)
630                .unwrap();
631        assert!(src_chunk_iter.next().unwrap().is_err());
632
633        // check upper bound
634        let mut src_chunk_iter =
635            MemoryChunkIterator::new(&memory_mapping, AccessType::Load, MM_PROGRAM_START + 42, 1)
636                .unwrap();
637        assert!(src_chunk_iter.next().unwrap().is_err());
638
639        for (vm_addr, len) in [
640            (MM_PROGRAM_START, 0),
641            (MM_PROGRAM_START + 42, 0),
642            (MM_PROGRAM_START, 1),
643            (MM_PROGRAM_START, 42),
644            (MM_PROGRAM_START + 41, 1),
645        ] {
646            for rev in [true, false] {
647                let iter =
648                    MemoryChunkIterator::new(&memory_mapping, AccessType::Load, vm_addr, len)
649                        .unwrap();
650                let res = if rev {
651                    to_chunk_vec(iter.rev())
652                } else {
653                    to_chunk_vec(iter)
654                };
655                if len == 0 {
656                    assert_eq!(res, &[]);
657                } else {
658                    assert_eq!(res, &[(vm_addr, len as usize)]);
659                }
660            }
661        }
662    }
663
664    #[test]
665    fn test_memory_chunk_iterator_two() {
666        let config = Config {
667            aligned_memory_mapping: false,
668            ..Config::default()
669        };
670        let mem1 = vec![0x11; 8];
671        let mem2 = vec![0x22; 4];
672        let memory_mapping = MemoryMapping::new(
673            vec![
674                MemoryRegion::new_readonly(&mem1, MM_PROGRAM_START),
675                MemoryRegion::new_readonly(&mem2, MM_PROGRAM_START + 8),
676            ],
677            &config,
678            &SBPFVersion::V2,
679        )
680        .unwrap();
681
682        for (vm_addr, len, mut expected) in [
683            (MM_PROGRAM_START, 8, vec![(MM_PROGRAM_START, 8)]),
684            (
685                MM_PROGRAM_START + 7,
686                2,
687                vec![(MM_PROGRAM_START + 7, 1), (MM_PROGRAM_START + 8, 1)],
688            ),
689            (MM_PROGRAM_START + 8, 4, vec![(MM_PROGRAM_START + 8, 4)]),
690        ] {
691            for rev in [false, true] {
692                let iter =
693                    MemoryChunkIterator::new(&memory_mapping, AccessType::Load, vm_addr, len)
694                        .unwrap();
695                let res = if rev {
696                    expected.reverse();
697                    to_chunk_vec(iter.rev())
698                } else {
699                    to_chunk_vec(iter)
700                };
701
702                assert_eq!(res, expected);
703            }
704        }
705    }
706
707    #[test]
708    fn test_iter_memory_pair_chunks_short() {
709        let config = Config {
710            aligned_memory_mapping: false,
711            ..Config::default()
712        };
713        let mem1 = vec![0x11; 8];
714        let mem2 = vec![0x22; 4];
715        let memory_mapping = MemoryMapping::new(
716            vec![
717                MemoryRegion::new_readonly(&mem1, MM_PROGRAM_START),
718                MemoryRegion::new_readonly(&mem2, MM_PROGRAM_START + 8),
719            ],
720            &config,
721            &SBPFVersion::V2,
722        )
723        .unwrap();
724
725        // dst is shorter than src
726        assert_matches!(
727            iter_memory_pair_chunks(
728                AccessType::Load,
729                MM_PROGRAM_START,
730                AccessType::Load,
731                MM_PROGRAM_START + 8,
732                8,
733                &memory_mapping,
734                false,
735                |_src, _dst, _len| Ok::<_, Error>(0),
736            ).unwrap_err().downcast_ref().unwrap(),
737            EbpfError::AccessViolation(AccessType::Load, addr, 8, "program") if *addr == MM_PROGRAM_START + 8
738        );
739
740        // src is shorter than dst
741        assert_matches!(
742            iter_memory_pair_chunks(
743                AccessType::Load,
744                MM_PROGRAM_START + 10,
745                AccessType::Load,
746                MM_PROGRAM_START + 2,
747                3,
748                &memory_mapping,
749                false,
750                |_src, _dst, _len| Ok::<_, Error>(0),
751            ).unwrap_err().downcast_ref().unwrap(),
752            EbpfError::AccessViolation(AccessType::Load, addr, 3, "program") if *addr == MM_PROGRAM_START + 10
753        );
754    }
755
756    #[test]
757    #[should_panic(expected = "AccessViolation(Store, 4294967296, 4")]
758    fn test_memmove_non_contiguous_readonly() {
759        let config = Config {
760            aligned_memory_mapping: false,
761            ..Config::default()
762        };
763        let mem1 = vec![0x11; 8];
764        let mem2 = vec![0x22; 4];
765        let memory_mapping = MemoryMapping::new(
766            vec![
767                MemoryRegion::new_readonly(&mem1, MM_PROGRAM_START),
768                MemoryRegion::new_readonly(&mem2, MM_PROGRAM_START + 8),
769            ],
770            &config,
771            &SBPFVersion::V2,
772        )
773        .unwrap();
774
775        memmove_non_contiguous(MM_PROGRAM_START, MM_PROGRAM_START + 8, 4, &memory_mapping).unwrap();
776    }
777
778    #[test_case(&[], (0, 0, 0); "no regions")]
779    #[test_case(&[10], (1, 10, 0); "single region 0 len")]
780    #[test_case(&[10], (0, 5, 5); "single region no overlap")]
781    #[test_case(&[10], (0, 0, 10) ; "single region complete overlap")]
782    #[test_case(&[10], (2, 0, 5); "single region partial overlap start")]
783    #[test_case(&[10], (0, 1, 6); "single region partial overlap middle")]
784    #[test_case(&[10], (2, 5, 5); "single region partial overlap end")]
785    #[test_case(&[3, 5], (0, 5, 2) ; "two regions no overlap, single source region")]
786    #[test_case(&[4, 7], (0, 5, 5) ; "two regions no overlap, multiple source regions")]
787    #[test_case(&[3, 8], (0, 0, 11) ; "two regions complete overlap")]
788    #[test_case(&[2, 9], (3, 0, 5) ; "two regions partial overlap start")]
789    #[test_case(&[3, 9], (1, 2, 5) ; "two regions partial overlap middle")]
790    #[test_case(&[7, 3], (2, 6, 4) ; "two regions partial overlap end")]
791    #[test_case(&[2, 6, 3, 4], (0, 10, 2) ; "many regions no overlap, single source region")]
792    #[test_case(&[2, 1, 2, 5, 6], (2, 10, 4) ; "many regions no overlap, multiple source regions")]
793    #[test_case(&[8, 1, 3, 6], (0, 0, 18) ; "many regions complete overlap")]
794    #[test_case(&[7, 3, 1, 4, 5], (5, 0, 8) ; "many regions overlap start")]
795    #[test_case(&[1, 5, 2, 9, 3], (5, 4, 8) ; "many regions overlap middle")]
796    #[test_case(&[3, 9, 1, 1, 2, 1], (2, 9, 8) ; "many regions overlap end")]
797    fn test_memmove_non_contiguous(
798        regions: &[usize],
799        (src_offset, dst_offset, len): (usize, usize, usize),
800    ) {
801        let config = Config {
802            aligned_memory_mapping: false,
803            ..Config::default()
804        };
805        let (mem, memory_mapping) = build_memory_mapping(regions, &config);
806
807        // flatten the memory so we can memmove it with ptr::copy
808        let mut expected_memory = flatten_memory(&mem);
809        unsafe {
810            std::ptr::copy(
811                expected_memory.as_ptr().add(src_offset),
812                expected_memory.as_mut_ptr().add(dst_offset),
813                len,
814            )
815        };
816
817        // do our memmove
818        memmove_non_contiguous(
819            MM_PROGRAM_START + dst_offset as u64,
820            MM_PROGRAM_START + src_offset as u64,
821            len as u64,
822            &memory_mapping,
823        )
824        .unwrap();
825
826        // flatten memory post our memmove
827        let memory = flatten_memory(&mem);
828
829        // compare libc's memmove with ours
830        assert_eq!(expected_memory, memory);
831    }
832
833    #[test]
834    #[should_panic(expected = "AccessViolation(Store, 4294967296, 9")]
835    fn test_memset_non_contiguous_readonly() {
836        let config = Config {
837            aligned_memory_mapping: false,
838            ..Config::default()
839        };
840        let mut mem1 = vec![0x11; 8];
841        let mem2 = vec![0x22; 4];
842        let memory_mapping = MemoryMapping::new(
843            vec![
844                MemoryRegion::new_writable(&mut mem1, MM_PROGRAM_START),
845                MemoryRegion::new_readonly(&mem2, MM_PROGRAM_START + 8),
846            ],
847            &config,
848            &SBPFVersion::V2,
849        )
850        .unwrap();
851
852        assert_eq!(
853            memset_non_contiguous(MM_PROGRAM_START, 0x33, 9, &memory_mapping).unwrap(),
854            0
855        );
856    }
857
858    #[test]
859    fn test_memset_non_contiguous() {
860        let config = Config {
861            aligned_memory_mapping: false,
862            ..Config::default()
863        };
864        let mem1 = vec![0x11; 1];
865        let mut mem2 = vec![0x22; 2];
866        let mut mem3 = vec![0x33; 3];
867        let mut mem4 = vec![0x44; 4];
868        let memory_mapping = MemoryMapping::new(
869            vec![
870                MemoryRegion::new_readonly(&mem1, MM_PROGRAM_START),
871                MemoryRegion::new_writable(&mut mem2, MM_PROGRAM_START + 1),
872                MemoryRegion::new_writable(&mut mem3, MM_PROGRAM_START + 3),
873                MemoryRegion::new_writable(&mut mem4, MM_PROGRAM_START + 6),
874            ],
875            &config,
876            &SBPFVersion::V2,
877        )
878        .unwrap();
879
880        assert_eq!(
881            memset_non_contiguous(MM_PROGRAM_START + 1, 0x55, 7, &memory_mapping).unwrap(),
882            0
883        );
884        assert_eq!(&mem1, &[0x11]);
885        assert_eq!(&mem2, &[0x55, 0x55]);
886        assert_eq!(&mem3, &[0x55, 0x55, 0x55]);
887        assert_eq!(&mem4, &[0x55, 0x55, 0x44, 0x44]);
888    }
889
890    #[test]
891    fn test_memcmp_non_contiguous() {
892        let config = Config {
893            aligned_memory_mapping: false,
894            ..Config::default()
895        };
896        let mem1 = b"foo".to_vec();
897        let mem2 = b"barbad".to_vec();
898        let mem3 = b"foobarbad".to_vec();
899        let memory_mapping = MemoryMapping::new(
900            vec![
901                MemoryRegion::new_readonly(&mem1, MM_PROGRAM_START),
902                MemoryRegion::new_readonly(&mem2, MM_PROGRAM_START + 3),
903                MemoryRegion::new_readonly(&mem3, MM_PROGRAM_START + 9),
904            ],
905            &config,
906            &SBPFVersion::V2,
907        )
908        .unwrap();
909
910        // non contiguous src
911        assert_eq!(
912            memcmp_non_contiguous(MM_PROGRAM_START, MM_PROGRAM_START + 9, 9, &memory_mapping)
913                .unwrap(),
914            0
915        );
916
917        // non contiguous dst
918        assert_eq!(
919            memcmp_non_contiguous(
920                MM_PROGRAM_START + 10,
921                MM_PROGRAM_START + 1,
922                8,
923                &memory_mapping
924            )
925            .unwrap(),
926            0
927        );
928
929        // diff
930        assert_eq!(
931            memcmp_non_contiguous(
932                MM_PROGRAM_START + 1,
933                MM_PROGRAM_START + 11,
934                5,
935                &memory_mapping
936            )
937            .unwrap(),
938            unsafe { memcmp(b"oobar", b"obarb", 5) }
939        );
940    }
941
942    fn build_memory_mapping<'a>(
943        regions: &[usize],
944        config: &'a Config,
945    ) -> (Vec<Vec<u8>>, MemoryMapping<'a>) {
946        let mut regs = vec![];
947        let mut mem = Vec::new();
948        let mut offset = 0;
949        for (i, region_len) in regions.iter().enumerate() {
950            mem.push(
951                (0..*region_len)
952                    .map(|x| (i * 10 + x) as u8)
953                    .collect::<Vec<_>>(),
954            );
955            regs.push(MemoryRegion::new_writable(
956                &mut mem[i],
957                MM_PROGRAM_START + offset as u64,
958            ));
959            offset += *region_len;
960        }
961
962        let memory_mapping = MemoryMapping::new(regs, config, &SBPFVersion::V2).unwrap();
963
964        (mem, memory_mapping)
965    }
966
967    fn flatten_memory(mem: &[Vec<u8>]) -> Vec<u8> {
968        mem.iter().flatten().copied().collect()
969    }
970}