Skip to main content

rlx_flow/
stream.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Named tensor streams — dual-/multi-stream models without IR in recipes.
17
18use std::collections::HashMap;
19use std::fmt;
20use std::sync::Arc;
21
22use anyhow::Result;
23
24use crate::context::FlowCtx;
25use crate::escape::Emit;
26use crate::stage::FlowStage;
27use crate::value::FlowValue;
28
29/// Well-known stream ids (conventions only — any string works).
30pub mod id {
31    pub const MAIN: &str = "main";
32    pub const IMG: &str = "img";
33    pub const TXT: &str = "txt";
34}
35
36type DualFn = Arc<
37    dyn Fn(&mut Emit<'_>, FlowValue, FlowValue) -> Result<(FlowValue, FlowValue)> + Send + Sync,
38>;
39
40/// Transform two named streams in place (e.g. FLUX img/txt dual block).
41#[derive(Clone)]
42pub struct DualStreamStage {
43    pub name: String,
44    pub stream_a: String,
45    pub stream_b: String,
46    inner: DualFn,
47}
48
49impl fmt::Debug for DualStreamStage {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        f.debug_struct("DualStreamStage")
52            .field("name", &self.name)
53            .field("stream_a", &self.stream_a)
54            .field("stream_b", &self.stream_b)
55            .finish_non_exhaustive()
56    }
57}
58
59impl DualStreamStage {
60    pub fn new<F>(
61        name: impl Into<String>,
62        stream_a: impl Into<String>,
63        stream_b: impl Into<String>,
64        f: F,
65    ) -> Self
66    where
67        F: Fn(&mut Emit<'_>, FlowValue, FlowValue) -> Result<(FlowValue, FlowValue)>
68            + Send
69            + Sync
70            + 'static,
71    {
72        Self {
73            name: name.into(),
74            stream_a: stream_a.into(),
75            stream_b: stream_b.into(),
76            inner: Arc::new(f),
77        }
78    }
79
80    pub fn emit(
81        &self,
82        ctx: &mut FlowCtx<'_>,
83        input: Option<FlowValue>,
84    ) -> Result<Option<FlowValue>> {
85        let a = ctx
86            .state
87            .streams
88            .get(&self.stream_a)
89            .cloned()
90            .ok_or_else(|| anyhow::anyhow!("dual stream missing `{}`", self.stream_a))?;
91        let b = ctx
92            .state
93            .streams
94            .get(&self.stream_b)
95            .cloned()
96            .ok_or_else(|| anyhow::anyhow!("dual stream missing `{}`", self.stream_b))?;
97        let mut emit = Emit::from_ctx(ctx);
98        let (na, nb) = (self.inner)(&mut emit, a, b)?;
99        ctx.state.streams.insert(self.stream_a.clone(), na);
100        ctx.state.streams.insert(self.stream_b.clone(), nb);
101        Ok(input)
102    }
103}
104
105/// Copy the active tensor flow into a named stream.
106#[derive(Debug, Clone)]
107pub struct StoreStreamStage {
108    pub name: String,
109}
110
111impl StoreStreamStage {
112    pub fn new(name: impl Into<String>) -> Self {
113        Self { name: name.into() }
114    }
115
116    pub fn emit(
117        &self,
118        ctx: &mut FlowCtx<'_>,
119        input: Option<FlowValue>,
120    ) -> Result<Option<FlowValue>> {
121        let v = input.ok_or_else(|| anyhow::anyhow!("StoreStream requires input"))?;
122        ctx.state.streams.insert(self.name.clone(), v.clone());
123        Ok(Some(v))
124    }
125}
126
127/// Replace the active tensor flow from a named stream.
128#[derive(Debug, Clone)]
129pub struct LoadStreamStage {
130    pub name: String,
131}
132
133impl LoadStreamStage {
134    pub fn new(name: impl Into<String>) -> Self {
135        Self { name: name.into() }
136    }
137
138    pub fn emit(
139        &self,
140        ctx: &mut FlowCtx<'_>,
141        input: Option<FlowValue>,
142    ) -> Result<Option<FlowValue>> {
143        let _ = input;
144        ctx.state
145            .streams
146            .get(&self.name)
147            .cloned()
148            .ok_or_else(|| anyhow::anyhow!("LoadStream missing `{}`", self.name))
149            .map(Some)
150    }
151}
152
153#[allow(dead_code)]
154pub(crate) fn stream_snapshot(state: &crate::context::FlowState) -> HashMap<String, FlowValue> {
155    state.streams.clone()
156}
157
158pub fn dual_stream_stage(
159    name: impl Into<String>,
160    stream_a: impl Into<String>,
161    stream_b: impl Into<String>,
162    f: impl Fn(&mut Emit<'_>, FlowValue, FlowValue) -> Result<(FlowValue, FlowValue)>
163    + Send
164    + Sync
165    + 'static,
166) -> FlowStage {
167    FlowStage::DualStream(DualStreamStage::new(name, stream_a, stream_b, f))
168}