Skip to main content

Trainer

Struct Trainer 

Source
pub struct Trainer;
Expand description

Primary entry point for training in flodl.

Trainer is the default API for training a model, whether you have one GPU, many GPUs, or no GPU at all. The training loop is identical in all cases: Trainer::setup (or Trainer::builder) configures the model, detects the hardware, and enables distributed training automatically when multiple CUDA devices are available. On a single GPU or CPU it’s a no-op wrapper with zero DDP overhead.

For explicit multi-GPU control (manual gradient sync, custom replica wrapping) use Ddp directly. Ddp::wrap remains the entry point for advanced patterns (GAN, RL, progressive).

§Setup mode (user owns the loop)

Trainer::setup(&model, |dev| build_model(dev), |p| Adam::new(p, 0.001))?;

for (x, y) in &train_loader {
    let out = model.forward(&x)?;
    let loss = cross_entropy_loss(&out, &y)?;
    loss.backward()?;
    model.step()?;
}

§Builder mode (framework owns the loop)

let handle = Trainer::builder(model_factory, optim_factory, train_fn)
    .dataset(dataset)
    .batch_size(32)
    .num_epochs(10)
    .run()?;

let state = handle.join()?;

Implementations§

Source§

impl Trainer

Source

pub fn setup<F, M, G, O>(model: &Graph, builder: F, optimizer: G) -> Result<()>
where F: Fn(Device) -> Result<M>, M: Module + 'static, G: Fn(&[Parameter]) -> O, O: Optimizer + 'static,

One-call setup: auto-detect GPUs, distribute the model, set the optimizer, and enable training mode.

  • Multi-GPU (2+ usable CUDA devices): replicates via Graph::distribute, creates per-replica optimizers, enables training.
  • Single-GPU / CPU: sets optimizer and training mode only (no DDP overhead).

Always prints a diagnostic summary to stderr showing detected hardware.

Trainer::setup(&model, |dev| build_model(dev), |p| Adam::new(p, 0.001))?;

for batch in model.epoch(epoch).activate() {
    let out = model.forward_batch(&batch?)?;
    loss.backward()?;
    model.step()?;
}
Source

pub fn setup_with<F, M, G, O>( model: &Graph, builder: F, optimizer: G, config: DdpConfig, ) -> Result<()>
where F: Fn(Device) -> Result<M>, M: Module + 'static, G: Fn(&[Parameter]) -> O, O: Optimizer + 'static,

One-call setup with explicit configuration.

Like setup() but accepts a DdpConfig for controlling El Che cadence, speed hints, and overhead targets.

Trainer::setup_with(&model, builder, optimizer,
    DdpConfig::new().speed_hint(1, 2.3))?;
Source

pub fn builder<F, M, G, O, T>( model_factory: F, optim_factory: G, train_fn: T, ) -> DdpBuilder<F, M, G, O, T>
where F: Fn(Device) -> Result<M> + Send + Sync + 'static, M: Module + 'static, G: Fn(&[Parameter]) -> O + Send + Sync + 'static, O: Optimizer + 'static, T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,

Create a builder for framework-managed training.

The framework owns the training loop, data pipeline, and epoch management. On multi-GPU hardware, each device gets its own model replica and optimizer, and a coordinator triggers periodic parameter averaging based on the configured ApplyPolicy and AverageBackend. On a single GPU, training runs on the main thread with no coordination - the API is identical in both cases.

Returns a DdpBuilder for fluent configuration. Call .run() to spawn training, then .join() on the returned DdpHandle to block until completion.

§Example
use flodl::*;

let handle = Trainer::builder(
    |dev| model_factory(dev),
    |params| Adam::new(params, 0.001),
    |model, batch| { /* forward + loss */ },
)
.dataset(dataset)
.batch_size(32)
.num_epochs(10)
.policy(ApplyPolicy::Cadence)
.backend(AverageBackend::Nccl)
.run()?;

let state = handle.join()?;
Source

pub fn setup_head<H, F, G, O>( head: &H, head_factory: F, optimizer: G, ) -> Result<()>
where H: HasGraph + 'static, F: Fn(Device) -> Result<H> + 'static, G: Fn(&[Parameter]) -> O, O: Optimizer + 'static,

One-call setup for a task-head wrapper (e.g. flodl-hf’s BertForSequenceClassification). The wrapper must implement HasGraph so Trainer can reach the underlying Graph.

Semantics match Trainer::setup exactly; the only difference is that head_factory builds a fresh wrapper (not a bare Graph) on each replica device. Useful when the training-loop code holds onto the wrapper’s richer surface (compute_loss, predict, attached tokenizer) but still wants transparent 1-or-N-GPU DDP.

let head = DistilBertForSequenceClassification::from_pretrained(repo)?;
let config = head.config().clone();
let num_labels = head.labels().len() as i64;

Trainer::setup_head(
    &head,
    move |dev| DistilBertForSequenceClassification::on_device(&config, num_labels, dev),
    |p| Adam::new(p, 5e-5),
)?;

for (enc, labels) in &batches {
    let loss = head.compute_loss(&enc, &labels)?;
    loss.backward()?;
    head.graph().step()?;
}
Source

pub fn setup_head_with<H, F, G, O>( head: &H, head_factory: F, optimizer: G, config: DdpConfig, ) -> Result<()>
where H: HasGraph + 'static, F: Fn(Device) -> Result<H> + 'static, G: Fn(&[Parameter]) -> O, O: Optimizer + 'static,

Task-head variant of Trainer::setup_with. Same behaviour as Trainer::setup_head but takes an explicit DdpConfig.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.