use crate::error::{Result, ZiporaError};
use crate::memory::SecureMemoryPool;
use std::cmp::Ordering;
use std::marker::PhantomData;
use std::sync::Arc;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
const CACHE_LINE_SIZE: usize = 64;
#[derive(Debug, Clone)]
pub struct LoserTreeConfig {
pub initial_capacity: usize,
pub use_secure_memory: bool,
pub stable_sort: bool,
pub cache_optimized: bool,
pub use_simd: bool,
pub prefetch_distance: usize,
pub alignment: usize,
}
impl Default for LoserTreeConfig {
fn default() -> Self {
Self {
initial_capacity: 64,
use_secure_memory: true,
stable_sort: true,
cache_optimized: true,
use_simd: cfg!(feature = "simd"),
prefetch_distance: 2,
alignment: CACHE_LINE_SIZE,
}
}
}
#[derive(Debug, Clone, Copy)]
#[repr(C, align(8))] pub struct TournamentNode {
pub loser_way: u32,
pub sequence_index: u32,
}
#[derive(Debug, Clone)]
#[repr(C, align(64))] pub struct CacheAlignedNode {
pub node: TournamentNode,
_padding: [u8; CACHE_LINE_SIZE - 8],
}
impl TournamentNode {
pub fn new(loser_way: usize, sequence_index: usize) -> Self {
Self {
loser_way: loser_way as u32,
sequence_index: sequence_index as u32,
}
}
pub fn loser_way(&self) -> usize {
self.loser_way as usize
}
pub fn sequence_index(&self) -> usize {
self.sequence_index as usize
}
}
impl CacheAlignedNode {
pub fn new(loser_way: usize, sequence_index: usize) -> Self {
Self {
node: TournamentNode::new(loser_way, sequence_index),
_padding: [0; CACHE_LINE_SIZE - 8],
}
}
}
struct WayIterator<I, T> {
iterator: I,
current: Option<T>,
way_index: usize,
position: usize,
exhausted: bool,
}
impl<I, T> WayIterator<I, T>
where
I: Iterator<Item = T>,
T: Clone,
{
fn new(mut iterator: I, way_index: usize) -> Self {
let current = iterator.next();
let exhausted = current.is_none();
Self {
iterator,
current,
way_index,
position: 0,
exhausted,
}
}
fn advance(&mut self) -> Result<()> {
if self.exhausted {
return Ok(());
}
self.current = self.iterator.next();
self.position += 1;
if self.current.is_none() {
self.exhausted = true;
}
Ok(())
}
fn peek(&self) -> Option<&T> {
self.current.as_ref()
}
fn is_exhausted(&self) -> bool {
self.exhausted
}
}
pub struct EnhancedLoserTree<T, F = fn(&T, &T) -> Ordering> {
tree: Vec<CacheAlignedNode>,
ways: Vec<WayIterator<Box<dyn Iterator<Item = T>>, T>>,
winner: usize,
num_ways: usize,
comparator: F,
config: LoserTreeConfig,
memory_pool: Option<Arc<SecureMemoryPool>>,
_phantom: PhantomData<T>,
}
impl<T> EnhancedLoserTree<T, fn(&T, &T) -> Ordering>
where
T: Ord + Clone,
{
pub fn new(config: LoserTreeConfig) -> Self {
Self::with_comparator(config, |a, b| a.cmp(b))
}
}
impl<T, F> EnhancedLoserTree<T, F>
where
T: Clone,
F: Fn(&T, &T) -> Ordering,
{
pub fn with_comparator(config: LoserTreeConfig, comparator: F) -> Self {
let memory_pool = if config.use_secure_memory {
match SecureMemoryPool::new(crate::memory::SecurePoolConfig::medium_secure()) {
Ok(pool) => Some(pool),
Err(_) => None, }
} else {
None
};
Self {
tree: Vec::with_capacity(config.initial_capacity),
ways: Vec::new(),
winner: 0,
num_ways: 0,
comparator,
config,
memory_pool,
_phantom: PhantomData,
}
}
pub fn add_way<I>(&mut self, iterator: I) -> Result<()>
where
I: Iterator<Item = T> + 'static,
{
let way_index = self.ways.len();
let boxed_iter: Box<dyn Iterator<Item = T>> = Box::new(iterator);
let way_iter = WayIterator::new(boxed_iter, way_index);
self.ways.push(way_iter);
self.num_ways += 1;
Ok(())
}
pub fn initialize(&mut self) -> Result<()> {
if self.ways.is_empty() {
return Err(ZiporaError::invalid_data("No input ways provided"));
}
self.num_ways = self.ways.len();
let tree_size = if self.num_ways > 1 { self.num_ways - 1 } else { 0 };
if self.config.cache_optimized {
self.tree.resize(tree_size, CacheAlignedNode::new(0, 0));
} else {
self.tree.resize(tree_size, CacheAlignedNode::new(0, 0));
}
self.build_enhanced_tree()?;
Ok(())
}
fn build_enhanced_tree(&mut self) -> Result<()> {
if self.num_ways <= 1 {
self.winner = 0;
return Ok(());
}
self.winner = self.find_initial_winner()?;
self.build_tree_structure()?;
Ok(())
}
fn find_initial_winner(&self) -> Result<usize> {
let mut min_way = 0;
let mut min_value: Option<&T> = None;
for (way_idx, way) in self.ways.iter().enumerate() {
if let Some(value) = way.peek() {
if min_value.is_none() || self.compare_optimized(value, min_value.expect("min_value set by prior iteration")) == Ordering::Less {
min_value = Some(value);
min_way = way_idx;
}
}
}
Ok(min_way)
}
#[inline]
fn compare_optimized(&self, a: &T, b: &T) -> Ordering {
if self.config.use_simd {
(self.comparator)(a, b)
} else {
(self.comparator)(a, b)
}
}
fn build_tree_structure(&mut self) -> Result<()> {
let num_ways = self.num_ways;
if num_ways <= 1 {
return Ok(());
}
for level in 0..self.tree.len() {
let left_child_idx = 2 * level + 1;
let right_child_idx = 2 * level + 2;
let (left_competitor, right_competitor) = if left_child_idx < self.tree.len() && right_child_idx < self.tree.len() {
(self.get_subtree_winner(left_child_idx), self.get_subtree_winner(right_child_idx))
} else if left_child_idx >= self.tree.len() {
let left_way = left_child_idx - self.tree.len();
let right_way = right_child_idx - self.tree.len();
if left_way < num_ways && right_way < num_ways {
(left_way, right_way)
} else {
continue;
}
} else {
continue;
};
let (winner, loser) = self.compare_competitors(left_competitor, right_competitor)?;
if self.config.cache_optimized && self.config.prefetch_distance > 0 {
self.prefetch_next_nodes(level);
}
self.tree[level] = CacheAlignedNode::new(loser, self.ways[loser].position);
}
Ok(())
}
fn get_subtree_winner(&self, node_idx: usize) -> usize {
self.winner
}
fn compare_competitors(&self, way1: usize, way2: usize) -> Result<(usize, usize)> {
let value1 = self.ways.get(way1)
.ok_or_else(|| ZiporaError::out_of_bounds(way1, self.ways.len()))?
.peek();
let value2 = self.ways.get(way2)
.ok_or_else(|| ZiporaError::out_of_bounds(way2, self.ways.len()))?
.peek();
match (value1, value2) {
(Some(v1), Some(v2)) => {
match self.compare_optimized(v1, v2) {
Ordering::Less => Ok((way1, way2)),
Ordering::Greater => Ok((way2, way1)),
Ordering::Equal => {
if self.config.stable_sort {
if self.ways[way1].position <= self.ways[way2].position {
Ok((way1, way2))
} else {
Ok((way2, way1))
}
} else {
Ok((way1, way2))
}
}
}
}
(Some(_), None) => Ok((way1, way2)),
(None, Some(_)) => Ok((way2, way1)),
(None, None) => Ok((way1, way2)), }
}
#[cfg(target_arch = "x86_64")]
fn prefetch_next_nodes(&self, current_level: usize) {
if self.config.prefetch_distance > 0 {
let prefetch_level = current_level + self.config.prefetch_distance;
if prefetch_level < self.tree.len() {
let node_ptr = &self.tree[prefetch_level] as *const CacheAlignedNode;
unsafe {
_mm_prefetch(node_ptr as *const i8, _MM_HINT_T0);
}
}
}
}
#[cfg(not(target_arch = "x86_64"))]
fn prefetch_next_nodes(&self, _current_level: usize) {
}
fn update_winner(&mut self) -> Result<()> {
self.winner = self.find_initial_winner()?;
Ok(())
}
fn replay_matches(&mut self) -> Result<()> {
let num_ways = self.ways.len();
if num_ways <= 1 {
return Ok(());
}
for level in 0..self.tree.len() {
let left_child = 2 * level + 1;
let right_child = 2 * level + 2;
if left_child < num_ways && right_child < num_ways {
let (_winner, loser) = self.compare_ways(left_child, right_child)?;
self.tree[level] = CacheAlignedNode::new(loser, self.ways[loser].position);
} else if left_child < self.tree.len() && right_child < self.tree.len() {
let left_winner = self.get_node_winner(left_child);
let right_winner = self.get_node_winner(right_child);
let (_winner, loser) = self.compare_ways(left_winner, right_winner)?;
self.tree[level] = CacheAlignedNode::new(loser, self.ways[loser].position);
}
}
Ok(())
}
fn compare_ways(&self, way1: usize, way2: usize) -> Result<(usize, usize)> {
let value1 = self.ways.get(way1)
.ok_or_else(|| ZiporaError::out_of_bounds(way1, self.ways.len()))?
.peek();
let value2 = self.ways.get(way2)
.ok_or_else(|| ZiporaError::out_of_bounds(way2, self.ways.len()))?
.peek();
match (value1, value2) {
(Some(v1), Some(v2)) => {
match (self.comparator)(v1, v2) {
Ordering::Less => Ok((way1, way2)),
Ordering::Greater => Ok((way2, way1)),
Ordering::Equal => {
if self.config.stable_sort {
if self.ways[way1].position <= self.ways[way2].position {
Ok((way1, way2))
} else {
Ok((way2, way1))
}
} else {
Ok((way1, way2))
}
}
}
}
(Some(_), None) => Ok((way1, way2)),
(None, Some(_)) => Ok((way2, way1)),
(None, None) => Ok((way1, way2)), }
}
fn get_node_winner(&self, node_index: usize) -> usize {
if node_index < self.tree.len() {
self.winner
} else {
node_index
}
}
pub fn peek(&self) -> Option<&T> {
if self.ways.is_empty() || self.winner >= self.ways.len() {
return None;
}
self.ways[self.winner].peek()
}
pub fn pop(&mut self) -> Result<Option<T>> {
if self.ways.is_empty() || self.winner >= self.ways.len() {
return Ok(None);
}
let result = self.ways[self.winner].current.clone();
if result.is_some() {
self.ways[self.winner].advance()?;
self.update_winner()?;
}
Ok(result)
}
#[inline]
pub fn is_empty(&self) -> bool {
self.ways.iter().all(|way| way.is_exhausted())
}
pub fn merge_all<O>(&mut self, output: &mut O) -> Result<()>
where
O: Extend<T>,
{
self.initialize()?;
let mut result = Vec::new();
while !self.is_empty() {
if let Some(value) = self.pop()? {
result.push(value);
}
}
output.extend(result);
Ok(())
}
pub fn merge_to_vec(&mut self) -> Result<Vec<T>> {
let mut result = Vec::new();
self.merge_all(&mut result)?;
Ok(result)
}
pub fn num_ways(&self) -> usize {
self.ways.len()
}
pub fn config(&self) -> &LoserTreeConfig {
&self.config
}
}
impl<T, F> Iterator for EnhancedLoserTree<T, F>
where
T: Clone,
F: Fn(&T, &T) -> Ordering,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.pop().unwrap_or(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tournament_node() {
let node = TournamentNode::new(5, 10);
assert_eq!(node.loser_way, 5);
assert_eq!(node.sequence_index, 10);
}
#[test]
fn test_loser_tree_config_default() {
let config = LoserTreeConfig::default();
assert_eq!(config.initial_capacity, 64);
assert!(config.use_secure_memory);
assert!(config.stable_sort);
assert!(config.cache_optimized);
}
#[test]
fn test_empty_tree() {
let config = LoserTreeConfig::default();
let mut tree = EnhancedLoserTree::<i32>::new(config);
assert!(tree.is_empty());
assert_eq!(tree.num_ways(), 0);
assert!(tree.peek().is_none());
assert!(tree.pop().unwrap().is_none());
}
#[test]
fn test_single_way() -> Result<()> {
let config = LoserTreeConfig::default();
let mut tree = EnhancedLoserTree::<i32>::new(config);
tree.add_way(vec![1, 2, 3].into_iter())?;
let result = tree.merge_to_vec()?;
assert_eq!(result, vec![1, 2, 3]);
Ok(())
}
#[test]
fn test_two_way_merge() -> Result<()> {
let config = LoserTreeConfig::default();
let mut tree = EnhancedLoserTree::<i32>::new(config);
tree.add_way(vec![1, 3, 5].into_iter())?;
tree.add_way(vec![2, 4, 6].into_iter())?;
let result = tree.merge_to_vec()?;
assert_eq!(result, vec![1, 2, 3, 4, 5, 6]);
Ok(())
}
#[test]
fn test_three_way_merge() -> Result<()> {
let config = LoserTreeConfig::default();
let mut tree = EnhancedLoserTree::<i32>::new(config);
tree.add_way(vec![1, 4, 7].into_iter())?;
tree.add_way(vec![2, 5, 8].into_iter())?;
tree.add_way(vec![3, 6, 9].into_iter())?;
let result = tree.merge_to_vec()?;
assert_eq!(result, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
Ok(())
}
#[test]
fn test_uneven_lengths() -> Result<()> {
let config = LoserTreeConfig::default();
let mut tree = EnhancedLoserTree::<i32>::new(config);
tree.add_way(vec![1].into_iter())?;
tree.add_way(vec![2, 3, 4, 5].into_iter())?;
tree.add_way(vec![6, 7].into_iter())?;
let result = tree.merge_to_vec()?;
assert_eq!(result, vec![1, 2, 3, 4, 5, 6, 7]);
Ok(())
}
#[test]
fn test_empty_ways() -> Result<()> {
let config = LoserTreeConfig::default();
let mut tree = EnhancedLoserTree::<i32>::new(config);
tree.add_way(vec![1, 2].into_iter())?;
tree.add_way(std::iter::empty())?;
tree.add_way(vec![3, 4].into_iter())?;
let result = tree.merge_to_vec()?;
assert_eq!(result, vec![1, 2, 3, 4]);
Ok(())
}
#[test]
fn test_duplicate_values() -> Result<()> {
let config = LoserTreeConfig::default();
let mut tree = EnhancedLoserTree::<i32>::new(config);
tree.add_way(vec![1, 2, 2, 3].into_iter())?;
tree.add_way(vec![2, 2, 4].into_iter())?;
let result = tree.merge_to_vec()?;
assert_eq!(result, vec![1, 2, 2, 2, 2, 3, 4]);
Ok(())
}
#[test]
fn test_custom_comparator() -> Result<()> {
let config = LoserTreeConfig::default();
let mut tree = EnhancedLoserTree::with_comparator(config, |a: &i32, b: &i32| b.cmp(a));
tree.add_way(vec![5, 3, 1].into_iter())?;
tree.add_way(vec![6, 4, 2].into_iter())?;
let result = tree.merge_to_vec()?;
assert_eq!(result, vec![6, 5, 4, 3, 2, 1]);
Ok(())
}
#[test]
fn test_iterator_interface() -> Result<()> {
let config = LoserTreeConfig::default();
let mut tree = EnhancedLoserTree::<i32>::new(config);
tree.add_way(vec![1, 3].into_iter())?;
tree.add_way(vec![2, 4].into_iter())?;
tree.initialize()?;
let collected: Vec<_> = tree.collect();
assert_eq!(collected, vec![1, 2, 3, 4]);
Ok(())
}
#[test]
fn test_peek_before_pop() -> Result<()> {
let config = LoserTreeConfig::default();
let mut tree = EnhancedLoserTree::<i32>::new(config);
tree.add_way(vec![1, 3].into_iter())?;
tree.add_way(vec![2, 4].into_iter())?;
tree.initialize()?;
assert_eq!(tree.peek(), Some(&1));
assert_eq!(tree.pop()?, Some(1));
assert_eq!(tree.peek(), Some(&2));
assert_eq!(tree.pop()?, Some(2));
Ok(())
}
#[test]
fn test_large_merge() -> Result<()> {
let config = LoserTreeConfig::default();
let mut tree = EnhancedLoserTree::<i32>::new(config);
for way in 0..10 {
let values: Vec<i32> = (0..100).map(|i| way * 100 + i).collect();
tree.add_way(values.into_iter())?;
}
let result = tree.merge_to_vec()?;
assert_eq!(result.len(), 1000);
for i in 1..result.len() {
assert!(result[i] >= result[i-1]);
}
Ok(())
}
}