Skip to main content

Fold

Struct Fold 

Source
pub struct Fold { /* private fields */ }
Expand description

Reassembles columns into a batched 2D tensor (col2im).

Inverse of Unfold. Combines the sliding local blocks back into a full spatial tensor. Equivalent to PyTorch’s nn.Fold.

Where blocks overlap (stride < kernel_size), values are summed. For a perfect roundtrip with non-overlapping blocks, use stride == kernel_size.

  • Input shape: [N, C * kernel_h * kernel_w, L]
  • Output shape: [N, C, output_h, output_w]

§Parameters

  • output_size – target spatial size [output_h, output_w]
  • kernel_size – size of the sliding window (must match the Unfold)
  • stride – step between consecutive windows (default 1)
  • padding – zero-padding that was applied during Unfold (default 0)
  • dilation – spacing between kernel elements (default 1)

§Example

let fold = Fold::new([32, 32], 3).stride([2, 2]).padding([1, 1]);
let img = fold.forward(&cols)?;  // [N, C, 32, 32]

Implementations§

Source§

impl Fold

Source

pub fn new(output_size: [i64; 2], kernel_size: i64) -> Self

Create with target output size [output_h, output_w] and a square kernel.

Stride defaults to 1, padding to 0, dilation to 1.

Source

pub fn with_kernel(output_size: [i64; 2], kernel_size: [i64; 2]) -> Self

Create with target output size and a rectangular kernel [kernel_h, kernel_w].

Stride defaults to 1, padding to 0, dilation to 1.

Source

pub fn dilation(self, dilation: [i64; 2]) -> Self

Set dilation (spacing between kernel elements) as [dH, dW].

Source

pub fn padding(self, padding: [i64; 2]) -> Self

Set zero-padding as [padH, padW] (must match the Unfold that produced the input).

Source

pub fn stride(self, stride: [i64; 2]) -> Self

Set stride of the sliding blocks as [strideH, strideW].

Trait Implementations§

Source§

impl Module for Fold

Source§

fn name(&self) -> &str

Human-readable type name used as node ID prefix in graph visualization. Override to return a lowercase identifier (e.g., “linear”, “gelu”).
Source§

fn forward(&self, input: &Variable) -> Result<Variable>

Run the forward pass on input and return the result.
Source§

fn parameters(&self) -> Vec<Parameter>

Return this module’s learnable parameters. Default: recursively collects from sub_modules() with pointer dedup. Leaf modules should override to return their own parameters.
Source§

fn buffers(&self) -> Vec<Buffer>

Return this module’s non-learnable persistent buffers (e.g., running stats). Default: recursively collects from sub_modules() with pointer dedup. Leaf modules should override to return their own buffers.
Source§

fn sub_modules(&self) -> Vec<Rc<dyn Module>>

Return direct child modules for recursive tree walks. Override in composite modules (loops, switches, gates).
Source§

fn move_to_device(&self, _device: Device)

Move all parameters and buffers to the given device. Override in modules like BatchNorm that hold non-parameter state.
Source§

fn set_training(&self, _training: bool)

Set training/eval mode. Affects Dropout, BatchNorm, etc. Override in modules with mode-dependent behavior.
Source§

fn train(&self)

Set training mode. Shorthand for set_training(true).
Source§

fn eval(&self)

Set eval mode. Shorthand for set_training(false).
Source§

fn trace(&self) -> Option<Variable>

Return per-iteration side output for loop tracing. Override in loop body modules that capture trajectory data (e.g., attention fixation points). Returns None by default. When Some, the loop executor collects traces accessible via Graph::traces().
Source§

fn as_named_input(&self) -> Option<&dyn NamedInputModule>

Upcast to NamedInputModule for multi-input graphs. Override in types that implement NamedInputModule to enable receiving additional named inputs via graph using().
Source§

fn as_loop_body(&self) -> Option<&dyn LoopBody>

Upcast to LoopBody for loop bodies that publish named per-iteration traces. Override in types that implement LoopBody to enable multi-output trace publishing via TraceEmit::publish. Default returns None, in which case the loop runner falls back to the legacy Module::trace path.
Source§

fn as_graph(&self) -> Option<&Graph>

Upcast to Graph for hierarchical tree composition. Override in Graph to enable subgraph nesting with label-path addressing.
Source§

fn structural_hash(&self) -> Option<String>

SHA-256 hex hash of module architecture for checkpoint validation. Override in composite modules (Graph) that compute a deterministic hash from their topology and parameter shapes.
Source§

fn reset(&self)

Reset internal state (e.g. recurrent hidden state) between sequences. Called by loops before iterating to clear stale tensors whose grad_fns may reference freed saved tensors. Override in stateful modules.
Source§

fn detach_state(&self)

Detach internal state from the computation graph (for truncated BPTT). Called between training steps to break gradient chains on state carried across forward passes (e.g., recurrent hidden state). Override in stateful modules.

Auto Trait Implementations§

§

impl Freeze for Fold

§

impl RefUnwindSafe for Fold

§

impl Send for Fold

§

impl Sync for Fold

§

impl Unpin for Fold

§

impl UnsafeUnpin for Fold

§

impl UnwindSafe for Fold

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.