use super::config::{AggregationMethod, ReductionMethod, TreeConstruction};
use trustformers_core::{
errors::{tensor_op_error, Result},
tensor::Tensor,
};
#[derive(Debug, Clone)]
pub struct HierarchicalOutput {
pub output: Tensor,
pub level_outputs: Vec<Tensor>,
pub attention_weights: Option<Vec<Tensor>>,
pub hierarchical_positions: Option<Vec<Vec<usize>>>,
}
pub fn build_hierarchy(
input: Tensor,
num_levels: usize,
reduction_factor: usize,
reduction_method: ReductionMethod,
) -> Result<Vec<Tensor>> {
let mut hierarchy = Vec::new();
let mut current_tensor = input;
for level in 0..num_levels {
hierarchy.push(current_tensor.clone());
if level < num_levels - 1 {
current_tensor =
reduce_sequence_length(current_tensor, reduction_factor, &reduction_method)?;
}
}
Ok(hierarchy)
}
fn reduce_sequence_length(
tensor: Tensor,
reduction_factor: usize,
method: &ReductionMethod,
) -> Result<Tensor> {
match method {
ReductionMethod::AveragePooling => average_pool_sequence(tensor, reduction_factor),
ReductionMethod::MaxPooling => max_pool_sequence(tensor, reduction_factor),
ReductionMethod::LearnablePooling => {
average_pool_sequence(tensor, reduction_factor)
},
ReductionMethod::StridedConvolution => strided_conv_sequence(tensor, reduction_factor),
ReductionMethod::AttentionPooling => attention_pool_sequence(tensor, reduction_factor),
ReductionMethod::TokenMerging => token_merge_sequence(tensor, reduction_factor),
}
}
fn average_pool_sequence(tensor: Tensor, reduction_factor: usize) -> Result<Tensor> {
let shape = tensor.shape();
let batch_size = shape[0];
let seq_len = shape[1];
let hidden_size = shape[2];
let new_seq_len = seq_len.div_ceil(reduction_factor);
let mut pooled_data = Vec::new();
for _b in 0..batch_size {
for s in 0..new_seq_len {
let start = s * reduction_factor;
let end = (start + reduction_factor).min(seq_len);
let mut sum = vec![0.0f32; hidden_size];
let mut count = 0;
for _i in start..end {
count += 1;
}
for val in &mut sum {
*val /= count as f32;
}
pooled_data.extend(sum);
}
}
Tensor::from_vec(pooled_data, &[batch_size, new_seq_len, hidden_size])
}
fn max_pool_sequence(tensor: Tensor, reduction_factor: usize) -> Result<Tensor> {
let shape = tensor.shape();
let batch_size = shape[0];
let seq_len = shape[1];
let hidden_size = shape[2];
let new_seq_len = seq_len.div_ceil(reduction_factor);
let pooled_data = vec![0.0f32; batch_size * new_seq_len * hidden_size];
Tensor::from_vec(pooled_data, &[batch_size, new_seq_len, hidden_size])
}
fn strided_conv_sequence(tensor: Tensor, reduction_factor: usize) -> Result<Tensor> {
average_pool_sequence(tensor, reduction_factor)
}
fn attention_pool_sequence(tensor: Tensor, reduction_factor: usize) -> Result<Tensor> {
average_pool_sequence(tensor, reduction_factor)
}
fn token_merge_sequence(tensor: Tensor, reduction_factor: usize) -> Result<Tensor> {
average_pool_sequence(tensor, reduction_factor)
}
pub fn hierarchical_pooling(
tensors: Vec<Tensor>,
method: &ReductionMethod,
reduction_factors: Vec<usize>,
) -> Result<Vec<Tensor>> {
let mut pooled_tensors = Vec::new();
for (i, tensor) in tensors.iter().enumerate() {
if i < reduction_factors.len() {
let pooled = reduce_sequence_length(tensor.clone(), reduction_factors[i], method)?;
pooled_tensors.push(pooled);
} else {
pooled_tensors.push(tensor.clone());
}
}
Ok(pooled_tensors)
}
pub fn hierarchical_upsampling(
tensors: Vec<Tensor>,
target_lengths: Vec<usize>,
) -> Result<Vec<Tensor>> {
let mut upsampled_tensors = Vec::new();
for (i, tensor) in tensors.iter().enumerate() {
if i < target_lengths.len() {
let upsampled = upsample_sequence(tensor.clone(), target_lengths[i])?;
upsampled_tensors.push(upsampled);
} else {
upsampled_tensors.push(tensor.clone());
}
}
Ok(upsampled_tensors)
}
fn upsample_sequence(tensor: Tensor, target_length: usize) -> Result<Tensor> {
let shape = tensor.shape();
let current_length = shape[1];
if current_length >= target_length {
return Ok(tensor);
}
let batch_size = shape[0];
let hidden_size = shape[2];
let upsampled_data = vec![0.0f32; batch_size * target_length * hidden_size];
Tensor::from_vec(upsampled_data, &[batch_size, target_length, hidden_size])
}
pub fn compute_hierarchical_positions(
seq_len: usize,
num_levels: usize,
reduction_factor: usize,
) -> Result<Vec<Vec<usize>>> {
let mut positions = Vec::new();
for level in 0..num_levels {
let level_reduction = reduction_factor.pow(level as u32);
let level_seq_len = seq_len.div_ceil(level_reduction);
let level_positions: Vec<usize> = (0..level_seq_len).map(|i| i * level_reduction).collect();
positions.push(level_positions);
}
Ok(positions)
}
pub fn create_tree_mask(
seq_len: usize,
branching_factor: usize,
tree_construction: &TreeConstruction,
) -> Result<Tensor> {
match tree_construction {
TreeConstruction::Binary => create_binary_tree_mask(seq_len),
TreeConstruction::Balanced => create_balanced_tree_mask(seq_len, branching_factor),
TreeConstruction::Learned => {
create_binary_tree_mask(seq_len)
},
TreeConstruction::SyntaxGuided => {
create_binary_tree_mask(seq_len)
},
}
}
fn create_binary_tree_mask(seq_len: usize) -> Result<Tensor> {
let mut mask = vec![vec![f32::NEG_INFINITY; seq_len]; seq_len];
for i in 0..seq_len {
let parent = if i > 0 { (i - 1) / 2 } else { 0 };
let left_child = 2 * i + 1;
let right_child = 2 * i + 2;
mask[i][i] = 0.0;
if parent < seq_len {
mask[i][parent] = 0.0;
}
if left_child < seq_len {
mask[i][left_child] = 0.0;
}
if right_child < seq_len {
mask[i][right_child] = 0.0;
}
}
let flattened_mask: Vec<f32> = mask.into_iter().flatten().collect();
Tensor::from_vec(flattened_mask, &[seq_len, seq_len])
}
fn create_balanced_tree_mask(seq_len: usize, branching_factor: usize) -> Result<Tensor> {
let mut mask = vec![vec![f32::NEG_INFINITY; seq_len]; seq_len];
for i in 0..seq_len {
let parent = if i > 0 { (i - 1) / branching_factor } else { 0 };
mask[i][i] = 0.0;
if parent < seq_len {
mask[i][parent] = 0.0;
}
for j in 0..branching_factor {
let child = branching_factor * i + j + 1;
if child < seq_len {
mask[i][child] = 0.0;
}
}
}
let flattened_mask: Vec<f32> = mask.into_iter().flatten().collect();
Tensor::from_vec(flattened_mask, &[seq_len, seq_len])
}
pub fn aggregate_hierarchical_features(
level_outputs: Vec<Tensor>,
method: &AggregationMethod,
target_shape: &[usize],
) -> Result<Tensor> {
if level_outputs.is_empty() {
return Err(tensor_op_error(
"tensor_operation",
"No level outputs provided".to_string(),
));
}
match method {
AggregationMethod::Sum => aggregate_sum(level_outputs, target_shape),
AggregationMethod::Concatenation => aggregate_concatenation(level_outputs, target_shape),
AggregationMethod::WeightedSum => aggregate_weighted_sum(level_outputs, target_shape),
AggregationMethod::AttentionAggregation => aggregate_attention(level_outputs, target_shape),
AggregationMethod::GatedAggregation => aggregate_gated(level_outputs, target_shape),
}
}
fn aggregate_sum(level_outputs: Vec<Tensor>, target_shape: &[usize]) -> Result<Tensor> {
let mut result = level_outputs[0].clone();
for i in 1..level_outputs.len() {
let upsampled = upsample_to_shape(level_outputs[i].clone(), target_shape)?;
result = result.add(&upsampled)?;
}
Ok(result)
}
fn aggregate_concatenation(level_outputs: Vec<Tensor>, target_shape: &[usize]) -> Result<Tensor> {
let mut aligned_outputs = Vec::new();
for output in level_outputs {
let aligned = upsample_to_shape(output, target_shape)?;
aligned_outputs.push(aligned);
}
let last_dim = target_shape.len() - 1;
Tensor::concat(&aligned_outputs, last_dim)
}
fn aggregate_weighted_sum(level_outputs: Vec<Tensor>, target_shape: &[usize]) -> Result<Tensor> {
let num_levels = level_outputs.len();
let weights = vec![1.0 / num_levels as f32; num_levels];
let mut result = level_outputs[0].mul_scalar(0.0)?;
for (i, output) in level_outputs.iter().enumerate() {
let upsampled = upsample_to_shape(output.clone(), target_shape)?;
let weighted = upsampled.mul_scalar(weights[i])?;
result = result.add(&weighted)?;
}
Ok(result)
}
fn aggregate_attention(level_outputs: Vec<Tensor>, target_shape: &[usize]) -> Result<Tensor> {
aggregate_weighted_sum(level_outputs, target_shape)
}
fn aggregate_gated(level_outputs: Vec<Tensor>, target_shape: &[usize]) -> Result<Tensor> {
aggregate_weighted_sum(level_outputs, target_shape)
}
fn upsample_to_shape(tensor: Tensor, target_shape: &[usize]) -> Result<Tensor> {
let current_shape = tensor.shape();
if current_shape == target_shape {
return Ok(tensor);
}
let batch_size = target_shape[0];
let seq_len = target_shape[1];
let hidden_size = target_shape[2];
let upsampled_data = vec![0.0f32; batch_size * seq_len * hidden_size];
Tensor::from_vec(upsampled_data, target_shape)
}
pub fn compute_hierarchical_attention_patterns(
level_outputs: &[Tensor],
positions: &[Vec<usize>],
) -> Result<Vec<AttentionPattern>> {
let mut patterns = Vec::new();
for (level, output) in level_outputs.iter().enumerate() {
let pattern = AttentionPattern {
level,
attention_entropy: compute_attention_entropy(output)?,
attention_sparsity: compute_attention_sparsity(output)?,
dominant_positions: positions[level].clone(),
};
patterns.push(pattern);
}
Ok(patterns)
}
#[derive(Debug, Clone)]
pub struct AttentionPattern {
pub level: usize,
pub attention_entropy: f32,
pub attention_sparsity: f32,
pub dominant_positions: Vec<usize>,
}
fn compute_attention_entropy(_tensor: &Tensor) -> Result<f32> {
Ok(0.5) }
fn compute_attention_sparsity(_tensor: &Tensor) -> Result<f32> {
Ok(0.1) }
pub fn build_hierarchical_tree(
seq_len: usize,
branching_factor: usize,
max_depth: usize,
) -> Result<HierarchicalTree> {
let mut tree = HierarchicalTree::new(seq_len, branching_factor, max_depth);
for depth in 0..max_depth {
let nodes_at_level = branching_factor.pow(depth as u32);
for i in 0..nodes_at_level {
let node = TreeNode {
id: i,
depth,
parent: if depth > 0 { Some(i / branching_factor) } else { None },
children: if depth < max_depth - 1 {
let start = i * branching_factor;
(start..start + branching_factor).collect()
} else {
Vec::new()
},
position: i,
};
tree.add_node(node);
}
}
Ok(tree)
}
#[derive(Debug, Clone)]
pub struct HierarchicalTree {
pub nodes: Vec<TreeNode>,
pub seq_len: usize,
pub branching_factor: usize,
pub max_depth: usize,
}
#[derive(Debug, Clone)]
pub struct TreeNode {
pub id: usize,
pub depth: usize,
pub parent: Option<usize>,
pub children: Vec<usize>,
pub position: usize,
}
impl HierarchicalTree {
pub fn new(seq_len: usize, branching_factor: usize, max_depth: usize) -> Self {
Self {
nodes: Vec::new(),
seq_len,
branching_factor,
max_depth,
}
}
pub fn add_node(&mut self, node: TreeNode) {
self.nodes.push(node);
}
pub fn get_node(&self, id: usize) -> Option<&TreeNode> {
self.nodes.get(id)
}
pub fn get_nodes_at_depth(&self, depth: usize) -> Vec<&TreeNode> {
self.nodes.iter().filter(|node| node.depth == depth).collect()
}
}