mod batch;
pub mod callback;
mod config;
mod curriculum;
mod loss;
mod metrics;
mod trainer;
pub mod transformer_trainer;
pub mod tui;
#[cfg(test)]
mod tests;
pub use batch::Batch;
pub use callback::{
CallbackAction, CallbackContext, CallbackManager, CheckpointCallback, EarlyStopping,
ExplainMethod, ExplainabilityCallback, FeatureImportanceResult, LRSchedulerCallback,
MonitorCallback, ProgressCallback, TrainerCallback,
};
pub use config::{MetricsTracker, TrainConfig};
pub use curriculum::{
efficiency_score, select_optimal_tier, AdaptiveCurriculum, CurriculumScheduler,
LinearCurriculum, TieredCurriculum,
};
pub use loss::{
BCEWithLogitsLoss, CausalLMLoss, CrossEntropyLoss, HuberLoss, L1Loss, LossFn, MSELoss,
SampleWeightedLoss, SmoothL1Loss, WeightedLoss,
};
pub use metrics::{Accuracy, F1Score, Metric, Precision, R2Score, Recall, MAE, RMSE};
pub use trainer::{TrainResult, Trainer};
pub use transformer_trainer::distributed_checkpoint::{
checkpoint_path, hash_weights, should_save_checkpoint, verify_weight_consistency,
CheckpointPhase,
};
pub use transformer_trainer::grad_accumulator::BLOCK_GRAD_COMPONENTS;
pub use transformer_trainer::{
perplexity,
shard_batches,
tokens_per_second,
BlockGradientSet,
CausalMaskType,
ColumnParallelShard,
CudaTransformerTrainer,
DistributedBackend,
DistributedCheckpointCoordinator,
DistributedRole,
DistributedTrainConfig,
ElasticCoordinator,
LMBatch,
OptimizerShard,
PerBlockGradientAccumulator,
PipelineAction,
PipelineActivationBuffer,
PipelineStage,
RingAttentionSchedule,
RowParallelShard,
SequenceParallelConfig,
SpCommCost,
TensorParallelConfig,
TpCommCost,
TransformerTrainConfig,
TransformerTrainer,
ZeroShardMap,
};
#[cfg(feature = "cuda")]
pub use transformer_trainer::{DistributedComm, DistributedCudaTrainer};
pub use tui::{
format_duration, sparkline, sparkline_range, Alert, AlertLevel, AndonSystem, DashboardLayout,
FeatureImportanceChart, GradientFlowHeatmap, KalmanEta, LossCurveDisplay, MetricsBuffer,
MonitorConfig, ProgressBar, ReferenceCurve, RefreshPolicy, SeriesSummaryTuple,
TerminalCapabilities, TerminalMode, TerminalMonitorCallback, SPARK_CHARS,
};