1mod batch;
33pub mod callback;
34mod config;
35mod curriculum;
36pub mod device;
37pub mod gputrain_003;
38pub mod gputrain_004;
39pub mod gputrain_005;
40pub mod gputrain_006;
41pub mod gputrain_007;
42mod loss;
43mod metrics;
44pub mod pretrain;
45pub mod pretrain_real;
46#[cfg(feature = "cuda")]
47pub mod pretrain_real_cuda;
48pub mod shard_reader;
49mod trainer;
50pub mod transformer_trainer;
51pub mod tui;
52
53#[cfg(test)]
54mod tests;
55
56pub use batch::Batch;
57pub use callback::{
58 CallbackAction, CallbackContext, CallbackManager, CheckpointCallback, EarlyStopping,
59 ExplainMethod, ExplainabilityCallback, FeatureImportanceResult, LRSchedulerCallback,
60 MonitorCallback, ProgressCallback, TrainerCallback,
61};
62pub use config::{MetricsTracker, TrainConfig};
63pub use curriculum::{
64 efficiency_score, select_optimal_tier, AdaptiveCurriculum, CurriculumScheduler,
65 LinearCurriculum, TieredCurriculum,
66};
67pub use device::{resolve_device, Device, DeviceError};
68pub use loss::{
69 BCEWithLogitsLoss, CausalLMLoss, CrossEntropyLoss, HuberLoss, L1Loss, LossFn, MSELoss,
70 SampleWeightedLoss, SmoothL1Loss, WeightedLoss,
71};
72pub use metrics::{Accuracy, F1Score, Metric, Precision, R2Score, Recall, MAE, RMSE};
73pub use pretrain::{
74 check_non_divergence, check_numerical_stability, CheckpointFn, EpochArtifact, EpochMetadata,
75 LinearDecaySynthetic, NanAtStepSynthetic, PretrainAbort, PretrainConfig, PretrainLoop,
76 RunStatus, ScriptedVal, StepFn, StepMetrics, TrainingRegime, ValFn, DIVERGENCE_RATIO_LIMIT,
77 EPOCH_ZERO_VAL_LOSS_LIMIT,
78};
79pub use trainer::{TrainResult, Trainer};
80pub use transformer_trainer::distributed_checkpoint::{
81 checkpoint_path, hash_weights, should_save_checkpoint, verify_weight_consistency,
82 CheckpointPhase,
83};
84pub use transformer_trainer::grad_accumulator::BLOCK_GRAD_COMPONENTS;
85pub use transformer_trainer::{
86 perplexity,
87 shard_batches,
89 tokens_per_second,
90 BlockGradientSet,
91 CausalMaskType,
93 ColumnParallelShard,
94 CudaTransformerTrainer,
95 DistributedBackend,
96 DistributedCheckpointCoordinator,
97 DistributedRole,
98 DistributedTrainConfig,
99 ElasticCoordinator,
100 LMBatch,
101 OptimizerShard,
102 PerBlockGradientAccumulator,
103 PipelineAction,
104 PipelineActivationBuffer,
105 PipelineStage,
106 RingAttentionSchedule,
107 RowParallelShard,
108 SequenceParallelConfig,
109 SpCommCost,
110 TensorParallelConfig,
111 TpCommCost,
112 TransformerTrainConfig,
113 TransformerTrainer,
114 ZeroShardMap,
115};
116#[cfg(feature = "cuda")]
117pub use transformer_trainer::{DistributedComm, DistributedCudaTrainer};
118pub use tui::{
119 format_duration, sparkline, sparkline_range, Alert, AlertLevel, AndonSystem, DashboardLayout,
120 FeatureImportanceChart, GradientFlowHeatmap, KalmanEta, LossCurveDisplay, MetricsBuffer,
121 MonitorConfig, ProgressBar, ReferenceCurve, RefreshPolicy, SeriesSummaryTuple,
122 TerminalCapabilities, TerminalMode, TerminalMonitorCallback, SPARK_CHARS,
123};