Skip to main content

rlx_flow/
stage_contract.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Typed stage contracts — associated artifacts per layer (Slang-style associated types).
17
18use anyhow::Result;
19use rlx_ir::Shape;
20
21use crate::blocks::BlockStage;
22use crate::context::FlowCtx;
23use crate::value::FlowValue;
24
25/// Outputs a layer stage may publish beyond the main hidden tensor.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct StageArtifacts {
28    pub hidden: Shape,
29    pub side_outputs: Vec<(String, Shape)>,
30}
31
32impl StageArtifacts {
33    pub fn hidden_only(shape: Shape) -> Self {
34        Self {
35            hidden: shape,
36            side_outputs: Vec::new(),
37        }
38    }
39
40    pub fn with_side(mut self, name: impl Into<String>, shape: Shape) -> Self {
41        self.side_outputs.push((name.into(), shape));
42        self
43    }
44}
45
46/// Layer block with an explicit artifact contract (for new blocks and plugins).
47pub trait LayerStage: Send + Sync {
48    fn name(&self) -> &str;
49
50    fn emit_layer(
51        &self,
52        ctx: &mut FlowCtx<'_>,
53        input: FlowValue,
54    ) -> Result<(FlowValue, StageArtifacts)>;
55}
56
57/// Bridge existing [`BlockStage`] impls to [`LayerStage`] with hidden-only artifacts.
58pub struct BlockAsLayer<S>(pub S);
59
60impl<S: BlockStage + Send + Sync> LayerStage for BlockAsLayer<S> {
61    fn name(&self) -> &str {
62        "block"
63    }
64
65    fn emit_layer(
66        &self,
67        ctx: &mut FlowCtx<'_>,
68        input: FlowValue,
69    ) -> Result<(FlowValue, StageArtifacts)> {
70        let out = self.0.emit(ctx, input.clone())?;
71        let value = match out {
72            Some(v) => v,
73            None => input,
74        };
75        Ok((
76            value.clone(),
77            StageArtifacts::hidden_only(value.shape.clone()),
78        ))
79    }
80}