Skip to main content

rlx_flow/blocks/
gdn_scan.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Gated DeltaNet scan — generic op wrapper (Qwen3.5 trunk, …).
5
6use anyhow::Result;
7use rlx_ir::HirGraphExt;
8use rlx_ir::Shape;
9use rlx_ir::hir::HirMut;
10
11use super::BlockStage;
12use crate::context::FlowCtx;
13use crate::value::FlowValue;
14
15/// Q/K/V/G/Beta tensors must already be shaped `[batch, seq, heads, state]`.
16#[derive(Debug, Clone)]
17pub struct GdnScanStage {
18    pub state_size: usize,
19    pub out_shape: Shape,
20    pub carry_state: bool,
21    pub state_key: Option<String>,
22}
23
24impl GdnScanStage {
25    pub fn prefill(state_size: usize, out_shape: Shape) -> Self {
26        Self {
27            state_size,
28            out_shape,
29            carry_state: false,
30            state_key: None,
31        }
32    }
33
34    pub fn with_carry(mut self, state_key: impl Into<String>) -> Self {
35        self.carry_state = true;
36        self.state_key = Some(state_key.into());
37        self
38    }
39}
40
41impl BlockStage for GdnScanStage {
42    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
43        let slots = ctx
44            .state
45            .gdn
46            .ok_or_else(|| anyhow::anyhow!("GdnScan requires gdn inputs in FlowState"))?;
47        let carry_state = if self.carry_state {
48            let key = self
49                .state_key
50                .as_deref()
51                .ok_or_else(|| anyhow::anyhow!("GdnScan carry requires state_key"))?;
52            Some(
53                *ctx.state
54                    .named
55                    .get(key)
56                    .ok_or_else(|| anyhow::anyhow!("GdnScan missing carry state `{key}`"))?,
57            )
58        } else {
59            None
60        };
61        let mut gb = HirMut::new(ctx.hir());
62        let id = if let Some(state) = carry_state {
63            gb.gated_delta_net_carry(
64                slots.q,
65                slots.k,
66                slots.v,
67                slots.g,
68                slots.beta,
69                state,
70                self.state_size,
71                self.out_shape.clone(),
72            )
73        } else {
74            gb.gated_delta_net(
75                slots.q,
76                slots.k,
77                slots.v,
78                slots.g,
79                slots.beta,
80                self.state_size,
81                self.out_shape.clone(),
82            )
83        };
84        let _ = input;
85        Ok(Some(ctx.wrap(id, self.out_shape.clone())))
86    }
87}