Skip to main content

rlx_flow/
stage_contract.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Typed stage contracts — associated artifacts per layer (Slang-style associated types).
5
6use anyhow::Result;
7use rlx_ir::Shape;
8
9use crate::blocks::BlockStage;
10use crate::context::FlowCtx;
11use crate::value::FlowValue;
12
13/// Outputs a layer stage may publish beyond the main hidden tensor.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct StageArtifacts {
16    pub hidden: Shape,
17    pub side_outputs: Vec<(String, Shape)>,
18}
19
20impl StageArtifacts {
21    pub fn hidden_only(shape: Shape) -> Self {
22        Self {
23            hidden: shape,
24            side_outputs: Vec::new(),
25        }
26    }
27
28    pub fn with_side(mut self, name: impl Into<String>, shape: Shape) -> Self {
29        self.side_outputs.push((name.into(), shape));
30        self
31    }
32}
33
34/// Layer block with an explicit artifact contract (for new blocks and plugins).
35pub trait LayerStage: Send + Sync {
36    fn name(&self) -> &str;
37
38    fn emit_layer(
39        &self,
40        ctx: &mut FlowCtx<'_>,
41        input: FlowValue,
42    ) -> Result<(FlowValue, StageArtifacts)>;
43}
44
45/// Bridge existing [`BlockStage`] impls to [`LayerStage`] with hidden-only artifacts.
46pub struct BlockAsLayer<S>(pub S);
47
48impl<S: BlockStage + Send + Sync> LayerStage for BlockAsLayer<S> {
49    fn name(&self) -> &str {
50        "block"
51    }
52
53    fn emit_layer(
54        &self,
55        ctx: &mut FlowCtx<'_>,
56        input: FlowValue,
57    ) -> Result<(FlowValue, StageArtifacts)> {
58        let out = self.0.emit(ctx, input.clone())?;
59        let value = match out {
60            Some(v) => v,
61            None => input,
62        };
63        Ok((
64            value.clone(),
65            StageArtifacts::hidden_only(value.shape.clone()),
66        ))
67    }
68}