Skip to main content

Dropout

Struct Dropout 

Source
pub struct Dropout { /* private fields */ }
Expand description

Inverted dropout module.

Uses a single fused torch::dropout kernel (1 autograd node). During training: randomly zeros elements with probability p, scales remaining by 1/(1-p). During eval: identity function.

Implementations§

Source§

impl Dropout

Source

pub fn new(p: f64) -> Self

Create a dropout module with drop probability p (0.0 to 1.0). Use set_training(false) to disable during inference.

Examples found in repository?
examples/showcase/main.rs (line 546)
515fn build_showcase() -> flodl::Result<Graph> {
516    const B: i64 = 2;  // batch size (>= 2 for BatchNorm training mode)
517    const H: i64 = 8;
518
519    FlowBuilder::from(Linear::new(2, H)?)
520        // input() declares auxiliary graph inputs — forward_multi receives them
521        .input(&["ctx"])
522
523        // Tag names a position in the stream for later reference via .using()
524        .tag("input")
525
526        // .through() chains modules sequentially: stream -> module -> stream
527        .through(GELU)
528        .through(LayerNorm::new(H)?)
529        .through(RmsNorm::new())
530
531        // ContextBlend is a NamedInputModule that reads the "ctx" auxiliary input
532        .through(ContextBlend)
533        .using(&["ctx"])
534
535        // .fork() runs a side-branch: output can be tagged, main stream unchanged
536        .fork(spectral_monitor(H)?)
537        .tag("spectral")
538
539        // .split() forks the stream into parallel branches, .merge() recombines.
540        // modules![] is shorthand for vec![Box::new(...) as Box<dyn Module>, ...]
541        .split(modules![read_head(H)?, read_head(H)?])
542        .merge(MergeOp::Mean)
543
544        // .also() adds a residual connection: output = stream + module(stream)
545        .also(Linear::new(H, H)?)
546        .through(Dropout::new(0.1))
547        .through(SoftClamp::new(0.5, 3.0))
548        .through(Softplus)
549
550        // VarianceGate exercises mean/var/std/expand
551        .through(VarianceGate::new(H))
552
553        // .map().slices(n) decomposes [B,D] -> [B*n,D/n], applies body, recomposes
554        .map(read_head(2)?)
555        .slices(H / 2)
556
557        // Reshape changes tensor dimensions without copying data
558        .through(Reshape::new(&[B * 2, H / 2]))
559
560        // .map().each() applies body independently to each element in a multi-stream
561        .map(Linear::new(H / 2, H / 2)?)
562        .each()
563        .tag("halves")
564
565        // .map().over(tag) iterates over a tagged tensor (backward ref) instead
566        // of the current stream — useful for refining previously computed features
567        .map(Linear::new(H / 2, H / 2)?)
568        .over("halves")
569
570        // .map().batched().each() — fast path: full batch in one call
571        .map(Linear::new(H / 2, H / 2)?)
572        .batched()
573        .each()
574
575        .through(Reshape::new(&[B, H]))
576        .through(ShapeOps::new(B, H))
577        .through(NegSigmoidGate)
578        .through(TransposeRoundTrip)
579
580        // CounterModule overrides reset() — loops auto-call it before iterating
581        .through(CounterModule::new())
582
583        // ChunkRecombine: chunk, relu (Variable op), cat
584        .through(ChunkRecombine)
585
586        // AttentionLikeOps: softmax, select, narrow, index_select
587        .through(AttentionLikeOps::new(H))
588
589        // TopKFilterOps: topk, sort, gather, min, max, pad
590        .through(TopKFilterOps::new(H))
591
592        // RepeatNarrow: repeat
593        .through(RepeatNarrow::new(H))
594
595        // .loop_body().for_n(n) repeats the body n times, feeding output back as input.
596        // silu_block is a sub-graph (Graph implements Module) — graphs compose freely.
597        .loop_body(silu_block(H)?)
598        .for_n(2)
599        .tag("refined")
600
601        // .gate() is soft routing (mixture of experts): all experts run, router
602        // produces weights, outputs are combined. .using() feeds the tagged "input"
603        // stream to the router as a backward reference.
604        .gate(
605            SoftmaxRouter::new(H, 2)?,
606            modules![Linear::new(H, H)?, Linear::new(H, H)?],
607        )
608        .using(&["input"])
609
610        // .switch() is hard routing: router picks one branch, others are skipped.
611        // HeavyPathSelector is a custom NamedInputModule — it receives the "refined"
612        // ref via forward_named() and decides which branch to activate.
613        .switch(
614            HeavyPathSelector,
615            modules![Linear::new(H, H)?, ffn_block(H)?],
616        )
617        .using(&["refined"])
618
619        // Forward reference: .using("memory") reads a tag that doesn't exist yet —
620        // the framework creates a state buffer. .tag("memory") writes to it.
621        // On the first pass, the state is zero (pass-through). On subsequent passes,
622        // StateAdd accumulates: stream + previous_memory.
623        .through(StateAdd)
624        .using(&["memory"])
625        .tag("memory")
626
627        // .while_cond() repeats until the halt module signals stop (or max iterations).
628        // ThresholdHalt stops when the stream's L2 norm exceeds the threshold.
629        .loop_body(Linear::new(H, H)?)
630        .while_cond(ThresholdHalt::new(100.0), 5)
631
632        // .until_cond() is the inverse: repeats until halt signals true.
633        // LearnedHalt has trainable parameters — it learns when to stop.
634        .loop_body(Linear::new(H, H)?)
635        .until_cond(LearnedHalt::new(H)?, 7)
636
637        .through(LogSoftmaxReduce)
638        .through(Linear::new(1, H)?)
639
640        // Split with tag_group: names each branch ("final_heads_0", "final_heads_1")
641        .split(vec![
642            Box::new(Linear::new(H, H)?),
643            Box::new(Linear::new(H, H)?),
644        ])
645        .tag_group("final_heads")
646        .merge(MergeOp::Add)
647
648        // Final projection and output tag for observation
649        .through(Linear::new(H, 2)?)
650        .tag("output")
651        .build()
652}

