pub trait ModelArch<D: Driver> {
// Required method
fn forward(
&self,
driver: &D,
encodings: &[Encoding],
) -> Result<Vec<Vec<f32>>>;
}Expand description
Model architecture that composes Driver primitives into a forward pass.
Implementations store their weights (on device) and model config, then orchestrate the driver to execute embedding lookup, encoder layers, pooling, and L2 normalization.
§Type parameter
D: Driver — the hardware backend. Architectures are generic over the
driver so they can be monomorphized for each backend (Metal, CUDA, CPU).
Required Methods§
Sourcefn forward(&self, driver: &D, encodings: &[Encoding]) -> Result<Vec<Vec<f32>>>
fn forward(&self, driver: &D, encodings: &[Encoding]) -> Result<Vec<Vec<f32>>>
Run the full forward pass: embeddings -> encoder layers -> pool -> L2 normalize.
Returns one L2-normalized embedding vector per input encoding.
§Errors
Returns an error if any driver operation fails (buffer allocation, kernel dispatch, synchronization, etc.).