Skip to main content

rialo_s_bpf_loader_program/syscalls/
mem_ops.rs

1// Copyright (c) Subzero Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3// This file is either (a) original to Subzero Labs, Inc. or (b) derived from the Anza codebase and modified by Subzero Labs, Inc.
4
5use std::slice;
6
7use rialo_s_program_runtime::invoke_context::SerializedAccountMetadata;
8use solana_sbpf::{error::EbpfError, memory_region::MemoryRegion};
9
10use super::*;
11
12fn mem_op_consume(invoke_context: &mut InvokeContext<'_>, n: u64) -> Result<(), Error> {
13    let compute_budget = invoke_context.get_compute_budget();
14    let cost = compute_budget.mem_op_base_cost.max(
15        n.checked_div(compute_budget.cpi_bytes_per_unit)
16            .unwrap_or(u64::MAX),
17    );
18    consume_compute_meter(invoke_context, cost)
19}
20
21declare_builtin_function!(
22    /// memcpy
23    SyscallMemcpy,
24    fn rust(
25        invoke_context: &mut InvokeContext<'_>,
26        dst_addr: u64,
27        src_addr: u64,
28        n: u64,
29        _arg4: u64,
30        _arg5: u64,
31        memory_mapping: &mut MemoryMapping<'_>,
32    ) -> Result<u64, Error> {
33        mem_op_consume(invoke_context, n)?;
34
35        if !is_nonoverlapping(src_addr, n, dst_addr, n) {
36            return Err(SyscallError::CopyOverlapping.into());
37        }
38
39        // host addresses can overlap so we always invoke memmove
40        memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
41    }
42);
43
44declare_builtin_function!(
45    /// memmove
46    SyscallMemmove,
47    fn rust(
48        invoke_context: &mut InvokeContext<'_>,
49        dst_addr: u64,
50        src_addr: u64,
51        n: u64,
52        _arg4: u64,
53        _arg5: u64,
54        memory_mapping: &mut MemoryMapping<'_>,
55    ) -> Result<u64, Error> {
56        mem_op_consume(invoke_context, n)?;
57
58        memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
59    }
60);
61
62declare_builtin_function!(
63    /// memcmp
64    SyscallMemcmp,
65    fn rust(
66        invoke_context: &mut InvokeContext<'_>,
67        s1_addr: u64,
68        s2_addr: u64,
69        n: u64,
70        cmp_result_addr: u64,
71        _arg5: u64,
72        memory_mapping: &mut MemoryMapping<'_>,
73    ) -> Result<u64, Error> {
74        mem_op_consume(invoke_context, n)?;
75
76        if invoke_context
77            .get_feature_set()
78            .is_active(&rialo_s_feature_set::bpf_account_data_direct_mapping::id())
79        {
80            let cmp_result = translate_type_mut::<i32>(
81                memory_mapping,
82                cmp_result_addr,
83                invoke_context.get_check_aligned(),
84            )?;
85            let syscall_context = invoke_context.get_syscall_context()?;
86
87            *cmp_result = memcmp_non_contiguous(s1_addr, s2_addr, n, &syscall_context.accounts_metadata, memory_mapping, invoke_context.get_check_aligned())?;
88        } else {
89            let s1 = translate_slice::<u8>(
90                memory_mapping,
91                s1_addr,
92                n,
93                invoke_context.get_check_aligned(),
94            )?;
95            let s2 = translate_slice::<u8>(
96                memory_mapping,
97                s2_addr,
98                n,
99                invoke_context.get_check_aligned(),
100            )?;
101            let cmp_result = translate_type_mut::<i32>(
102                memory_mapping,
103                cmp_result_addr,
104                invoke_context.get_check_aligned(),
105            )?;
106
107            debug_assert_eq!(s1.len(), n as usize);
108            debug_assert_eq!(s2.len(), n as usize);
109            // Safety:
110            // memcmp is marked unsafe since it assumes that the inputs are at least
111            // `n` bytes long. `s1` and `s2` are guaranteed to be exactly `n` bytes
112            // long because `translate_slice` would have failed otherwise.
113            *cmp_result = unsafe { memcmp(s1, s2, n as usize) };
114        }
115
116        Ok(0)
117    }
118);
119
120declare_builtin_function!(
121    /// memset
122    SyscallMemset,
123    fn rust(
124        invoke_context: &mut InvokeContext<'_>,
125        dst_addr: u64,
126        c: u64,
127        n: u64,
128        _arg4: u64,
129        _arg5: u64,
130        memory_mapping: &mut MemoryMapping<'_>,
131    ) -> Result<u64, Error> {
132        mem_op_consume(invoke_context, n)?;
133
134        if invoke_context
135            .get_feature_set()
136            .is_active(&rialo_s_feature_set::bpf_account_data_direct_mapping::id())
137        {
138            let syscall_context = invoke_context.get_syscall_context()?;
139
140            memset_non_contiguous(dst_addr, c as u8, n, &syscall_context.accounts_metadata, memory_mapping, invoke_context.get_check_aligned())
141        } else {
142            let s = translate_slice_mut::<u8>(
143                memory_mapping,
144                dst_addr,
145                n,
146                invoke_context.get_check_aligned(),
147            )?;
148            s.fill(c as u8);
149            Ok(0)
150        }
151    }
152);
153
154fn memmove(
155    invoke_context: &mut InvokeContext<'_>,
156    dst_addr: u64,
157    src_addr: u64,
158    n: u64,
159    memory_mapping: &MemoryMapping<'_>,
160) -> Result<u64, Error> {
161    if invoke_context
162        .get_feature_set()
163        .is_active(&rialo_s_feature_set::bpf_account_data_direct_mapping::id())
164    {
165        let syscall_context = invoke_context.get_syscall_context()?;
166
167        memmove_non_contiguous(
168            dst_addr,
169            src_addr,
170            n,
171            &syscall_context.accounts_metadata,
172            memory_mapping,
173            invoke_context.get_check_aligned(),
174        )
175    } else {
176        let dst_ptr = translate_slice_mut::<u8>(
177            memory_mapping,
178            dst_addr,
179            n,
180            invoke_context.get_check_aligned(),
181        )?
182        .as_mut_ptr();
183        let src_ptr = translate_slice::<u8>(
184            memory_mapping,
185            src_addr,
186            n,
187            invoke_context.get_check_aligned(),
188        )?
189        .as_ptr();
190
191        unsafe { std::ptr::copy(src_ptr, dst_ptr, n as usize) };
192        Ok(0)
193    }
194}
195
196fn memmove_non_contiguous(
197    dst_addr: u64,
198    src_addr: u64,
199    n: u64,
200    accounts: &[SerializedAccountMetadata],
201    memory_mapping: &MemoryMapping<'_>,
202    resize_area: bool,
203) -> Result<u64, Error> {
204    let reverse = dst_addr.wrapping_sub(src_addr) < n;
205    iter_memory_pair_chunks(
206        AccessType::Load,
207        src_addr,
208        AccessType::Store,
209        dst_addr,
210        n,
211        accounts,
212        memory_mapping,
213        reverse,
214        resize_area,
215        |src_host_addr, dst_host_addr, chunk_len| {
216            unsafe { std::ptr::copy(src_host_addr, dst_host_addr as *mut u8, chunk_len) };
217            Ok(0)
218        },
219    )
220}
221
222// Marked unsafe since it assumes that the slices are at least `n` bytes long.
223unsafe fn memcmp(s1: &[u8], s2: &[u8], n: usize) -> i32 {
224    for i in 0..n {
225        let a = *s1.get_unchecked(i);
226        let b = *s2.get_unchecked(i);
227        if a != b {
228            return (a as i32).saturating_sub(b as i32);
229        };
230    }
231
232    0
233}
234
235fn memcmp_non_contiguous(
236    src_addr: u64,
237    dst_addr: u64,
238    n: u64,
239    accounts: &[SerializedAccountMetadata],
240    memory_mapping: &MemoryMapping<'_>,
241    resize_area: bool,
242) -> Result<i32, Error> {
243    let memcmp_chunk = |s1_addr, s2_addr, chunk_len| {
244        let res = unsafe {
245            let s1 = slice::from_raw_parts(s1_addr, chunk_len);
246            let s2 = slice::from_raw_parts(s2_addr, chunk_len);
247            // Safety:
248            // memcmp is marked unsafe since it assumes that s1 and s2 are exactly chunk_len
249            // long. The whole point of iter_memory_pair_chunks is to find same length chunks
250            // across two memory regions.
251            memcmp(s1, s2, chunk_len)
252        };
253        if res != 0 {
254            return Err(MemcmpError::Diff(res).into());
255        }
256        Ok(0)
257    };
258    match iter_memory_pair_chunks(
259        AccessType::Load,
260        src_addr,
261        AccessType::Load,
262        dst_addr,
263        n,
264        accounts,
265        memory_mapping,
266        false,
267        resize_area,
268        memcmp_chunk,
269    ) {
270        Ok(res) => Ok(res),
271        Err(error) => match error.downcast_ref() {
272            Some(MemcmpError::Diff(diff)) => Ok(*diff),
273            _ => Err(error),
274        },
275    }
276}
277
278#[derive(Debug)]
279enum MemcmpError {
280    Diff(i32),
281}
282
283impl std::fmt::Display for MemcmpError {
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        match self {
286            MemcmpError::Diff(diff) => write!(f, "memcmp diff: {diff}"),
287        }
288    }
289}
290
291impl std::error::Error for MemcmpError {
292    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
293        match self {
294            MemcmpError::Diff(_) => None,
295        }
296    }
297}
298
299fn memset_non_contiguous(
300    dst_addr: u64,
301    c: u8,
302    n: u64,
303    accounts: &[SerializedAccountMetadata],
304    memory_mapping: &MemoryMapping<'_>,
305    check_aligned: bool,
306) -> Result<u64, Error> {
307    let dst_chunk_iter = MemoryChunkIterator::new(
308        memory_mapping,
309        accounts,
310        AccessType::Store,
311        dst_addr,
312        n,
313        check_aligned,
314    )?;
315    for item in dst_chunk_iter {
316        let (dst_region, dst_vm_addr, dst_len) = item?;
317        let dst_host_addr = Result::from(dst_region.vm_to_host(dst_vm_addr, dst_len as u64))?;
318        unsafe { slice::from_raw_parts_mut(dst_host_addr as *mut u8, dst_len).fill(c) }
319    }
320
321    Ok(0)
322}
323
324#[allow(clippy::too_many_arguments)]
325fn iter_memory_pair_chunks<T, F>(
326    src_access: AccessType,
327    src_addr: u64,
328    dst_access: AccessType,
329    dst_addr: u64,
330    n_bytes: u64,
331    accounts: &[SerializedAccountMetadata],
332    memory_mapping: &MemoryMapping<'_>,
333    reverse: bool,
334    resize_area: bool,
335    mut fun: F,
336) -> Result<T, Error>
337where
338    T: Default,
339    F: FnMut(*const u8, *const u8, usize) -> Result<T, Error>,
340{
341    let mut src_chunk_iter = MemoryChunkIterator::new(
342        memory_mapping,
343        accounts,
344        src_access,
345        src_addr,
346        n_bytes,
347        resize_area,
348    )?;
349    let mut dst_chunk_iter = MemoryChunkIterator::new(
350        memory_mapping,
351        accounts,
352        dst_access,
353        dst_addr,
354        n_bytes,
355        resize_area,
356    )?;
357
358    let mut src_chunk = None;
359    let mut dst_chunk = None;
360
361    macro_rules! memory_chunk {
362        ($chunk_iter:ident, $chunk:ident) => {
363            if let Some($chunk) = &mut $chunk {
364                // Keep processing the current chunk
365                $chunk
366            } else {
367                // This is either the first call or we've processed all the bytes in the current
368                // chunk. Move to the next one.
369                let chunk = match if reverse {
370                    $chunk_iter.next_back()
371                } else {
372                    $chunk_iter.next()
373                } {
374                    Some(item) => item?,
375                    None => break,
376                };
377                $chunk.insert(chunk)
378            }
379        };
380    }
381
382    loop {
383        let (src_region, src_chunk_addr, src_remaining) = memory_chunk!(src_chunk_iter, src_chunk);
384        let (dst_region, dst_chunk_addr, dst_remaining) = memory_chunk!(dst_chunk_iter, dst_chunk);
385
386        // We always process same-length pairs
387        let chunk_len = *src_remaining.min(dst_remaining);
388
389        let (src_host_addr, dst_host_addr) = {
390            let (src_addr, dst_addr) = if reverse {
391                // When scanning backwards not only we want to scan regions from the end,
392                // we want to process the memory within regions backwards as well.
393                (
394                    src_chunk_addr
395                        .saturating_add(*src_remaining as u64)
396                        .saturating_sub(chunk_len as u64),
397                    dst_chunk_addr
398                        .saturating_add(*dst_remaining as u64)
399                        .saturating_sub(chunk_len as u64),
400                )
401            } else {
402                (*src_chunk_addr, *dst_chunk_addr)
403            };
404
405            (
406                Result::from(src_region.vm_to_host(src_addr, chunk_len as u64))?,
407                Result::from(dst_region.vm_to_host(dst_addr, chunk_len as u64))?,
408            )
409        };
410
411        fun(
412            src_host_addr as *const u8,
413            dst_host_addr as *const u8,
414            chunk_len,
415        )?;
416
417        // Update how many bytes we have left to scan in each chunk
418        *src_remaining = src_remaining.saturating_sub(chunk_len);
419        *dst_remaining = dst_remaining.saturating_sub(chunk_len);
420
421        if !reverse {
422            // We've scanned `chunk_len` bytes so we move the vm address forward. In reverse
423            // mode we don't do this since we make progress by decreasing src_len and
424            // dst_len.
425            *src_chunk_addr = src_chunk_addr.saturating_add(chunk_len as u64);
426            *dst_chunk_addr = dst_chunk_addr.saturating_add(chunk_len as u64);
427        }
428
429        if *src_remaining == 0 {
430            src_chunk = None;
431        }
432
433        if *dst_remaining == 0 {
434            dst_chunk = None;
435        }
436    }
437
438    Ok(T::default())
439}
440
441struct MemoryChunkIterator<'a> {
442    memory_mapping: &'a MemoryMapping<'a>,
443    accounts: &'a [SerializedAccountMetadata],
444    access_type: AccessType,
445    initial_vm_addr: u64,
446    vm_addr_start: u64,
447    // exclusive end index (start + len, so one past the last valid address)
448    vm_addr_end: u64,
449    len: u64,
450    account_index: Option<usize>,
451    is_account: Option<bool>,
452    resize_area: bool,
453}
454
455impl<'a> MemoryChunkIterator<'a> {
456    fn new(
457        memory_mapping: &'a MemoryMapping<'_>,
458        accounts: &'a [SerializedAccountMetadata],
459        access_type: AccessType,
460        vm_addr: u64,
461        len: u64,
462        resize_area: bool,
463    ) -> Result<MemoryChunkIterator<'a>, EbpfError> {
464        let vm_addr_end = vm_addr.checked_add(len).ok_or(EbpfError::AccessViolation(
465            access_type,
466            vm_addr,
467            len,
468            "unknown",
469        ))?;
470
471        Ok(MemoryChunkIterator {
472            memory_mapping,
473            accounts,
474            access_type,
475            initial_vm_addr: vm_addr,
476            len,
477            vm_addr_start: vm_addr,
478            vm_addr_end,
479            account_index: None,
480            is_account: None,
481            resize_area,
482        })
483    }
484
485    fn region(&mut self, vm_addr: u64) -> Result<&'a MemoryRegion, Error> {
486        match self.memory_mapping.region(self.access_type, vm_addr) {
487            Ok(region) => Ok(region),
488            Err(error) => match error {
489                EbpfError::AccessViolation(access_type, _vm_addr, _len, name) => Err(Box::new(
490                    EbpfError::AccessViolation(access_type, self.initial_vm_addr, self.len, name),
491                )),
492                EbpfError::StackAccessViolation(access_type, _vm_addr, _len, frame) => {
493                    Err(Box::new(EbpfError::StackAccessViolation(
494                        access_type,
495                        self.initial_vm_addr,
496                        self.len,
497                        frame,
498                    )))
499                }
500                _ => Err(error.into()),
501            },
502        }
503    }
504}
505
506impl<'a> Iterator for MemoryChunkIterator<'a> {
507    type Item = Result<(&'a MemoryRegion, u64, usize), Error>;
508
509    fn next(&mut self) -> Option<Self::Item> {
510        if self.vm_addr_start == self.vm_addr_end {
511            return None;
512        }
513
514        let region = match self.region(self.vm_addr_start) {
515            Ok(region) => region,
516            Err(e) => {
517                self.vm_addr_start = self.vm_addr_end;
518                return Some(Err(e));
519            }
520        };
521
522        let region_is_account;
523
524        let mut account_index = self.account_index.unwrap_or_default();
525        self.account_index = Some(account_index);
526
527        loop {
528            if let Some(account) = self.accounts.get(account_index) {
529                let account_addr = account.vm_data_addr;
530                let resize_addr = account_addr.saturating_add(account.original_data_len as u64);
531
532                if resize_addr < region.vm_addr {
533                    // region is after this account, move on next one
534                    account_index = account_index.saturating_add(1);
535                    self.account_index = Some(account_index);
536                } else {
537                    region_is_account = region.vm_addr == account_addr
538                        // unaligned programs do not have a resize area
539                        || (self.resize_area && region.vm_addr == resize_addr);
540                    break;
541                }
542            } else {
543                // address is after all the accounts
544                region_is_account = false;
545                break;
546            }
547        }
548
549        if let Some(is_account) = self.is_account {
550            if is_account != region_is_account {
551                return Some(Err(SyscallError::InvalidLength.into()));
552            }
553        } else {
554            self.is_account = Some(region_is_account);
555        }
556
557        let vm_addr = self.vm_addr_start;
558
559        let chunk_len = if region.vm_addr_end <= self.vm_addr_end {
560            // consume the whole region
561            let len = region.vm_addr_end.saturating_sub(self.vm_addr_start);
562            self.vm_addr_start = region.vm_addr_end;
563            len
564        } else {
565            // consume part of the region
566            let len = self.vm_addr_end.saturating_sub(self.vm_addr_start);
567            self.vm_addr_start = self.vm_addr_end;
568            len
569        };
570
571        Some(Ok((region, vm_addr, chunk_len as usize)))
572    }
573}
574
575impl DoubleEndedIterator for MemoryChunkIterator<'_> {
576    fn next_back(&mut self) -> Option<Self::Item> {
577        if self.vm_addr_start == self.vm_addr_end {
578            return None;
579        }
580
581        let region = match self.region(self.vm_addr_end.saturating_sub(1)) {
582            Ok(region) => region,
583            Err(e) => {
584                self.vm_addr_start = self.vm_addr_end;
585                return Some(Err(e));
586            }
587        };
588
589        let region_is_account;
590
591        let mut account_index = self
592            .account_index
593            .unwrap_or_else(|| self.accounts.len().saturating_sub(1));
594        self.account_index = Some(account_index);
595
596        loop {
597            let Some(account) = self.accounts.get(account_index) else {
598                // address is after all the accounts
599                region_is_account = false;
600                break;
601            };
602
603            let account_addr = account.vm_data_addr;
604            let resize_addr = account_addr.saturating_add(account.original_data_len as u64);
605
606            if account_index > 0 && account_addr > region.vm_addr {
607                account_index = account_index.saturating_sub(1);
608
609                self.account_index = Some(account_index);
610            } else {
611                region_is_account = region.vm_addr == account_addr
612                    // unaligned programs do not have a resize area
613                    || (self.resize_area && region.vm_addr == resize_addr);
614                break;
615            }
616        }
617
618        if let Some(is_account) = self.is_account {
619            if is_account != region_is_account {
620                return Some(Err(SyscallError::InvalidLength.into()));
621            }
622        } else {
623            self.is_account = Some(region_is_account);
624        }
625
626        let chunk_len = if region.vm_addr >= self.vm_addr_start {
627            // consume the whole region
628            let len = self.vm_addr_end.saturating_sub(region.vm_addr);
629            self.vm_addr_end = region.vm_addr;
630            len
631        } else {
632            // consume part of the region
633            let len = self.vm_addr_end.saturating_sub(self.vm_addr_start);
634            self.vm_addr_end = self.vm_addr_start;
635            len
636        };
637
638        Some(Ok((region, self.vm_addr_end, chunk_len as usize)))
639    }
640}
641
642#[cfg(test)]
643#[allow(clippy::indexing_slicing)]
644#[allow(clippy::arithmetic_side_effects)]
645mod tests {
646    use assert_matches::assert_matches;
647    use solana_sbpf::{ebpf::MM_RODATA_START, program::SBPFVersion};
648    use test_case::test_case;
649
650    use super::*;
651
652    fn to_chunk_vec<'a>(
653        iter: impl Iterator<Item = Result<(&'a MemoryRegion, u64, usize), Error>>,
654    ) -> Vec<(u64, usize)> {
655        iter.flat_map(|res| res.map(|(_, vm_addr, len)| (vm_addr, len)))
656            .collect::<Vec<_>>()
657    }
658
659    #[test]
660    #[should_panic(expected = "AccessViolation")]
661    fn test_memory_chunk_iterator_no_regions() {
662        let config = Config {
663            aligned_memory_mapping: false,
664            ..Config::default()
665        };
666        let memory_mapping = MemoryMapping::new(vec![], &config, SBPFVersion::V3).unwrap();
667
668        let mut src_chunk_iter =
669            MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, 0, 1, true).unwrap();
670        src_chunk_iter.next().unwrap().unwrap();
671    }
672
673    #[test]
674    #[should_panic(expected = "AccessViolation")]
675    fn test_memory_chunk_iterator_new_out_of_bounds_upper() {
676        let config = Config {
677            aligned_memory_mapping: false,
678            ..Config::default()
679        };
680        let memory_mapping = MemoryMapping::new(vec![], &config, SBPFVersion::V3).unwrap();
681
682        let mut src_chunk_iter =
683            MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, u64::MAX, 1, true)
684                .unwrap();
685        src_chunk_iter.next().unwrap().unwrap();
686    }
687
688    #[test]
689    fn test_memory_chunk_iterator_out_of_bounds() {
690        let config = Config {
691            aligned_memory_mapping: false,
692            ..Config::default()
693        };
694        let mem1 = vec![0xFF; 42];
695        let memory_mapping = MemoryMapping::new(
696            vec![MemoryRegion::new_readonly(&mem1, MM_RODATA_START)],
697            &config,
698            SBPFVersion::V3,
699        )
700        .unwrap();
701
702        // check oob at the lower bound on the first next()
703        let mut src_chunk_iter = MemoryChunkIterator::new(
704            &memory_mapping,
705            &[],
706            AccessType::Load,
707            MM_RODATA_START - 1,
708            42,
709            true,
710        )
711        .unwrap();
712        assert_matches!(
713            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
714            EbpfError::AccessViolation(AccessType::Load, addr, 42, "unknown") if *addr == MM_RODATA_START - 1
715        );
716
717        // check oob at the upper bound. Since the memory mapping isn't empty,
718        // this always happens on the second next().
719        let mut src_chunk_iter = MemoryChunkIterator::new(
720            &memory_mapping,
721            &[],
722            AccessType::Load,
723            MM_RODATA_START,
724            43,
725            true,
726        )
727        .unwrap();
728        assert!(src_chunk_iter.next().unwrap().is_ok());
729        assert_matches!(
730            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
731            EbpfError::AccessViolation(AccessType::Load, addr, 43, "program") if *addr == MM_RODATA_START
732        );
733
734        // check oob at the upper bound on the first next_back()
735        let mut src_chunk_iter = MemoryChunkIterator::new(
736            &memory_mapping,
737            &[],
738            AccessType::Load,
739            MM_RODATA_START,
740            43,
741            true,
742        )
743        .unwrap()
744        .rev();
745        assert_matches!(
746            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
747            EbpfError::AccessViolation(AccessType::Load, addr, 43, "program") if *addr == MM_RODATA_START
748        );
749
750        // check oob at the upper bound on the 2nd next_back()
751        let mut src_chunk_iter = MemoryChunkIterator::new(
752            &memory_mapping,
753            &[],
754            AccessType::Load,
755            MM_RODATA_START - 1,
756            43,
757            true,
758        )
759        .unwrap()
760        .rev();
761        assert!(src_chunk_iter.next().unwrap().is_ok());
762        assert_matches!(
763            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
764            EbpfError::AccessViolation(AccessType::Load, addr, 43, "unknown") if *addr == MM_RODATA_START - 1
765        );
766    }
767
768    #[test]
769    fn test_memory_chunk_iterator_one() {
770        let config = Config {
771            aligned_memory_mapping: false,
772            ..Config::default()
773        };
774        let mem1 = vec![0xFF; 42];
775        let memory_mapping = MemoryMapping::new(
776            vec![MemoryRegion::new_readonly(&mem1, MM_RODATA_START)],
777            &config,
778            SBPFVersion::V3,
779        )
780        .unwrap();
781
782        // check lower bound
783        let mut src_chunk_iter = MemoryChunkIterator::new(
784            &memory_mapping,
785            &[],
786            AccessType::Load,
787            MM_RODATA_START - 1,
788            1,
789            true,
790        )
791        .unwrap();
792        assert!(src_chunk_iter.next().unwrap().is_err());
793
794        // check upper bound
795        let mut src_chunk_iter = MemoryChunkIterator::new(
796            &memory_mapping,
797            &[],
798            AccessType::Load,
799            MM_RODATA_START + 42,
800            1,
801            true,
802        )
803        .unwrap();
804        assert!(src_chunk_iter.next().unwrap().is_err());
805
806        for (vm_addr, len) in [
807            (MM_RODATA_START, 0),
808            (MM_RODATA_START + 42, 0),
809            (MM_RODATA_START, 1),
810            (MM_RODATA_START, 42),
811            (MM_RODATA_START + 41, 1),
812        ] {
813            for rev in [true, false] {
814                let iter = MemoryChunkIterator::new(
815                    &memory_mapping,
816                    &[],
817                    AccessType::Load,
818                    vm_addr,
819                    len,
820                    true,
821                )
822                .unwrap();
823                let res = if rev {
824                    to_chunk_vec(iter.rev())
825                } else {
826                    to_chunk_vec(iter)
827                };
828                if len == 0 {
829                    assert_eq!(res, &[]);
830                } else {
831                    assert_eq!(res, &[(vm_addr, len as usize)]);
832                }
833            }
834        }
835    }
836
837    #[test]
838    fn test_memory_chunk_iterator_two() {
839        let config = Config {
840            aligned_memory_mapping: false,
841            ..Config::default()
842        };
843        let mem1 = vec![0x11; 8];
844        let mem2 = vec![0x22; 4];
845        let memory_mapping = MemoryMapping::new(
846            vec![
847                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
848                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
849            ],
850            &config,
851            SBPFVersion::V3,
852        )
853        .unwrap();
854
855        for (vm_addr, len, mut expected) in [
856            (MM_RODATA_START, 8, vec![(MM_RODATA_START, 8)]),
857            (
858                MM_RODATA_START + 7,
859                2,
860                vec![(MM_RODATA_START + 7, 1), (MM_RODATA_START + 8, 1)],
861            ),
862            (MM_RODATA_START + 8, 4, vec![(MM_RODATA_START + 8, 4)]),
863        ] {
864            for rev in [false, true] {
865                let iter = MemoryChunkIterator::new(
866                    &memory_mapping,
867                    &[],
868                    AccessType::Load,
869                    vm_addr,
870                    len,
871                    true,
872                )
873                .unwrap();
874                let res = if rev {
875                    expected.reverse();
876                    to_chunk_vec(iter.rev())
877                } else {
878                    to_chunk_vec(iter)
879                };
880
881                assert_eq!(res, expected);
882            }
883        }
884    }
885
886    #[test]
887    fn test_iter_memory_pair_chunks_short() {
888        let config = Config {
889            aligned_memory_mapping: false,
890            ..Config::default()
891        };
892        let mem1 = vec![0x11; 8];
893        let mem2 = vec![0x22; 4];
894        let memory_mapping = MemoryMapping::new(
895            vec![
896                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
897                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
898            ],
899            &config,
900            SBPFVersion::V3,
901        )
902        .unwrap();
903
904        // dst is shorter than src
905        assert_matches!(
906            iter_memory_pair_chunks(
907                AccessType::Load,
908                MM_RODATA_START,
909                AccessType::Load,
910                MM_RODATA_START + 8,
911                8,
912                &[],
913                &memory_mapping,
914                false,
915                true,
916                |_src, _dst, _len| Ok::<_, Error>(0),
917            ).unwrap_err().downcast_ref().unwrap(),
918            EbpfError::AccessViolation(AccessType::Load, addr, 8, "program") if *addr == MM_RODATA_START + 8
919        );
920
921        // src is shorter than dst
922        assert_matches!(
923            iter_memory_pair_chunks(
924                AccessType::Load,
925                MM_RODATA_START + 10,
926                AccessType::Load,
927                MM_RODATA_START + 2,
928                3,
929                &[],
930                &memory_mapping,
931                false,
932                true,
933                |_src, _dst, _len| Ok::<_, Error>(0),
934            ).unwrap_err().downcast_ref().unwrap(),
935            EbpfError::AccessViolation(AccessType::Load, addr, 3, "program") if *addr == MM_RODATA_START + 10
936        );
937    }
938
939    #[test]
940    #[should_panic(expected = "AccessViolation(Store, 4294967296, 4")]
941    fn test_memmove_non_contiguous_readonly() {
942        let config = Config {
943            aligned_memory_mapping: false,
944            ..Config::default()
945        };
946        let mem1 = vec![0x11; 8];
947        let mem2 = vec![0x22; 4];
948        let memory_mapping = MemoryMapping::new(
949            vec![
950                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
951                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
952            ],
953            &config,
954            SBPFVersion::V3,
955        )
956        .unwrap();
957
958        memmove_non_contiguous(
959            MM_RODATA_START,
960            MM_RODATA_START + 8,
961            4,
962            &[],
963            &memory_mapping,
964            true,
965        )
966        .unwrap();
967    }
968
969    #[test_case(&[], (0, 0, 0); "no regions")]
970    #[test_case(&[10], (1, 10, 0); "single region 0 len")]
971    #[test_case(&[10], (0, 5, 5); "single region no overlap")]
972    #[test_case(&[10], (0, 0, 10) ; "single region complete overlap")]
973    #[test_case(&[10], (2, 0, 5); "single region partial overlap start")]
974    #[test_case(&[10], (0, 1, 6); "single region partial overlap middle")]
975    #[test_case(&[10], (2, 5, 5); "single region partial overlap end")]
976    #[test_case(&[3, 5], (0, 5, 2) ; "two regions no overlap, single source region")]
977    #[test_case(&[4, 7], (0, 5, 5) ; "two regions no overlap, multiple source regions")]
978    #[test_case(&[3, 8], (0, 0, 11) ; "two regions complete overlap")]
979    #[test_case(&[2, 9], (3, 0, 5) ; "two regions partial overlap start")]
980    #[test_case(&[3, 9], (1, 2, 5) ; "two regions partial overlap middle")]
981    #[test_case(&[7, 3], (2, 6, 4) ; "two regions partial overlap end")]
982    #[test_case(&[2, 6, 3, 4], (0, 10, 2) ; "many regions no overlap, single source region")]
983    #[test_case(&[2, 1, 2, 5, 6], (2, 10, 4) ; "many regions no overlap, multiple source regions")]
984    #[test_case(&[8, 1, 3, 6], (0, 0, 18) ; "many regions complete overlap")]
985    #[test_case(&[7, 3, 1, 4, 5], (5, 0, 8) ; "many regions overlap start")]
986    #[test_case(&[1, 5, 2, 9, 3], (5, 4, 8) ; "many regions overlap middle")]
987    #[test_case(&[3, 9, 1, 1, 2, 1], (2, 9, 8) ; "many regions overlap end")]
988    fn test_memmove_non_contiguous(
989        regions: &[usize],
990        (src_offset, dst_offset, len): (usize, usize, usize),
991    ) {
992        let config = Config {
993            aligned_memory_mapping: false,
994            ..Config::default()
995        };
996        let (mem, memory_mapping) = build_memory_mapping(regions, &config);
997
998        // flatten the memory so we can memmove it with ptr::copy
999        let mut expected_memory = flatten_memory(&mem);
1000        unsafe {
1001            std::ptr::copy(
1002                expected_memory.as_ptr().add(src_offset),
1003                expected_memory.as_mut_ptr().add(dst_offset),
1004                len,
1005            )
1006        };
1007
1008        // do our memmove
1009        memmove_non_contiguous(
1010            MM_RODATA_START + dst_offset as u64,
1011            MM_RODATA_START + src_offset as u64,
1012            len as u64,
1013            &[],
1014            &memory_mapping,
1015            true,
1016        )
1017        .unwrap();
1018
1019        // flatten memory post our memmove
1020        let memory = flatten_memory(&mem);
1021
1022        // compare libc's memmove with ours
1023        assert_eq!(expected_memory, memory);
1024    }
1025
1026    #[test]
1027    #[should_panic(expected = "AccessViolation(Store, 4294967296, 9")]
1028    fn test_memset_non_contiguous_readonly() {
1029        let config = Config {
1030            aligned_memory_mapping: false,
1031            ..Config::default()
1032        };
1033        let mut mem1 = vec![0x11; 8];
1034        let mem2 = vec![0x22; 4];
1035        let memory_mapping = MemoryMapping::new(
1036            vec![
1037                MemoryRegion::new_writable(&mut mem1, MM_RODATA_START),
1038                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
1039            ],
1040            &config,
1041            SBPFVersion::V3,
1042        )
1043        .unwrap();
1044
1045        assert_eq!(
1046            memset_non_contiguous(MM_RODATA_START, 0x33, 9, &[], &memory_mapping, true).unwrap(),
1047            0
1048        );
1049    }
1050
1051    #[test]
1052    fn test_memset_non_contiguous() {
1053        let config = Config {
1054            aligned_memory_mapping: false,
1055            ..Config::default()
1056        };
1057        let mem1 = vec![0x11; 1];
1058        let mut mem2 = vec![0x22; 2];
1059        let mut mem3 = vec![0x33; 3];
1060        let mut mem4 = vec![0x44; 4];
1061        let memory_mapping = MemoryMapping::new(
1062            vec![
1063                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
1064                MemoryRegion::new_writable(&mut mem2, MM_RODATA_START + 1),
1065                MemoryRegion::new_writable(&mut mem3, MM_RODATA_START + 3),
1066                MemoryRegion::new_writable(&mut mem4, MM_RODATA_START + 6),
1067            ],
1068            &config,
1069            SBPFVersion::V3,
1070        )
1071        .unwrap();
1072
1073        assert_eq!(
1074            memset_non_contiguous(MM_RODATA_START + 1, 0x55, 7, &[], &memory_mapping, true)
1075                .unwrap(),
1076            0
1077        );
1078        assert_eq!(&mem1, &[0x11]);
1079        assert_eq!(&mem2, &[0x55, 0x55]);
1080        assert_eq!(&mem3, &[0x55, 0x55, 0x55]);
1081        assert_eq!(&mem4, &[0x55, 0x55, 0x44, 0x44]);
1082    }
1083
1084    #[test]
1085    fn test_memcmp_non_contiguous() {
1086        let config = Config {
1087            aligned_memory_mapping: false,
1088            ..Config::default()
1089        };
1090        let mem1 = b"foo".to_vec();
1091        let mem2 = b"barbad".to_vec();
1092        let mem3 = b"foobarbad".to_vec();
1093        let memory_mapping = MemoryMapping::new(
1094            vec![
1095                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
1096                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 3),
1097                MemoryRegion::new_readonly(&mem3, MM_RODATA_START + 9),
1098            ],
1099            &config,
1100            SBPFVersion::V3,
1101        )
1102        .unwrap();
1103
1104        // non contiguous src
1105        assert_eq!(
1106            memcmp_non_contiguous(
1107                MM_RODATA_START,
1108                MM_RODATA_START + 9,
1109                9,
1110                &[],
1111                &memory_mapping,
1112                true
1113            )
1114            .unwrap(),
1115            0
1116        );
1117
1118        // non contiguous dst
1119        assert_eq!(
1120            memcmp_non_contiguous(
1121                MM_RODATA_START + 10,
1122                MM_RODATA_START + 1,
1123                8,
1124                &[],
1125                &memory_mapping,
1126                true
1127            )
1128            .unwrap(),
1129            0
1130        );
1131
1132        // diff
1133        assert_eq!(
1134            memcmp_non_contiguous(
1135                MM_RODATA_START + 1,
1136                MM_RODATA_START + 11,
1137                5,
1138                &[],
1139                &memory_mapping,
1140                true
1141            )
1142            .unwrap(),
1143            unsafe { memcmp(b"oobar", b"obarb", 5) }
1144        );
1145    }
1146
1147    fn build_memory_mapping<'a>(
1148        regions: &[usize],
1149        config: &'a Config,
1150    ) -> (Vec<Vec<u8>>, MemoryMapping<'a>) {
1151        let mut regs = vec![];
1152        let mut mem = Vec::new();
1153        let mut offset = 0;
1154        for (i, region_len) in regions.iter().enumerate() {
1155            mem.push(
1156                (0..*region_len)
1157                    .map(|x| (i * 10 + x) as u8)
1158                    .collect::<Vec<_>>(),
1159            );
1160            regs.push(MemoryRegion::new_writable(
1161                &mut mem[i],
1162                MM_RODATA_START + offset as u64,
1163            ));
1164            offset += *region_len;
1165        }
1166
1167        let memory_mapping = MemoryMapping::new(regs, config, SBPFVersion::V3).unwrap();
1168
1169        (mem, memory_mapping)
1170    }
1171
1172    fn flatten_memory(mem: &[Vec<u8>]) -> Vec<u8> {
1173        mem.iter().flatten().copied().collect()
1174    }
1175}