use crate::{
memory::{FlexibleMemoryManager, MemoryManager, RefCount},
word::{CompositeWord, Shape, Word, WordIdx, Words, on_all_words},
};
use core::{cell::UnsafeCell, fmt::Debug};
use zeroize::Zeroizing;
pub trait WordPairPool: Sized + Debug + Default {
fn alloc<W: Word, const N: usize>(&mut self) -> WordIdx<W, N>;
fn increase_refcount<W: Word, const N: usize>(&mut self, idx: WordIdx<W, N>);
fn decrease_refcount<W: Word, const N: usize>(&mut self, idx: WordIdx<W, N>);
fn read<W: Word, const N: usize>(&self, idx: WordIdx<W, N>) -> [CompositeWord<W, N>; 2];
fn write<W: Word, const N: usize>(
&mut self,
idx: WordIdx<W, N>,
words: [CompositeWord<W, N>; 2],
);
}
pub trait WordPairSource: Debug + Default {
fn words<W: Word>(&self) -> [&[W]; 2];
fn words_mut<W: Word>(&mut self) -> [&mut [W]; 2];
fn resize<W: Word>(&mut self, new_len: usize);
}
#[derive(Debug)]
pub struct WordPairSourceWrapper {
words: [Zeroizing<Words>; 2],
}
impl WordPairSourceWrapper {
pub fn new(words: [Words; 2]) -> Self {
let [w0, w1] = words;
return Self {
words: [Zeroizing::new(w0), Zeroizing::new(w1)],
};
}
}
impl WordPairSource for WordPairSourceWrapper {
fn words<W: Word>(&self) -> [&[W]; 2] {
return [self.words[0].as_vec::<W>(), self.words[1].as_vec::<W>()];
}
fn words_mut<W: Word>(&mut self) -> [&mut [W]; 2] {
let [words0, words1] = &mut self.words;
return [words0.as_vec_mut::<W>(), words1.as_vec_mut::<W>()];
}
fn resize<W: Word>(&mut self, new_len: usize) {
for words in self.words.iter_mut() {
words.as_vec_mut::<W>().resize(new_len, W::ZERO);
}
}
}
impl Default for WordPairSourceWrapper {
fn default() -> Self {
return Self::new([Words::new(), Words::new()]);
}
}
#[derive(Debug)]
pub struct WordPairPoolWrapper<WS: WordPairSource, M: MemoryManager> {
word_pair_source: WS,
memory_manager: M,
}
impl<WS: WordPairSource, M: MemoryManager> WordPairPoolWrapper<WS, M> {
pub fn new(word_pair_source: WS, memory_manager: M) -> Self {
return Self {
word_pair_source,
memory_manager,
};
}
}
impl<WS: WordPairSource, M: MemoryManager> WordPairPool for WordPairPoolWrapper<WS, M> {
fn alloc<W: Word, const N: usize>(&mut self) -> WordIdx<W, N> {
let (idx, vec_len) = self.memory_manager.alloc::<W, N>();
self.word_pair_source.resize::<W>(vec_len);
return idx;
}
fn increase_refcount<W: Word, const N: usize>(&mut self, idx: WordIdx<W, N>) {
self.memory_manager.increase_refcount::<W, N>(idx);
}
fn decrease_refcount<W: Word, const N: usize>(&mut self, idx: WordIdx<W, N>) {
self.memory_manager.decrease_refcount::<W, N>(idx);
}
fn read<W: Word, const N: usize>(&self, idx: WordIdx<W, N>) -> [CompositeWord<W, N>; 2] {
let word_slices = self.word_pair_source.words::<W>();
return word_slices
.map(|words| CompositeWord::from_le_words(idx.into_array().map(|i| words[i])));
}
fn write<W: Word, const N: usize>(
&mut self,
idx: WordIdx<W, N>,
words: [CompositeWord<W, N>; 2],
) {
let mut word_slices = self.word_pair_source.words_mut::<W>();
for (k, word_slice) in word_slices.iter_mut().enumerate() {
for (i, w) in idx.into_array().into_iter().zip(words[k].to_le_words()) {
word_slice[i] = w;
}
}
}
}
impl<WS: WordPairSource, M: MemoryManager> Default for WordPairPoolWrapper<WS, M> {
fn default() -> Self {
let memory_manager = M::new();
let word_pair_source = WS::default();
return Self::new(word_pair_source, memory_manager);
}
}
#[allow(type_alias_bounds)]
pub type OwnedFlexibleWordPairPool<RC: RefCount> =
WordPairPoolWrapper<WordPairSourceWrapper, FlexibleMemoryManager<RC>>;
struct GlobalWordPairs(UnsafeCell<[Words; 2]>);
unsafe impl Sync for GlobalWordPairs {}
static GLOBAL_WORD_PAIRS: GlobalWordPairs =
GlobalWordPairs(UnsafeCell::new([Words::new(), Words::new()]));
#[derive(Debug, Default)]
pub struct GlobalWordPairSource;
impl WordPairSource for GlobalWordPairSource {
fn words<W: Word>(&self) -> [&[W]; 2] {
let words = unsafe { &*GLOBAL_WORD_PAIRS.0.get() };
[words[0].as_vec::<W>(), words[1].as_vec::<W>()]
}
fn words_mut<W: Word>(&mut self) -> [&mut [W]; 2] {
let [w0, w1] = unsafe { &mut *GLOBAL_WORD_PAIRS.0.get() };
[w0.as_vec_mut::<W>(), w1.as_vec_mut::<W>()]
}
fn resize<W: Word>(&mut self, new_len: usize) {
let words = unsafe { &mut *GLOBAL_WORD_PAIRS.0.get() };
words[0].as_vec_mut::<W>().resize(new_len, W::ZERO);
words[1].as_vec_mut::<W>().resize(new_len, W::ZERO);
}
}
#[allow(type_alias_bounds)]
pub type GlobalFlexibleWordPairPool<RC: RefCount> =
WordPairPoolWrapper<GlobalWordPairSource, FlexibleMemoryManager<RC>>;
impl<RC: RefCount> GlobalFlexibleWordPairPool<RC> {
pub fn shape() -> Shape {
let shape0 = unsafe { (&mut *GLOBAL_WORD_PAIRS.0.get())[0].shape() };
let shape1 = unsafe { (&mut *GLOBAL_WORD_PAIRS.0.get())[1].shape() };
assert_eq!(
shape0, shape1,
"Shape mismatch between word pools in global word pair pool."
);
return shape0;
}
pub fn capacity() -> Shape {
let capacity0 = unsafe { (&mut *GLOBAL_WORD_PAIRS.0.get())[0].capacity() };
let capacity1 = unsafe { (&mut *GLOBAL_WORD_PAIRS.0.get())[1].capacity() };
assert_eq!(
capacity0, capacity1,
"Capacity mismatch between word pools in global word triple pool."
);
return capacity0;
}
pub fn reserve(capacity: Shape) {
let current_capacity = Self::capacity();
let additional = capacity.zip(¤t_capacity, |desired, current| {
if desired < current {
0
} else {
desired - current
}
});
on_all_words!(W, {
unsafe {
(&mut *GLOBAL_WORD_PAIRS.0.get())[0]
.as_vec_mut::<W>()
.reserve_exact(*additional.as_value::<W>());
(&mut *GLOBAL_WORD_PAIRS.0.get())[1]
.as_vec_mut::<W>()
.reserve_exact(*additional.as_value::<W>());
};
});
}
pub fn resize(new_len: Shape) {
on_all_words!(W, {
unsafe {
(&mut *GLOBAL_WORD_PAIRS.0.get())[0]
.as_vec_mut::<W>()
.resize(*new_len.as_value::<W>(), <W as Word>::ZERO);
(&mut *GLOBAL_WORD_PAIRS.0.get())[1]
.as_vec_mut::<W>()
.resize(*new_len.as_value::<W>(), <W as Word>::ZERO);
};
});
assert_eq!(Self::shape(), new_len);
}
}