use serde::{Deserialize, Serialize};
use trustformers_core::{errors::invalid_config, traits::Config};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HierarchicalConfig {
pub hidden_size: usize,
pub num_levels: usize,
pub num_heads: usize,
pub reduction_factor: usize,
pub num_layers_per_level: usize,
pub intermediate_size: usize,
pub dropout: f32,
pub attention_dropout: f32,
pub layer_norm_eps: f32,
pub hierarchical_type: HierarchicalType,
pub reduction_method: ReductionMethod,
pub aggregation_method: AggregationMethod,
pub max_seq_lengths: Vec<usize>,
pub cross_level_residual: bool,
pub use_position_embeddings: bool,
pub tree_config: Option<TreeConfig>,
pub pyramid_config: Option<PyramidConfig>,
pub nested_config: Option<NestedConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HierarchicalType {
Hierarchical,
Pyramid,
Tree,
Nested,
Hybrid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ReductionMethod {
AveragePooling,
MaxPooling,
LearnablePooling,
StridedConvolution,
AttentionPooling,
TokenMerging,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AggregationMethod {
Sum,
Concatenation,
WeightedSum,
AttentionAggregation,
GatedAggregation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TreeConfig {
pub branching_factor: usize,
pub max_depth: usize,
pub learnable_structure: bool,
pub tree_construction: TreeConstruction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TreeConstruction {
Binary,
Balanced,
Learned,
SyntaxGuided,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PyramidConfig {
pub scaling_factors: Vec<f32>,
pub skip_connections: bool,
pub upsampling_method: UpsamplingMethod,
pub use_fpn: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum UpsamplingMethod {
Linear,
TransposedConvolution,
Learned,
PixelShuffle,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NestedConfig {
pub num_nested_levels: usize,
pub share_parameters: bool,
pub information_flow: InformationFlow,
pub progressive_training: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum InformationFlow {
BottomUp,
TopDown,
Bidirectional,
SkipConnections,
}
impl Default for HierarchicalConfig {
fn default() -> Self {
Self {
hidden_size: 768,
num_levels: 4,
num_heads: 12,
reduction_factor: 2,
num_layers_per_level: 3,
intermediate_size: 3072,
dropout: 0.1,
attention_dropout: 0.1,
layer_norm_eps: 1e-5,
hierarchical_type: HierarchicalType::Hierarchical,
reduction_method: ReductionMethod::AveragePooling,
aggregation_method: AggregationMethod::WeightedSum,
max_seq_lengths: vec![512, 256, 128, 64],
cross_level_residual: true,
use_position_embeddings: true,
tree_config: None,
pyramid_config: None,
nested_config: None,
}
}
}
impl Default for TreeConfig {
fn default() -> Self {
Self {
branching_factor: 2,
max_depth: 8,
learnable_structure: false,
tree_construction: TreeConstruction::Binary,
}
}
}
impl Default for PyramidConfig {
fn default() -> Self {
Self {
scaling_factors: vec![1.0, 0.5, 0.25, 0.125],
skip_connections: true,
upsampling_method: UpsamplingMethod::Linear,
use_fpn: false,
}
}
}
impl Default for NestedConfig {
fn default() -> Self {
Self {
num_nested_levels: 3,
share_parameters: false,
information_flow: InformationFlow::Bidirectional,
progressive_training: false,
}
}
}
impl HierarchicalConfig {
pub fn hierarchical(hidden_size: usize, num_levels: usize) -> Self {
Self {
hidden_size,
num_levels,
hierarchical_type: HierarchicalType::Hierarchical,
max_seq_lengths: (0..num_levels).map(|i| 512 / (2_usize.pow(i as u32))).collect(),
..Default::default()
}
}
pub fn pyramid(hidden_size: usize, num_levels: usize) -> Self {
Self {
hidden_size,
num_levels,
hierarchical_type: HierarchicalType::Pyramid,
pyramid_config: Some(PyramidConfig::default()),
max_seq_lengths: (0..num_levels).map(|i| 512 / (2_usize.pow(i as u32))).collect(),
..Default::default()
}
}
pub fn tree(hidden_size: usize, branching_factor: usize, max_depth: usize) -> Self {
Self {
hidden_size,
num_levels: max_depth,
hierarchical_type: HierarchicalType::Tree,
tree_config: Some(TreeConfig {
branching_factor,
max_depth,
..Default::default()
}),
..Default::default()
}
}
pub fn nested(hidden_size: usize, num_nested_levels: usize) -> Self {
Self {
hidden_size,
num_levels: num_nested_levels,
hierarchical_type: HierarchicalType::Nested,
nested_config: Some(NestedConfig {
num_nested_levels,
..Default::default()
}),
..Default::default()
}
}
pub fn get_hidden_size(&self, level: usize) -> usize {
match self.hierarchical_type {
HierarchicalType::Pyramid => {
if let Some(pyramid_config) = &self.pyramid_config {
if level < pyramid_config.scaling_factors.len() {
(self.hidden_size as f32 * pyramid_config.scaling_factors[level]) as usize
} else {
self.hidden_size
}
} else {
self.hidden_size
}
},
_ => self.hidden_size,
}
}
pub fn get_seq_length(&self, level: usize) -> usize {
if level < self.max_seq_lengths.len() {
self.max_seq_lengths[level]
} else {
512 / (2_usize.pow(level as u32))
}
}
pub fn get_reduction_factor(&self, level: usize) -> usize {
self.reduction_factor.pow(level as u32)
}
pub fn validate(&self) -> std::result::Result<(), String> {
if self.num_levels == 0 {
return Err("num_levels must be greater than 0".to_string());
}
if self.reduction_factor == 0 {
return Err("reduction_factor must be greater than 0".to_string());
}
if self.num_heads == 0 {
return Err("num_heads must be greater than 0".to_string());
}
if !self.hidden_size.is_multiple_of(self.num_heads) {
return Err("hidden_size must be divisible by num_heads".to_string());
}
if self.dropout < 0.0 || self.dropout > 1.0 {
return Err("dropout must be between 0.0 and 1.0".to_string());
}
if self.attention_dropout < 0.0 || self.attention_dropout > 1.0 {
return Err("attention_dropout must be between 0.0 and 1.0".to_string());
}
if !self.max_seq_lengths.is_empty() && self.max_seq_lengths.len() != self.num_levels {
return Err("max_seq_lengths length must match num_levels".to_string());
}
if let Some(tree_config) = &self.tree_config {
if tree_config.branching_factor == 0 {
return Err("branching_factor must be greater than 0".to_string());
}
if tree_config.max_depth == 0 {
return Err("max_depth must be greater than 0".to_string());
}
}
if let Some(pyramid_config) = &self.pyramid_config {
if pyramid_config.scaling_factors.is_empty() {
return Err("scaling_factors cannot be empty".to_string());
}
for &factor in &pyramid_config.scaling_factors {
if factor <= 0.0 {
return Err("scaling_factors must be positive".to_string());
}
}
}
if let Some(nested_config) = &self.nested_config {
if nested_config.num_nested_levels == 0 {
return Err("num_nested_levels must be greater than 0".to_string());
}
}
Ok(())
}
pub fn estimate_parameters(&self) -> usize {
let mut total = 0;
for level in 0..self.num_levels {
let hidden_size = self.get_hidden_size(level);
let _seq_len = self.get_seq_length(level);
total += hidden_size * hidden_size * 4;
total += hidden_size * self.intermediate_size * 2;
total += hidden_size * 2;
total *= self.num_layers_per_level;
}
total
}
}
impl Config for HierarchicalConfig {
fn validate(&self) -> trustformers_core::errors::Result<()> {
if self.num_levels == 0 {
return Err(invalid_config(
"config_field",
"num_levels must be greater than 0".to_string(),
));
}
Ok(())
}
fn architecture(&self) -> &'static str {
"hierarchical"
}
}
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::traits::Config;
struct Lcg {
state: u64,
}
impl Lcg {
fn new(seed: u64) -> Self {
Lcg { state: seed }
}
fn next(&mut self) -> u64 {
self.state = self
.state
.wrapping_mul(6364136223846793005u64)
.wrapping_add(1442695040888963407u64);
self.state
}
fn next_f32(&mut self) -> f32 {
(self.next() >> 11) as f32 / (1u64 << 53) as f32
}
}
#[test]
fn test_default_config_fields() {
let cfg = HierarchicalConfig::default();
assert_eq!(cfg.hidden_size, 768);
assert_eq!(cfg.num_levels, 4);
assert_eq!(cfg.num_heads, 12);
assert_eq!(cfg.reduction_factor, 2);
assert!(cfg.cross_level_residual);
assert!(cfg.use_position_embeddings);
}
#[test]
fn test_default_validate_passes() {
let cfg = HierarchicalConfig::default();
let result = Config::validate(&cfg);
assert!(result.is_ok());
}
#[test]
fn test_architecture_name() {
let cfg = HierarchicalConfig::default();
assert_eq!(cfg.architecture(), "hierarchical");
}
#[test]
fn test_zero_num_levels_fails_trait_validate() {
let cfg = HierarchicalConfig {
num_levels: 0,
max_seq_lengths: vec![],
..HierarchicalConfig::default()
};
assert!(Config::validate(&cfg).is_err());
}
#[test]
fn test_validate_method_zero_levels_fails() {
let cfg = HierarchicalConfig {
num_levels: 0,
max_seq_lengths: vec![],
..HierarchicalConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_validate_zero_reduction_factor_fails() {
let cfg = HierarchicalConfig {
reduction_factor: 0,
..HierarchicalConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_validate_hidden_not_divisible_fails() {
let cfg = HierarchicalConfig {
hidden_size: 100,
num_heads: 12,
..HierarchicalConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_validate_dropout_out_of_range_fails() {
let cfg = HierarchicalConfig {
dropout: 1.5,
..HierarchicalConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_validate_seq_lengths_mismatch_fails() {
let cfg = HierarchicalConfig {
num_levels: 4,
max_seq_lengths: vec![512, 256], ..HierarchicalConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_get_hidden_size_non_pyramid() {
let cfg = HierarchicalConfig::default();
assert_eq!(cfg.get_hidden_size(0), cfg.hidden_size);
assert_eq!(cfg.get_hidden_size(2), cfg.hidden_size);
}
#[test]
fn test_get_hidden_size_pyramid() {
let cfg = HierarchicalConfig::pyramid(768, 4);
assert_eq!(cfg.get_hidden_size(0), 768);
assert_eq!(cfg.get_hidden_size(1), 384);
}
#[test]
fn test_get_seq_length_within_bounds() {
let cfg = HierarchicalConfig::default();
assert_eq!(cfg.get_seq_length(0), cfg.max_seq_lengths[0]);
assert_eq!(cfg.get_seq_length(1), cfg.max_seq_lengths[1]);
}
#[test]
fn test_get_seq_length_out_of_bounds_uses_formula() {
let cfg = HierarchicalConfig::default();
let lvl = cfg.max_seq_lengths.len() + 2;
let expected = 512 / (2_usize.pow(lvl as u32));
assert_eq!(cfg.get_seq_length(lvl), expected);
}
#[test]
fn test_get_reduction_factor_level_zero() {
let cfg = HierarchicalConfig::default();
assert_eq!(cfg.get_reduction_factor(0), 1); }
#[test]
fn test_get_reduction_factor_level_two() {
let cfg = HierarchicalConfig::default();
assert_eq!(cfg.get_reduction_factor(2), 4); }
#[test]
fn test_hierarchical_factory() {
let cfg = HierarchicalConfig::hierarchical(512, 3);
assert_eq!(cfg.hidden_size, 512);
assert_eq!(cfg.num_levels, 3);
assert_eq!(cfg.max_seq_lengths.len(), 3);
}
#[test]
fn test_pyramid_factory_sets_pyramid_config() {
let cfg = HierarchicalConfig::pyramid(768, 4);
assert!(cfg.pyramid_config.is_some());
assert_eq!(cfg.num_levels, 4);
}
#[test]
fn test_tree_factory_sets_tree_config() {
let cfg = HierarchicalConfig::tree(512, 2, 5);
assert!(cfg.tree_config.is_some());
if let Some(tc) = &cfg.tree_config {
assert_eq!(tc.branching_factor, 2);
assert_eq!(tc.max_depth, 5);
}
}
#[test]
fn test_nested_factory_sets_nested_config() {
let cfg = HierarchicalConfig::nested(768, 3);
assert!(cfg.nested_config.is_some());
if let Some(nc) = &cfg.nested_config {
assert_eq!(nc.num_nested_levels, 3);
}
}
#[test]
fn test_estimate_parameters_nonzero() {
let cfg = HierarchicalConfig::default();
let params = cfg.estimate_parameters();
assert!(params > 0);
}
#[test]
fn test_tree_config_default() {
let tc = TreeConfig::default();
assert_eq!(tc.branching_factor, 2);
assert_eq!(tc.max_depth, 8);
assert!(!tc.learnable_structure);
}
#[test]
fn test_pyramid_config_default() {
let pc = PyramidConfig::default();
assert!(!pc.scaling_factors.is_empty());
assert!(pc.skip_connections);
assert!(!pc.use_fpn);
}
#[test]
fn test_nested_config_default() {
let nc = NestedConfig::default();
assert_eq!(nc.num_nested_levels, 3);
assert!(!nc.share_parameters);
assert!(!nc.progressive_training);
}
#[test]
fn test_lcg_values_in_range() {
let mut rng = Lcg::new(314159);
for _ in 0..100 {
let v = rng.next_f32();
assert!((0.0..1.0).contains(&v));
}
}
}