Skip to main content

rlx_flow/
stream.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Named tensor streams — dual-/multi-stream models without IR in recipes.
5
6use std::collections::HashMap;
7use std::fmt;
8use std::sync::Arc;
9
10use anyhow::Result;
11
12use crate::context::FlowCtx;
13use crate::escape::Emit;
14use crate::stage::FlowStage;
15use crate::value::FlowValue;
16
17/// Well-known stream ids (conventions only — any string works).
18pub mod id {
19    pub const MAIN: &str = "main";
20    pub const IMG: &str = "img";
21    pub const TXT: &str = "txt";
22}
23
24type DualFn = Arc<
25    dyn Fn(&mut Emit<'_>, FlowValue, FlowValue) -> Result<(FlowValue, FlowValue)> + Send + Sync,
26>;
27
28/// Transform two named streams in place (e.g. FLUX img/txt dual block).
29#[derive(Clone)]
30pub struct DualStreamStage {
31    pub name: String,
32    pub stream_a: String,
33    pub stream_b: String,
34    inner: DualFn,
35}
36
37impl fmt::Debug for DualStreamStage {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        f.debug_struct("DualStreamStage")
40            .field("name", &self.name)
41            .field("stream_a", &self.stream_a)
42            .field("stream_b", &self.stream_b)
43            .finish_non_exhaustive()
44    }
45}
46
47impl DualStreamStage {
48    pub fn new<F>(
49        name: impl Into<String>,
50        stream_a: impl Into<String>,
51        stream_b: impl Into<String>,
52        f: F,
53    ) -> Self
54    where
55        F: Fn(&mut Emit<'_>, FlowValue, FlowValue) -> Result<(FlowValue, FlowValue)>
56            + Send
57            + Sync
58            + 'static,
59    {
60        Self {
61            name: name.into(),
62            stream_a: stream_a.into(),
63            stream_b: stream_b.into(),
64            inner: Arc::new(f),
65        }
66    }
67
68    pub fn emit(
69        &self,
70        ctx: &mut FlowCtx<'_>,
71        input: Option<FlowValue>,
72    ) -> Result<Option<FlowValue>> {
73        let a = ctx
74            .state
75            .streams
76            .get(&self.stream_a)
77            .cloned()
78            .ok_or_else(|| anyhow::anyhow!("dual stream missing `{}`", self.stream_a))?;
79        let b = ctx
80            .state
81            .streams
82            .get(&self.stream_b)
83            .cloned()
84            .ok_or_else(|| anyhow::anyhow!("dual stream missing `{}`", self.stream_b))?;
85        let mut emit = Emit::from_ctx(ctx);
86        let (na, nb) = (self.inner)(&mut emit, a, b)?;
87        ctx.state.streams.insert(self.stream_a.clone(), na);
88        ctx.state.streams.insert(self.stream_b.clone(), nb);
89        Ok(input)
90    }
91}
92
93/// Copy the active tensor flow into a named stream.
94#[derive(Debug, Clone)]
95pub struct StoreStreamStage {
96    pub name: String,
97}
98
99impl StoreStreamStage {
100    pub fn new(name: impl Into<String>) -> Self {
101        Self { name: name.into() }
102    }
103
104    pub fn emit(
105        &self,
106        ctx: &mut FlowCtx<'_>,
107        input: Option<FlowValue>,
108    ) -> Result<Option<FlowValue>> {
109        let v = input.ok_or_else(|| anyhow::anyhow!("StoreStream requires input"))?;
110        ctx.state.streams.insert(self.name.clone(), v.clone());
111        Ok(Some(v))
112    }
113}
114
115/// Replace the active tensor flow from a named stream.
116#[derive(Debug, Clone)]
117pub struct LoadStreamStage {
118    pub name: String,
119}
120
121impl LoadStreamStage {
122    pub fn new(name: impl Into<String>) -> Self {
123        Self { name: name.into() }
124    }
125
126    pub fn emit(
127        &self,
128        ctx: &mut FlowCtx<'_>,
129        input: Option<FlowValue>,
130    ) -> Result<Option<FlowValue>> {
131        let _ = input;
132        ctx.state
133            .streams
134            .get(&self.name)
135            .cloned()
136            .ok_or_else(|| anyhow::anyhow!("LoadStream missing `{}`", self.name))
137            .map(Some)
138    }
139}
140
141#[allow(dead_code)]
142pub(crate) fn stream_snapshot(state: &crate::context::FlowState) -> HashMap<String, FlowValue> {
143    state.streams.clone()
144}
145
146pub fn dual_stream_stage(
147    name: impl Into<String>,
148    stream_a: impl Into<String>,
149    stream_b: impl Into<String>,
150    f: impl Fn(&mut Emit<'_>, FlowValue, FlowValue) -> Result<(FlowValue, FlowValue)>
151    + Send
152    + Sync
153    + 'static,
154) -> FlowStage {
155    FlowStage::DualStream(DualStreamStage::new(name, stream_a, stream_b, f))
156}