use std::mem::MaybeUninit;
use std::ptr;
use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
use parking_lot::Mutex;
use super::Node;
const CHUNK_SIZE: usize = 1024;
const MAX_CHUNKS: usize = 262_144;
pub struct NodeArena {
chunks: Mutex<Vec<Box<[MaybeUninit<Node>; CHUNK_SIZE]>>>,
chunk_ptrs: Box<[AtomicPtr<Node>]>,
len: AtomicUsize,
}
impl NodeArena {
pub fn new() -> Self {
let chunk_ptrs: Box<[AtomicPtr<Node>]> = (0..MAX_CHUNKS)
.map(|_| AtomicPtr::new(ptr::null_mut()))
.collect::<Vec<_>>()
.into_boxed_slice();
NodeArena {
chunks: Mutex::new(Vec::new()),
chunk_ptrs,
len: AtomicUsize::new(0),
}
}
pub fn push(&self, mut node: Node) -> usize {
let mut chunks = self.chunks.lock();
let idx = self.len.load(Ordering::Relaxed);
node.idx = idx as u32;
let chunk_idx = idx / CHUNK_SIZE;
let slot_idx = idx % CHUNK_SIZE;
assert!(
chunk_idx < MAX_CHUNKS,
"NodeArena: exceeded maximum capacity of {} nodes",
MAX_CHUNKS * CHUNK_SIZE
);
while chunks.len() <= chunk_idx {
let chunk = Box::new([const { MaybeUninit::uninit() }; CHUNK_SIZE]);
let chunk_ptr = chunk.as_ptr() as *mut Node;
chunks.push(chunk);
self.chunk_ptrs[chunks.len() - 1].store(chunk_ptr, Ordering::Release);
}
chunks[chunk_idx][slot_idx].write(node);
self.len.store(idx + 1, Ordering::Release);
idx
}
pub fn get(&self, idx: usize) -> &Node {
let current_len = self.len.load(Ordering::Acquire);
assert!(
idx < current_len,
"NodeArena::get({idx}) out of bounds (len={current_len})"
);
let chunk_idx = idx / CHUNK_SIZE;
let slot_idx = idx % CHUNK_SIZE;
let chunk_ptr = self.chunk_ptrs[chunk_idx].load(Ordering::Acquire);
debug_assert!(!chunk_ptr.is_null());
unsafe { &*chunk_ptr.add(slot_idx) }
}
pub fn len(&self) -> usize {
self.len.load(Ordering::Acquire)
}
pub fn iter(&self) -> NodeArenaIter<'_> {
NodeArenaIter {
arena: self,
current: 0,
len: self.len(),
}
}
}
impl Drop for NodeArena {
fn drop(&mut self) {
let len = *self.len.get_mut();
let chunks = self.chunks.get_mut();
for i in 0..len {
let chunk_idx = i / CHUNK_SIZE;
let slot_idx = i % CHUNK_SIZE;
unsafe {
chunks[chunk_idx][slot_idx].assume_init_drop();
}
}
}
}
impl std::fmt::Debug for NodeArena {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NodeArena")
.field("len", &self.len())
.finish_non_exhaustive()
}
}
pub struct NodeArenaIter<'a> {
arena: &'a NodeArena,
current: usize,
len: usize,
}
impl<'a> Iterator for NodeArenaIter<'a> {
type Item = &'a Node;
fn next(&mut self) -> Option<Self::Item> {
if self.current < self.len {
let node = self.arena.get(self.current);
self.current += 1;
Some(node)
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.len - self.current;
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for NodeArenaIter<'_> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::arena::cfr::NodeData;
#[test]
fn test_push_and_get() {
let arena = NodeArena::new();
let root = Node::new_root();
let idx = arena.push(root);
assert_eq!(idx, 0);
assert_eq!(arena.len(), 1);
let node = arena.get(0);
assert_eq!(node.idx, 0);
}
#[test]
fn test_multiple_pushes() {
let arena = NodeArena::new();
for i in 0..100 {
let node = Node::new(0, 0, NodeData::Chance);
let idx = arena.push(node);
assert_eq!(idx, i);
}
assert_eq!(arena.len(), 100);
for i in 0..100 {
let node = arena.get(i);
assert_eq!(node.idx, i as u32);
}
}
#[test]
fn test_cross_chunk_boundary() {
let arena = NodeArena::new();
for _ in 0..(CHUNK_SIZE + 10) {
let node = Node::new(0, 0, NodeData::Chance);
arena.push(node);
}
assert_eq!(arena.len(), CHUNK_SIZE + 10);
let last_in_first_chunk = arena.get(CHUNK_SIZE - 1);
assert_eq!(last_in_first_chunk.idx, (CHUNK_SIZE - 1) as u32);
let first_in_second_chunk = arena.get(CHUNK_SIZE);
assert_eq!(first_in_second_chunk.idx, CHUNK_SIZE as u32);
}
#[test]
fn test_iter() {
let arena = NodeArena::new();
for _ in 0..5 {
let node = Node::new(0, 0, NodeData::Chance);
arena.push(node);
}
let indices: Vec<u32> = arena.iter().map(|n| n.idx).collect();
assert_eq!(indices, vec![0, 1, 2, 3, 4]);
}
#[test]
#[should_panic(expected = "out of bounds")]
fn test_get_out_of_bounds() {
let arena = NodeArena::new();
arena.get(0);
}
#[test]
fn test_concurrent_reads() {
use std::sync::Arc;
let arena = Arc::new(NodeArena::new());
for _ in 0..100 {
let node = Node::new(0, 0, NodeData::Chance);
arena.push(node);
}
let handles: Vec<_> = (0..4)
.map(|_| {
let arena = arena.clone();
std::thread::spawn(move || {
for i in 0..100 {
let node = arena.get(i);
assert_eq!(node.idx, i as u32);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_concurrent_push_and_read() {
use std::sync::Arc;
let arena = Arc::new(NodeArena::new());
for i in 0..50 {
let node = Node::new(0, 0, NodeData::Chance);
let idx = arena.push(node);
assert_eq!(idx, i);
}
let arena_writer = arena.clone();
let writer = std::thread::spawn(move || {
for _ in 50..100 {
let node = Node::new(0, 0, NodeData::Chance);
arena_writer.push(node);
}
});
for i in 0..50 {
let node = arena.get(i);
assert_eq!(node.idx, i as u32);
}
writer.join().unwrap();
assert_eq!(arena.len(), 100);
}
#[test]
fn test_node_arena_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<NodeArena>();
}
#[test]
fn test_iter_snapshot_ignores_concurrent_pushes() {
let arena = NodeArena::new();
for _ in 0..5 {
arena.push(Node::new(0, 0, NodeData::Chance));
}
let iter = arena.iter();
arena.push(Node::new(0, 0, NodeData::Chance));
arena.push(Node::new(0, 0, NodeData::Chance));
assert_eq!(iter.len(), 5);
assert_eq!(iter.count(), 5);
assert_eq!(arena.len(), 7);
}
#[test]
fn test_concurrent_pushes() {
use std::sync::Arc;
let arena = Arc::new(NodeArena::new());
let num_threads = 4;
let pushes_per_thread = 250;
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let arena = arena.clone();
std::thread::spawn(move || {
for _ in 0..pushes_per_thread {
let node = Node::new(0, 0, NodeData::Chance);
arena.push(node);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(arena.len(), num_threads * pushes_per_thread);
let mut indices: Vec<u32> = arena.iter().map(|n| n.idx).collect();
indices.sort();
indices.dedup();
assert_eq!(indices.len(), num_threads * pushes_per_thread);
}
}