Skip to main content

midnight_storage/delta_tracking/
rcmap.rs

1// This file is part of midnight-ledger.
2// Copyright (C) 2025 Midnight Foundation
3// SPDX-License-Identifier: Apache-2.0
4// Licensed under the Apache License, Version 2.0 (the "License");
5// You may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! Reference count map for tracking charged keys in write and delete costing
15use crate::Storable;
16use crate::arena::{ArenaHash, ArenaKey, Opaque, Sp};
17use crate::db::DB;
18use crate::storable::Loader;
19use crate::storage::{Map, default_storage};
20use crate::{self as storage, DefaultDB};
21use derive_where::derive_where;
22use rand::distributions::{Distribution, Standard};
23use serialize::{Deserializable, Serializable, Tagged};
24#[cfg(test)]
25use std::collections::BTreeMap;
26use std::collections::BTreeSet;
27#[cfg(feature = "proptest")]
28use {proptest::prelude::Arbitrary, serialize::NoStrategy, std::marker::PhantomData};
29
30/// A wrapper around `ArenaKey` that ensures the referenced node is persisted.
31///
32/// When stored in the arena, `ArenaKey` reports the wrapped key as its child,
33/// which causes the back-end to keep the referenced node alive as long as the
34/// `ChildRef`.
35//
36// NOTE: Long-term, it would be nice if this could be a wrapper around `Sp<dyn Any>` instead of
37// around an arena key. This would be a safer alternative, as we would not need to make an
38// assumption that the child is allocated in the backend on construction.
39#[derive_where(Debug, PartialEq, Eq)]
40pub struct ChildRef<D: DB> {
41    /// The referenced child
42    pub child: ArenaKey<D::Hasher>,
43}
44
45// NOTE: This used to not be necessary, as creating an Sp of the ref would guarnatee allocation in
46// the backend. With the small nodes optimisation, this is no longer guaranteed, as the backend is
47// only invoked when a parent that isn't a small node is instantiated.
48//
49// However, if the referenced node(s) aren't in the backend, the ref doesn't do its job of keeping
50// these allocated. Therefore, we manually increment its ref count on allocation, and decrement it
51// on deallocation, using the backend `persist`/`unpersist` methods. Note that these are part of
52// what happens during (non-small node) Sp allocation, so this is only replicating a subset of this
53// behaviour. (Technically those are refcount updates instead of persist/unpersist, but the latter
54// are just thin wrappers around refcount updates)
55impl<D: DB> ChildRef<D> {
56    /// Creates a new reference
57    pub fn new(child: ArenaKey<D::Hasher>) -> Self {
58        // this *will* panic if `child` is not already allocated.
59        default_storage::<D>().with_backend(|b| child.refs().iter().for_each(|r| b.persist(r)));
60        Self { child }
61    }
62}
63
64impl<D: DB> Clone for ChildRef<D> {
65    fn clone(&self) -> Self {
66        ChildRef::new(self.child.clone())
67    }
68}
69
70impl<D: DB> Drop for ChildRef<D> {
71    fn drop(&mut self) {
72        default_storage::<D>()
73            .with_backend(|b| self.child.refs().iter().for_each(|r| b.unpersist(r)));
74    }
75}
76
77impl<D: DB> Storable<D> for ChildRef<D> {
78    fn children(&self) -> std::vec::Vec<ArenaKey<D::Hasher>> {
79        vec![self.child.clone()]
80    }
81
82    fn to_binary_repr<W: std::io::Write>(&self, _writer: &mut W) -> Result<(), std::io::Error>
83    where
84        Self: Sized,
85    {
86        Ok(())
87    }
88
89    fn from_binary_repr<R: std::io::Read>(
90        reader: &mut R,
91        children: &mut impl Iterator<Item = ArenaKey<D::Hasher>>,
92        loader: &impl Loader<D>,
93    ) -> Result<Self, std::io::Error>
94    where
95        Self: Sized,
96    {
97        let mut children = children.collect::<Vec<_>>();
98        let mut data = Vec::new();
99        reader.read_to_end(&mut data)?;
100        if children.len() == 1 && data.is_empty() {
101            let child = children.pop().expect("must be present");
102            let mut sp: Sp<Opaque<D>, D> = loader.get(&child)?;
103            sp.persist();
104            let child_ref = Self::new(child);
105            sp.unpersist();
106            Ok(child_ref)
107        } else {
108            Err(std::io::Error::new(
109                std::io::ErrorKind::InvalidData,
110                "Ref should have exactly one child and no data",
111            ))
112        }
113    }
114}
115
116impl<D: DB> Serializable for ChildRef<D> {
117    fn serialize(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
118        self.child.serialize(writer)
119    }
120
121    fn serialized_size(&self) -> usize {
122        self.child.serialized_size()
123    }
124}
125
126impl<D: DB> Deserializable for ChildRef<D> {
127    fn deserialize(reader: &mut impl std::io::Read, recursive_depth: u32) -> std::io::Result<Self> {
128        ArenaKey::<D::Hasher>::deserialize(reader, recursive_depth).map(ChildRef::new)
129    }
130}
131
132impl<D: DB> Distribution<ChildRef<D>> for Standard {
133    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> ChildRef<D> {
134        ChildRef::new(ArenaKey::Ref(rng.r#gen()))
135    }
136}
137
138// Manual impl because we don't derive Storable
139impl<D: DB> Tagged for ChildRef<D> {
140    fn tag() -> std::borrow::Cow<'static, str> {
141        "childref[v1]".into()
142    }
143    fn tag_unique_factor() -> String {
144        "children[v1]".into()
145    }
146}
147
148/// Reference count map for tracking charged keys in write and delete costing.
149///
150/// Internally we use `ChildRef` to ensure that nodes for all keys in the `RcMap`
151/// will be persisted as long a the `RcMap` itself is.
152#[derive_where(Debug, Clone, PartialEq, Eq)]
153#[derive(Storable)]
154//#[derive(serde::Serialize, serde::Deserialize, Storable)]
155//#[serde(bound(serialize = "", deserialize = ""))]
156#[storable(db = D)]
157#[tag = "rcmap[v1]"]
158pub struct RcMap<D: DB = DefaultDB> {
159    /// Reference counts for keys with `rc >= 1`
160    #[cfg(feature = "public-internal-structure")]
161    pub rc_ge_1: Map<ArenaHash<D::Hasher>, u64, D>,
162    #[cfg(not(feature = "public-internal-structure"))]
163    rc_ge_1: Map<ArenaHash<D::Hasher>, u64, D>,
164    /// Keys with reference count zero, for efficient garbage collection.
165    ///
166    /// The `ChildRef` here creates storage overhead -- an additional dag node for
167    /// each key -- but the `rc_0` map is expected to be small, so this
168    /// shouldn't matter.
169    #[cfg(feature = "public-internal-structure")]
170    pub rc_0: Map<ArenaKey<D::Hasher>, ChildRef<D>, D>,
171    #[cfg(not(feature = "public-internal-structure"))]
172    rc_0: Map<ArenaKey<D::Hasher>, ChildRef<D>, D>,
173}
174
175impl<D: DB> RcMap<D> {
176    /// Returns true iff the key is charged.
177    pub(crate) fn contains(&self, key: &ArenaKey<D::Hasher>) -> bool {
178        self.get_rc(key).is_some()
179    }
180
181    /// Get the current reference count for a key.
182    /// Returns Some(n) if key is charged (n >= 0), None if key is not in `RcMap`.
183    pub(crate) fn get_rc(&self, key: &ArenaKey<D::Hasher>) -> Option<u64> {
184        if let ArenaKey::Ref(key) = key
185            && let Some(count) = self.rc_ge_1.get(key)
186        {
187            Some(*count)
188        } else if self.rc_0.contains_key(key) {
189            Some(0)
190        } else {
191            None // Key not charged at all
192        }
193    }
194
195    #[must_use]
196    pub(crate) fn ins_root(&self, key: ArenaKey<D::Hasher>) -> Self {
197        RcMap {
198            rc_ge_1: self.rc_ge_1.clone(),
199            rc_0: self.rc_0.insert(key.clone(), ChildRef::new(key.clone())),
200        }
201    }
202
203    #[must_use]
204    pub(crate) fn rm_root(&self, key: &ArenaKey<D::Hasher>) -> Self {
205        RcMap {
206            rc_ge_1: self.rc_ge_1.clone(),
207            rc_0: self.rc_0.remove(key),
208        }
209    }
210
211    /// Increment the reference count for a key.
212    /// Returns `(new_rcmap, new_rc)`.
213    #[must_use]
214    pub(crate) fn modify_rc(&self, key: &ArenaHash<D::Hasher>, updated: u64) -> Self {
215        let curr = self.rc_ge_1.get(key).copied().unwrap_or(0);
216        match (curr, updated) {
217            (0, 0) =>
218            // Final ref count is zero, add to rc_0.
219            {
220                RcMap {
221                    rc_ge_1: self.rc_ge_1.clone(),
222                    rc_0: self.rc_0.insert(
223                        ArenaKey::Ref(key.clone()),
224                        ChildRef::new(ArenaKey::Ref(key.clone())),
225                    ),
226                }
227            }
228            (0, 1..) =>
229            // Key exists with rc = 0, move to rc_ge_1 with count n
230            {
231                RcMap {
232                    rc_ge_1: self.rc_ge_1.insert(key.clone(), updated),
233                    rc_0: self.rc_0.remove(&ArenaKey::Ref(key.clone())),
234                }
235            }
236            (1.., 1..) =>
237            // Key exists with rc_ge_1, update
238            {
239                RcMap {
240                    rc_ge_1: self.rc_ge_1.insert(key.clone(), updated),
241                    rc_0: self.rc_0.clone(),
242                }
243            }
244            (1.., 0) =>
245            // Key exists with rc_ge_1, move to rc_0
246            {
247                RcMap {
248                    rc_ge_1: self.rc_ge_1.remove(key),
249                    rc_0: self.rc_0.insert(
250                        ArenaKey::Ref(key.clone()),
251                        ChildRef::new(ArenaKey::Ref(key.clone())),
252                    ),
253                }
254            }
255        }
256    }
257
258    /// Get all keys that are unreachable (have `rc=0`) and not in the provided set.
259    /// This is used to initialize garbage collection.
260    pub(crate) fn get_unreachable_keys_not_in(
261        &self,
262        roots: &BTreeSet<ArenaKey<D::Hasher>>,
263    ) -> impl Iterator<Item = ArenaKey<D::Hasher>> {
264        self.rc_0.keys().filter(|key| !roots.contains(key))
265    }
266
267    /// Remove a key from the unreachable set (used during garbage collection).
268    /// Returns `Some(updated rc map)` if key was present with `rc == 0`, and
269    /// `None` otherwise.
270    #[must_use]
271    pub(crate) fn remove_unreachable_key(&self, key: &ArenaKey<D::Hasher>) -> Option<Self> {
272        if self.rc_0.contains_key(key) {
273            Some(RcMap {
274                rc_ge_1: self.rc_ge_1.clone(),
275                rc_0: self.rc_0.remove(key),
276            })
277        } else {
278            None
279        }
280    }
281
282    /// Get all charged keys and their reference counts (for testing).
283    #[cfg(test)]
284    pub(crate) fn get_rcs(&self) -> BTreeMap<ArenaKey<D::Hasher>, u64> {
285        let mut result = BTreeMap::new();
286
287        // Add all keys with rc = 0
288        for key in self.rc_0.keys() {
289            result.insert(key.clone(), 0);
290        }
291
292        // Add all keys with rc >= 1
293        for (key, count) in self.rc_ge_1.iter() {
294            result.insert(ArenaKey::Ref(key.clone()), *count);
295        }
296
297        result
298    }
299}
300
301impl<D: DB> Default for RcMap<D> {
302    fn default() -> Self {
303        RcMap {
304            rc_ge_1: Map::new(),
305            rc_0: Map::new(),
306        }
307    }
308}
309
310#[cfg(feature = "proptest")]
311impl<D: DB> Arbitrary for RcMap<D>
312where
313    D::Hasher: Arbitrary,
314{
315    type Strategy = NoStrategy<RcMap<D>>;
316    type Parameters = ();
317    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
318        NoStrategy(PhantomData)
319    }
320}
321
322impl<D: DB> Distribution<RcMap<D>> for Standard {
323    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> RcMap<D> {
324        RcMap {
325            rc_ge_1: rng.r#gen(),
326            rc_0: rng.r#gen(),
327        }
328    }
329}
330
331#[cfg(test)]
332pub(crate) mod tests {
333    use super::*;
334    use crate::arena::Sp;
335    use crate::db::InMemoryDB;
336    use crate::storable::SMALL_OBJECT_LIMIT;
337
338    // Test Storable serialization of vector of ChildRef, to be sure the manual
339    // Storable impl makes sense.
340    #[test]
341    fn keyref_round_trip_storable() {
342        // Create a dummy value to get an arena key
343        let val = Sp::<_, InMemoryDB>::new([0u8; 1024]);
344        let key = val.as_child();
345        let keyref = ChildRef::<InMemoryDB>::new(key);
346
347        let _ = Sp::new(keyref.clone());
348        // Create a vector with 3 of the same ChildRef
349        let keyrefs = vec![
350            Sp::new(keyref.clone()),
351            Sp::new(keyref.clone()),
352            Sp::new(keyref.clone()),
353        ];
354
355        // Roundtrip Storable
356        let mut bytes = Vec::new();
357        keyrefs.to_binary_repr(&mut bytes).unwrap();
358        let mut reader = &bytes[..];
359        let mut child_iter = keyrefs.children().into_iter();
360        let arena = &crate::storage::default_storage().arena;
361        let loader = storage_core::arena::BackendLoader::new(arena, None);
362        let deserialized: Vec<Sp<ChildRef<InMemoryDB>, InMemoryDB>> =
363            Storable::from_binary_repr(&mut reader, &mut child_iter, &loader).unwrap();
364
365        assert_eq!(keyrefs, deserialized);
366    }
367
368    // Helper function to get all descendants of RcMap recursively
369    #[cfg(test)]
370    pub(crate) fn get_rcmap_descendants<D: DB>(
371        rcmap: &RcMap<D>,
372    ) -> std::collections::BTreeSet<ArenaKey<D::Hasher>> {
373        let mut visited = std::collections::BTreeSet::new();
374        let mut to_visit = rcmap.children();
375        let arena = &crate::storage::default_storage::<D>().arena;
376        while let Some(current) = to_visit.pop() {
377            if !visited.insert(current.clone()) {
378                continue;
379            }
380            match current {
381                ArenaKey::Direct(d) => to_visit.extend(d.children.iter().cloned()),
382                ArenaKey::Ref(ref r) => {
383                    arena.with_backend(|backend| {
384                        let disk_obj = backend.get(r).expect("Key should exist in backend");
385                        to_visit.extend(disk_obj.children.clone());
386                    });
387                }
388            }
389        }
390        visited
391    }
392
393    // Test that keys in rc_0 are descendants of RcMap via ChildRef storage.
394    #[test]
395    fn rc_0_keys_are_descendants() {
396        let val = Sp::<_, InMemoryDB>::new([42u8; SMALL_OBJECT_LIMIT]);
397        let key = val.root.clone();
398
399        // Create RcMap with key in rc_0
400        let rcmap = RcMap::<InMemoryDB>::default().modify_rc(&key, 0);
401        assert!(rcmap.rc_0.contains_key(&ArenaKey::Ref(key.clone())));
402
403        let descendants = get_rcmap_descendants(&rcmap);
404        assert!(
405            descendants.contains(&val.as_child()),
406            "Key in rc_0 must be a descendant of RcMap"
407        );
408    }
409
410    // Comprehensive test of RcMap basic operations
411    #[test]
412    fn rcmap_operations() {
413        // Create test keys using simple u8 values
414        let val1 = Sp::<_, InMemoryDB>::new([1u8; 1024]);
415        let key1 = val1.as_child();
416        let ArenaKey::Ref(hash1) = key1.clone() else {
417            panic!("testing refs");
418        };
419        let val2 = Sp::<_, InMemoryDB>::new([2u8; 1024]);
420        let key2 = val2.as_child();
421        let ArenaKey::Ref(hash2) = key2.clone() else {
422            panic!("testing refs");
423        };
424        let val3 = Sp::<_, InMemoryDB>::new([3u8; 1024]);
425        let key3 = val3.as_child();
426        let ArenaKey::Ref(hash3) = key3.clone() else {
427            panic!("testing refs");
428        };
429
430        let rcmap = RcMap::<InMemoryDB>::default().ins_root(key1.clone());
431
432        // Test initialize_key sets rc=0
433        assert_eq!(rcmap.get_rc(&key1), Some(0), "get_rc should return 0");
434        assert!(rcmap.rc_0.contains_key(&key1), "key1 should be in rc_0 map");
435        assert!(
436            !rcmap.rc_ge_1.contains_key(&hash1),
437            "key1 should not be in rc_ge_1 map"
438        );
439
440        // Test increment_rc from 0 to 1 moves to rc_ge_1
441        let rcmap = rcmap.modify_rc(&hash1, 1);
442        assert_eq!(rcmap.get_rc(&key1), Some(1), "get_rc should return 1");
443        assert!(
444            !rcmap.rc_0.contains_key(&key1),
445            "key1 should not be in rc_0 map"
446        );
447        assert!(
448            rcmap.rc_ge_1.contains_key(&hash1),
449            "key1 should be in rc_ge_1 map"
450        );
451
452        // Test increment_rc multiple times
453        let rcmap = rcmap.modify_rc(&hash1, 2);
454        let rcmap = rcmap.modify_rc(&hash1, 3);
455        assert_eq!(rcmap.get_rc(&key1), Some(3), "get_rc should return 3");
456        assert!(
457            rcmap.rc_ge_1.contains_key(&hash1),
458            "key1 should remain in rc_ge_1 map"
459        );
460
461        // Test decrement_rc multiple times
462        let rcmap = rcmap.modify_rc(&hash1, 2);
463        let rcmap = rcmap.modify_rc(&hash1, 1);
464        assert!(
465            rcmap.rc_ge_1.contains_key(&hash1),
466            "key1 should still be in rc_ge_1 map"
467        );
468
469        // Test decrement_rc from 1 to 0 moves back to rc_0
470        let rcmap = rcmap.modify_rc(&hash1, 0);
471        assert_eq!(rcmap.get_rc(&key1), Some(0), "get_rc should return 0");
472        assert!(
473            rcmap.rc_0.contains_key(&key1),
474            "key1 should be back in rc_0 map"
475        );
476        assert!(
477            !rcmap.rc_ge_1.contains_key(&hash1),
478            "key1 should not be in rc_ge_1 map"
479        );
480
481        // Test get_rc on nonexistent key returns None
482        assert_eq!(
483            rcmap.get_rc(&key2),
484            None,
485            "get_rc on nonexistent key should return None"
486        );
487
488        // Test multiple keys
489        let rcmap = rcmap.modify_rc(&hash2, 1);
490        let rcmap = rcmap.modify_rc(&hash3, 2);
491
492        // Verify all keys have correct reference counts
493        assert_eq!(rcmap.get_rc(&key1), Some(0));
494        assert_eq!(rcmap.get_rc(&key2), Some(1));
495        assert_eq!(rcmap.get_rc(&key3), Some(2));
496
497        // Verify correct map placement
498        assert!(rcmap.rc_0.contains_key(&key1));
499        assert!(rcmap.rc_ge_1.contains_key(&hash2));
500        assert!(rcmap.rc_ge_1.contains_key(&hash3));
501
502        // Test remove_unreachable_key functionality
503        // Remove key1 (rc=0) should succeed
504        let rcmap_new = rcmap.remove_unreachable_key(&key1);
505        assert!(
506            rcmap_new.is_some(),
507            "remove_unreachable_key should succeed for rc=0 key"
508        );
509        let rcmap = rcmap_new.unwrap();
510        assert!(!rcmap.contains(&key1), "key1 should no longer be in rcmap");
511        assert_eq!(
512            rcmap.get_rc(&key1),
513            None,
514            "get_rc should return None for removed key"
515        );
516
517        // Remove key2 (rc=1) should fail
518        let rcmap_new = rcmap.remove_unreachable_key(&key2);
519        assert!(
520            rcmap_new.is_none(),
521            "remove_unreachable_key should fail for rc>0 key"
522        );
523
524        // Remove nonexistent key should fail
525        let rcmap_new = rcmap.remove_unreachable_key(&key1);
526        assert!(
527            rcmap_new.is_none(),
528            "remove_unreachable_key should fail for nonexistent key"
529        );
530    }
531}