rlx_flow/stage_contract.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Typed stage contracts — associated artifacts per layer (Slang-style associated types).
17
18use anyhow::Result;
19use rlx_ir::Shape;
20
21use crate::blocks::BlockStage;
22use crate::context::FlowCtx;
23use crate::value::FlowValue;
24
25/// Outputs a layer stage may publish beyond the main hidden tensor.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct StageArtifacts {
28 pub hidden: Shape,
29 pub side_outputs: Vec<(String, Shape)>,
30}
31
32impl StageArtifacts {
33 pub fn hidden_only(shape: Shape) -> Self {
34 Self {
35 hidden: shape,
36 side_outputs: Vec::new(),
37 }
38 }
39
40 pub fn with_side(mut self, name: impl Into<String>, shape: Shape) -> Self {
41 self.side_outputs.push((name.into(), shape));
42 self
43 }
44}
45
46/// Layer block with an explicit artifact contract (for new blocks and plugins).
47pub trait LayerStage: Send + Sync {
48 fn name(&self) -> &str;
49
50 fn emit_layer(
51 &self,
52 ctx: &mut FlowCtx<'_>,
53 input: FlowValue,
54 ) -> Result<(FlowValue, StageArtifacts)>;
55}
56
57/// Bridge existing [`BlockStage`] impls to [`LayerStage`] with hidden-only artifacts.
58pub struct BlockAsLayer<S>(pub S);
59
60impl<S: BlockStage + Send + Sync> LayerStage for BlockAsLayer<S> {
61 fn name(&self) -> &str {
62 "block"
63 }
64
65 fn emit_layer(
66 &self,
67 ctx: &mut FlowCtx<'_>,
68 input: FlowValue,
69 ) -> Result<(FlowValue, StageArtifacts)> {
70 let out = self.0.emit(ctx, input.clone())?;
71 let value = match out {
72 Some(v) => v,
73 None => input,
74 };
75 Ok((
76 value.clone(),
77 StageArtifacts::hidden_only(value.shape.clone()),
78 ))
79 }
80}