1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
//! This module defines an isomorphic wrapper around raw stable memory API from `ic_cdk` crate.
//!
//! When compiled to wasm, each function simply inlines into a call to the same function of raw API.
//! For example, [MemContext::size_pages()] on wasm simply transforms into [ic_cdk::api::stable::stable64_size()].
//!
//! But when compiled to something else, a stable memory emulation is enabled, which allows all APIs
//! continue to work even when running inside a `cargo test`, allocating stable memory on heap. This
//! emulation is pretty accurate in terms of performance. If your algorithm is 4 times slower in stable
//! memory tests, than in heap, than it is pretty likely that it will be 4 times expensive inside a read
//! canister's stable memory, than in its heap.
//!
//! This makes it possible to write full-scale tests which use stable memory as their main memory.

use std::cmp::min;

/// Each wasm memory page is 64K in size
pub const PAGE_SIZE_BYTES: u64 = 64 * 1024;

/// Indicates that the canister is out of stable memory at this moment.
#[derive(Debug, Copy, Clone)]
pub struct OutOfMemory;

pub(crate) trait MemContext {
    fn size_pages(&self) -> u64;
    fn grow(&mut self, new_pages: u64) -> Result<u64, OutOfMemory>;
    fn read(&self, offset: u64, buf: &mut [u8]);
    fn write(&mut self, offset: u64, buf: &[u8]);
}

#[derive(Clone)]
pub(crate) struct StableMemContext;

#[cfg(target_family = "wasm")]
use ic_cdk::api::stable::{stable64_grow, stable64_read, stable64_size, stable64_write};

#[cfg(target_family = "wasm")]
impl MemContext for StableMemContext {
    #[inline]
    fn size_pages(&self) -> u64 {
        stable64_size()
    }

    #[inline]
    fn grow(&mut self, new_pages: u64) -> Result<u64, OutOfMemory> {
        stable64_grow(new_pages).map_err(|_| OutOfMemory)
    }

    #[inline]
    fn read(&self, offset: u64, buf: &mut [u8]) {
        stable64_read(offset, buf)
    }

    #[inline]
    fn write(&mut self, offset: u64, buf: &[u8]) {
        stable64_write(offset, buf)
    }
}

#[derive(Clone)]
pub(crate) struct TestMemContext {
    pub pages: Vec<[u8; PAGE_SIZE_BYTES as usize]>,
}

impl TestMemContext {
    const fn default() -> Self {
        Self { pages: Vec::new() }
    }
}

impl MemContext for TestMemContext {
    #[inline]
    fn size_pages(&self) -> u64 {
        self.pages.len() as u64
    }

    fn grow(&mut self, new_pages: u64) -> Result<u64, OutOfMemory> {
        let prev_pages = self.size_pages();

        for _ in 0..new_pages {
            self.pages.push([0u8; PAGE_SIZE_BYTES as usize]);
        }

        Ok(prev_pages)
    }

    fn read(&self, offset: u64, buf: &mut [u8]) {
        let start_page_idx = (offset / PAGE_SIZE_BYTES) as usize;
        let start_page_inner_idx = (offset % PAGE_SIZE_BYTES) as usize;
        let start_page_size = min(PAGE_SIZE_BYTES as usize - start_page_inner_idx, buf.len());

        let (pages_in_between, last_page_size) = if start_page_size == buf.len() {
            (0usize, 0usize)
        } else {
            (
                (buf.len() - start_page_size) / PAGE_SIZE_BYTES as usize,
                (buf.len() - start_page_size) % PAGE_SIZE_BYTES as usize,
            )
        };

        // read first page
        buf[0..start_page_size].copy_from_slice(
            &self.pages[start_page_idx]
                [start_page_inner_idx..(start_page_inner_idx + start_page_size)],
        );

        // read pages in-between
        for i in 0..pages_in_between {
            buf[(start_page_size + i * PAGE_SIZE_BYTES as usize)
                ..(start_page_size + (i + 1) * PAGE_SIZE_BYTES as usize)]
                .copy_from_slice(&self.pages[start_page_idx + i + 1]);
        }

        // read last pages
        if last_page_size == 0 {
            return;
        }

        buf[(start_page_size + pages_in_between * PAGE_SIZE_BYTES as usize)
            ..(start_page_size + pages_in_between * PAGE_SIZE_BYTES as usize + last_page_size)]
            .copy_from_slice(&self.pages[start_page_idx + pages_in_between + 1][0..last_page_size]);
    }

