Skip to main content

rlx_flow/
stage_interfaces.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Block interface traits — generics with associated types (Slang interface bounds).
5
6use anyhow::Result;
7use rlx_ir::Shape;
8
9use crate::context::FlowCtx;
10use crate::stage_contract::{LayerStage, StageArtifacts};
11use crate::value::FlowValue;
12
13/// KV cache tensor shapes exposed by attention blocks (associated type stand-in).
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct KvCacheContract {
16    pub k: Shape,
17    pub v: Shape,
18}
19
20/// Attention block interface: hidden in, hidden out, plus cache contract.
21pub trait AttentionStage: LayerStage {
22    fn cache_contract(&self, ctx: &FlowCtx<'_>, hidden: &Shape) -> KvCacheContract;
23
24    fn emit_attention(
25        &self,
26        ctx: &mut FlowCtx<'_>,
27        input: FlowValue,
28    ) -> Result<(FlowValue, StageArtifacts, KvCacheContract)> {
29        let contract = self.cache_contract(ctx, &input.shape);
30        let (value, artifacts) = self.emit_layer(ctx, input)?;
31        Ok((value, artifacts, contract))
32    }
33}
34
35/// FFN block interface (SwiGLU / MLP).
36pub trait FfnStage: LayerStage {
37    /// Intermediate projection width (associated type as shape).
38    fn intermediate_shape(&self, ctx: &FlowCtx<'_>, hidden: &Shape) -> Shape;
39}
40
41/// Normalization block interface.
42pub trait NormStage: LayerStage {
43    fn eps(&self) -> f32;
44}