agave_syscalls/
mem_ops.rs

1use {super::*, crate::translate_mut};
2
3fn mem_op_consume(invoke_context: &mut InvokeContext, n: u64) -> Result<(), Error> {
4    let compute_cost = invoke_context.get_execution_cost();
5    let cost = compute_cost.mem_op_base_cost.max(
6        n.checked_div(compute_cost.cpi_bytes_per_unit)
7            .unwrap_or(u64::MAX),
8    );
9    consume_compute_meter(invoke_context, cost)
10}
11
12/// Check that two regions do not overlap.
13pub(crate) fn is_nonoverlapping<N>(src: N, src_len: N, dst: N, dst_len: N) -> bool
14where
15    N: Ord + num_traits::SaturatingSub,
16{
17    // If the absolute distance between the ptrs is at least as big as the size of the other,
18    // they do not overlap.
19    if src > dst {
20        src.saturating_sub(&dst) >= dst_len
21    } else {
22        dst.saturating_sub(&src) >= src_len
23    }
24}
25
26declare_builtin_function!(
27    /// memcpy
28    SyscallMemcpy,
29    fn rust(
30        invoke_context: &mut InvokeContext,
31        dst_addr: u64,
32        src_addr: u64,
33        n: u64,
34        _arg4: u64,
35        _arg5: u64,
36        memory_mapping: &mut MemoryMapping,
37    ) -> Result<u64, Error> {
38        mem_op_consume(invoke_context, n)?;
39
40        if !is_nonoverlapping(src_addr, n, dst_addr, n) {
41            return Err(SyscallError::CopyOverlapping.into());
42        }
43
44        // host addresses can overlap so we always invoke memmove
45        memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
46    }
47);
48
49declare_builtin_function!(
50    /// memmove
51    SyscallMemmove,
52    fn rust(
53        invoke_context: &mut InvokeContext,
54        dst_addr: u64,
55        src_addr: u64,
56        n: u64,
57        _arg4: u64,
58        _arg5: u64,
59        memory_mapping: &mut MemoryMapping,
60    ) -> Result<u64, Error> {
61        mem_op_consume(invoke_context, n)?;
62
63        memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
64    }
65);
66
67declare_builtin_function!(
68    /// memcmp
69    SyscallMemcmp,
70    fn rust(
71        invoke_context: &mut InvokeContext,
72        s1_addr: u64,
73        s2_addr: u64,
74        n: u64,
75        cmp_result_addr: u64,
76        _arg5: u64,
77        memory_mapping: &mut MemoryMapping,
78    ) -> Result<u64, Error> {
79        mem_op_consume(invoke_context, n)?;
80
81        let s1 = translate_slice::<u8>(
82            memory_mapping,
83            s1_addr,
84            n,
85            invoke_context.get_check_aligned(),
86        )?;
87        let s2 = translate_slice::<u8>(
88            memory_mapping,
89            s2_addr,
90            n,
91            invoke_context.get_check_aligned(),
92        )?;
93
94        debug_assert_eq!(s1.len(), n as usize);
95        debug_assert_eq!(s2.len(), n as usize);
96        // Safety:
97        // memcmp is marked unsafe since it assumes that the inputs are at least
98        // `n` bytes long. `s1` and `s2` are guaranteed to be exactly `n` bytes
99        // long because `translate_slice` would have failed otherwise.
100        let result = unsafe { memcmp(s1, s2, n as usize) };
101
102        translate_mut!(
103            memory_mapping,
104            invoke_context.get_check_aligned(),
105            let cmp_result_ref_mut: &mut i32 = map(cmp_result_addr)?;
106        );
107        *cmp_result_ref_mut = result;
108
109        Ok(0)
110    }
111);
112
113declare_builtin_function!(
114    /// memset
115    SyscallMemset,
116    fn rust(
117        invoke_context: &mut InvokeContext,
118        dst_addr: u64,
119        c: u64,
120        n: u64,
121        _arg4: u64,
122        _arg5: u64,
123        memory_mapping: &mut MemoryMapping,
124    ) -> Result<u64, Error> {
125        mem_op_consume(invoke_context, n)?;
126
127        translate_mut!(
128            memory_mapping,
129            invoke_context.get_check_aligned(),
130            let s: &mut [u8] = map(dst_addr, n)?;
131        );
132        s.fill(c as u8);
133        Ok(0)
134    }
135);
136
137fn memmove(
138    invoke_context: &mut InvokeContext,
139    dst_addr: u64,
140    src_addr: u64,
141    n: u64,
142    memory_mapping: &mut MemoryMapping,
143) -> Result<u64, Error> {
144    translate_mut!(
145        memory_mapping,
146        invoke_context.get_check_aligned(),
147        let dst_ref_mut: &mut [u8] = map(dst_addr, n)?;
148    );
149    let dst_ptr = dst_ref_mut.as_mut_ptr();
150    let src_ptr = translate_slice::<u8>(
151        memory_mapping,
152        src_addr,
153        n,
154        invoke_context.get_check_aligned(),
155    )?
156    .as_ptr();
157
158    unsafe { std::ptr::copy(src_ptr, dst_ptr, n as usize) };
159    Ok(0)
160}
161
162// Marked unsafe since it assumes that the slices are at least `n` bytes long.
163unsafe fn memcmp(s1: &[u8], s2: &[u8], n: usize) -> i32 {
164    for i in 0..n {
165        let a = *s1.get_unchecked(i);
166        let b = *s2.get_unchecked(i);
167        if a != b {
168            return (a as i32).saturating_sub(b as i32);
169        };
170    }
171
172    0
173}
174
175#[cfg(test)]
176#[allow(clippy::indexing_slicing)]
177#[allow(clippy::arithmetic_side_effects)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn test_is_nonoverlapping() {
183        for dst in 0..8 {
184            assert!(is_nonoverlapping(10, 3, dst, 3));
185        }
186        for dst in 8..13 {
187            assert!(!is_nonoverlapping(10, 3, dst, 3));
188        }
189        for dst in 13..20 {
190            assert!(is_nonoverlapping(10, 3, dst, 3));
191        }
192        assert!(is_nonoverlapping::<u8>(255, 3, 254, 1));
193        assert!(!is_nonoverlapping::<u8>(255, 2, 254, 3));
194    }
195}