pub struct Dropout { /* private fields */ }Expand description
Inverted dropout module.
Uses a single fused torch::dropout kernel (1 autograd node).
During training: randomly zeros elements with probability p,
scales remaining by 1/(1-p).
During eval: identity function.
Implementations§
Source§impl Dropout
impl Dropout
Sourcepub fn new(p: f64) -> Self
pub fn new(p: f64) -> Self
Create a dropout module with drop probability p (0.0 to 1.0).
Use set_training(false) to disable during inference.
Examples found in repository?
examples/showcase/main.rs (line 546)
515fn build_showcase() -> flodl::Result<Graph> {
516 const B: i64 = 2; // batch size (>= 2 for BatchNorm training mode)
517 const H: i64 = 8;
518
519 FlowBuilder::from(Linear::new(2, H)?)
520 // input() declares auxiliary graph inputs — forward_multi receives them
521 .input(&["ctx"])
522
523 // Tag names a position in the stream for later reference via .using()
524 .tag("input")
525
526 // .through() chains modules sequentially: stream -> module -> stream
527 .through(GELU)
528 .through(LayerNorm::new(H)?)
529 .through(RmsNorm::new())
530
531 // ContextBlend is a NamedInputModule that reads the "ctx" auxiliary input
532 .through(ContextBlend)
533 .using(&["ctx"])
534
535 // .fork() runs a side-branch: output can be tagged, main stream unchanged
536 .fork(spectral_monitor(H)?)
537 .tag("spectral")
538
539 // .split() forks the stream into parallel branches, .merge() recombines.
540 // modules![] is shorthand for vec![Box::new(...) as Box<dyn Module>, ...]
541 .split(modules![read_head(H)?, read_head(H)?])
542 .merge(MergeOp::Mean)
543
544 // .also() adds a residual connection: output = stream + module(stream)
545 .also(Linear::new(H, H)?)
546 .through(Dropout::new(0.1))
547 .through(SoftClamp::new(0.5, 3.0))
548 .through(Softplus)
549
550 // VarianceGate exercises mean/var/std/expand
551 .through(VarianceGate::new(H))
552
553 // .map().slices(n) decomposes [B,D] -> [B*n,D/n], applies body, recomposes
554 .map(read_head(2)?)
555 .slices(H / 2)
556
557 // Reshape changes tensor dimensions without copying data
558 .through(Reshape::new(&[B * 2, H / 2]))
559
560 // .map().each() applies body independently to each element in a multi-stream
561 .map(Linear::new(H / 2, H / 2)?)
562 .each()
563 .tag("halves")
564
565 // .map().over(tag) iterates over a tagged tensor (backward ref) instead
566 // of the current stream — useful for refining previously computed features
567 .map(Linear::new(H / 2, H / 2)?)
568 .over("halves")
569
570 // .map().batched().each() — fast path: full batch in one call
571 .map(Linear::new(H / 2, H / 2)?)
572 .batched()
573 .each()
574
575 .through(Reshape::new(&[B, H]))
576 .through(ShapeOps::new(B, H))
577 .through(NegSigmoidGate)
578 .through(TransposeRoundTrip)
579
580 // CounterModule overrides reset() — loops auto-call it before iterating
581 .through(CounterModule::new())
582
583 // ChunkRecombine: chunk, relu (Variable op), cat
584 .through(ChunkRecombine)
585
586 // AttentionLikeOps: softmax, select, narrow, index_select
587 .through(AttentionLikeOps::new(H))
588
589 // TopKFilterOps: topk, sort, gather, min, max, pad
590 .through(TopKFilterOps::new(H))
591
592 // RepeatNarrow: repeat
593 .through(RepeatNarrow::new(H))
594
595 // .loop_body().for_n(n) repeats the body n times, feeding output back as input.
596 // silu_block is a sub-graph (Graph implements Module) — graphs compose freely.
597 .loop_body(silu_block(H)?)
598 .for_n(2)
599 .tag("refined")
600
601 // .gate() is soft routing (mixture of experts): all experts run, router
602 // produces weights, outputs are combined. .using() feeds the tagged "input"
603 // stream to the router as a backward reference.
604 .gate(
605 SoftmaxRouter::new(H, 2)?,
606 modules![Linear::new(H, H)?, Linear::new(H, H)?],
607 )
608 .using(&["input"])
609
610 // .switch() is hard routing: router picks one branch, others are skipped.
611 // HeavyPathSelector is a custom NamedInputModule — it receives the "refined"
612 // ref via forward_named() and decides which branch to activate.
613 .switch(
614 HeavyPathSelector,
615 modules![Linear::new(H, H)?, ffn_block(H)?],
616 )
617 .using(&["refined"])
618
619 // Forward reference: .using("memory") reads a tag that doesn't exist yet —
620 // the framework creates a state buffer. .tag("memory") writes to it.
621 // On the first pass, the state is zero (pass-through). On subsequent passes,
622 // StateAdd accumulates: stream + previous_memory.
623 .through(StateAdd)
624 .using(&["memory"])
625 .tag("memory")
626
627 // .while_cond() repeats until the halt module signals stop (or max iterations).
628 // ThresholdHalt stops when the stream's L2 norm exceeds the threshold.
629 .loop_body(Linear::new(H, H)?)
630 .while_cond(ThresholdHalt::new(100.0), 5)
631
632 // .until_cond() is the inverse: repeats until halt signals true.
633 // LearnedHalt has trainable parameters — it learns when to stop.
634 .loop_body(Linear::new(H, H)?)
635 .until_cond(LearnedHalt::new(H)?, 7)
636
637 .through(LogSoftmaxReduce)
638 .through(Linear::new(1, H)?)
639
640 // Split with tag_group: names each branch ("final_heads_0", "final_heads_1")
641 .split(vec![
642 Box::new(Linear::new(H, H)?),
643 Box::new(Linear::new(H, H)?),
644 ])
645 .tag_group("final_heads")
646 .merge(MergeOp::Add)
647
648 // Final projection and output tag for observation
649 .through(Linear::new(H, 2)?)
650 .tag("output")
651 .build()
652}Trait Implementations§
Source§impl Module for Dropout
impl Module for Dropout
Source§fn name(&self) -> &str
fn name(&self) -> &str
Human-readable type name used as node ID prefix in graph visualization.
Override to return a lowercase identifier (e.g., “linear”, “gelu”).
Source§fn forward(&self, input: &Variable) -> Result<Variable>
fn forward(&self, input: &Variable) -> Result<Variable>
Run the forward pass on
input and return the result.Source§fn set_training(&self, training: bool)
fn set_training(&self, training: bool)
Set training/eval mode. Affects Dropout, BatchNorm, etc.
Override in modules with mode-dependent behavior.
Source§fn parameters(&self) -> Vec<Parameter>
fn parameters(&self) -> Vec<Parameter>
Return this module’s learnable parameters.
Default: recursively collects from
sub_modules() with pointer dedup.
Leaf modules should override to return their own parameters.Source§fn buffers(&self) -> Vec<Buffer>
fn buffers(&self) -> Vec<Buffer>
Return this module’s non-learnable persistent buffers (e.g., running stats).
Default: recursively collects from
sub_modules() with pointer dedup.
Leaf modules should override to return their own buffers.Source§fn sub_modules(&self) -> Vec<Rc<dyn Module>>
fn sub_modules(&self) -> Vec<Rc<dyn Module>>
Return direct child modules for recursive tree walks.
Override in composite modules (loops, switches, gates).
Source§fn move_to_device(&self, _device: Device)
fn move_to_device(&self, _device: Device)
Move all parameters and buffers to the given device.
Override in modules like BatchNorm that hold non-parameter state.
Source§fn trace(&self) -> Option<Variable>
fn trace(&self) -> Option<Variable>
Return per-iteration side output for loop tracing.
Override in loop body modules that capture trajectory data
(e.g., attention fixation points). Returns
None by default.
When Some, the loop executor collects traces accessible via
Graph::traces().Source§fn as_named_input(&self) -> Option<&dyn NamedInputModule>
fn as_named_input(&self) -> Option<&dyn NamedInputModule>
Upcast to
NamedInputModule for multi-input graphs.
Override in types that implement NamedInputModule to enable
receiving additional named inputs via graph using().Source§fn structural_hash(&self) -> Option<String>
fn structural_hash(&self) -> Option<String>
SHA-256 hex hash of module architecture for checkpoint validation.
Override in composite modules (Graph) that compute a deterministic
hash from their topology and parameter shapes.
Source§fn reset(&self)
fn reset(&self)
Reset internal state (e.g. recurrent hidden state) between sequences.
Called by loops before iterating to clear stale tensors whose
grad_fns may reference freed saved tensors.
Override in stateful modules.
Source§fn detach_state(&self)
fn detach_state(&self)
Detach internal state from the computation graph (for truncated BPTT).
Called between training steps to break gradient chains on state
carried across forward passes (e.g., recurrent hidden state).
Override in stateful modules.
Auto Trait Implementations§
impl !Freeze for Dropout
impl !RefUnwindSafe for Dropout
impl Send for Dropout
impl !Sync for Dropout
impl Unpin for Dropout
impl UnsafeUnpin for Dropout
impl UnwindSafe for Dropout
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more