rlx_flow/
stage_interfaces.rs1use anyhow::Result;
7use rlx_ir::Shape;
8
9use crate::context::FlowCtx;
10use crate::stage_contract::{LayerStage, StageArtifacts};
11use crate::value::FlowValue;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct KvCacheContract {
16 pub k: Shape,
17 pub v: Shape,
18}
19
20pub trait AttentionStage: LayerStage {
22 fn cache_contract(&self, ctx: &FlowCtx<'_>, hidden: &Shape) -> KvCacheContract;
23
24 fn emit_attention(
25 &self,
26 ctx: &mut FlowCtx<'_>,
27 input: FlowValue,
28 ) -> Result<(FlowValue, StageArtifacts, KvCacheContract)> {
29 let contract = self.cache_contract(ctx, &input.shape);
30 let (value, artifacts) = self.emit_layer(ctx, input)?;
31 Ok((value, artifacts, contract))
32 }
33}
34
35pub trait FfnStage: LayerStage {
37 fn intermediate_shape(&self, ctx: &FlowCtx<'_>, hidden: &Shape) -> Shape;
39}
40
41pub trait NormStage: LayerStage {
43 fn eps(&self) -> f32;
44}