use anyhow::Result;
use rlx_ir::Shape;
use crate::context::FlowCtx;
use crate::stage_contract::{LayerStage, StageArtifacts};
use crate::value::FlowValue;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KvCacheContract {
pub k: Shape,
pub v: Shape,
}
pub trait AttentionStage: LayerStage {
fn cache_contract(&self, ctx: &FlowCtx<'_>, hidden: &Shape) -> KvCacheContract;
fn emit_attention(
&self,
ctx: &mut FlowCtx<'_>,
input: FlowValue,
) -> Result<(FlowValue, StageArtifacts, KvCacheContract)> {
let contract = self.cache_contract(ctx, &input.shape);
let (value, artifacts) = self.emit_layer(ctx, input)?;
Ok((value, artifacts, contract))
}
}
pub trait FfnStage: LayerStage {
fn intermediate_shape(&self, ctx: &FlowCtx<'_>, hidden: &Shape) -> Shape;
}
pub trait NormStage: LayerStage {
fn eps(&self) -> f32;
}