Skip to main content

ModelArch

Trait ModelArch 

Source
pub trait ModelArch<D: Driver> {
    // Required method
    fn forward(
        &self,
        driver: &D,
        encodings: &[Encoding],
    ) -> Result<Vec<Vec<f32>>>;
}
Expand description

Model architecture that composes Driver primitives into a forward pass.

Implementations store their weights (on device) and model config, then orchestrate the driver to execute embedding lookup, encoder layers, pooling, and L2 normalization.

§Type parameter

D: Driver — the hardware backend. Architectures are generic over the driver so they can be monomorphized for each backend (Metal, CUDA, CPU).

Required Methods§

Source

fn forward(&self, driver: &D, encodings: &[Encoding]) -> Result<Vec<Vec<f32>>>

Run the full forward pass: embeddings -> encoder layers -> pool -> L2 normalize.

Returns one L2-normalized embedding vector per input encoding.

§Errors

Returns an error if any driver operation fails (buffer allocation, kernel dispatch, synchronization, etc.).

Implementors§