rlx_flow/blocks/
gather_from_input.rs1use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7
8use super::BlockStage;
9use crate::context::FlowCtx;
10use crate::value::FlowValue;
11
12#[derive(Debug, Clone)]
14pub struct GatherFromInputStage {
15 pub input_name: String,
16 pub weight_key: String,
17 pub axis: usize,
18}
19
20impl GatherFromInputStage {
21 pub fn new(input_name: impl Into<String>, weight_key: impl Into<String>, axis: usize) -> Self {
22 Self {
23 input_name: input_name.into(),
24 weight_key: weight_key.into(),
25 axis,
26 }
27 }
28}
29
30impl BlockStage for GatherFromInputStage {
31 fn emit(&self, ctx: &mut FlowCtx<'_>, _input: FlowValue) -> Result<Option<FlowValue>> {
32 let (indices_id, indices_shape) = ctx
33 .state
34 .inputs
35 .get(&self.input_name)
36 .ok_or_else(|| anyhow::anyhow!("GatherFromInput missing input `{}`", self.input_name))?
37 .clone();
38 let embed_w = ctx.load_param(&self.weight_key, false)?;
39 let w_shape = ctx.hir().node(embed_w).shape.clone();
40 let mut dims: Vec<rlx_ir::Dim> = indices_shape.dims().to_vec();
41 dims.push(w_shape.dim(1));
42 let out_shape = rlx_ir::Shape::from_dims(&dims, indices_shape.dtype());
43
44 let mut gb = HirMut::new(ctx.hir());
45 let id = gb.gather_(embed_w, indices_id, self.axis);
46 Ok(Some(ctx.wrap(id, out_shape)))
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct GatherAddStage {
53 pub input_name: String,
54 pub weight_key: String,
55 pub axis: usize,
56}
57
58impl GatherAddStage {
59 pub fn new(input_name: impl Into<String>, weight_key: impl Into<String>, axis: usize) -> Self {
60 Self {
61 input_name: input_name.into(),
62 weight_key: weight_key.into(),
63 axis,
64 }
65 }
66}
67
68impl BlockStage for GatherAddStage {
69 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
70 let (indices_id, indices_shape) = ctx
71 .state
72 .inputs
73 .get(&self.input_name)
74 .ok_or_else(|| anyhow::anyhow!("GatherAdd missing input `{}`", self.input_name))?
75 .clone();
76 let embed_w = ctx.load_param(&self.weight_key, false)?;
77 let w_shape = ctx.hir().node(embed_w).shape.clone();
78 let mut dims: Vec<rlx_ir::Dim> = indices_shape.dims().to_vec();
79 dims.push(w_shape.dim(1));
80 let out_shape = rlx_ir::Shape::from_dims(&dims, indices_shape.dtype());
81
82 let mut gb = HirMut::new(ctx.hir());
83 let gathered = gb.gather_(embed_w, indices_id, self.axis);
84 let id = gb.add(input.id, gathered);
85 Ok(Some(ctx.wrap(id, out_shape)))
86 }
87}