ic_stable_memory/primitive/
s_box.rs1use crate::encoding::{AsDynSizeBytes, AsFixedSizeBytes};
2use crate::mem::s_slice::SSlice;
3use crate::primitive::StableType;
4use crate::utils::certification::{AsHashTree, AsHashableBytes, HashTree};
5use crate::{allocate, deallocate, reallocate, OutOfMemory};
6use candid::types::{Serializer, Type, TypeId};
7use candid::CandidType;
8use serde::{Deserialize, Deserializer};
9use std::borrow::Borrow;
10use std::cell::UnsafeCell;
11use std::cmp::Ordering;
12use std::fmt::{Debug, Formatter};
13use std::hash::{Hash, Hasher};
14use std::mem::ManuallyDrop;
15use std::ops::Deref;
16
17pub struct SBox<T: AsDynSizeBytes + StableType> {
50 slice: Option<SSlice>,
51 inner: UnsafeCell<Option<T>>,
52 stable_drop_flag: bool,
53}
54
55impl<T: AsDynSizeBytes + StableType> SBox<T> {
56 #[inline]
60 pub fn new(mut it: T) -> Result<Self, T> {
61 let buf = it.as_dyn_size_bytes();
62 if let Ok(slice) = unsafe { allocate(buf.len() as u64) } {
63 unsafe {
64 crate::mem::write_bytes(slice.offset(0), &buf);
65 it.stable_drop_flag_off();
66 }
67
68 Ok(Self {
69 slice: Some(slice),
70 inner: UnsafeCell::new(Some(it)),
71 stable_drop_flag: true,
72 })
73 } else {
74 Err(it)
75 }
76 }
77
78 #[inline]
82 pub fn as_ptr(&self) -> u64 {
83 self.slice.unwrap().as_ptr()
84 }
85
86 #[inline]
88 pub fn into_inner(mut self) -> T {
89 unsafe {
90 self.lazy_read(true);
91 };
92
93 let res = self.inner.get_mut().take().unwrap();
94
95 unsafe {
96 self.stable_drop();
97 self.stable_drop_flag_off();
98 }
99
100 res
101 }
102
103 pub unsafe fn from_ptr(ptr: u64) -> Self {
126 let slice = SSlice::from_ptr(ptr).unwrap();
127
128 Self {
129 stable_drop_flag: false,
130 slice: Some(slice),
131 inner: UnsafeCell::default(),
132 }
133 }
134
135 #[inline]
160 pub fn with<R, F: FnOnce(&mut T) -> R>(&mut self, func: F) -> Result<R, OutOfMemory> {
161 unsafe {
162 self.lazy_read(true);
163
164 let it = self.inner.get_mut().as_mut().unwrap();
165 let res = func(it);
166
167 self.repersist().map(|_| res)
168 }
169 }
170
171 unsafe fn lazy_read(&self, drop_flag: bool) {
172 if let Some(it) = (*self.inner.get()).as_mut() {
173 if drop_flag {
174 it.stable_drop_flag_on();
175 } else {
176 it.stable_drop_flag_off();
177 }
178
179 return;
180 }
181
182 let slice = self.slice.as_ref().unwrap();
183 let mut buf = vec![0u8; slice.get_size_bytes() as usize];
184 unsafe { crate::mem::read_bytes(slice.offset(0), &mut buf) };
185
186 let mut inner = T::from_dyn_size_bytes(&buf);
187 if drop_flag {
188 inner.stable_drop_flag_on();
189 } else {
190 inner.stable_drop_flag_off();
191 }
192
193 *self.inner.get() = Some(inner);
194 }
195
196 fn repersist(&mut self) -> Result<(), OutOfMemory> {
197 let mut slice = self.slice.take().unwrap();
198 let buf = self.inner.get_mut().as_ref().unwrap().as_dyn_size_bytes();
199
200 unsafe { self.inner.get_mut().stable_drop_flag_off() };
201
202 if slice.get_size_bytes() < buf.len() as u64 {
203 match unsafe { reallocate(slice, buf.len() as u64) } {
205 Ok(s) => {
206 slice = s;
207 }
208 Err(e) => {
209 self.slice = Some(slice);
210 return Err(e);
211 }
212 }
213 }
214
215 unsafe { crate::mem::write_bytes(slice.offset(0), &buf) };
216 self.slice = Some(slice);
217
218 Ok(())
219 }
220}
221
222impl<T: AsDynSizeBytes + StableType> AsFixedSizeBytes for SBox<T> {
223 const SIZE: usize = u64::SIZE;
224 type Buf = [u8; u64::SIZE];
225
226 #[inline]
227 fn as_fixed_size_bytes(&self, buf: &mut [u8]) {
228 self.as_ptr().as_fixed_size_bytes(buf)
229 }
230
231 #[inline]
232 fn from_fixed_size_bytes(arr: &[u8]) -> Self {
233 let ptr = u64::from_fixed_size_bytes(arr);
234
235 unsafe { Self::from_ptr(ptr) }
236 }
237}
238
239impl<T: AsDynSizeBytes + StableType> StableType for SBox<T> {
240 #[inline]
241 fn should_stable_drop(&self) -> bool {
242 self.stable_drop_flag
243 }
244
245 #[inline]
246 unsafe fn stable_drop_flag_off(&mut self) {
247 self.stable_drop_flag = false;
248 }
249
250 #[inline]
251 unsafe fn stable_drop_flag_on(&mut self) {
252 self.stable_drop_flag = true;
253 }
254
255 #[inline]
256 unsafe fn stable_drop(&mut self) {
257 deallocate(self.slice.take().unwrap());
258 }
259}
260
261impl<T: AsDynSizeBytes + StableType> Drop for SBox<T> {
262 fn drop(&mut self) {
263 unsafe {
264 if self.should_stable_drop() {
265 self.lazy_read(true);
266 self.stable_drop();
267 }
268 }
269 }
270}
271
272impl<T: AsHashableBytes + AsDynSizeBytes + StableType> AsHashableBytes for SBox<T> {
273 #[inline]
274 fn as_hashable_bytes(&self) -> Vec<u8> {
275 unsafe {
276 self.lazy_read(false);
277
278 (*self.inner.get()).as_ref().unwrap().as_hashable_bytes()
279 }
280 }
281}
282
283impl<T: AsHashTree + AsDynSizeBytes + StableType> AsHashTree for SBox<T> {
284 #[inline]
285 fn root_hash(&self) -> crate::utils::certification::Hash {
286 unsafe {
287 self.lazy_read(false);
288
289 (*self.inner.get()).as_ref().unwrap().root_hash()
290 }
291 }
292
293 #[inline]
294 fn hash_tree(&self) -> HashTree {
295 unsafe {
296 self.lazy_read(false);
297
298 (*self.inner.get()).as_ref().unwrap().hash_tree()
299 }
300 }
301}
302
303impl<T: CandidType + AsDynSizeBytes + StableType> CandidType for SBox<T> {
304 #[inline]
305 fn _ty() -> Type {
306 T::_ty()
307 }
308
309 #[inline]
310 fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
311 where
312 S: Serializer,
313 {
314 unsafe {
315 self.lazy_read(false);
316 (*self.inner.get())
317 .as_ref()
318 .unwrap()
319 .idl_serialize(serializer)
320 }
321 }
322}
323
324impl<T: PartialEq + AsDynSizeBytes + StableType> PartialEq for SBox<T> {
325 #[inline]
326 fn eq(&self, other: &Self) -> bool {
327 unsafe {
328 self.lazy_read(false);
329 other.lazy_read(false);
330
331 (*self.inner.get()).eq(&(*other.inner.get()))
332 }
333 }
334}
335
336impl<T: PartialOrd + AsDynSizeBytes + StableType> PartialOrd for SBox<T> {
337 #[inline]
338 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
339 unsafe {
340 self.lazy_read(false);
341 other.lazy_read(false);
342
343 (*self.inner.get()).partial_cmp(&(*other.inner.get()))
344 }
345 }
346}
347
348impl<T: Eq + PartialEq + AsDynSizeBytes + StableType> Eq for SBox<T> {}
349
350impl<T: Ord + PartialOrd + AsDynSizeBytes + StableType> Ord for SBox<T> {
351 #[inline]
352 fn cmp(&self, other: &Self) -> Ordering {
353 unsafe {
354 self.lazy_read(false);
355 other.lazy_read(false);
356
357 (*self.inner.get()).cmp(&(*other.inner.get()))
358 }
359 }
360}
361
362impl<T: Hash + AsDynSizeBytes + StableType> Hash for SBox<T> {
363 #[inline]
364 fn hash<H: Hasher>(&self, state: &mut H) {
365 unsafe {
366 self.lazy_read(false);
367
368 (*self.inner.get()).as_ref().unwrap().hash(state);
369 }
370 }
371}
372
373impl<T: Debug + AsDynSizeBytes + StableType> Debug for SBox<T> {
374 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
375 f.write_str("SBox(")?;
376
377 unsafe {
378 self.lazy_read(false);
379
380 (*self.inner.get()).as_ref().unwrap().fmt(f)?;
381 }
382
383 f.write_str(")")
384 }
385}
386
387impl<T: AsDynSizeBytes + StableType> Borrow<T> for SBox<T> {
388 #[inline]
389 fn borrow(&self) -> &T {
390 unsafe {
391 self.lazy_read(false);
392
393 (*self.inner.get()).as_ref().unwrap()
394 }
395 }
396}
397
398impl<T: AsDynSizeBytes + StableType> Deref for SBox<T> {
399 type Target = T;
400
401 #[inline]
402 fn deref(&self) -> &Self::Target {
403 unsafe {
404 self.lazy_read(false);
405
406 (*self.inner.get()).as_ref().unwrap()
407 }
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use crate::collections::SVec;
414 use crate::primitive::s_box::SBox;
415 use crate::{
416 _debug_validate_allocator, get_allocated_size, retrieve_custom_data, stable,
417 stable_memory_init, store_custom_data,
418 };
419 use candid::encode_one;
420 use std::cmp::Ordering;
421 use std::ops::Deref;
422
423 #[test]
424 fn sboxes_work_fine() {
425 stable::clear();
426 stable_memory_init();
427
428 {
429 let sbox = SBox::new(100).unwrap();
430 }
431
432 _debug_validate_allocator();
433 assert_eq!(get_allocated_size(), 0);
434
435 {
436 let mut sbox = SBox::new(100).unwrap();
437 let mut o_sbox = SBox::new(sbox).unwrap();
438 let mut oo_sbox = SBox::new(o_sbox).unwrap();
439
440 store_custom_data(0, oo_sbox);
441 oo_sbox = retrieve_custom_data::<SBox<SBox<i32>>>(0).unwrap();
442 }
443
444 _debug_validate_allocator();
445 assert_eq!(get_allocated_size(), 0);
446
447 {
448 let mut sbox = SBox::new(100).unwrap();
449 let mut o_sbox = SBox::new(sbox).unwrap();
450 let mut oo_sbox = SBox::new(o_sbox).unwrap();
451
452 store_custom_data(0, oo_sbox);
453 o_sbox = retrieve_custom_data::<SBox<SBox<i32>>>(0)
454 .unwrap()
455 .into_inner();
456
457 o_sbox.with(|sbox| *sbox = SBox::new(200).unwrap()).unwrap();
458
459 sbox = o_sbox.into_inner();
460
461 assert_eq!(*sbox, 200);
462 }
463
464 _debug_validate_allocator();
465 assert_eq!(get_allocated_size(), 0);
466
467 {
468 let mut sbox1 = SBox::new(10).unwrap();
469 let mut sbox11 = SBox::new(10).unwrap();
470 let mut sbox2 = SBox::new(20).unwrap();
471
472 assert_eq!(sbox1.deref(), &10);
473 assert_eq!(*sbox1, 10);
474
475 assert!(sbox1 < sbox2);
476 assert!(sbox2 > sbox1);
477 assert_eq!(sbox1, sbox11);
478
479 println!("{:?}", sbox1);
480
481 let sbox = SBox::<i32>::new(i32::default()).unwrap();
482 assert!(matches!(sbox1.cmp(&sbox), Ordering::Greater));
483 }
484
485 _debug_validate_allocator();
486 assert_eq!(get_allocated_size(), 0);
487 }
488
489 #[test]
490 fn complex_nested_structures_work_fine() {
491 stable::clear();
492 stable_memory_init();
493
494 {
495 let mut b = SBox::new(Some(SVec::new())).unwrap();
496
497 b.with(|it: &mut Option<SVec<u64>>| {
498 if let Some(v) = it.as_mut() {
499 v.push(10);
500 }
501 });
502
503 assert_eq!(*b.as_ref().unwrap().get(0).unwrap(), 10);
504
505 store_custom_data(0, b);
506
507 b = retrieve_custom_data(0).unwrap();
508
509 assert_eq!(*b.as_ref().unwrap().get(0).unwrap(), 10);
510
511 b.with(|it: &mut Option<SVec<u64>>| {
512 *it = None;
513 });
514 }
515
516 _debug_validate_allocator();
517 assert_eq!(get_allocated_size(), 0);
518 }
519
520 #[test]
521 fn serialization_works_fine() {
522 stable::clear();
523 stable_memory_init();
524
525 {
526 let b = SBox::new(String::from("test-test")).unwrap();
527 let bytes = encode_one(&b).unwrap();
528
529 assert_eq!(bytes.len(), 17);
530 }
531 }
532}