Skip to main content

Module data_parallel

Module data_parallel 

Source
Expand description

DataParallelTrainer — replicates a model across N replicas, splits a mini-batch evenly, runs forward+backward per replica, aggregates loss/grad-norm, and applies an optimizer step.

The trainer is generic over a ReplicaProtocol trait that describes the message contract to a single replica actor. F4.x ships the protocol with a CPU-side host_step that completes a synchronous forward/backward and returns (loss, grad_norm). F5 swaps that for a real GPU forward/backward+AllReduce path; the public surface stays the same.

Structs§

DataParallelTrainer
ReplicaStepResult
TrainSample
TrainerConfig

Enums§

TrainerMsg

Traits§

ReplicaProtocol
Per-replica step contract. Each replica receives a chunk of the mini-batch and replies with (loss, grad_norm) for that chunk.