1use std::slice;
6
7use rialo_s_program_runtime::invoke_context::SerializedAccountMetadata;
8use solana_sbpf::{error::EbpfError, memory_region::MemoryRegion};
9
10use super::*;
11
12fn mem_op_consume(invoke_context: &mut InvokeContext<'_>, n: u64) -> Result<(), Error> {
13 let compute_budget = invoke_context.get_compute_budget();
14 let cost = compute_budget.mem_op_base_cost.max(
15 n.checked_div(compute_budget.cpi_bytes_per_unit)
16 .unwrap_or(u64::MAX),
17 );
18 consume_compute_meter(invoke_context, cost)
19}
20
21declare_builtin_function!(
22 SyscallMemcpy,
24 fn rust(
25 invoke_context: &mut InvokeContext<'_>,
26 dst_addr: u64,
27 src_addr: u64,
28 n: u64,
29 _arg4: u64,
30 _arg5: u64,
31 memory_mapping: &mut MemoryMapping<'_>,
32 ) -> Result<u64, Error> {
33 mem_op_consume(invoke_context, n)?;
34
35 if !is_nonoverlapping(src_addr, n, dst_addr, n) {
36 return Err(SyscallError::CopyOverlapping.into());
37 }
38
39 memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
41 }
42);
43
44declare_builtin_function!(
45 SyscallMemmove,
47 fn rust(
48 invoke_context: &mut InvokeContext<'_>,
49 dst_addr: u64,
50 src_addr: u64,
51 n: u64,
52 _arg4: u64,
53 _arg5: u64,
54 memory_mapping: &mut MemoryMapping<'_>,
55 ) -> Result<u64, Error> {
56 mem_op_consume(invoke_context, n)?;
57
58 memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
59 }
60);
61
62declare_builtin_function!(
63 SyscallMemcmp,
65 fn rust(
66 invoke_context: &mut InvokeContext<'_>,
67 s1_addr: u64,
68 s2_addr: u64,
69 n: u64,
70 cmp_result_addr: u64,
71 _arg5: u64,
72 memory_mapping: &mut MemoryMapping<'_>,
73 ) -> Result<u64, Error> {
74 mem_op_consume(invoke_context, n)?;
75
76 if invoke_context
77 .get_feature_set()
78 .is_active(&rialo_s_feature_set::bpf_account_data_direct_mapping::id())
79 {
80 let cmp_result = translate_type_mut::<i32>(
81 memory_mapping,
82 cmp_result_addr,
83 invoke_context.get_check_aligned(),
84 )?;
85 let syscall_context = invoke_context.get_syscall_context()?;
86
87 *cmp_result = memcmp_non_contiguous(s1_addr, s2_addr, n, &syscall_context.accounts_metadata, memory_mapping, invoke_context.get_check_aligned())?;
88 } else {
89 let s1 = translate_slice::<u8>(
90 memory_mapping,
91 s1_addr,
92 n,
93 invoke_context.get_check_aligned(),
94 )?;
95 let s2 = translate_slice::<u8>(
96 memory_mapping,
97 s2_addr,
98 n,
99 invoke_context.get_check_aligned(),
100 )?;
101 let cmp_result = translate_type_mut::<i32>(
102 memory_mapping,
103 cmp_result_addr,
104 invoke_context.get_check_aligned(),
105 )?;
106
107 debug_assert_eq!(s1.len(), n as usize);
108 debug_assert_eq!(s2.len(), n as usize);
109 *cmp_result = unsafe { memcmp(s1, s2, n as usize) };
114 }
115
116 Ok(0)
117 }
118);
119
120declare_builtin_function!(
121 SyscallMemset,
123 fn rust(
124 invoke_context: &mut InvokeContext<'_>,
125 dst_addr: u64,
126 c: u64,
127 n: u64,
128 _arg4: u64,
129 _arg5: u64,
130 memory_mapping: &mut MemoryMapping<'_>,
131 ) -> Result<u64, Error> {
132 mem_op_consume(invoke_context, n)?;
133
134 if invoke_context
135 .get_feature_set()
136 .is_active(&rialo_s_feature_set::bpf_account_data_direct_mapping::id())
137 {
138 let syscall_context = invoke_context.get_syscall_context()?;
139
140 memset_non_contiguous(dst_addr, c as u8, n, &syscall_context.accounts_metadata, memory_mapping, invoke_context.get_check_aligned())
141 } else {
142 let s = translate_slice_mut::<u8>(
143 memory_mapping,
144 dst_addr,
145 n,
146 invoke_context.get_check_aligned(),
147 )?;
148 s.fill(c as u8);
149 Ok(0)
150 }
151 }
152);
153
154fn memmove(
155 invoke_context: &mut InvokeContext<'_>,
156 dst_addr: u64,
157 src_addr: u64,
158 n: u64,
159 memory_mapping: &MemoryMapping<'_>,
160) -> Result<u64, Error> {
161 if invoke_context
162 .get_feature_set()
163 .is_active(&rialo_s_feature_set::bpf_account_data_direct_mapping::id())
164 {
165 let syscall_context = invoke_context.get_syscall_context()?;
166
167 memmove_non_contiguous(
168 dst_addr,
169 src_addr,
170 n,
171 &syscall_context.accounts_metadata,
172 memory_mapping,
173 invoke_context.get_check_aligned(),
174 )
175 } else {
176 let dst_ptr = translate_slice_mut::<u8>(
177 memory_mapping,
178 dst_addr,
179 n,
180 invoke_context.get_check_aligned(),
181 )?
182 .as_mut_ptr();
183 let src_ptr = translate_slice::<u8>(
184 memory_mapping,
185 src_addr,
186 n,
187 invoke_context.get_check_aligned(),
188 )?
189 .as_ptr();
190
191 unsafe { std::ptr::copy(src_ptr, dst_ptr, n as usize) };
192 Ok(0)
193 }
194}
195
196fn memmove_non_contiguous(
197 dst_addr: u64,
198 src_addr: u64,
199 n: u64,
200 accounts: &[SerializedAccountMetadata],
201 memory_mapping: &MemoryMapping<'_>,
202 resize_area: bool,
203) -> Result<u64, Error> {
204 let reverse = dst_addr.wrapping_sub(src_addr) < n;
205 iter_memory_pair_chunks(
206 AccessType::Load,
207 src_addr,
208 AccessType::Store,
209 dst_addr,
210 n,
211 accounts,
212 memory_mapping,
213 reverse,
214 resize_area,
215 |src_host_addr, dst_host_addr, chunk_len| {
216 unsafe { std::ptr::copy(src_host_addr, dst_host_addr as *mut u8, chunk_len) };
217 Ok(0)
218 },
219 )
220}
221
222unsafe fn memcmp(s1: &[u8], s2: &[u8], n: usize) -> i32 {
224 for i in 0..n {
225 let a = *s1.get_unchecked(i);
226 let b = *s2.get_unchecked(i);
227 if a != b {
228 return (a as i32).saturating_sub(b as i32);
229 };
230 }
231
232 0
233}
234
235fn memcmp_non_contiguous(
236 src_addr: u64,
237 dst_addr: u64,
238 n: u64,
239 accounts: &[SerializedAccountMetadata],
240 memory_mapping: &MemoryMapping<'_>,
241 resize_area: bool,
242) -> Result<i32, Error> {
243 let memcmp_chunk = |s1_addr, s2_addr, chunk_len| {
244 let res = unsafe {
245 let s1 = slice::from_raw_parts(s1_addr, chunk_len);
246 let s2 = slice::from_raw_parts(s2_addr, chunk_len);
247 memcmp(s1, s2, chunk_len)
252 };
253 if res != 0 {
254 return Err(MemcmpError::Diff(res).into());
255 }
256 Ok(0)
257 };
258 match iter_memory_pair_chunks(
259 AccessType::Load,
260 src_addr,
261 AccessType::Load,
262 dst_addr,
263 n,
264 accounts,
265 memory_mapping,
266 false,
267 resize_area,
268 memcmp_chunk,
269 ) {
270 Ok(res) => Ok(res),
271 Err(error) => match error.downcast_ref() {
272 Some(MemcmpError::Diff(diff)) => Ok(*diff),
273 _ => Err(error),
274 },
275 }
276}
277
278#[derive(Debug)]
279enum MemcmpError {
280 Diff(i32),
281}
282
283impl std::fmt::Display for MemcmpError {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 match self {
286 MemcmpError::Diff(diff) => write!(f, "memcmp diff: {diff}"),
287 }
288 }
289}
290
291impl std::error::Error for MemcmpError {
292 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
293 match self {
294 MemcmpError::Diff(_) => None,
295 }
296 }
297}
298
299fn memset_non_contiguous(
300 dst_addr: u64,
301 c: u8,
302 n: u64,
303 accounts: &[SerializedAccountMetadata],
304 memory_mapping: &MemoryMapping<'_>,
305 check_aligned: bool,
306) -> Result<u64, Error> {
307 let dst_chunk_iter = MemoryChunkIterator::new(
308 memory_mapping,
309 accounts,
310 AccessType::Store,
311 dst_addr,
312 n,
313 check_aligned,
314 )?;
315 for item in dst_chunk_iter {
316 let (dst_region, dst_vm_addr, dst_len) = item?;
317 let dst_host_addr = Result::from(dst_region.vm_to_host(dst_vm_addr, dst_len as u64))?;
318 unsafe { slice::from_raw_parts_mut(dst_host_addr as *mut u8, dst_len).fill(c) }
319 }
320
321 Ok(0)
322}
323
324#[allow(clippy::too_many_arguments)]
325fn iter_memory_pair_chunks<T, F>(
326 src_access: AccessType,
327 src_addr: u64,
328 dst_access: AccessType,
329 dst_addr: u64,
330 n_bytes: u64,
331 accounts: &[SerializedAccountMetadata],
332 memory_mapping: &MemoryMapping<'_>,
333 reverse: bool,
334 resize_area: bool,
335 mut fun: F,
336) -> Result<T, Error>
337where
338 T: Default,
339 F: FnMut(*const u8, *const u8, usize) -> Result<T, Error>,
340{
341 let mut src_chunk_iter = MemoryChunkIterator::new(
342 memory_mapping,
343 accounts,
344 src_access,
345 src_addr,
346 n_bytes,
347 resize_area,
348 )?;
349 let mut dst_chunk_iter = MemoryChunkIterator::new(
350 memory_mapping,
351 accounts,
352 dst_access,
353 dst_addr,
354 n_bytes,
355 resize_area,
356 )?;
357
358 let mut src_chunk = None;
359 let mut dst_chunk = None;
360
361 macro_rules! memory_chunk {
362 ($chunk_iter:ident, $chunk:ident) => {
363 if let Some($chunk) = &mut $chunk {
364 $chunk
366 } else {
367 let chunk = match if reverse {
370 $chunk_iter.next_back()
371 } else {
372 $chunk_iter.next()
373 } {
374 Some(item) => item?,
375 None => break,
376 };
377 $chunk.insert(chunk)
378 }
379 };
380 }
381
382 loop {
383 let (src_region, src_chunk_addr, src_remaining) = memory_chunk!(src_chunk_iter, src_chunk);
384 let (dst_region, dst_chunk_addr, dst_remaining) = memory_chunk!(dst_chunk_iter, dst_chunk);
385
386 let chunk_len = *src_remaining.min(dst_remaining);
388
389 let (src_host_addr, dst_host_addr) = {
390 let (src_addr, dst_addr) = if reverse {
391 (
394 src_chunk_addr
395 .saturating_add(*src_remaining as u64)
396 .saturating_sub(chunk_len as u64),
397 dst_chunk_addr
398 .saturating_add(*dst_remaining as u64)
399 .saturating_sub(chunk_len as u64),
400 )
401 } else {
402 (*src_chunk_addr, *dst_chunk_addr)
403 };
404
405 (
406 Result::from(src_region.vm_to_host(src_addr, chunk_len as u64))?,
407 Result::from(dst_region.vm_to_host(dst_addr, chunk_len as u64))?,
408 )
409 };
410
411 fun(
412 src_host_addr as *const u8,
413 dst_host_addr as *const u8,
414 chunk_len,
415 )?;
416
417 *src_remaining = src_remaining.saturating_sub(chunk_len);
419 *dst_remaining = dst_remaining.saturating_sub(chunk_len);
420
421 if !reverse {
422 *src_chunk_addr = src_chunk_addr.saturating_add(chunk_len as u64);
426 *dst_chunk_addr = dst_chunk_addr.saturating_add(chunk_len as u64);
427 }
428
429 if *src_remaining == 0 {
430 src_chunk = None;
431 }
432
433 if *dst_remaining == 0 {
434 dst_chunk = None;
435 }
436 }
437
438 Ok(T::default())
439}
440
441struct MemoryChunkIterator<'a> {
442 memory_mapping: &'a MemoryMapping<'a>,
443 accounts: &'a [SerializedAccountMetadata],
444 access_type: AccessType,
445 initial_vm_addr: u64,
446 vm_addr_start: u64,
447 vm_addr_end: u64,
449 len: u64,
450 account_index: Option<usize>,
451 is_account: Option<bool>,
452 resize_area: bool,
453}
454
455impl<'a> MemoryChunkIterator<'a> {
456 fn new(
457 memory_mapping: &'a MemoryMapping<'_>,
458 accounts: &'a [SerializedAccountMetadata],
459 access_type: AccessType,
460 vm_addr: u64,
461 len: u64,
462 resize_area: bool,
463 ) -> Result<MemoryChunkIterator<'a>, EbpfError> {
464 let vm_addr_end = vm_addr.checked_add(len).ok_or(EbpfError::AccessViolation(
465 access_type,
466 vm_addr,
467 len,
468 "unknown",
469 ))?;
470
471 Ok(MemoryChunkIterator {
472 memory_mapping,
473 accounts,
474 access_type,
475 initial_vm_addr: vm_addr,
476 len,
477 vm_addr_start: vm_addr,
478 vm_addr_end,
479 account_index: None,
480 is_account: None,
481 resize_area,
482 })
483 }
484
485 fn region(&mut self, vm_addr: u64) -> Result<&'a MemoryRegion, Error> {
486 match self.memory_mapping.region(self.access_type, vm_addr) {
487 Ok(region) => Ok(region),
488 Err(error) => match error {
489 EbpfError::AccessViolation(access_type, _vm_addr, _len, name) => Err(Box::new(
490 EbpfError::AccessViolation(access_type, self.initial_vm_addr, self.len, name),
491 )),
492 EbpfError::StackAccessViolation(access_type, _vm_addr, _len, frame) => {
493 Err(Box::new(EbpfError::StackAccessViolation(
494 access_type,
495 self.initial_vm_addr,
496 self.len,
497 frame,
498 )))
499 }
500 _ => Err(error.into()),
501 },
502 }
503 }
504}
505
506impl<'a> Iterator for MemoryChunkIterator<'a> {
507 type Item = Result<(&'a MemoryRegion, u64, usize), Error>;
508
509 fn next(&mut self) -> Option<Self::Item> {
510 if self.vm_addr_start == self.vm_addr_end {
511 return None;
512 }
513
514 let region = match self.region(self.vm_addr_start) {
515 Ok(region) => region,
516 Err(e) => {
517 self.vm_addr_start = self.vm_addr_end;
518 return Some(Err(e));
519 }
520 };
521
522 let region_is_account;
523
524 let mut account_index = self.account_index.unwrap_or_default();
525 self.account_index = Some(account_index);
526
527 loop {
528 if let Some(account) = self.accounts.get(account_index) {
529 let account_addr = account.vm_data_addr;
530 let resize_addr = account_addr.saturating_add(account.original_data_len as u64);
531
532 if resize_addr < region.vm_addr {
533 account_index = account_index.saturating_add(1);
535 self.account_index = Some(account_index);
536 } else {
537 region_is_account = region.vm_addr == account_addr
538 || (self.resize_area && region.vm_addr == resize_addr);
540 break;
541 }
542 } else {
543 region_is_account = false;
545 break;
546 }
547 }
548
549 if let Some(is_account) = self.is_account {
550 if is_account != region_is_account {
551 return Some(Err(SyscallError::InvalidLength.into()));
552 }
553 } else {
554 self.is_account = Some(region_is_account);
555 }
556
557 let vm_addr = self.vm_addr_start;
558
559 let chunk_len = if region.vm_addr_end <= self.vm_addr_end {
560 let len = region.vm_addr_end.saturating_sub(self.vm_addr_start);
562 self.vm_addr_start = region.vm_addr_end;
563 len
564 } else {
565 let len = self.vm_addr_end.saturating_sub(self.vm_addr_start);
567 self.vm_addr_start = self.vm_addr_end;
568 len
569 };
570
571 Some(Ok((region, vm_addr, chunk_len as usize)))
572 }
573}
574
575impl DoubleEndedIterator for MemoryChunkIterator<'_> {
576 fn next_back(&mut self) -> Option<Self::Item> {
577 if self.vm_addr_start == self.vm_addr_end {
578 return None;
579 }
580
581 let region = match self.region(self.vm_addr_end.saturating_sub(1)) {
582 Ok(region) => region,
583 Err(e) => {
584 self.vm_addr_start = self.vm_addr_end;
585 return Some(Err(e));
586 }
587 };
588
589 let region_is_account;
590
591 let mut account_index = self
592 .account_index
593 .unwrap_or_else(|| self.accounts.len().saturating_sub(1));
594 self.account_index = Some(account_index);
595
596 loop {
597 let Some(account) = self.accounts.get(account_index) else {
598 region_is_account = false;
600 break;
601 };
602
603 let account_addr = account.vm_data_addr;
604 let resize_addr = account_addr.saturating_add(account.original_data_len as u64);
605
606 if account_index > 0 && account_addr > region.vm_addr {
607 account_index = account_index.saturating_sub(1);
608
609 self.account_index = Some(account_index);
610 } else {
611 region_is_account = region.vm_addr == account_addr
612 || (self.resize_area && region.vm_addr == resize_addr);
614 break;
615 }
616 }
617
618 if let Some(is_account) = self.is_account {
619 if is_account != region_is_account {
620 return Some(Err(SyscallError::InvalidLength.into()));
621 }
622 } else {
623 self.is_account = Some(region_is_account);
624 }
625
626 let chunk_len = if region.vm_addr >= self.vm_addr_start {
627 let len = self.vm_addr_end.saturating_sub(region.vm_addr);
629 self.vm_addr_end = region.vm_addr;
630 len
631 } else {
632 let len = self.vm_addr_end.saturating_sub(self.vm_addr_start);
634 self.vm_addr_end = self.vm_addr_start;
635 len
636 };
637
638 Some(Ok((region, self.vm_addr_end, chunk_len as usize)))
639 }
640}
641
642#[cfg(test)]
643#[allow(clippy::indexing_slicing)]
644#[allow(clippy::arithmetic_side_effects)]
645mod tests {
646 use assert_matches::assert_matches;
647 use solana_sbpf::{ebpf::MM_RODATA_START, program::SBPFVersion};
648 use test_case::test_case;
649
650 use super::*;
651
652 fn to_chunk_vec<'a>(
653 iter: impl Iterator<Item = Result<(&'a MemoryRegion, u64, usize), Error>>,
654 ) -> Vec<(u64, usize)> {
655 iter.flat_map(|res| res.map(|(_, vm_addr, len)| (vm_addr, len)))
656 .collect::<Vec<_>>()
657 }
658
659 #[test]
660 #[should_panic(expected = "AccessViolation")]
661 fn test_memory_chunk_iterator_no_regions() {
662 let config = Config {
663 aligned_memory_mapping: false,
664 ..Config::default()
665 };
666 let memory_mapping = MemoryMapping::new(vec![], &config, SBPFVersion::V3).unwrap();
667
668 let mut src_chunk_iter =
669 MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, 0, 1, true).unwrap();
670 src_chunk_iter.next().unwrap().unwrap();
671 }
672
673 #[test]
674 #[should_panic(expected = "AccessViolation")]
675 fn test_memory_chunk_iterator_new_out_of_bounds_upper() {
676 let config = Config {
677 aligned_memory_mapping: false,
678 ..Config::default()
679 };
680 let memory_mapping = MemoryMapping::new(vec![], &config, SBPFVersion::V3).unwrap();
681
682 let mut src_chunk_iter =
683 MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, u64::MAX, 1, true)
684 .unwrap();
685 src_chunk_iter.next().unwrap().unwrap();
686 }
687
688 #[test]
689 fn test_memory_chunk_iterator_out_of_bounds() {
690 let config = Config {
691 aligned_memory_mapping: false,
692 ..Config::default()
693 };
694 let mem1 = vec![0xFF; 42];
695 let memory_mapping = MemoryMapping::new(
696 vec![MemoryRegion::new_readonly(&mem1, MM_RODATA_START)],
697 &config,
698 SBPFVersion::V3,
699 )
700 .unwrap();
701
702 let mut src_chunk_iter = MemoryChunkIterator::new(
704 &memory_mapping,
705 &[],
706 AccessType::Load,
707 MM_RODATA_START - 1,
708 42,
709 true,
710 )
711 .unwrap();
712 assert_matches!(
713 src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
714 EbpfError::AccessViolation(AccessType::Load, addr, 42, "unknown") if *addr == MM_RODATA_START - 1
715 );
716
717 let mut src_chunk_iter = MemoryChunkIterator::new(
720 &memory_mapping,
721 &[],
722 AccessType::Load,
723 MM_RODATA_START,
724 43,
725 true,
726 )
727 .unwrap();
728 assert!(src_chunk_iter.next().unwrap().is_ok());
729 assert_matches!(
730 src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
731 EbpfError::AccessViolation(AccessType::Load, addr, 43, "program") if *addr == MM_RODATA_START
732 );
733
734 let mut src_chunk_iter = MemoryChunkIterator::new(
736 &memory_mapping,
737 &[],
738 AccessType::Load,
739 MM_RODATA_START,
740 43,
741 true,
742 )
743 .unwrap()
744 .rev();
745 assert_matches!(
746 src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
747 EbpfError::AccessViolation(AccessType::Load, addr, 43, "program") if *addr == MM_RODATA_START
748 );
749
750 let mut src_chunk_iter = MemoryChunkIterator::new(
752 &memory_mapping,
753 &[],
754 AccessType::Load,
755 MM_RODATA_START - 1,
756 43,
757 true,
758 )
759 .unwrap()
760 .rev();
761 assert!(src_chunk_iter.next().unwrap().is_ok());
762 assert_matches!(
763 src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
764 EbpfError::AccessViolation(AccessType::Load, addr, 43, "unknown") if *addr == MM_RODATA_START - 1
765 );
766 }
767
768 #[test]
769 fn test_memory_chunk_iterator_one() {
770 let config = Config {
771 aligned_memory_mapping: false,
772 ..Config::default()
773 };
774 let mem1 = vec![0xFF; 42];
775 let memory_mapping = MemoryMapping::new(
776 vec![MemoryRegion::new_readonly(&mem1, MM_RODATA_START)],
777 &config,
778 SBPFVersion::V3,
779 )
780 .unwrap();
781
782 let mut src_chunk_iter = MemoryChunkIterator::new(
784 &memory_mapping,
785 &[],
786 AccessType::Load,
787 MM_RODATA_START - 1,
788 1,
789 true,
790 )
791 .unwrap();
792 assert!(src_chunk_iter.next().unwrap().is_err());
793
794 let mut src_chunk_iter = MemoryChunkIterator::new(
796 &memory_mapping,
797 &[],
798 AccessType::Load,
799 MM_RODATA_START + 42,
800 1,
801 true,
802 )
803 .unwrap();
804 assert!(src_chunk_iter.next().unwrap().is_err());
805
806 for (vm_addr, len) in [
807 (MM_RODATA_START, 0),
808 (MM_RODATA_START + 42, 0),
809 (MM_RODATA_START, 1),
810 (MM_RODATA_START, 42),
811 (MM_RODATA_START + 41, 1),
812 ] {
813 for rev in [true, false] {
814 let iter = MemoryChunkIterator::new(
815 &memory_mapping,
816 &[],
817 AccessType::Load,
818 vm_addr,
819 len,
820 true,
821 )
822 .unwrap();
823 let res = if rev {
824 to_chunk_vec(iter.rev())
825 } else {
826 to_chunk_vec(iter)
827 };
828 if len == 0 {
829 assert_eq!(res, &[]);
830 } else {
831 assert_eq!(res, &[(vm_addr, len as usize)]);
832 }
833 }
834 }
835 }
836
837 #[test]
838 fn test_memory_chunk_iterator_two() {
839 let config = Config {
840 aligned_memory_mapping: false,
841 ..Config::default()
842 };
843 let mem1 = vec![0x11; 8];
844 let mem2 = vec![0x22; 4];
845 let memory_mapping = MemoryMapping::new(
846 vec![
847 MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
848 MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
849 ],
850 &config,
851 SBPFVersion::V3,
852 )
853 .unwrap();
854
855 for (vm_addr, len, mut expected) in [
856 (MM_RODATA_START, 8, vec![(MM_RODATA_START, 8)]),
857 (
858 MM_RODATA_START + 7,
859 2,
860 vec![(MM_RODATA_START + 7, 1), (MM_RODATA_START + 8, 1)],
861 ),
862 (MM_RODATA_START + 8, 4, vec![(MM_RODATA_START + 8, 4)]),
863 ] {
864 for rev in [false, true] {
865 let iter = MemoryChunkIterator::new(
866 &memory_mapping,
867 &[],
868 AccessType::Load,
869 vm_addr,
870 len,
871 true,
872 )
873 .unwrap();
874 let res = if rev {
875 expected.reverse();
876 to_chunk_vec(iter.rev())
877 } else {
878 to_chunk_vec(iter)
879 };
880
881 assert_eq!(res, expected);
882 }
883 }
884 }
885
886 #[test]
887 fn test_iter_memory_pair_chunks_short() {
888 let config = Config {
889 aligned_memory_mapping: false,
890 ..Config::default()
891 };
892 let mem1 = vec![0x11; 8];
893 let mem2 = vec![0x22; 4];
894 let memory_mapping = MemoryMapping::new(
895 vec![
896 MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
897 MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
898 ],
899 &config,
900 SBPFVersion::V3,
901 )
902 .unwrap();
903
904 assert_matches!(
906 iter_memory_pair_chunks(
907 AccessType::Load,
908 MM_RODATA_START,
909 AccessType::Load,
910 MM_RODATA_START + 8,
911 8,
912 &[],
913 &memory_mapping,
914 false,
915 true,
916 |_src, _dst, _len| Ok::<_, Error>(0),
917 ).unwrap_err().downcast_ref().unwrap(),
918 EbpfError::AccessViolation(AccessType::Load, addr, 8, "program") if *addr == MM_RODATA_START + 8
919 );
920
921 assert_matches!(
923 iter_memory_pair_chunks(
924 AccessType::Load,
925 MM_RODATA_START + 10,
926 AccessType::Load,
927 MM_RODATA_START + 2,
928 3,
929 &[],
930 &memory_mapping,
931 false,
932 true,
933 |_src, _dst, _len| Ok::<_, Error>(0),
934 ).unwrap_err().downcast_ref().unwrap(),
935 EbpfError::AccessViolation(AccessType::Load, addr, 3, "program") if *addr == MM_RODATA_START + 10
936 );
937 }
938
939 #[test]
940 #[should_panic(expected = "AccessViolation(Store, 4294967296, 4")]
941 fn test_memmove_non_contiguous_readonly() {
942 let config = Config {
943 aligned_memory_mapping: false,
944 ..Config::default()
945 };
946 let mem1 = vec![0x11; 8];
947 let mem2 = vec![0x22; 4];
948 let memory_mapping = MemoryMapping::new(
949 vec![
950 MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
951 MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
952 ],
953 &config,
954 SBPFVersion::V3,
955 )
956 .unwrap();
957
958 memmove_non_contiguous(
959 MM_RODATA_START,
960 MM_RODATA_START + 8,
961 4,
962 &[],
963 &memory_mapping,
964 true,
965 )
966 .unwrap();
967 }
968
969 #[test_case(&[], (0, 0, 0); "no regions")]
970 #[test_case(&[10], (1, 10, 0); "single region 0 len")]
971 #[test_case(&[10], (0, 5, 5); "single region no overlap")]
972 #[test_case(&[10], (0, 0, 10) ; "single region complete overlap")]
973 #[test_case(&[10], (2, 0, 5); "single region partial overlap start")]
974 #[test_case(&[10], (0, 1, 6); "single region partial overlap middle")]
975 #[test_case(&[10], (2, 5, 5); "single region partial overlap end")]
976 #[test_case(&[3, 5], (0, 5, 2) ; "two regions no overlap, single source region")]
977 #[test_case(&[4, 7], (0, 5, 5) ; "two regions no overlap, multiple source regions")]
978 #[test_case(&[3, 8], (0, 0, 11) ; "two regions complete overlap")]
979 #[test_case(&[2, 9], (3, 0, 5) ; "two regions partial overlap start")]
980 #[test_case(&[3, 9], (1, 2, 5) ; "two regions partial overlap middle")]
981 #[test_case(&[7, 3], (2, 6, 4) ; "two regions partial overlap end")]
982 #[test_case(&[2, 6, 3, 4], (0, 10, 2) ; "many regions no overlap, single source region")]
983 #[test_case(&[2, 1, 2, 5, 6], (2, 10, 4) ; "many regions no overlap, multiple source regions")]
984 #[test_case(&[8, 1, 3, 6], (0, 0, 18) ; "many regions complete overlap")]
985 #[test_case(&[7, 3, 1, 4, 5], (5, 0, 8) ; "many regions overlap start")]
986 #[test_case(&[1, 5, 2, 9, 3], (5, 4, 8) ; "many regions overlap middle")]
987 #[test_case(&[3, 9, 1, 1, 2, 1], (2, 9, 8) ; "many regions overlap end")]
988 fn test_memmove_non_contiguous(
989 regions: &[usize],
990 (src_offset, dst_offset, len): (usize, usize, usize),
991 ) {
992 let config = Config {
993 aligned_memory_mapping: false,
994 ..Config::default()
995 };
996 let (mem, memory_mapping) = build_memory_mapping(regions, &config);
997
998 let mut expected_memory = flatten_memory(&mem);
1000 unsafe {
1001 std::ptr::copy(
1002 expected_memory.as_ptr().add(src_offset),
1003 expected_memory.as_mut_ptr().add(dst_offset),
1004 len,
1005 )
1006 };
1007
1008 memmove_non_contiguous(
1010 MM_RODATA_START + dst_offset as u64,
1011 MM_RODATA_START + src_offset as u64,
1012 len as u64,
1013 &[],
1014 &memory_mapping,
1015 true,
1016 )
1017 .unwrap();
1018
1019 let memory = flatten_memory(&mem);
1021
1022 assert_eq!(expected_memory, memory);
1024 }
1025
1026 #[test]
1027 #[should_panic(expected = "AccessViolation(Store, 4294967296, 9")]
1028 fn test_memset_non_contiguous_readonly() {
1029 let config = Config {
1030 aligned_memory_mapping: false,
1031 ..Config::default()
1032 };
1033 let mut mem1 = vec![0x11; 8];
1034 let mem2 = vec![0x22; 4];
1035 let memory_mapping = MemoryMapping::new(
1036 vec![
1037 MemoryRegion::new_writable(&mut mem1, MM_RODATA_START),
1038 MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
1039 ],
1040 &config,
1041 SBPFVersion::V3,
1042 )
1043 .unwrap();
1044
1045 assert_eq!(
1046 memset_non_contiguous(MM_RODATA_START, 0x33, 9, &[], &memory_mapping, true).unwrap(),
1047 0
1048 );
1049 }
1050
1051 #[test]
1052 fn test_memset_non_contiguous() {
1053 let config = Config {
1054 aligned_memory_mapping: false,
1055 ..Config::default()
1056 };
1057 let mem1 = vec![0x11; 1];
1058 let mut mem2 = vec![0x22; 2];
1059 let mut mem3 = vec![0x33; 3];
1060 let mut mem4 = vec![0x44; 4];
1061 let memory_mapping = MemoryMapping::new(
1062 vec![
1063 MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
1064 MemoryRegion::new_writable(&mut mem2, MM_RODATA_START + 1),
1065 MemoryRegion::new_writable(&mut mem3, MM_RODATA_START + 3),
1066 MemoryRegion::new_writable(&mut mem4, MM_RODATA_START + 6),
1067 ],
1068 &config,
1069 SBPFVersion::V3,
1070 )
1071 .unwrap();
1072
1073 assert_eq!(
1074 memset_non_contiguous(MM_RODATA_START + 1, 0x55, 7, &[], &memory_mapping, true)
1075 .unwrap(),
1076 0
1077 );
1078 assert_eq!(&mem1, &[0x11]);
1079 assert_eq!(&mem2, &[0x55, 0x55]);
1080 assert_eq!(&mem3, &[0x55, 0x55, 0x55]);
1081 assert_eq!(&mem4, &[0x55, 0x55, 0x44, 0x44]);
1082 }
1083
1084 #[test]
1085 fn test_memcmp_non_contiguous() {
1086 let config = Config {
1087 aligned_memory_mapping: false,
1088 ..Config::default()
1089 };
1090 let mem1 = b"foo".to_vec();
1091 let mem2 = b"barbad".to_vec();
1092 let mem3 = b"foobarbad".to_vec();
1093 let memory_mapping = MemoryMapping::new(
1094 vec![
1095 MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
1096 MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 3),
1097 MemoryRegion::new_readonly(&mem3, MM_RODATA_START + 9),
1098 ],
1099 &config,
1100 SBPFVersion::V3,
1101 )
1102 .unwrap();
1103
1104 assert_eq!(
1106 memcmp_non_contiguous(
1107 MM_RODATA_START,
1108 MM_RODATA_START + 9,
1109 9,
1110 &[],
1111 &memory_mapping,
1112 true
1113 )
1114 .unwrap(),
1115 0
1116 );
1117
1118 assert_eq!(
1120 memcmp_non_contiguous(
1121 MM_RODATA_START + 10,
1122 MM_RODATA_START + 1,
1123 8,
1124 &[],
1125 &memory_mapping,
1126 true
1127 )
1128 .unwrap(),
1129 0
1130 );
1131
1132 assert_eq!(
1134 memcmp_non_contiguous(
1135 MM_RODATA_START + 1,
1136 MM_RODATA_START + 11,
1137 5,
1138 &[],
1139 &memory_mapping,
1140 true
1141 )
1142 .unwrap(),
1143 unsafe { memcmp(b"oobar", b"obarb", 5) }
1144 );
1145 }
1146
1147 fn build_memory_mapping<'a>(
1148 regions: &[usize],
1149 config: &'a Config,
1150 ) -> (Vec<Vec<u8>>, MemoryMapping<'a>) {
1151 let mut regs = vec![];
1152 let mut mem = Vec::new();
1153 let mut offset = 0;
1154 for (i, region_len) in regions.iter().enumerate() {
1155 mem.push(
1156 (0..*region_len)
1157 .map(|x| (i * 10 + x) as u8)
1158 .collect::<Vec<_>>(),
1159 );
1160 regs.push(MemoryRegion::new_writable(
1161 &mut mem[i],
1162 MM_RODATA_START + offset as u64,
1163 ));
1164 offset += *region_len;
1165 }
1166
1167 let memory_mapping = MemoryMapping::new(regs, config, SBPFVersion::V3).unwrap();
1168
1169 (mem, memory_mapping)
1170 }
1171
1172 fn flatten_memory(mem: &[Vec<u8>]) -> Vec<u8> {
1173 mem.iter().flatten().copied().collect()
1174 }
1175}