Skip to main content

card_est_array/impls/
slice_estimator_array.rs

1/*
2 * SPDX-FileCopyrightText: 2024 Matteo Dell'Acqua
3 * SPDX-FileCopyrightText: 2025 Sebastiano Vigna
4 *
5 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
6 */
7
8use super::DefaultEstimator;
9use crate::traits::Word;
10use crate::traits::*;
11use sync_cell_slice::{SyncCell, SyncSlice};
12
13/// An array for estimators implementing a shared [`EstimationLogic`], and whose
14/// backend is a slice.
15///
16/// Note that we need a specific type for arrays of slice backends as one cannot
17/// create a slice of slices.
18pub struct SliceEstimatorArray<L, W = usize, S = Box<[W]>> {
19    pub(super) logic: L,
20    pub(super) backend: S,
21    _marker: std::marker::PhantomData<W>,
22}
23
24/// A view of a [`SliceEstimatorArray`] as a [`SyncEstimatorArray`].
25pub struct SyncSliceEstimatorArray<L, W = usize, S = Box<[W]>> {
26    pub(super) logic: L,
27    pub(super) backend: S,
28    _marker: std::marker::PhantomData<W>,
29}
30
31unsafe impl<L, W, S> Sync for SyncSliceEstimatorArray<L, W, S>
32where
33    L: Sync,
34    W: Sync,
35    S: Sync,
36{
37}
38
39impl<L: SliceEstimationLogic<W> + Sync, W: Word, S: AsRef<[SyncCell<W>]> + Sync>
40    SyncEstimatorArray<L> for SyncSliceEstimatorArray<L, W, S>
41{
42    unsafe fn set(&self, index: usize, content: &L::Backend) {
43        debug_assert_eq!(content.len(), self.logic.backend_len());
44        let offset = index * self.logic.backend_len();
45        for (c, &b) in self.backend.as_ref()[offset..].iter().zip(content) {
46            // SAFETY: we are the only ones writing to this cell
47            unsafe { c.set(b) }
48        }
49    }
50
51    #[inline(always)]
52    fn logic(&self) -> &L {
53        &self.logic
54    }
55
56    unsafe fn get(&self, index: usize, backend: &mut L::Backend) {
57        debug_assert_eq!(backend.len(), self.logic.backend_len());
58        let offset = index * self.logic.backend_len();
59        for (b, c) in backend
60            .iter_mut()
61            .zip(self.backend.as_ref()[offset..].iter())
62        {
63            // SAFETY: we are the only ones reading from this cell
64            *b = unsafe { c.get() }
65        }
66    }
67
68    unsafe fn clear(&self) {
69        self.backend
70            .as_ref()
71            .iter()
72            .for_each(|c| unsafe { c.set(W::ZERO) })
73    }
74
75    #[inline(always)]
76    fn len(&self) -> usize {
77        self.backend.as_ref().len() / self.logic.backend_len()
78    }
79}
80
81impl<L: SliceEstimationLogic<W> + Clone + Sync, W: Word, S: AsMut<[W]>> AsSyncArray<L>
82    for SliceEstimatorArray<L, W, S>
83{
84    type SyncEstimatorArray<'a>
85        = SyncSliceEstimatorArray<L, W, &'a [SyncCell<W>]>
86    where
87        Self: 'a;
88
89    fn as_sync_array(&mut self) -> SyncSliceEstimatorArray<L, W, &[SyncCell<W>]> {
90        SyncSliceEstimatorArray {
91            logic: self.logic.clone(),
92            backend: self.backend.as_mut().as_sync_slice(),
93            _marker: std::marker::PhantomData,
94        }
95    }
96}
97
98impl<L, W, S: AsRef<[W]>> AsRef<[W]> for SliceEstimatorArray<L, W, S> {
99    #[inline(always)]
100    fn as_ref(&self) -> &[W] {
101        self.backend.as_ref()
102    }
103}
104
105impl<L, W, S: AsMut<[W]>> AsMut<[W]> for SliceEstimatorArray<L, W, S> {
106    #[inline(always)]
107    fn as_mut(&mut self) -> &mut [W] {
108        self.backend.as_mut()
109    }
110}
111
112impl<L: SliceEstimationLogic<W>, W: Word> SliceEstimatorArray<L, W, Box<[W]>> {
113    /// Creates a new estimator slice with the provided logic.
114    ///
115    /// # Arguments
116    /// * `logic`: the estimator logic to use.
117    /// * `len`: the number of the estimators in the array.
118    pub fn new(logic: L, len: usize) -> Self {
119        let num_backend_len = logic.backend_len();
120        let backend = vec![W::ZERO; len * num_backend_len].into();
121        Self {
122            logic,
123            backend,
124            _marker: std::marker::PhantomData,
125        }
126    }
127}
128
129impl<L: SliceEstimationLogic<W> + Clone, W: Word, S: AsRef<[W]>> EstimatorArray<L>
130    for SliceEstimatorArray<L, W, S>
131{
132    type Estimator<'a>
133        = DefaultEstimator<L, &'a L, &'a [W]>
134    where
135        Self: 'a;
136
137    #[inline(always)]
138    fn get_backend(&self, index: usize) -> &L::Backend {
139        let offset = index * self.logic.backend_len();
140        &self.backend.as_ref()[offset..][..self.logic.backend_len()]
141    }
142
143    #[inline(always)]
144    fn logic(&self) -> &L {
145        &self.logic
146    }
147
148    #[inline(always)]
149    fn get_estimator(&self, index: usize) -> Self::Estimator<'_> {
150        DefaultEstimator::new(&self.logic, self.get_backend(index))
151    }
152
153    #[inline(always)]
154    fn len(&self) -> usize {
155        let backend = self.backend.as_ref();
156        debug_assert_eq!(backend.len() % self.logic.backend_len(), 0);
157        backend.len() / self.logic.backend_len()
158    }
159}
160
161impl<L: SliceEstimationLogic<W> + Clone, W: Word, S: AsRef<[W]> + AsMut<[W]>> EstimatorArrayMut<L>
162    for SliceEstimatorArray<L, W, S>
163{
164    type EstimatorMut<'a>
165        = DefaultEstimator<L, &'a L, &'a mut [W]>
166    where
167        Self: 'a;
168
169    #[inline(always)]
170    fn get_backend_mut(&mut self, index: usize) -> &mut L::Backend {
171        let offset = index * self.logic.backend_len();
172        &mut self.backend.as_mut()[offset..][..self.logic.backend_len()]
173    }
174
175    #[inline(always)]
176    fn get_estimator_mut(&mut self, index: usize) -> Self::EstimatorMut<'_> {
177        let logic = &self.logic;
178        // We have to extract manually the backend because get_backend_mut
179        // borrows self mutably, but we need to borrow just self.backend.
180        let offset = index * self.logic.backend_len();
181        let backend = &mut self.backend.as_mut()[offset..][..self.logic.backend_len()];
182
183        DefaultEstimator::new(logic, backend)
184    }
185
186    #[inline(always)]
187    fn clear(&mut self) {
188        self.backend.as_mut().iter_mut().for_each(|v| *v = W::ZERO)
189    }
190}