rlx_flow/blocks/
linear.rs1use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7
8use super::BlockStage;
9use crate::context::FlowCtx;
10use crate::value::FlowValue;
11
12#[derive(Debug, Clone)]
13pub struct LinearStage {
14 pub weight_key: String,
15 pub transpose: bool,
16}
17
18impl LinearStage {
19 pub fn new(weight_key: impl Into<String>, transpose: bool) -> Self {
20 Self {
21 weight_key: weight_key.into(),
22 transpose,
23 }
24 }
25}
26
27impl BlockStage for LinearStage {
28 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
29 let w = ctx.load_param(&self.weight_key, self.transpose)?;
30 let mut gb = HirMut::new(ctx.hir());
31 let id = gb.mm(input.id, w);
32 let out_shape = gb.shape(id).clone();
33 Ok(Some(ctx.wrap(id, out_shape)))
34 }
35}