pub mod types;
pub use types::{
utils,
AudioMetadata,
CollatedBatch,
DataExample,
DocumentFormat,
DocumentMetadata,
FeatureInput,
FeatureOutput,
ImageFormat,
ImageMetadata,
MultimodalMetadata,
PaddingStrategy,
SpecialToken,
TextMetadata,
};
pub mod feature_extractors;
pub mod data_collators;
pub mod metrics;
pub mod optimizers;
pub use feature_extractors::{
AudioFeatureExtractor, AutoFeatureExtractor, DocumentFeatureExtractor, FeatureExtractor,
FeatureExtractorConfig, GenericFeatureExtractor, VisionFeatureExtractor,
};
pub use data_collators::*;
pub use metrics::*;
pub use optimizers::{
AdamConfig, AdamOptimizer, AdamWConfig, AdamWOptimizer, AutoOptimizer, LearningRateSchedule,
Optimizer, OptimizerGradients, OptimizerUpdate, ScheduledOptimizer,
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_padding_strategy_reachable() {
let ps = PaddingStrategy::Longest;
assert!(ps.should_pad(), "PaddingStrategy::Longest should pad");
}
#[test]
fn test_padding_strategy_none_reachable() {
let ps = PaddingStrategy::None;
assert!(!ps.should_pad(), "PaddingStrategy::None should not pad");
}
#[test]
fn test_image_format_reachable() {
let fmt = ImageFormat::Jpeg;
assert_eq!(fmt.extension(), "jpg", "Jpeg extension should be 'jpg'");
}
#[test]
fn test_image_format_mime_type() {
let fmt = ImageFormat::Png;
assert_eq!(fmt.mime_type(), "image/png");
}
#[test]
fn test_document_format_reachable() {
let fmt = DocumentFormat::Pdf;
assert_eq!(fmt.extension(), "pdf");
assert_eq!(fmt.mime_type(), "application/pdf");
}
#[test]
fn test_image_metadata_reachable() {
let meta = ImageMetadata {
width: 640,
height: 480,
channels: 3,
dpi: Some(96),
};
assert_eq!(meta.width, 640);
assert_eq!(meta.height, 480);
}
#[test]
fn test_data_example_reachable() {
let example = DataExample::new(vec![1, 2, 3]);
assert_eq!(example.sequence_length(), 3);
}
#[test]
fn test_collated_batch_reachable() {
let batch = CollatedBatch::new(
vec![vec![1_u32, 2], vec![3_u32, 4]],
vec![vec![1_u32, 1], vec![1_u32, 1]],
2,
2,
);
assert_eq!(batch.batch_size, 2);
assert_eq!(batch.total_tokens(), 4);
}
#[test]
fn test_collated_batch_input_shape() {
let batch = CollatedBatch::new(vec![vec![10_u32, 20, 30]], vec![vec![1_u32, 1, 1]], 1, 3);
let shape = batch.input_shape();
assert_eq!(shape, (1, 3));
}
#[test]
fn test_adam_config_reachable() {
let config = AdamConfig {
learning_rate: 1e-3,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
amsgrad: false,
};
let diff = (config.learning_rate - 1e-3).abs();
assert!(diff < 1e-12);
}
#[test]
fn test_adamw_config_reachable() {
let config = AdamWConfig {
learning_rate: 2e-5,
beta1: 0.9,
beta2: 0.999,
weight_decay: 0.01,
eps: 1e-8,
amsgrad: false,
};
let diff = (config.weight_decay - 0.01).abs();
assert!(diff < 1e-12);
}
#[test]
fn test_lr_schedule_constant_reachable() {
let s = LearningRateSchedule::Constant;
assert!(matches!(s, LearningRateSchedule::Constant));
}
#[test]
fn test_optimizer_gradients_reachable() {
let gradients = OptimizerGradients {
parameters: std::collections::HashMap::new(),
parameter_shapes: std::collections::HashMap::new(),
};
assert!(gradients.parameters.is_empty());
}
#[test]
fn test_optimizer_update_reachable() {
let update = OptimizerUpdate {
parameter_updates: std::collections::HashMap::new(),
learning_rate: 1e-4,
step_count: 0,
};
assert_eq!(update.step_count, 0);
let diff = (update.learning_rate - 1e-4).abs();
assert!(diff < 1e-12);
}
#[test]
fn test_adam_optimizer_create_and_get_lr() {
let optimizer = AdamOptimizer::new(AdamConfig {
learning_rate: 3e-4,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
amsgrad: false,
});
let diff = (optimizer.get_lr() - 3e-4).abs();
assert!(diff < 1e-12);
}
#[test]
fn test_adamw_optimizer_create_and_get_lr() {
let optimizer = AdamWOptimizer::new(AdamWConfig {
learning_rate: 2e-5,
beta1: 0.9,
beta2: 0.999,
weight_decay: 0.01,
eps: 1e-8,
amsgrad: false,
});
let diff = (optimizer.get_lr() - 2e-5).abs();
assert!(diff < 1e-12);
}
#[test]
fn test_special_token_reachable() {
let tok = SpecialToken::new("CLS", 0, "[CLS]");
assert!(tok.is_cls_token(), "Should be identified as CLS token");
}
#[test]
fn test_special_token_sep() {
let tok = SpecialToken::new("SEP", 5, "[SEP]");
assert!(tok.is_sep_token(), "Should be identified as SEP token");
}
#[test]
fn test_text_metadata_reachable() {
let meta = TextMetadata::new()
.with_language("en")
.with_encoding("utf-8")
.with_word_count(20);
assert_eq!(meta.language.as_deref(), Some("en"));
assert_eq!(meta.word_count, Some(20));
}
#[test]
fn test_auto_data_collator_for_task_fill_mask() {
let config = serde_json::json!({"pad_token_id": 0, "mask_token_id": 103});
let result = AutoDataCollator::for_task("fill-mask", &config);
assert!(
result.is_ok(),
"fill-mask task should produce a valid collator"
);
}
#[test]
fn test_auto_optimizer_from_config_default_config() {
let config = serde_json::json!({});
let result = AutoOptimizer::from_config(&config);
assert!(result.is_ok(), "from_config with empty JSON should succeed");
if let Ok(opt) = result {
assert!(opt.get_lr() > 0.0, "Optimizer LR should be positive");
}
}
}