Skip to main content

rlx_flow/blocks/
linear.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7
8use super::BlockStage;
9use crate::context::FlowCtx;
10use crate::value::FlowValue;
11
12#[derive(Debug, Clone)]
13pub struct LinearStage {
14    pub weight_key: String,
15    pub transpose: bool,
16}
17
18impl LinearStage {
19    pub fn new(weight_key: impl Into<String>, transpose: bool) -> Self {
20        Self {
21            weight_key: weight_key.into(),
22            transpose,
23        }
24    }
25}
26
27impl BlockStage for LinearStage {
28    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
29        let w = ctx.load_param(&self.weight_key, self.transpose)?;
30        let mut gb = HirMut::new(ctx.hir());
31        let id = gb.mm(input.id, w);
32        let out_shape = gb.shape(id).clone();
33        Ok(Some(ctx.wrap(id, out_shape)))
34    }
35}