Skip to main content

vmap

Function vmap 

Source
pub fn vmap(
    forward: &Graph,
    batched_input_names: &[&str],
    batch_size: usize,
) -> Graph
Expand description

Vectorize forward over a leading batch axis.

batched_input_names lists the Op::Input names whose leading axis is the batch axis after vmap. Inputs/Params not in the list are shared across the batch (they get broadcast on demand by ops that consume them alongside batched values).

The returned graph:

  • Has the same input names as forward. Batched inputs gain a leading [batch_size, ...] dim.
  • Has the same output count. Every output gains a leading batch axis (out_axes = 0 implicit).
  • Has the same set of Op::Param slots — params are always shared.

§Panics

Panics on any op without a vmap rule. Add rules incrementally.