#![allow(unused_imports)]
pub mod core;
pub mod stacking;
pub mod builder;
pub mod optimized;
pub mod advanced;
pub mod utils;
pub mod examples;
pub use builder::{CollateBuilder, CollateStrategy};
pub use core::{Collate, DefaultCollate};
pub use stacking::TensorStacker;
pub use utils::{collate_fn, CollateFn};
pub use optimized::{optimized_collate_fn, stack_tensors, OptimizedCollate};
pub use advanced::{
AdaptiveBatchSampler, BucketBatchSampler, CachedCollate, DynamicBatchCollate,
DynamicBatchCollateWrapper, PadCollate,
};
#[cfg(feature = "sparse")]
pub use advanced::{collate_sparse_tensors, MixedCollate, SparseCollate};
pub use examples::{collate_data_label, collate_dict};
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::ones;
#[test]
fn test_default_collate() {
let batch = vec![
ones::<f32>(&[3, 4]).expect("operation should succeed"),
ones::<f32>(&[3, 4]).expect("operation should succeed"),
ones::<f32>(&[3, 4]).expect("operation should succeed"),
];
let collate = DefaultCollate;
let result = collate.collate(batch);
assert!(result.is_ok());
}
#[test]
fn test_custom_collate_fn() {
let collate = CollateFn::new(|batch: Vec<i32>| Ok(batch.iter().sum::<i32>()));
let result = collate
.collate(vec![1, 2, 3, 4, 5])
.expect("collation should succeed");
assert_eq!(result, 15);
}
#[test]
fn test_pad_collate() {
let batch = vec![
ones::<f32>(&[2, 3]).expect("operation should succeed"),
ones::<f32>(&[2, 3]).expect("operation should succeed"),
];
let collate = PadCollate::new(0.0f32);
let result = collate.collate(batch);
assert!(result.is_ok());
}
#[cfg(feature = "sparse")]
#[test]
fn test_sparse_collate() {
use torsh_sparse::{CooTensor, SparseFormat};
use torsh_tensor::creation::zeros;
let dense1 = zeros::<f32>(&[2, 3]).expect("operation should succeed");
let dense2 = zeros::<f32>(&[2, 3]).expect("operation should succeed");
let _sparse1 = torsh_sparse::sparse_from_dense(&dense1, SparseFormat::Coo, None)
.expect("torsh sparse should succeed");
let _sparse2 = torsh_sparse::sparse_from_dense(&dense2, SparseFormat::Coo, None)
.expect("torsh sparse should succeed");
}
}