use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
#[derive(Debug)]
pub struct SoftToken<T> {
data: T,
out_degree: AtomicU32,
active: AtomicBool,
frame: u32,
cost: f32,
}
impl<T> SoftToken<T> {
pub fn new(data: T, out_degree: u32, frame: u32, cost: f32) -> Self {
Self {
data,
out_degree: AtomicU32::new(out_degree),
active: AtomicBool::new(true),
frame,
cost,
}
}
pub fn data(&self) -> &T {
&self.data
}
pub fn data_mut(&mut self) -> &mut T {
&mut self.data
}
pub fn out_degree(&self) -> u32 {
self.out_degree.load(Ordering::Acquire)
}
pub fn set_out_degree(&self, degree: u32) {
self.out_degree.store(degree, Ordering::Release);
}
pub fn is_active(&self) -> bool {
self.active.load(Ordering::Acquire)
}
pub fn is_pruned(&self) -> bool {
!self.is_active() || self.out_degree() == 0
}
pub fn frame(&self) -> u32 {
self.frame
}
pub fn cost(&self) -> f32 {
self.cost
}
pub fn soft_prune(&self) {
self.out_degree.store(0, Ordering::Release);
self.active.store(false, Ordering::Release);
}
pub fn should_prune(&self, threshold: f32) -> bool {
self.cost > threshold
}
pub fn prune_if_above(&self, threshold: f32) -> bool {
if self.should_prune(threshold) {
self.soft_prune();
true
} else {
false
}
}
}
impl<T: Clone> Clone for SoftToken<T> {
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
out_degree: AtomicU32::new(self.out_degree.load(Ordering::Relaxed)),
active: AtomicBool::new(self.active.load(Ordering::Relaxed)),
frame: self.frame,
cost: self.cost,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct SoftPruneConfig {
pub beam: f32,
pub max_active: usize,
pub compact_threshold: f32,
pub min_tokens_for_compact: usize,
}
impl SoftPruneConfig {
pub fn new(beam: f32, max_active: usize) -> Self {
Self {
beam,
max_active,
compact_threshold: 0.5,
min_tokens_for_compact: 1000,
}
}
pub fn with_compaction(
beam: f32,
max_active: usize,
compact_threshold: f32,
min_tokens_for_compact: usize,
) -> Self {
Self {
beam,
max_active,
compact_threshold,
min_tokens_for_compact,
}
}
}
impl Default for SoftPruneConfig {
fn default() -> Self {
Self::new(16.0, 10000)
}
}
#[derive(Debug)]
pub struct SoftPruneBuffer<T> {
tokens: Vec<SoftToken<T>>,
active_count: AtomicUsize,
config: SoftPruneConfig,
best_cost: f32,
}
impl<T> SoftPruneBuffer<T> {
pub fn new(config: SoftPruneConfig) -> Self {
Self {
tokens: Vec::new(),
active_count: AtomicUsize::new(0),
config,
best_cost: f32::INFINITY,
}
}
pub fn with_capacity(config: SoftPruneConfig, capacity: usize) -> Self {
Self {
tokens: Vec::with_capacity(capacity),
active_count: AtomicUsize::new(0),
config,
best_cost: f32::INFINITY,
}
}
pub fn config(&self) -> &SoftPruneConfig {
&self.config
}
pub fn active_count(&self) -> usize {
self.active_count.load(Ordering::Acquire)
}
pub fn total_count(&self) -> usize {
self.tokens.len()
}
pub fn pruned_count(&self) -> usize {
self.tokens.iter().filter(|t| t.is_pruned()).count()
}
pub fn actual_active_count(&self) -> usize {
self.tokens.iter().filter(|t| t.is_active()).count()
}
pub fn best_cost(&self) -> f32 {
self.best_cost
}
pub fn threshold(&self) -> f32 {
self.best_cost + self.config.beam
}
pub fn needs_compaction(&self) -> bool {
let total = self.total_count();
if total < self.config.min_tokens_for_compact {
return false;
}
let pruned_ratio = self.pruned_count() as f32 / total as f32;
pruned_ratio >= self.config.compact_threshold
}
pub fn push(&mut self, token: SoftToken<T>) -> Option<usize> {
if token.cost() < self.best_cost {
self.best_cost = token.cost();
}
if token.should_prune(self.threshold()) {
return None;
}
let index = self.tokens.len();
self.tokens.push(token);
self.active_count.fetch_add(1, Ordering::AcqRel);
Some(index)
}
pub fn get(&self, index: usize) -> Option<&SoftToken<T>> {
self.tokens.get(index)
}
pub fn active_tokens(&self) -> impl Iterator<Item = (usize, &SoftToken<T>)> {
self.tokens
.iter()
.enumerate()
.filter(|(_, t)| t.is_active())
}
pub fn apply_beam_pruning(&self) -> usize {
let threshold = self.threshold();
let mut pruned = 0;
for token in &self.tokens {
if token.is_active() && token.prune_if_above(threshold) {
self.active_count.fetch_sub(1, Ordering::AcqRel);
pruned += 1;
}
}
pruned
}
pub fn update_best_and_prune(&mut self, new_best: f32) -> usize {
if new_best < self.best_cost {
self.best_cost = new_best;
}
self.apply_beam_pruning()
}
pub fn clear(&mut self) {
self.tokens.clear();
self.active_count.store(0, Ordering::Release);
self.best_cost = f32::INFINITY;
}
pub fn reset_for_frame(&mut self) {
self.clear();
}
}
impl<T: Clone> SoftPruneBuffer<T> {
pub fn compact(&mut self) -> usize {
let original_len = self.tokens.len();
self.tokens.retain(|t| t.is_active());
let removed = original_len - self.tokens.len();
self.active_count
.store(self.tokens.len(), Ordering::Release);
removed
}
pub fn compact_if_needed(&mut self) -> usize {
if self.needs_compaction() {
self.compact()
} else {
0
}
}
pub fn into_survivors(self) -> Vec<T> {
self.tokens
.into_iter()
.filter(|t| t.is_active())
.map(|t| t.data)
.collect()
}
}
#[derive(Clone, Debug, Default)]
pub struct SoftPruneStats {
pub total_tokens: usize,
pub beam_pruned: usize,
pub limit_pruned: usize,
pub compactions: usize,
pub compacted_tokens: usize,
}
impl SoftPruneStats {
pub fn new() -> Self {
Self::default()
}
pub fn total_pruned(&self) -> usize {
self.beam_pruned + self.limit_pruned
}
pub fn prune_ratio(&self) -> f64 {
if self.total_tokens == 0 {
0.0
} else {
self.total_pruned() as f64 / self.total_tokens as f64
}
}
pub fn compaction_efficiency(&self) -> f64 {
if self.compactions == 0 {
0.0
} else {
self.compacted_tokens as f64 / self.compactions as f64
}
}
pub fn record_beam_prune(&mut self, count: usize) {
self.beam_pruned += count;
}
pub fn record_limit_prune(&mut self, count: usize) {
self.limit_pruned += count;
}
pub fn record_compaction(&mut self, tokens_removed: usize) {
self.compactions += 1;
self.compacted_tokens += tokens_removed;
}
pub fn record_tokens(&mut self, count: usize) {
self.total_tokens += count;
}
pub fn merge(&mut self, other: &SoftPruneStats) {
self.total_tokens += other.total_tokens;
self.beam_pruned += other.beam_pruned;
self.limit_pruned += other.limit_pruned;
self.compactions += other.compactions;
self.compacted_tokens += other.compacted_tokens;
}
}
#[derive(Debug)]
pub struct AdaptiveBeam {
num_buckets: usize,
buckets: Vec<AtomicUsize>,
min_cost: f32,
max_cost: f32,
target_active: usize,
}
impl AdaptiveBeam {
pub fn new(num_buckets: usize, target_active: usize) -> Self {
Self {
num_buckets,
buckets: (0..num_buckets).map(|_| AtomicUsize::new(0)).collect(),
min_cost: f32::INFINITY,
max_cost: f32::NEG_INFINITY,
target_active,
}
}
pub fn reset(&mut self) {
for bucket in &self.buckets {
bucket.store(0, Ordering::Relaxed);
}
self.min_cost = f32::INFINITY;
self.max_cost = f32::NEG_INFINITY;
}
pub fn add(&mut self, cost: f32) {
if cost < self.min_cost {
self.min_cost = cost;
}
if cost > self.max_cost {
self.max_cost = cost;
}
}
pub fn set_range(&mut self, min_cost: f32, max_cost: f32) {
self.min_cost = min_cost;
self.max_cost = max_cost;
}
pub fn add_with_range(&self, cost: f32, range_min: f32, range_max: f32) {
if range_max <= range_min {
return;
}
let normalized = (cost - range_min) / (range_max - range_min);
let bucket = ((normalized * self.num_buckets as f32) as usize).min(self.num_buckets - 1);
self.buckets[bucket].fetch_add(1, Ordering::Relaxed);
}
pub fn compute_threshold(&self) -> f32 {
if self.max_cost <= self.min_cost {
return f32::INFINITY;
}
let mut cumulative = 0;
let bucket_width = (self.max_cost - self.min_cost) / self.num_buckets as f32;
for (i, bucket) in self.buckets.iter().enumerate() {
cumulative += bucket.load(Ordering::Relaxed);
if cumulative >= self.target_active {
return self.min_cost + (i + 1) as f32 * bucket_width;
}
}
f32::INFINITY
}
pub fn total_count(&self) -> usize {
self.buckets.iter().map(|b| b.load(Ordering::Relaxed)).sum()
}
}
#[derive(Debug)]
pub struct SoftPruneManager<T> {
current: SoftPruneBuffer<T>,
previous: SoftPruneBuffer<T>,
frame: u32,
adaptive_beam: AdaptiveBeam,
stats: SoftPruneStats,
}
impl<T> SoftPruneManager<T> {
pub fn new(config: SoftPruneConfig) -> Self {
Self {
current: SoftPruneBuffer::with_capacity(config, config.max_active),
previous: SoftPruneBuffer::with_capacity(config, config.max_active),
frame: 0,
adaptive_beam: AdaptiveBeam::new(100, config.max_active),
stats: SoftPruneStats::new(),
}
}
pub fn frame(&self) -> u32 {
self.frame
}
pub fn current(&self) -> &SoftPruneBuffer<T> {
&self.current
}
pub fn current_mut(&mut self) -> &mut SoftPruneBuffer<T> {
&mut self.current
}
pub fn previous(&self) -> &SoftPruneBuffer<T> {
&self.previous
}
pub fn stats(&self) -> &SoftPruneStats {
&self.stats
}
pub fn add_token(&mut self, data: T, out_degree: u32, cost: f32) -> Option<usize> {
let token = SoftToken::new(data, out_degree, self.frame, cost);
self.stats.record_tokens(1);
self.current.push(token)
}
pub fn apply_pruning(&mut self) -> usize {
let pruned = self.current.apply_beam_pruning();
self.stats.record_beam_prune(pruned);
pruned
}
}
impl<T: Clone> SoftPruneManager<T> {
pub fn advance_frame(&mut self) {
let compacted = self.current.compact_if_needed();
if compacted > 0 {
self.stats.record_compaction(compacted);
}
std::mem::swap(&mut self.current, &mut self.previous);
self.current.reset_for_frame();
self.adaptive_beam.reset();
self.frame += 1;
}
pub fn survivors(&self) -> Vec<T> {
self.current
.tokens
.iter()
.filter(|t| t.is_active())
.map(|t| t.data.clone())
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_soft_token_creation() {
let token = SoftToken::new(42, 5, 0, 1.5);
assert_eq!(*token.data(), 42);
assert_eq!(token.out_degree(), 5);
assert_eq!(token.frame(), 0);
assert!((token.cost() - 1.5).abs() < 1e-6);
assert!(token.is_active());
assert!(!token.is_pruned());
}
#[test]
fn test_soft_token_pruning() {
let token = SoftToken::new(42, 5, 0, 1.5);
assert!(token.is_active());
token.soft_prune();
assert!(!token.is_active());
assert!(token.is_pruned());
assert_eq!(token.out_degree(), 0);
}
#[test]
fn test_soft_token_threshold_pruning() {
let token = SoftToken::new(42, 5, 0, 10.0);
assert!(!token.prune_if_above(15.0)); assert!(token.is_active());
assert!(token.prune_if_above(5.0)); assert!(!token.is_active());
}
#[test]
fn test_soft_prune_config() {
let config = SoftPruneConfig::new(10.0, 5000);
assert!((config.beam - 10.0).abs() < 1e-6);
assert_eq!(config.max_active, 5000);
}
#[test]
fn test_soft_prune_buffer() {
let config = SoftPruneConfig::new(10.0, 100);
let mut buffer = SoftPruneBuffer::new(config);
let idx1 = buffer.push(SoftToken::new(1, 3, 0, 1.0));
let idx2 = buffer.push(SoftToken::new(2, 2, 0, 2.0));
let idx3 = buffer.push(SoftToken::new(3, 1, 0, 15.0));
assert!(idx1.is_some());
assert!(idx2.is_some());
assert!(idx3.is_none());
assert_eq!(buffer.active_count(), 2);
assert_eq!(buffer.total_count(), 2);
}
#[test]
fn test_soft_prune_buffer_beam_pruning() {
let config = SoftPruneConfig::new(5.0, 100);
let mut buffer = SoftPruneBuffer::new(config);
buffer.push(SoftToken::new(1, 3, 0, 1.0));
buffer.push(SoftToken::new(2, 2, 0, 3.0));
buffer.push(SoftToken::new(3, 1, 0, 4.0));
assert_eq!(buffer.active_count(), 3);
let pruned = buffer.update_best_and_prune(0.5);
assert_eq!(pruned, 0);
assert_eq!(buffer.active_count(), 3);
}
#[test]
fn test_soft_prune_buffer_compact() {
let config = SoftPruneConfig::new(10.0, 100);
let mut buffer = SoftPruneBuffer::new(config);
buffer.push(SoftToken::new(1, 3, 0, 1.0));
buffer.push(SoftToken::new(2, 2, 0, 2.0));
buffer.push(SoftToken::new(3, 1, 0, 3.0));
buffer
.get(1)
.expect("gpu/soft_prune.rs: required value was None/Err")
.soft_prune();
assert_eq!(buffer.total_count(), 3);
let removed = buffer.compact();
assert_eq!(removed, 1);
assert_eq!(buffer.total_count(), 2);
}
#[test]
fn test_soft_prune_stats() {
let mut stats = SoftPruneStats::new();
stats.record_tokens(100);
stats.record_beam_prune(20);
stats.record_limit_prune(10);
stats.record_compaction(15);
assert_eq!(stats.total_tokens, 100);
assert_eq!(stats.total_pruned(), 30);
assert!((stats.prune_ratio() - 0.3).abs() < 1e-6);
assert_eq!(stats.compactions, 1);
assert_eq!(stats.compacted_tokens, 15);
}
#[test]
fn test_adaptive_beam() {
let mut beam = AdaptiveBeam::new(10, 50);
beam.set_range(0.0, 100.0);
for i in 0..100 {
beam.add_with_range(i as f32, 0.0, 100.0);
}
let threshold = beam.compute_threshold();
assert!(threshold > 40.0 && threshold < 60.0);
}
#[test]
fn test_soft_prune_manager() {
let config = SoftPruneConfig::new(10.0, 100);
let mut manager = SoftPruneManager::new(config);
manager.add_token(1, 3, 1.0);
manager.add_token(2, 2, 2.0);
manager.add_token(3, 1, 3.0);
assert_eq!(manager.current().active_count(), 3);
assert_eq!(manager.frame(), 0);
manager.advance_frame();
assert_eq!(manager.frame(), 1);
assert_eq!(manager.current().active_count(), 0);
assert_eq!(manager.previous().active_count(), 3);
}
#[test]
fn test_soft_prune_manager_survivors() {
let config = SoftPruneConfig::new(10.0, 100);
let mut manager = SoftPruneManager::new(config);
manager.add_token(1, 3, 1.0);
manager.add_token(2, 2, 2.0);
manager
.current()
.get(0)
.expect("gpu/soft_prune.rs: required value was None/Err")
.soft_prune();
let survivors = manager.survivors();
assert_eq!(survivors.len(), 1);
assert_eq!(survivors[0], 2);
}
#[test]
fn test_needs_compaction() {
let config = SoftPruneConfig::with_compaction(10.0, 100, 0.5, 4);
let mut buffer = SoftPruneBuffer::new(config);
for i in 0..10 {
buffer.push(SoftToken::new(i, 3, 0, i as f32));
}
assert!(!buffer.needs_compaction());
for i in 0..6 {
buffer
.get(i)
.expect("gpu/soft_prune.rs: required value was None/Err")
.soft_prune();
}
assert!(buffer.needs_compaction()); }
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn soft_token_new_is_active(
data in any::<i32>(),
out_degree in 1u32..100,
frame in 0u32..1000,
cost in -1000.0f32..1000.0
) {
let token = SoftToken::new(data, out_degree, frame, cost);
prop_assert!(token.is_active());
prop_assert!(!token.is_pruned());
prop_assert_eq!(*token.data(), data);
prop_assert_eq!(token.out_degree(), out_degree);
prop_assert_eq!(token.frame(), frame);
prop_assert!((token.cost() - cost).abs() < 1e-6);
}
#[test]
fn soft_prune_makes_inactive(
data in any::<i32>(),
out_degree in 1u32..100,
frame in 0u32..1000,
cost in -1000.0f32..1000.0
) {
let token = SoftToken::new(data, out_degree, frame, cost);
token.soft_prune();
prop_assert!(!token.is_active());
prop_assert!(token.is_pruned());
prop_assert_eq!(token.out_degree(), 0);
}
#[test]
fn prune_if_above_correct(
cost in 0.0f32..100.0,
threshold in 0.0f32..100.0
) {
let token = SoftToken::new(42, 5, 0, cost);
let was_pruned = token.prune_if_above(threshold);
if cost > threshold {
prop_assert!(was_pruned);
prop_assert!(!token.is_active());
} else {
prop_assert!(!was_pruned);
prop_assert!(token.is_active());
}
}
#[test]
fn soft_token_clone_preserves_fields(
data in any::<i32>(),
out_degree in 0u32..100,
frame in 0u32..1000,
cost in -1000.0f32..1000.0
) {
let token = SoftToken::new(data, out_degree, frame, cost);
let cloned = token.clone();
prop_assert_eq!(*cloned.data(), data);
prop_assert_eq!(cloned.out_degree(), out_degree);
prop_assert_eq!(cloned.frame(), frame);
prop_assert!((cloned.cost() - cost).abs() < 1e-6);
prop_assert_eq!(cloned.is_active(), token.is_active());
}
#[test]
fn zero_out_degree_is_pruned(
data in any::<i32>(),
frame in 0u32..1000,
cost in -1000.0f32..1000.0
) {
let token = SoftToken::new(data, 0, frame, cost);
prop_assert!(token.is_pruned());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn config_preserves_settings(
beam in 0.1f32..100.0,
max_active in 1usize..10000
) {
let config = SoftPruneConfig::new(beam, max_active);
prop_assert!((config.beam - beam).abs() < 1e-6);
prop_assert_eq!(config.max_active, max_active);
}
#[test]
fn config_custom_compaction(
beam in 0.1f32..100.0,
max_active in 1usize..10000,
compact_threshold in 0.0f32..1.0,
min_tokens in 0usize..1000
) {
let config = SoftPruneConfig::with_compaction(
beam,
max_active,
compact_threshold,
min_tokens,
);
prop_assert!((config.beam - beam).abs() < 1e-6);
prop_assert_eq!(config.max_active, max_active);
prop_assert!((config.compact_threshold - compact_threshold).abs() < 1e-6);
prop_assert_eq!(config.min_tokens_for_compact, min_tokens);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn buffer_push_increases_active_count(
num_tokens in 1usize..20,
beam in 10.0f32..100.0
) {
let config = SoftPruneConfig::new(beam, 1000);
let mut buffer = SoftPruneBuffer::new(config);
for i in 0..num_tokens {
buffer.push(SoftToken::new(i as i32, 5, 0, i as f32 * 0.1));
}
prop_assert_eq!(buffer.active_count(), num_tokens);
prop_assert_eq!(buffer.total_count(), num_tokens);
}
#[test]
fn buffer_rejects_high_cost_tokens(
base_cost in 0.0f32..10.0,
beam in 1.0f32..5.0
) {
let config = SoftPruneConfig::new(beam, 1000);
let mut buffer = SoftPruneBuffer::new(config);
let idx1 = buffer.push(SoftToken::new(1, 5, 0, base_cost));
prop_assert!(idx1.is_some());
let high_cost = base_cost + beam + 1.0;
let idx2 = buffer.push(SoftToken::new(2, 5, 0, high_cost));
prop_assert!(idx2.is_none());
prop_assert_eq!(buffer.active_count(), 1);
}
#[test]
fn buffer_compact_removes_pruned(
num_tokens in 5usize..20,
num_to_prune in 1usize..5
) {
let config = SoftPruneConfig::new(100.0, 1000);
let mut buffer = SoftPruneBuffer::<i32>::new(config);
for i in 0..num_tokens {
buffer.push(SoftToken::new(i as i32, 5, 0, i as f32));
}
let actual_prune = num_to_prune.min(num_tokens);
for i in 0..actual_prune {
if let Some(token) = buffer.get(i) {
token.soft_prune();
}
}
let removed = buffer.compact();
prop_assert_eq!(removed, actual_prune);
prop_assert_eq!(buffer.total_count(), num_tokens - actual_prune);
}
#[test]
fn buffer_into_survivors_only_active(
num_tokens in 2usize..10,
prune_indices in proptest::collection::vec(0usize..10, 0..5)
) {
let config = SoftPruneConfig::new(100.0, 1000);
let mut buffer = SoftPruneBuffer::new(config);
for i in 0..num_tokens {
buffer.push(SoftToken::new(i as i32, 5, 0, i as f32 * 0.5));
}
let mut pruned_count = 0;
for &idx in &prune_indices {
if idx < num_tokens {
if let Some(token) = buffer.get(idx) {
if token.is_active() {
token.soft_prune();
pruned_count += 1;
}
}
}
}
let survivors = buffer.into_survivors();
prop_assert_eq!(survivors.len(), num_tokens - pruned_count);
}
#[test]
fn buffer_tracks_best_cost(costs in proptest::collection::vec(0.0f32..100.0, 1..20)) {
let config = SoftPruneConfig::new(200.0, 1000);
let mut buffer = SoftPruneBuffer::new(config);
for (i, &cost) in costs.iter().enumerate() {
buffer.push(SoftToken::new(i as i32, 5, 0, cost));
}
let expected_best = costs.iter().cloned().fold(f32::INFINITY, f32::min);
prop_assert!((buffer.best_cost() - expected_best).abs() < 1e-6);
}
#[test]
fn buffer_clear_resets(num_tokens in 1usize..20) {
let config = SoftPruneConfig::new(100.0, 1000);
let mut buffer = SoftPruneBuffer::new(config);
for i in 0..num_tokens {
buffer.push(SoftToken::new(i as i32, 5, 0, i as f32));
}
buffer.clear();
prop_assert_eq!(buffer.active_count(), 0);
prop_assert_eq!(buffer.total_count(), 0);
prop_assert!(buffer.best_cost().is_infinite());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn stats_total_pruned_correct(
beam_pruned in 0usize..1000,
limit_pruned in 0usize..1000
) {
let mut stats = SoftPruneStats::new();
stats.record_beam_prune(beam_pruned);
stats.record_limit_prune(limit_pruned);
prop_assert_eq!(stats.total_pruned(), beam_pruned + limit_pruned);
}
#[test]
fn stats_prune_ratio_bounded(
total_tokens in 1usize..1000,
beam_pruned in 0usize..500,
limit_pruned in 0usize..500
) {
let mut stats = SoftPruneStats::new();
stats.record_tokens(total_tokens);
stats.record_beam_prune(beam_pruned.min(total_tokens));
stats.record_limit_prune(limit_pruned.min(total_tokens - beam_pruned.min(total_tokens)));
let ratio = stats.prune_ratio();
prop_assert!(ratio >= 0.0);
prop_assert!(ratio <= 1.0);
}
#[test]
fn stats_prune_ratio_zero_tokens(
beam_pruned in 0usize..100,
limit_pruned in 0usize..100
) {
let mut stats = SoftPruneStats::new();
stats.record_beam_prune(beam_pruned);
stats.record_limit_prune(limit_pruned);
prop_assert!((stats.prune_ratio() - 0.0).abs() < 1e-10);
}
#[test]
fn stats_merge_correct(
total1 in 0usize..500,
beam1 in 0usize..250,
limit1 in 0usize..250,
compact1 in 0usize..100,
total2 in 0usize..500,
beam2 in 0usize..250,
limit2 in 0usize..250,
compact2 in 0usize..100
) {
let mut stats1 = SoftPruneStats::new();
stats1.record_tokens(total1);
stats1.record_beam_prune(beam1);
stats1.record_limit_prune(limit1);
stats1.record_compaction(compact1);
let mut stats2 = SoftPruneStats::new();
stats2.record_tokens(total2);
stats2.record_beam_prune(beam2);
stats2.record_limit_prune(limit2);
stats2.record_compaction(compact2);
stats1.merge(&stats2);
prop_assert_eq!(stats1.total_tokens, total1 + total2);
prop_assert_eq!(stats1.beam_pruned, beam1 + beam2);
prop_assert_eq!(stats1.limit_pruned, limit1 + limit2);
prop_assert_eq!(stats1.compactions, 2);
prop_assert_eq!(stats1.compacted_tokens, compact1 + compact2);
}
#[test]
fn stats_compaction_efficiency_zero(compacted in 0usize..100) {
let mut stats = SoftPruneStats::new();
stats.compacted_tokens = compacted;
prop_assert!((stats.compaction_efficiency() - 0.0).abs() < 1e-10);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn adaptive_beam_reset_clears(
num_buckets in 5usize..50,
target in 10usize..100,
num_adds in 1usize..50
) {
let mut beam = AdaptiveBeam::new(num_buckets, target);
beam.set_range(0.0, 100.0);
for i in 0..num_adds {
beam.add_with_range(i as f32, 0.0, 100.0);
}
beam.reset();
prop_assert_eq!(beam.total_count(), 0);
prop_assert!(beam.min_cost.is_infinite());
prop_assert!(beam.max_cost.is_infinite() && beam.max_cost.is_sign_negative());
}
#[test]
fn adaptive_beam_total_count_correct(
num_buckets in 5usize..50,
target in 10usize..100,
costs in proptest::collection::vec(0.0f32..100.0, 1..50)
) {
let beam = AdaptiveBeam::new(num_buckets, target);
for &cost in &costs {
beam.add_with_range(cost, 0.0, 100.0);
}
prop_assert_eq!(beam.total_count(), costs.len());
}
#[test]
fn adaptive_beam_threshold_infinity_when_fits(
num_buckets in 5usize..20,
num_costs in 1usize..10
) {
let target = num_costs + 10;
let mut beam = AdaptiveBeam::new(num_buckets, target);
beam.set_range(0.0, 100.0);
for i in 0..num_costs {
beam.add_with_range(i as f32 * 5.0, 0.0, 100.0);
}
let threshold = beam.compute_threshold();
prop_assert!(threshold.is_infinite());
}
#[test]
fn adaptive_beam_threshold_in_range(
num_buckets in 10usize..50,
costs in proptest::collection::vec(0.0f32..100.0, 20..100)
) {
let target = costs.len() / 3;
let mut beam = AdaptiveBeam::new(num_buckets, target);
let min_cost = costs.iter().cloned().fold(f32::INFINITY, f32::min);
let max_cost = costs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
if max_cost > min_cost {
beam.set_range(min_cost, max_cost);
for &cost in &costs {
beam.add_with_range(cost, min_cost, max_cost);
}
let threshold = beam.compute_threshold();
if !threshold.is_infinite() {
prop_assert!(threshold >= min_cost);
prop_assert!(threshold <= max_cost + 1e-6);
}
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn manager_advance_frame_increments(num_advances in 1usize..10) {
let config = SoftPruneConfig::new(100.0, 1000);
let mut manager = SoftPruneManager::<i32>::new(config);
for i in 0..num_advances {
prop_assert_eq!(manager.frame(), i as u32);
manager.advance_frame();
}
prop_assert_eq!(manager.frame(), num_advances as u32);
}
#[test]
fn manager_frame_swap(num_tokens in 1usize..10) {
let config = SoftPruneConfig::new(100.0, 1000);
let mut manager = SoftPruneManager::new(config);
for i in 0..num_tokens {
manager.add_token(i as i32, 5, i as f32);
}
let current_count_before = manager.current().active_count();
manager.advance_frame();
prop_assert_eq!(manager.previous().active_count(), current_count_before);
prop_assert_eq!(manager.current().active_count(), 0);
}
#[test]
fn manager_survivors_only_active(
num_tokens in 2usize..10,
prune_first in any::<bool>()
) {
let config = SoftPruneConfig::new(100.0, 1000);
let mut manager = SoftPruneManager::new(config);
for i in 0..num_tokens {
manager.add_token(i as i32, 5, i as f32);
}
let expected_survivors = if prune_first {
manager.current().get(0).expect("gpu/soft_prune.rs: required value was None/Err").soft_prune();
num_tokens - 1
} else {
num_tokens
};
let survivors = manager.survivors();
prop_assert_eq!(survivors.len(), expected_survivors);
}
#[test]
fn manager_stats_track_tokens(num_tokens in 1usize..20) {
let config = SoftPruneConfig::new(100.0, 1000);
let mut manager = SoftPruneManager::new(config);
for i in 0..num_tokens {
manager.add_token(i as i32, 5, i as f32);
}
prop_assert_eq!(manager.stats().total_tokens, num_tokens);
}
#[test]
fn manager_apply_pruning_updates_stats(num_tokens in 5usize..20) {
let beam = 5.0f32;
let config = SoftPruneConfig::new(beam, 1000);
let mut manager = SoftPruneManager::new(config);
for i in 0..num_tokens {
manager.add_token(i as i32, 5, i as f32);
}
let pruned = manager.apply_pruning();
prop_assert_eq!(manager.stats().beam_pruned, pruned);
}
}
}