rlx_flow/blocks/
gdn_scan.rs1use anyhow::Result;
7use rlx_ir::HirGraphExt;
8use rlx_ir::Shape;
9use rlx_ir::hir::HirMut;
10
11use super::BlockStage;
12use crate::context::FlowCtx;
13use crate::value::FlowValue;
14
15#[derive(Debug, Clone)]
17pub struct GdnScanStage {
18 pub state_size: usize,
19 pub out_shape: Shape,
20 pub carry_state: bool,
21 pub state_key: Option<String>,
22}
23
24impl GdnScanStage {
25 pub fn prefill(state_size: usize, out_shape: Shape) -> Self {
26 Self {
27 state_size,
28 out_shape,
29 carry_state: false,
30 state_key: None,
31 }
32 }
33
34 pub fn with_carry(mut self, state_key: impl Into<String>) -> Self {
35 self.carry_state = true;
36 self.state_key = Some(state_key.into());
37 self
38 }
39}
40
41impl BlockStage for GdnScanStage {
42 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
43 let slots = ctx
44 .state
45 .gdn
46 .ok_or_else(|| anyhow::anyhow!("GdnScan requires gdn inputs in FlowState"))?;
47 let carry_state = if self.carry_state {
48 let key = self
49 .state_key
50 .as_deref()
51 .ok_or_else(|| anyhow::anyhow!("GdnScan carry requires state_key"))?;
52 Some(
53 *ctx.state
54 .named
55 .get(key)
56 .ok_or_else(|| anyhow::anyhow!("GdnScan missing carry state `{key}`"))?,
57 )
58 } else {
59 None
60 };
61 let mut gb = HirMut::new(ctx.hir());
62 let id = if let Some(state) = carry_state {
63 gb.gated_delta_net_carry(
64 slots.q,
65 slots.k,
66 slots.v,
67 slots.g,
68 slots.beta,
69 state,
70 self.state_size,
71 self.out_shape.clone(),
72 )
73 } else {
74 gb.gated_delta_net(
75 slots.q,
76 slots.k,
77 slots.v,
78 slots.g,
79 slots.beta,
80 self.state_size,
81 self.out_shape.clone(),
82 )
83 };
84 let _ = input;
85 Ok(Some(ctx.wrap(id, self.out_shape.clone())))
86 }
87}