use super::node::{Node, NodeMetadata};
use crate::error::{Error, Result};
#[cfg(feature = "std")]
use std::vec::Vec;
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[derive(Debug, Clone)]
pub struct TreeConfig {
pub max_depth: usize,
pub fanout: usize,
pub min_cluster_size: usize,
}
impl Default for TreeConfig {
fn default() -> Self {
Self {
max_depth: 4,
fanout: 6,
min_cluster_size: 2,
}
}
}
impl TreeConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_depth(mut self, depth: usize) -> Self {
self.max_depth = depth;
self
}
pub fn with_fanout(mut self, fanout: usize) -> Self {
self.fanout = fanout;
self
}
pub fn with_min_cluster_size(mut self, size: usize) -> Self {
self.min_cluster_size = size;
self
}
}
#[derive(Debug, Clone)]
pub struct RaptorTree<T, S = T> {
nodes: Vec<Node<T, S>>,
levels: Vec<Vec<usize>>,
#[allow(dead_code)]
config: TreeConfig,
}
impl<T, S> RaptorTree<T, S> {
pub fn new(config: TreeConfig) -> Self {
Self {
nodes: Vec::new(),
levels: Vec::new(),
config,
}
}
pub fn build<F, G>(
items: Vec<T>,
config: TreeConfig,
cluster_fn: F,
summarize_fn: G,
) -> Result<Self>
where
F: Fn(&[usize], usize) -> Vec<Vec<usize>>,
G: Fn(&[&T]) -> S,
T: Clone,
S: Clone,
{
if items.is_empty() {
return Err(Error::EmptyInput);
}
let mut tree = Self::new(config.clone());
let leaf_ids: Vec<usize> = items
.into_iter()
.enumerate()
.map(|(i, item)| {
let node = Node::leaf(i, item);
tree.nodes.push(node);
i
})
.collect();
tree.levels.push(leaf_ids.clone());
let mut current_ids = leaf_ids;
let mut next_id = tree.nodes.len();
for level in 1..=config.max_depth {
if current_ids.len() <= config.min_cluster_size {
break;
}
let clusters = cluster_fn(¤t_ids, config.fanout);
if clusters.is_empty() {
break;
}
let mut level_ids = Vec::new();
for cluster in clusters {
if cluster.is_empty() {
continue;
}
let items_to_summarize: Vec<&T> = cluster
.iter()
.filter_map(|&id| tree.get_leaf_content(id))
.collect();
if items_to_summarize.is_empty() {
continue;
}
let summary = summarize_fn(&items_to_summarize);
let metadata = NodeMetadata {
leaf_count: cluster.len(),
cluster_method: Some("provided".into()),
summary_method: Some("provided".into()),
};
let node = Node::internal(next_id, summary, level, cluster).with_metadata(metadata);
tree.nodes.push(node);
level_ids.push(next_id);
next_id += 1;
}
if level_ids.is_empty() {
break;
}
tree.levels.push(level_ids.clone());
current_ids = level_ids;
}
Ok(tree)
}
fn get_leaf_content(&self, id: usize) -> Option<&T> {
self.nodes.get(id).and_then(|n| n.as_leaf())
}
pub fn get_node(&self, id: usize) -> Option<&Node<T, S>> {
self.nodes.get(id)
}
pub fn get_level(&self, level: usize) -> Option<Vec<&Node<T, S>>> {
self.levels
.get(level)
.map(|ids| ids.iter().filter_map(|&id| self.nodes.get(id)).collect())
}
pub fn depth(&self) -> usize {
self.levels.len()
}
pub fn leaves(&self) -> Vec<&Node<T, S>> {
self.get_level(0).unwrap_or_default()
}
pub fn roots(&self) -> Vec<&Node<T, S>> {
self.get_level(self.depth().saturating_sub(1))
.unwrap_or_default()
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &Node<T, S>> {
self.nodes.iter()
}
pub fn collapsed(&self) -> Vec<&Node<T, S>> {
self.nodes.iter().collect()
}
pub fn view_at_level(&self, level: usize) -> Vec<&Node<T, S>> {
let target = level.min(self.depth().saturating_sub(1));
self.get_level(target).unwrap_or_default()
}
}
impl<T, S> Default for RaptorTree<T, S> {
fn default() -> Self {
Self::new(TreeConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tree_config_default() {
let config = TreeConfig::default();
assert_eq!(config.max_depth, 4);
assert_eq!(config.fanout, 6);
}
#[test]
fn test_empty_tree() {
let tree: RaptorTree<String> = RaptorTree::default();
assert!(tree.is_empty());
assert_eq!(tree.depth(), 0);
}
}