Skip to main content

rlx_flow/
stage_interfaces.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Block interface traits — generics with associated types (Slang interface bounds).
17
18use anyhow::Result;
19use rlx_ir::Shape;
20
21use crate::context::FlowCtx;
22use crate::stage_contract::{LayerStage, StageArtifacts};
23use crate::value::FlowValue;
24
25/// KV cache tensor shapes exposed by attention blocks (associated type stand-in).
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct KvCacheContract {
28    pub k: Shape,
29    pub v: Shape,
30}
31
32/// Attention block interface: hidden in, hidden out, plus cache contract.
33pub trait AttentionStage: LayerStage {
34    fn cache_contract(&self, ctx: &FlowCtx<'_>, hidden: &Shape) -> KvCacheContract;
35
36    fn emit_attention(
37        &self,
38        ctx: &mut FlowCtx<'_>,
39        input: FlowValue,
40    ) -> Result<(FlowValue, StageArtifacts, KvCacheContract)> {
41        let contract = self.cache_contract(ctx, &input.shape);
42        let (value, artifacts) = self.emit_layer(ctx, input)?;
43        Ok((value, artifacts, contract))
44    }
45}
46
47/// FFN block interface (SwiGLU / MLP).
48pub trait FfnStage: LayerStage {
49    /// Intermediate projection width (associated type as shape).
50    fn intermediate_shape(&self, ctx: &FlowCtx<'_>, hidden: &Shape) -> Shape;
51}
52
53/// Normalization block interface.
54pub trait NormStage: LayerStage {
55    fn eps(&self) -> f32;
56}