clone_solana_bpf_loader_program/syscalls/
mem_ops.rs

1use {
2    super::*,
3    clone_solana_program_runtime::invoke_context::SerializedAccountMetadata,
4    solana_sbpf::{error::EbpfError, memory_region::MemoryRegion},
5    std::slice,
6};
7
8fn mem_op_consume(invoke_context: &mut InvokeContext, n: u64) -> Result<(), Error> {
9    let compute_budget = invoke_context.get_compute_budget();
10    let cost = compute_budget.mem_op_base_cost.max(
11        n.checked_div(compute_budget.cpi_bytes_per_unit)
12            .unwrap_or(u64::MAX),
13    );
14    consume_compute_meter(invoke_context, cost)
15}
16
17/// Check that two regions do not overlap.
18pub(crate) fn is_nonoverlapping<N>(src: N, src_len: N, dst: N, dst_len: N) -> bool
19where
20    N: Ord + num_traits::SaturatingSub,
21{
22    // If the absolute distance between the ptrs is at least as big as the size of the other,
23    // they do not overlap.
24    if src > dst {
25        src.saturating_sub(&dst) >= dst_len
26    } else {
27        dst.saturating_sub(&src) >= src_len
28    }
29}
30
31declare_builtin_function!(
32    /// memcpy
33    SyscallMemcpy,
34    fn rust(
35        invoke_context: &mut InvokeContext,
36        dst_addr: u64,
37        src_addr: u64,
38        n: u64,
39        _arg4: u64,
40        _arg5: u64,
41        memory_mapping: &mut MemoryMapping,
42    ) -> Result<u64, Error> {
43        mem_op_consume(invoke_context, n)?;
44
45        if !is_nonoverlapping(src_addr, n, dst_addr, n) {
46            return Err(SyscallError::CopyOverlapping.into());
47        }
48
49        // host addresses can overlap so we always invoke memmove
50        memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
51    }
52);
53
54declare_builtin_function!(
55    /// memmove
56    SyscallMemmove,
57    fn rust(
58        invoke_context: &mut InvokeContext,
59        dst_addr: u64,
60        src_addr: u64,
61        n: u64,
62        _arg4: u64,
63        _arg5: u64,
64        memory_mapping: &mut MemoryMapping,
65    ) -> Result<u64, Error> {
66        mem_op_consume(invoke_context, n)?;
67
68        memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
69    }
70);
71
72declare_builtin_function!(
73    /// memcmp
74    SyscallMemcmp,
75    fn rust(
76        invoke_context: &mut InvokeContext,
77        s1_addr: u64,
78        s2_addr: u64,
79        n: u64,
80        cmp_result_addr: u64,
81        _arg5: u64,
82        memory_mapping: &mut MemoryMapping,
83    ) -> Result<u64, Error> {
84        mem_op_consume(invoke_context, n)?;
85
86        if invoke_context
87            .get_feature_set()
88            .is_active(&clone_agave_feature_set::bpf_account_data_direct_mapping::id())
89        {
90            let cmp_result = translate_type_mut::<i32>(
91                memory_mapping,
92                cmp_result_addr,
93                invoke_context.get_check_aligned(),
94            )?;
95            let syscall_context = invoke_context.get_syscall_context()?;
96
97            *cmp_result = memcmp_non_contiguous(s1_addr, s2_addr, n, &syscall_context.accounts_metadata, memory_mapping, invoke_context.get_check_aligned())?;
98        } else {
99            let s1 = translate_slice::<u8>(
100                memory_mapping,
101                s1_addr,
102                n,
103                invoke_context.get_check_aligned(),
104            )?;
105            let s2 = translate_slice::<u8>(
106                memory_mapping,
107                s2_addr,
108                n,
109                invoke_context.get_check_aligned(),
110            )?;
111            let cmp_result = translate_type_mut::<i32>(
112                memory_mapping,
113                cmp_result_addr,
114                invoke_context.get_check_aligned(),
115            )?;
116
117            debug_assert_eq!(s1.len(), n as usize);
118            debug_assert_eq!(s2.len(), n as usize);
119            // Safety:
120            // memcmp is marked unsafe since it assumes that the inputs are at least
121            // `n` bytes long. `s1` and `s2` are guaranteed to be exactly `n` bytes
122            // long because `translate_slice` would have failed otherwise.
123            *cmp_result = unsafe { memcmp(s1, s2, n as usize) };
124        }
125
126        Ok(0)
127    }
128);
129
130declare_builtin_function!(
131    /// memset
132    SyscallMemset,
133    fn rust(
134        invoke_context: &mut InvokeContext,
135        dst_addr: u64,
136        c: u64,
137        n: u64,
138        _arg4: u64,
139        _arg5: u64,
140        memory_mapping: &mut MemoryMapping,
141    ) -> Result<u64, Error> {
142        mem_op_consume(invoke_context, n)?;
143
144        if invoke_context
145            .get_feature_set()
146            .is_active(&clone_agave_feature_set::bpf_account_data_direct_mapping::id())
147        {
148            let syscall_context = invoke_context.get_syscall_context()?;
149
150            memset_non_contiguous(dst_addr, c as u8, n, &syscall_context.accounts_metadata, memory_mapping, invoke_context.get_check_aligned())
151        } else {
152            let s = translate_slice_mut::<u8>(
153                memory_mapping,
154                dst_addr,
155                n,
156                invoke_context.get_check_aligned(),
157            )?;
158            s.fill(c as u8);
159            Ok(0)
160        }
161    }
162);
163
164fn memmove(
165    invoke_context: &mut InvokeContext,
166    dst_addr: u64,
167    src_addr: u64,
168    n: u64,
169    memory_mapping: &MemoryMapping,
170) -> Result<u64, Error> {
171    if invoke_context
172        .get_feature_set()
173        .is_active(&clone_agave_feature_set::bpf_account_data_direct_mapping::id())
174    {
175        let syscall_context = invoke_context.get_syscall_context()?;
176
177        memmove_non_contiguous(
178            dst_addr,
179            src_addr,
180            n,
181            &syscall_context.accounts_metadata,
182            memory_mapping,
183            invoke_context.get_check_aligned(),
184        )
185    } else {
186        let dst_ptr = translate_slice_mut::<u8>(
187            memory_mapping,
188            dst_addr,
189            n,
190            invoke_context.get_check_aligned(),
191        )?
192        .as_mut_ptr();
193        let src_ptr = translate_slice::<u8>(
194            memory_mapping,
195            src_addr,
196            n,
197            invoke_context.get_check_aligned(),
198        )?
199        .as_ptr();
200
201        unsafe { std::ptr::copy(src_ptr, dst_ptr, n as usize) };
202        Ok(0)
203    }
204}
205
206fn memmove_non_contiguous(
207    dst_addr: u64,
208    src_addr: u64,
209    n: u64,
210    accounts: &[SerializedAccountMetadata],
211    memory_mapping: &MemoryMapping,
212    resize_area: bool,
213) -> Result<u64, Error> {
214    let reverse = dst_addr.wrapping_sub(src_addr) < n;
215    iter_memory_pair_chunks(
216        AccessType::Load,
217        src_addr,
218        AccessType::Store,
219        dst_addr,
220        n,
221        accounts,
222        memory_mapping,
223        reverse,
224        resize_area,
225        |src_host_addr, dst_host_addr, chunk_len| {
226            unsafe { std::ptr::copy(src_host_addr, dst_host_addr as *mut u8, chunk_len) };
227            Ok(0)
228        },
229    )
230}
231
232// Marked unsafe since it assumes that the slices are at least `n` bytes long.
233unsafe fn memcmp(s1: &[u8], s2: &[u8], n: usize) -> i32 {
234    for i in 0..n {
235        let a = *s1.get_unchecked(i);
236        let b = *s2.get_unchecked(i);
237        if a != b {
238            return (a as i32).saturating_sub(b as i32);
239        };
240    }
241
242    0
243}
244
245fn memcmp_non_contiguous(
246    src_addr: u64,
247    dst_addr: u64,
248    n: u64,
249    accounts: &[SerializedAccountMetadata],
250    memory_mapping: &MemoryMapping,
251    resize_area: bool,
252) -> Result<i32, Error> {
253    let memcmp_chunk = |s1_addr, s2_addr, chunk_len| {
254        let res = unsafe {
255            let s1 = slice::from_raw_parts(s1_addr, chunk_len);
256            let s2 = slice::from_raw_parts(s2_addr, chunk_len);
257            // Safety:
258            // memcmp is marked unsafe since it assumes that s1 and s2 are exactly chunk_len
259            // long. The whole point of iter_memory_pair_chunks is to find same length chunks
260            // across two memory regions.
261            memcmp(s1, s2, chunk_len)
262        };
263        if res != 0 {
264            return Err(MemcmpError::Diff(res).into());
265        }
266        Ok(0)
267    };
268    match iter_memory_pair_chunks(
269        AccessType::Load,
270        src_addr,
271        AccessType::Load,
272        dst_addr,
273        n,
274        accounts,
275        memory_mapping,
276        false,
277        resize_area,
278        memcmp_chunk,
279    ) {
280        Ok(res) => Ok(res),
281        Err(error) => match error.downcast_ref() {
282            Some(MemcmpError::Diff(diff)) => Ok(*diff),
283            _ => Err(error),
284        },
285    }
286}
287
288#[derive(Debug)]
289enum MemcmpError {
290    Diff(i32),
291}
292
293impl std::fmt::Display for MemcmpError {
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        match self {
296            MemcmpError::Diff(diff) => write!(f, "memcmp diff: {diff}"),
297        }
298    }
299}
300
301impl std::error::Error for MemcmpError {
302    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
303        match self {
304            MemcmpError::Diff(_) => None,
305        }
306    }
307}
308
309fn memset_non_contiguous(
310    dst_addr: u64,
311    c: u8,
312    n: u64,
313    accounts: &[SerializedAccountMetadata],
314    memory_mapping: &MemoryMapping,
315    check_aligned: bool,
316) -> Result<u64, Error> {
317    let dst_chunk_iter = MemoryChunkIterator::new(
318        memory_mapping,
319        accounts,
320        AccessType::Store,
321        dst_addr,
322        n,
323        check_aligned,
324    )?;
325    for item in dst_chunk_iter {
326        let (dst_region, dst_vm_addr, dst_len) = item?;
327        let dst_host_addr = Result::from(dst_region.vm_to_host(dst_vm_addr, dst_len as u64))?;
328        unsafe { slice::from_raw_parts_mut(dst_host_addr as *mut u8, dst_len).fill(c) }
329    }
330
331    Ok(0)
332}
333
334#[allow(clippy::too_many_arguments)]
335fn iter_memory_pair_chunks<T, F>(
336    src_access: AccessType,
337    src_addr: u64,
338    dst_access: AccessType,
339    dst_addr: u64,
340    n_bytes: u64,
341    accounts: &[SerializedAccountMetadata],
342    memory_mapping: &MemoryMapping,
343    reverse: bool,
344    resize_area: bool,
345    mut fun: F,
346) -> Result<T, Error>
347where
348    T: Default,
349    F: FnMut(*const u8, *const u8, usize) -> Result<T, Error>,
350{
351    let mut src_chunk_iter = MemoryChunkIterator::new(
352        memory_mapping,
353        accounts,
354        src_access,
355        src_addr,
356        n_bytes,
357        resize_area,
358    )
359    .map_err(EbpfError::from)?;
360    let mut dst_chunk_iter = MemoryChunkIterator::new(
361        memory_mapping,
362        accounts,
363        dst_access,
364        dst_addr,
365        n_bytes,
366        resize_area,
367    )
368    .map_err(EbpfError::from)?;
369
370    let mut src_chunk = None;
371    let mut dst_chunk = None;
372
373    macro_rules! memory_chunk {
374        ($chunk_iter:ident, $chunk:ident) => {
375            if let Some($chunk) = &mut $chunk {
376                // Keep processing the current chunk
377                $chunk
378            } else {
379                // This is either the first call or we've processed all the bytes in the current
380                // chunk. Move to the next one.
381                let chunk = match if reverse {
382                    $chunk_iter.next_back()
383                } else {
384                    $chunk_iter.next()
385                } {
386                    Some(item) => item?,
387                    None => break,
388                };
389                $chunk.insert(chunk)
390            }
391        };
392    }
393
394    loop {
395        let (src_region, src_chunk_addr, src_remaining) = memory_chunk!(src_chunk_iter, src_chunk);
396        let (dst_region, dst_chunk_addr, dst_remaining) = memory_chunk!(dst_chunk_iter, dst_chunk);
397
398        // We always process same-length pairs
399        let chunk_len = *src_remaining.min(dst_remaining);
400
401        let (src_host_addr, dst_host_addr) = {
402            let (src_addr, dst_addr) = if reverse {
403                // When scanning backwards not only we want to scan regions from the end,
404                // we want to process the memory within regions backwards as well.
405                (
406                    src_chunk_addr
407                        .saturating_add(*src_remaining as u64)
408                        .saturating_sub(chunk_len as u64),
409                    dst_chunk_addr
410                        .saturating_add(*dst_remaining as u64)
411                        .saturating_sub(chunk_len as u64),
412                )
413            } else {
414                (*src_chunk_addr, *dst_chunk_addr)
415            };
416
417            (
418                Result::from(src_region.vm_to_host(src_addr, chunk_len as u64))?,
419                Result::from(dst_region.vm_to_host(dst_addr, chunk_len as u64))?,
420            )
421        };
422
423        fun(
424            src_host_addr as *const u8,
425            dst_host_addr as *const u8,
426            chunk_len,
427        )?;
428
429        // Update how many bytes we have left to scan in each chunk
430        *src_remaining = src_remaining.saturating_sub(chunk_len);
431        *dst_remaining = dst_remaining.saturating_sub(chunk_len);
432
433        if !reverse {
434            // We've scanned `chunk_len` bytes so we move the vm address forward. In reverse
435            // mode we don't do this since we make progress by decreasing src_len and
436            // dst_len.
437            *src_chunk_addr = src_chunk_addr.saturating_add(chunk_len as u64);
438            *dst_chunk_addr = dst_chunk_addr.saturating_add(chunk_len as u64);
439        }
440
441        if *src_remaining == 0 {
442            src_chunk = None;
443        }
444
445        if *dst_remaining == 0 {
446            dst_chunk = None;
447        }
448    }
449
450    Ok(T::default())
451}
452
453struct MemoryChunkIterator<'a> {
454    memory_mapping: &'a MemoryMapping<'a>,
455    accounts: &'a [SerializedAccountMetadata],
456    access_type: AccessType,
457    initial_vm_addr: u64,
458    vm_addr_start: u64,
459    // exclusive end index (start + len, so one past the last valid address)
460    vm_addr_end: u64,
461    len: u64,
462    account_index: Option<usize>,
463    is_account: Option<bool>,
464    resize_area: bool,
465}
466
467impl<'a> MemoryChunkIterator<'a> {
468    fn new(
469        memory_mapping: &'a MemoryMapping,
470        accounts: &'a [SerializedAccountMetadata],
471        access_type: AccessType,
472        vm_addr: u64,
473        len: u64,
474        resize_area: bool,
475    ) -> Result<MemoryChunkIterator<'a>, EbpfError> {
476        let vm_addr_end = vm_addr.checked_add(len).ok_or(EbpfError::AccessViolation(
477            access_type,
478            vm_addr,
479            len,
480            "unknown",
481        ))?;
482
483        Ok(MemoryChunkIterator {
484            memory_mapping,
485            accounts,
486            access_type,
487            initial_vm_addr: vm_addr,
488            len,
489            vm_addr_start: vm_addr,
490            vm_addr_end,
491            account_index: None,
492            is_account: None,
493            resize_area,
494        })
495    }
496
497    fn region(&mut self, vm_addr: u64) -> Result<&'a MemoryRegion, Error> {
498        match self.memory_mapping.region(self.access_type, vm_addr) {
499            Ok(region) => Ok(region),
500            Err(error) => match error {
501                EbpfError::AccessViolation(access_type, _vm_addr, _len, name) => Err(Box::new(
502                    EbpfError::AccessViolation(access_type, self.initial_vm_addr, self.len, name),
503                )),
504                EbpfError::StackAccessViolation(access_type, _vm_addr, _len, frame) => {
505                    Err(Box::new(EbpfError::StackAccessViolation(
506                        access_type,
507                        self.initial_vm_addr,
508                        self.len,
509                        frame,
510                    )))
511                }
512                _ => Err(error.into()),
513            },
514        }
515    }
516}
517
518impl<'a> Iterator for MemoryChunkIterator<'a> {
519    type Item = Result<(&'a MemoryRegion, u64, usize), Error>;
520
521    fn next(&mut self) -> Option<Self::Item> {
522        if self.vm_addr_start == self.vm_addr_end {
523            return None;
524        }
525
526        let region = match self.region(self.vm_addr_start) {
527            Ok(region) => region,
528            Err(e) => {
529                self.vm_addr_start = self.vm_addr_end;
530                return Some(Err(e));
531            }
532        };
533
534        let region_is_account;
535
536        let mut account_index = self.account_index.unwrap_or_default();
537        self.account_index = Some(account_index);
538
539        loop {
540            if let Some(account) = self.accounts.get(account_index) {
541                let account_addr = account.vm_data_addr;
542                let resize_addr = account_addr.saturating_add(account.original_data_len as u64);
543
544                if resize_addr < region.vm_addr {
545                    // region is after this account, move on next one
546                    account_index = account_index.saturating_add(1);
547                    self.account_index = Some(account_index);
548                } else {
549                    region_is_account = (account.original_data_len != 0 && region.vm_addr == account_addr)
550                        // unaligned programs do not have a resize area
551                        || (self.resize_area && region.vm_addr == resize_addr);
552                    break;
553                }
554            } else {
555                // address is after all the accounts
556                region_is_account = false;
557                break;
558            }
559        }
560
561        if let Some(is_account) = self.is_account {
562            if is_account != region_is_account {
563                return Some(Err(SyscallError::InvalidLength.into()));
564            }
565        } else {
566            self.is_account = Some(region_is_account);
567        }
568
569        let vm_addr = self.vm_addr_start;
570
571        let chunk_len = if region.vm_addr_end <= self.vm_addr_end {
572            // consume the whole region
573            let len = region.vm_addr_end.saturating_sub(self.vm_addr_start);
574            self.vm_addr_start = region.vm_addr_end;
575            len
576        } else {
577            // consume part of the region
578            let len = self.vm_addr_end.saturating_sub(self.vm_addr_start);
579            self.vm_addr_start = self.vm_addr_end;
580            len
581        };
582
583        Some(Ok((region, vm_addr, chunk_len as usize)))
584    }
585}
586
587impl DoubleEndedIterator for MemoryChunkIterator<'_> {
588    fn next_back(&mut self) -> Option<Self::Item> {
589        if self.vm_addr_start == self.vm_addr_end {
590            return None;
591        }
592
593        let region = match self.region(self.vm_addr_end.saturating_sub(1)) {
594            Ok(region) => region,
595            Err(e) => {
596                self.vm_addr_start = self.vm_addr_end;
597                return Some(Err(e));
598            }
599        };
600
601        let region_is_account;
602
603        let mut account_index = self
604            .account_index
605            .unwrap_or_else(|| self.accounts.len().saturating_sub(1));
606        self.account_index = Some(account_index);
607
608        loop {
609            let Some(account) = self.accounts.get(account_index) else {
610                // address is after all the accounts
611                region_is_account = false;
612                break;
613            };
614
615            let account_addr = account.vm_data_addr;
616            let resize_addr = account_addr.saturating_add(account.original_data_len as u64);
617
618            if account_index > 0 && account_addr > region.vm_addr {
619                account_index = account_index.saturating_sub(1);
620
621                self.account_index = Some(account_index);
622            } else {
623                region_is_account = (account.original_data_len != 0 && region.vm_addr == account_addr)
624                    // unaligned programs do not have a resize area
625                    || (self.resize_area && region.vm_addr == resize_addr);
626                break;
627            }
628        }
629
630        if let Some(is_account) = self.is_account {
631            if is_account != region_is_account {
632                return Some(Err(SyscallError::InvalidLength.into()));
633            }
634        } else {
635            self.is_account = Some(region_is_account);
636        }
637
638        let chunk_len = if region.vm_addr >= self.vm_addr_start {
639            // consume the whole region
640            let len = self.vm_addr_end.saturating_sub(region.vm_addr);
641            self.vm_addr_end = region.vm_addr;
642            len
643        } else {
644            // consume part of the region
645            let len = self.vm_addr_end.saturating_sub(self.vm_addr_start);
646            self.vm_addr_end = self.vm_addr_start;
647            len
648        };
649
650        Some(Ok((region, self.vm_addr_end, chunk_len as usize)))
651    }
652}
653
654#[cfg(test)]
655#[allow(clippy::indexing_slicing)]
656#[allow(clippy::arithmetic_side_effects)]
657mod tests {
658    use {
659        super::*,
660        assert_matches::assert_matches,
661        solana_sbpf::{ebpf::MM_RODATA_START, program::SBPFVersion},
662        test_case::test_case,
663    };
664
665    fn to_chunk_vec<'a>(
666        iter: impl Iterator<Item = Result<(&'a MemoryRegion, u64, usize), Error>>,
667    ) -> Vec<(u64, usize)> {
668        iter.flat_map(|res| res.map(|(_, vm_addr, len)| (vm_addr, len)))
669            .collect::<Vec<_>>()
670    }
671
672    #[test]
673    #[should_panic(expected = "AccessViolation")]
674    fn test_memory_chunk_iterator_no_regions() {
675        let config = Config {
676            aligned_memory_mapping: false,
677            ..Config::default()
678        };
679        let memory_mapping = MemoryMapping::new(vec![], &config, SBPFVersion::V3).unwrap();
680
681        let mut src_chunk_iter =
682            MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, 0, 1, true).unwrap();
683        src_chunk_iter.next().unwrap().unwrap();
684    }
685
686    #[test]
687    #[should_panic(expected = "AccessViolation")]
688    fn test_memory_chunk_iterator_new_out_of_bounds_upper() {
689        let config = Config {
690            aligned_memory_mapping: false,
691            ..Config::default()
692        };
693        let memory_mapping = MemoryMapping::new(vec![], &config, SBPFVersion::V3).unwrap();
694
695        let mut src_chunk_iter =
696            MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, u64::MAX, 1, true)
697                .unwrap();
698        src_chunk_iter.next().unwrap().unwrap();
699    }
700
701    #[test]
702    fn test_memory_chunk_iterator_out_of_bounds() {
703        let config = Config {
704            aligned_memory_mapping: false,
705            ..Config::default()
706        };
707        let mem1 = vec![0xFF; 42];
708        let memory_mapping = MemoryMapping::new(
709            vec![MemoryRegion::new_readonly(&mem1, MM_RODATA_START)],
710            &config,
711            SBPFVersion::V3,
712        )
713        .unwrap();
714
715        // check oob at the lower bound on the first next()
716        let mut src_chunk_iter = MemoryChunkIterator::new(
717            &memory_mapping,
718            &[],
719            AccessType::Load,
720            MM_RODATA_START - 1,
721            42,
722            true,
723        )
724        .unwrap();
725        assert_matches!(
726            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
727            EbpfError::AccessViolation(AccessType::Load, addr, 42, "unknown") if *addr == MM_RODATA_START - 1
728        );
729
730        // check oob at the upper bound. Since the memory mapping isn't empty,
731        // this always happens on the second next().
732        let mut src_chunk_iter = MemoryChunkIterator::new(
733            &memory_mapping,
734            &[],
735            AccessType::Load,
736            MM_RODATA_START,
737            43,
738            true,
739        )
740        .unwrap();
741        assert!(src_chunk_iter.next().unwrap().is_ok());
742        assert_matches!(
743            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
744            EbpfError::AccessViolation(AccessType::Load, addr, 43, "program") if *addr == MM_RODATA_START
745        );
746
747        // check oob at the upper bound on the first next_back()
748        let mut src_chunk_iter = MemoryChunkIterator::new(
749            &memory_mapping,
750            &[],
751            AccessType::Load,
752            MM_RODATA_START,
753            43,
754            true,
755        )
756        .unwrap()
757        .rev();
758        assert_matches!(
759            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
760            EbpfError::AccessViolation(AccessType::Load, addr, 43, "program") if *addr == MM_RODATA_START
761        );
762
763        // check oob at the upper bound on the 2nd next_back()
764        let mut src_chunk_iter = MemoryChunkIterator::new(
765            &memory_mapping,
766            &[],
767            AccessType::Load,
768            MM_RODATA_START - 1,
769            43,
770            true,
771        )
772        .unwrap()
773        .rev();
774        assert!(src_chunk_iter.next().unwrap().is_ok());
775        assert_matches!(
776            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
777            EbpfError::AccessViolation(AccessType::Load, addr, 43, "unknown") if *addr == MM_RODATA_START - 1
778        );
779    }
780
781    #[test]
782    fn test_memory_chunk_iterator_one() {
783        let config = Config {
784            aligned_memory_mapping: false,
785            ..Config::default()
786        };
787        let mem1 = vec![0xFF; 42];
788        let memory_mapping = MemoryMapping::new(
789            vec![MemoryRegion::new_readonly(&mem1, MM_RODATA_START)],
790            &config,
791            SBPFVersion::V3,
792        )
793        .unwrap();
794
795        // check lower bound
796        let mut src_chunk_iter = MemoryChunkIterator::new(
797            &memory_mapping,
798            &[],
799            AccessType::Load,
800            MM_RODATA_START - 1,
801            1,
802            true,
803        )
804        .unwrap();
805        assert!(src_chunk_iter.next().unwrap().is_err());
806
807        // check upper bound
808        let mut src_chunk_iter = MemoryChunkIterator::new(
809            &memory_mapping,
810            &[],
811            AccessType::Load,
812            MM_RODATA_START + 42,
813            1,
814            true,
815        )
816        .unwrap();
817        assert!(src_chunk_iter.next().unwrap().is_err());
818
819        for (vm_addr, len) in [
820            (MM_RODATA_START, 0),
821            (MM_RODATA_START + 42, 0),
822            (MM_RODATA_START, 1),
823            (MM_RODATA_START, 42),
824            (MM_RODATA_START + 41, 1),
825        ] {
826            for rev in [true, false] {
827                let iter = MemoryChunkIterator::new(
828                    &memory_mapping,
829                    &[],
830                    AccessType::Load,
831                    vm_addr,
832                    len,
833                    true,
834                )
835                .unwrap();
836                let res = if rev {
837                    to_chunk_vec(iter.rev())
838                } else {
839                    to_chunk_vec(iter)
840                };
841                if len == 0 {
842                    assert_eq!(res, &[]);
843                } else {
844                    assert_eq!(res, &[(vm_addr, len as usize)]);
845                }
846            }
847        }
848    }
849
850    #[test]
851    fn test_memory_chunk_iterator_two() {
852        let config = Config {
853            aligned_memory_mapping: false,
854            ..Config::default()
855        };
856        let mem1 = vec![0x11; 8];
857        let mem2 = vec![0x22; 4];
858        let memory_mapping = MemoryMapping::new(
859            vec![
860                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
861                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
862            ],
863            &config,
864            SBPFVersion::V3,
865        )
866        .unwrap();
867
868        for (vm_addr, len, mut expected) in [
869            (MM_RODATA_START, 8, vec![(MM_RODATA_START, 8)]),
870            (
871                MM_RODATA_START + 7,
872                2,
873                vec![(MM_RODATA_START + 7, 1), (MM_RODATA_START + 8, 1)],
874            ),
875            (MM_RODATA_START + 8, 4, vec![(MM_RODATA_START + 8, 4)]),
876        ] {
877            for rev in [false, true] {
878                let iter = MemoryChunkIterator::new(
879                    &memory_mapping,
880                    &[],
881                    AccessType::Load,
882                    vm_addr,
883                    len,
884                    true,
885                )
886                .unwrap();
887                let res = if rev {
888                    expected.reverse();
889                    to_chunk_vec(iter.rev())
890                } else {
891                    to_chunk_vec(iter)
892                };
893
894                assert_eq!(res, expected);
895            }
896        }
897    }
898
899    #[test]
900    fn test_iter_memory_pair_chunks_short() {
901        let config = Config {
902            aligned_memory_mapping: false,
903            ..Config::default()
904        };
905        let mem1 = vec![0x11; 8];
906        let mem2 = vec![0x22; 4];
907        let memory_mapping = MemoryMapping::new(
908            vec![
909                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
910                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
911            ],
912            &config,
913            SBPFVersion::V3,
914        )
915        .unwrap();
916
917        // dst is shorter than src
918        assert_matches!(
919            iter_memory_pair_chunks(
920                AccessType::Load,
921                MM_RODATA_START,
922                AccessType::Load,
923                MM_RODATA_START + 8,
924                8,
925                &[],
926                &memory_mapping,
927                false,
928                true,
929                |_src, _dst, _len| Ok::<_, Error>(0),
930            ).unwrap_err().downcast_ref().unwrap(),
931            EbpfError::AccessViolation(AccessType::Load, addr, 8, "program") if *addr == MM_RODATA_START + 8
932        );
933
934        // src is shorter than dst
935        assert_matches!(
936            iter_memory_pair_chunks(
937                AccessType::Load,
938                MM_RODATA_START + 10,
939                AccessType::Load,
940                MM_RODATA_START + 2,
941                3,
942                &[],
943                &memory_mapping,
944                false,
945                true,
946                |_src, _dst, _len| Ok::<_, Error>(0),
947            ).unwrap_err().downcast_ref().unwrap(),
948            EbpfError::AccessViolation(AccessType::Load, addr, 3, "program") if *addr == MM_RODATA_START + 10
949        );
950    }
951
952    #[test]
953    #[should_panic(expected = "AccessViolation(Store, 4294967296, 4")]
954    fn test_memmove_non_contiguous_readonly() {
955        let config = Config {
956            aligned_memory_mapping: false,
957            ..Config::default()
958        };
959        let mem1 = vec![0x11; 8];
960        let mem2 = vec![0x22; 4];
961        let memory_mapping = MemoryMapping::new(
962            vec![
963                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
964                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
965            ],
966            &config,
967            SBPFVersion::V3,
968        )
969        .unwrap();
970
971        memmove_non_contiguous(
972            MM_RODATA_START,
973            MM_RODATA_START + 8,
974            4,
975            &[],
976            &memory_mapping,
977            true,
978        )
979        .unwrap();
980    }
981
982    #[test_case(&[], (0, 0, 0); "no regions")]
983    #[test_case(&[10], (1, 10, 0); "single region 0 len")]
984    #[test_case(&[10], (0, 5, 5); "single region no overlap")]
985    #[test_case(&[10], (0, 0, 10) ; "single region complete overlap")]
986    #[test_case(&[10], (2, 0, 5); "single region partial overlap start")]
987    #[test_case(&[10], (0, 1, 6); "single region partial overlap middle")]
988    #[test_case(&[10], (2, 5, 5); "single region partial overlap end")]
989    #[test_case(&[3, 5], (0, 5, 2) ; "two regions no overlap, single source region")]
990    #[test_case(&[4, 7], (0, 5, 5) ; "two regions no overlap, multiple source regions")]
991    #[test_case(&[3, 8], (0, 0, 11) ; "two regions complete overlap")]
992    #[test_case(&[2, 9], (3, 0, 5) ; "two regions partial overlap start")]
993    #[test_case(&[3, 9], (1, 2, 5) ; "two regions partial overlap middle")]
994    #[test_case(&[7, 3], (2, 6, 4) ; "two regions partial overlap end")]
995    #[test_case(&[2, 6, 3, 4], (0, 10, 2) ; "many regions no overlap, single source region")]
996    #[test_case(&[2, 1, 2, 5, 6], (2, 10, 4) ; "many regions no overlap, multiple source regions")]
997    #[test_case(&[8, 1, 3, 6], (0, 0, 18) ; "many regions complete overlap")]
998    #[test_case(&[7, 3, 1, 4, 5], (5, 0, 8) ; "many regions overlap start")]
999    #[test_case(&[1, 5, 2, 9, 3], (5, 4, 8) ; "many regions overlap middle")]
1000    #[test_case(&[3, 9, 1, 1, 2, 1], (2, 9, 8) ; "many regions overlap end")]
1001    fn test_memmove_non_contiguous(
1002        regions: &[usize],
1003        (src_offset, dst_offset, len): (usize, usize, usize),
1004    ) {
1005        let config = Config {
1006            aligned_memory_mapping: false,
1007            ..Config::default()
1008        };
1009        let (mem, memory_mapping) = build_memory_mapping(regions, &config);
1010
1011        // flatten the memory so we can memmove it with ptr::copy
1012        let mut expected_memory = flatten_memory(&mem);
1013        unsafe {
1014            std::ptr::copy(
1015                expected_memory.as_ptr().add(src_offset),
1016                expected_memory.as_mut_ptr().add(dst_offset),
1017                len,
1018            )
1019        };
1020
1021        // do our memmove
1022        memmove_non_contiguous(
1023            MM_RODATA_START + dst_offset as u64,
1024            MM_RODATA_START + src_offset as u64,
1025            len as u64,
1026            &[],
1027            &memory_mapping,
1028            true,
1029        )
1030        .unwrap();
1031
1032        // flatten memory post our memmove
1033        let memory = flatten_memory(&mem);
1034
1035        // compare libc's memmove with ours
1036        assert_eq!(expected_memory, memory);
1037    }
1038
1039    #[test]
1040    #[should_panic(expected = "AccessViolation(Store, 4294967296, 9")]
1041    fn test_memset_non_contiguous_readonly() {
1042        let config = Config {
1043            aligned_memory_mapping: false,
1044            ..Config::default()
1045        };
1046        let mut mem1 = vec![0x11; 8];
1047        let mem2 = vec![0x22; 4];
1048        let memory_mapping = MemoryMapping::new(
1049            vec![
1050                MemoryRegion::new_writable(&mut mem1, MM_RODATA_START),
1051                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
1052            ],
1053            &config,
1054            SBPFVersion::V3,
1055        )
1056        .unwrap();
1057
1058        assert_eq!(
1059            memset_non_contiguous(MM_RODATA_START, 0x33, 9, &[], &memory_mapping, true).unwrap(),
1060            0
1061        );
1062    }
1063
1064    #[test]
1065    fn test_memset_non_contiguous() {
1066        let config = Config {
1067            aligned_memory_mapping: false,
1068            ..Config::default()
1069        };
1070        let mem1 = vec![0x11; 1];
1071        let mut mem2 = vec![0x22; 2];
1072        let mut mem3 = vec![0x33; 3];
1073        let mut mem4 = vec![0x44; 4];
1074        let memory_mapping = MemoryMapping::new(
1075            vec![
1076                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
1077                MemoryRegion::new_writable(&mut mem2, MM_RODATA_START + 1),
1078                MemoryRegion::new_writable(&mut mem3, MM_RODATA_START + 3),
1079                MemoryRegion::new_writable(&mut mem4, MM_RODATA_START + 6),
1080            ],
1081            &config,
1082            SBPFVersion::V3,
1083        )
1084        .unwrap();
1085
1086        assert_eq!(
1087            memset_non_contiguous(MM_RODATA_START + 1, 0x55, 7, &[], &memory_mapping, true)
1088                .unwrap(),
1089            0
1090        );
1091        assert_eq!(&mem1, &[0x11]);
1092        assert_eq!(&mem2, &[0x55, 0x55]);
1093        assert_eq!(&mem3, &[0x55, 0x55, 0x55]);
1094        assert_eq!(&mem4, &[0x55, 0x55, 0x44, 0x44]);
1095    }
1096
1097    #[test]
1098    fn test_memcmp_non_contiguous() {
1099        let config = Config {
1100            aligned_memory_mapping: false,
1101            ..Config::default()
1102        };
1103        let mem1 = b"foo".to_vec();
1104        let mem2 = b"barbad".to_vec();
1105        let mem3 = b"foobarbad".to_vec();
1106        let memory_mapping = MemoryMapping::new(
1107            vec![
1108                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
1109                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 3),
1110                MemoryRegion::new_readonly(&mem3, MM_RODATA_START + 9),
1111            ],
1112            &config,
1113            SBPFVersion::V3,
1114        )
1115        .unwrap();
1116
1117        // non contiguous src
1118        assert_eq!(
1119            memcmp_non_contiguous(
1120                MM_RODATA_START,
1121                MM_RODATA_START + 9,
1122                9,
1123                &[],
1124                &memory_mapping,
1125                true
1126            )
1127            .unwrap(),
1128            0
1129        );
1130
1131        // non contiguous dst
1132        assert_eq!(
1133            memcmp_non_contiguous(
1134                MM_RODATA_START + 10,
1135                MM_RODATA_START + 1,
1136                8,
1137                &[],
1138                &memory_mapping,
1139                true
1140            )
1141            .unwrap(),
1142            0
1143        );
1144
1145        // diff
1146        assert_eq!(
1147            memcmp_non_contiguous(
1148                MM_RODATA_START + 1,
1149                MM_RODATA_START + 11,
1150                5,
1151                &[],
1152                &memory_mapping,
1153                true
1154            )
1155            .unwrap(),
1156            unsafe { memcmp(b"oobar", b"obarb", 5) }
1157        );
1158    }
1159
1160    fn build_memory_mapping<'a>(
1161        regions: &[usize],
1162        config: &'a Config,
1163    ) -> (Vec<Vec<u8>>, MemoryMapping<'a>) {
1164        let mut regs = vec![];
1165        let mut mem = Vec::new();
1166        let mut offset = 0;
1167        for (i, region_len) in regions.iter().enumerate() {
1168            mem.push(
1169                (0..*region_len)
1170                    .map(|x| (i * 10 + x) as u8)
1171                    .collect::<Vec<_>>(),
1172            );
1173            regs.push(MemoryRegion::new_writable(
1174                &mut mem[i],
1175                MM_RODATA_START + offset as u64,
1176            ));
1177            offset += *region_len;
1178        }
1179
1180        let memory_mapping = MemoryMapping::new(regs, config, SBPFVersion::V3).unwrap();
1181
1182        (mem, memory_mapping)
1183    }
1184
1185    fn flatten_memory(mem: &[Vec<u8>]) -> Vec<u8> {
1186        mem.iter().flatten().copied().collect()
1187    }
1188
1189    #[test]
1190    fn test_is_nonoverlapping() {
1191        for dst in 0..8 {
1192            assert!(is_nonoverlapping(10, 3, dst, 3));
1193        }
1194        for dst in 8..13 {
1195            assert!(!is_nonoverlapping(10, 3, dst, 3));
1196        }
1197        for dst in 13..20 {
1198            assert!(is_nonoverlapping(10, 3, dst, 3));
1199        }
1200        assert!(is_nonoverlapping::<u8>(255, 3, 254, 1));
1201        assert!(!is_nonoverlapping::<u8>(255, 2, 254, 3));
1202    }
1203}