#![allow(unused_imports)]
pub(crate) use std::any::TypeId;
pub(crate) use std::collections::{HashMap, VecDeque};
pub(crate) use std::fs;
pub(crate) use std::io;
pub(crate) use std::path::{Path, PathBuf};
pub(crate) use std::sync::Arc;
pub(crate) use std::sync::Mutex;
pub(crate) use std::sync::atomic::{AtomicBool, Ordering};
pub(crate) use std::time::{SystemTime, UNIX_EPOCH};
pub(crate) use anyhow::{Context, Result, anyhow};
#[cfg(feature = "cli")]
pub(crate) use clap::{Args as ClapArgs, Parser, Subcommand, ValueEnum};
pub(crate) use names::Generator;
pub(crate) use rand::{Rng, SeedableRng, rngs::StdRng, thread_rng};
pub(crate) use burn::data::dataloader::DataLoader;
pub(crate) use burn::lr_scheduler::{
LrScheduler,
cosine::{CosineAnnealingLrScheduler, CosineAnnealingLrSchedulerConfig},
exponential::{ExponentialLrScheduler, ExponentialLrSchedulerConfig},
linear::{LinearLrScheduler, LinearLrSchedulerConfig},
noam::{NoamLrScheduler, NoamLrSchedulerConfig},
step::{StepLrScheduler, StepLrSchedulerConfig},
};
pub(crate) use burn::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, Param,
};
pub(crate) use burn::nn::loss::CrossEntropyLossConfig;
pub(crate) use burn::nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig};
pub(crate) use burn::optim::adaptor::OptimizerAdaptor;
pub(crate) use burn::optim::grad_clipping::GradientClippingConfig;
pub(crate) use burn::optim::{
AdamW, AdamWConfig, GradientsAccumulator, GradientsParams, LearningRate,
};
pub(crate) use burn::tensor::Distribution as TensorDistribution;
pub(crate) use burn::tensor::activation;
pub(crate) use burn::tensor::backend::{AutodiffBackend, Backend as BackendTrait};
pub(crate) use burn::tensor::module::conv2d;
pub(crate) use burn::tensor::ops::{ConvOptions, InterpolateMode};
pub(crate) use burn::tensor::{Int, Tensor, TensorData};
#[cfg(feature = "cli")]
pub(crate) use burn_autodiff::Autodiff;
#[cfg(any(feature = "train", feature = "cli"))]
pub(crate) use burn_ndarray::NdArrayDevice;
pub(crate) use burn_train::metric::{LearningRateMetric, LossMetric};
pub(crate) use burn_train::{
LearnerBuilder, LearningStrategy, TrainOutput, TrainStep, TrainingResult, ValidStep,
};
#[cfg(feature = "cli")]
pub(crate) use burn_wgpu::Wgpu;
#[cfg(any(feature = "train", feature = "cli"))]
pub(crate) use burn_wgpu::WgpuDevice;
pub(crate) use tracing::info;
#[cfg(all(feature = "cuda", any(feature = "cli", test)))]
pub(crate) use burn_cuda::Cuda;
pub(crate) use burn::record::{BinFileRecorder, FullPrecisionSettings};
pub(crate) use burn_dino::correctness::load_model_from_checkpoint;
pub(crate) use burn_dino::model::dino::{DinoVisionTransformer, DinoVisionTransformerConfig};
#[cfg(feature = "cli")]
pub(crate) use burn_dragon_train::wgpu::init_runtime;
pub(crate) use burn_dragon_train::{
GdpoHardGate, ImagenetteVariant, LearningRateScheduleConfig, OptimizerConfig,
VisionArtifactOutputMode, VisionDatasetConfig,
VisionDatasetDownloadConfig, VisionFoveaSamplingMode, VisionFoveaScatterMode,
VisionFoveaWarpMode, VisionLejepaConfig, VisionLejepaLossConfig, VisionMaeConfig,
VisionPyramidMode, VisionSaccadeConfig, VisionSaccadeInputProjectionCnnConfig,
VisionSaccadeInputProjectionConfig, VisionSaccadeInputProjectionMicroVitConfig,
VisionTeacherConfig, VisionTeacherVariant, VisionTrainingConfig, VisionTrainingHyperparameters,
VisionTrainingModeConfig,
};
pub(crate) use burn_dragon_core::{
DinoFeatureStore, ImageNetAugmentations, ImageNetBatch, ImageNetDataLoader, ImageNetDataset,
ImageNetDatasetConfig, ImageNetSplit, PatchGrid, SpatialPositionalEncodingKind,
VisionAttentionMode, VisionDragonHatchling, VisionDragonHatchlingConfig,
VisionLatentActivation, VisionNormalize, VisionPatchEmbedMode, patchify, unpatchify,
};
pub(crate) use burn_dragon_loss::{
VisionDistillationLossConfig, vision_distillation_loss,
};
#[cfg(feature = "cli")]
pub(crate) use burn_dragon_train::load_vision_training_config;
pub(crate) use serde::Serialize;
pub(crate) use crate::train::constants::*;
pub(crate) use crate::train::saccade::*;
pub(crate) use crate::train::vision::*;
pub(crate) use burn_dragon_train::train::teacher::*;
pub(crate) use burn_dragon_train::train::pipeline::*;
pub(crate) use burn_dragon_train::train::metrics::{
ActionClampRateInput, AdvantageAbsMeanInput, AdvantageStdInput, DeviceMemoryMetric,
DeviceMetric, InvLossInput, LanguageModelOutput, LanguageModelTrainItem, LogProbMeanInput,
LossValue, MemoryCleanupMetric, PolicyEntropyInput, PolicyLossInput, ProbeAccInput,
ProbeLossInput, ReconLossInput, ReconPsnrInput, ScalarMetric, SigRegLossInput,
VisionArtifactInput, VisionArtifactMetric, VisionOutput, VisionTrainItem,
};