    fn write(&mut self, offset: u64, buf: &[u8]) {
        let start_page_idx = (offset / PAGE_SIZE_BYTES) as usize;
        let start_page_inner_idx = (offset % PAGE_SIZE_BYTES) as usize;
        let start_page_size = min(PAGE_SIZE_BYTES as usize - start_page_inner_idx, buf.len());

        let (pages_in_between, last_page_size) = if start_page_size == buf.len() {
            (0usize, 0usize)
        } else {
            (
                (buf.len() - start_page_size) / PAGE_SIZE_BYTES as usize,
                (buf.len() - start_page_size) % PAGE_SIZE_BYTES as usize,
            )
        };

        // write to first page
        self.pages[start_page_idx][start_page_inner_idx..(start_page_inner_idx + start_page_size)]
            .copy_from_slice(&buf[0..start_page_size]);

        // write to pages in-between
        for i in 0..pages_in_between {
            self.pages[start_page_idx + i + 1].copy_from_slice(
                &buf[(start_page_size + i * PAGE_SIZE_BYTES as usize)
                    ..(start_page_size + (i + 1) * PAGE_SIZE_BYTES as usize)],
            );
        }

        // write to last page
        if last_page_size == 0 {
            return;
        }

        self.pages[start_page_idx + pages_in_between + 1][0..last_page_size].copy_from_slice(
            &buf[(start_page_size + pages_in_between * PAGE_SIZE_BYTES as usize)
                ..(start_page_size + pages_in_between * PAGE_SIZE_BYTES as usize + last_page_size)],
        );
    }
}

#[cfg(target_family = "wasm")]
pub mod stable {
    use crate::utils::mem_context::{MemContext, OutOfMemory, StableMemContext};

    #[inline]
    pub fn size_pages() -> u64 {
        MemContext::size_pages(&StableMemContext)
    }

    #[inline]
    pub fn grow(new_pages: u64) -> Result<u64, OutOfMemory> {
        MemContext::grow(&mut StableMemContext, new_pages)
    }

    #[inline]
    pub fn read(offset: u64, buf: &mut [u8]) {
        MemContext::read(&StableMemContext, offset, buf)
    }

    #[inline]
    pub fn write(offset: u64, buf: &[u8]) {
        MemContext::write(&mut StableMemContext, offset, buf)
    }
}

#[cfg(not(target_family = "wasm"))]
pub mod stable {
    use crate::utils::mem_context::{MemContext, OutOfMemory, TestMemContext};
    use std::cell::RefCell;

    thread_local! {
        static CONTEXT: RefCell<TestMemContext> = RefCell::new(TestMemContext::default());
    }

    #[inline]
    pub fn clear() {
        CONTEXT.with(|it| it.borrow_mut().pages.clear())
    }

    #[inline]
    pub fn size_pages() -> u64 {
        CONTEXT.with(|it| it.borrow().size_pages())
    }

    #[inline]
    pub fn grow(new_pages: u64) -> Result<u64, OutOfMemory> {
        CONTEXT.with(|it| it.borrow_mut().grow(new_pages))
    }

    #[inline]
    pub fn read(offset: u64, buf: &mut [u8]) {
        CONTEXT.with(|it| it.borrow().read(offset, buf))
    }

    #[inline]
    pub fn write(offset: u64, buf: &[u8]) {
        CONTEXT.with(|it| it.borrow_mut().write(offset, buf))
    }
}

#[cfg(test)]
mod tests {
    use crate::{stable, PAGE_SIZE_BYTES};
    use rand::seq::SliceRandom;
    use rand::{thread_rng, Rng};

    #[test]
    fn random_works_fine() {
        for _ in 0..100 {
            stable::clear();
            stable::grow(1000).unwrap();

            let mut rng = thread_rng();
            let iterations = 500usize;
            let size_range = (0..(u16::MAX as usize * 2));

            let mut sizes = Vec::new();
            let mut cur_ptr = 0;
            for i in 0..iterations {
                let size = rng.gen_range(size_range.clone());
                let buf = vec![(i % 256) as u8; size];

                stable::write(cur_ptr, &buf);

                sizes.push(size);
                cur_ptr += size as u64;

                let mut c_ptr = 0u64;
                for j in 0..i {
                    let size = sizes[j];
                    let mut buf = vec![0u8; size];

                    stable::read(c_ptr, &mut buf);

                    assert_eq!(buf, vec![(j % 256) as u8; size]);

                    c_ptr += size as u64;
                }
            }
        }
    }

    #[test]
    fn big_reads_writes_work_fine() {
        stable::clear();
        stable::grow(10).unwrap();

        let buf = [10u8; PAGE_SIZE_BYTES as usize * 10];
        stable::write(0, &buf);

        let mut buf1 = [0u8; PAGE_SIZE_BYTES as usize * 10 - 50];
        stable::read(25, &mut buf1);

        assert_eq!(buf[25..PAGE_SIZE_BYTES as usize * 10 - 25], buf1);
    }
}