Trait Implementations§

Source§

impl Module for Dropout

Source§

fn name(&self) -> &str

Human-readable type name used as node ID prefix in graph visualization. Override to return a lowercase identifier (e.g., “linear”, “gelu”).
Source§

fn forward(&self, input: &Variable) -> Result<Variable>

Run the forward pass on input and return the result.
Source§

fn set_training(&self, training: bool)

Set training/eval mode. Affects Dropout, BatchNorm, etc. Override in modules with mode-dependent behavior.
Source§

fn parameters(&self) -> Vec<Parameter>

Return this module’s learnable parameters. Default: recursively collects from sub_modules() with pointer dedup. Leaf modules should override to return their own parameters.
Source§

fn buffers(&self) -> Vec<Buffer>

Return this module’s non-learnable persistent buffers (e.g., running stats). Default: recursively collects from sub_modules() with pointer dedup. Leaf modules should override to return their own buffers.
Source§

fn sub_modules(&self) -> Vec<Rc<dyn Module>>

Return direct child modules for recursive tree walks. Override in composite modules (loops, switches, gates).
Source§

fn move_to_device(&self, _device: Device)

Move all parameters and buffers to the given device. Override in modules like BatchNorm that hold non-parameter state.
Source§

fn train(&self)

Set training mode. Shorthand for set_training(true).
Source§

fn eval(&self)

Set eval mode. Shorthand for set_training(false).
Source§

fn trace(&self) -> Option<Variable>

Return per-iteration side output for loop tracing. Override in loop body modules that capture trajectory data (e.g., attention fixation points). Returns None by default. When Some, the loop executor collects traces accessible via Graph::traces().
Source§

fn as_named_input(&self) -> Option<&dyn NamedInputModule>

Upcast to NamedInputModule for multi-input graphs. Override in types that implement NamedInputModule to enable receiving additional named inputs via graph using().
Source§

fn structural_hash(&self) -> Option<String>

SHA-256 hex hash of module architecture for checkpoint validation. Override in composite modules (Graph) that compute a deterministic hash from their topology and parameter shapes.
Source§

fn reset(&self)

Reset internal state (e.g. recurrent hidden state) between sequences. Called by loops before iterating to clear stale tensors whose grad_fns may reference freed saved tensors. Override in stateful modules.
Source§

fn detach_state(&self)

Detach internal state from the computation graph (for truncated BPTT). Called between training steps to break gradient chains on state carried across forward passes (e.g., recurrent hidden state). Override in stateful modules.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.