Skip to main content

telex/
stream_state.rs

1//! Stream state management for Telex.
2//!
3//! Provides the `use_stream` hook for handling async streams (e.g., LLM token streaming).
4
5use std::cell::RefCell;
6use std::rc::Rc;
7use std::sync::mpsc::{self, Receiver, Sender};
8use std::thread;
9
10/// Represents the state of a streaming operation.
11#[derive(Clone, Debug)]
12pub enum StreamState {
13    /// Stream not started yet.
14    Idle,
15    /// Stream is active and receiving data.
16    Streaming,
17    /// Stream completed successfully.
18    Done,
19    /// Stream encountered an error.
20    Error(String),
21}
22
23/// Handle for stream state that can be stored and cloned.
24pub struct StreamHandle<T> {
25    inner: Rc<RefCell<StreamInner<T>>>,
26}
27
28struct StreamInner<T> {
29    /// Accumulated values from the stream.
30    accumulated: T,
31    /// Current state of the stream.
32    state: StreamState,
33    /// Whether the stream has been started.
34    started: bool,
35    /// Receiver for stream items.
36    receiver: Option<Receiver<StreamItem<T>>>,
37}
38
39/// An item received from the stream.
40enum StreamItem<T> {
41    /// A value from the stream.
42    Value(T),
43    /// Stream completed.
44    Done,
45    /// Stream errored.
46    Error(String),
47}
48
49impl<T> Clone for StreamHandle<T>
50where
51    T: Clone,
52{
53    fn clone(&self) -> Self {
54        Self {
55            inner: Rc::clone(&self.inner),
56        }
57    }
58}
59
60impl<T: Clone + Default + 'static> StreamHandle<T> {
61    /// Create a new stream handle with default accumulated value.
62    pub fn new() -> Self {
63        Self {
64            inner: Rc::new(RefCell::new(StreamInner {
65                accumulated: T::default(),
66                state: StreamState::Idle,
67                started: false,
68                receiver: None,
69            })),
70        }
71    }
72
73    /// Create a new stream handle with a specific initial value.
74    pub fn with_initial(initial: T) -> Self {
75        Self {
76            inner: Rc::new(RefCell::new(StreamInner {
77                accumulated: initial,
78                state: StreamState::Idle,
79                started: false,
80                receiver: None,
81            })),
82        }
83    }
84
85    /// Get the current accumulated value.
86    pub fn get(&self) -> T {
87        self.inner.borrow().accumulated.clone()
88    }
89
90    /// Check if the stream is currently loading/streaming.
91    pub fn is_loading(&self) -> bool {
92        matches!(
93            self.inner.borrow().state,
94            StreamState::Idle | StreamState::Streaming
95        )
96    }
97
98    /// Check if the stream is actively receiving data.
99    pub fn is_streaming(&self) -> bool {
100        matches!(self.inner.borrow().state, StreamState::Streaming)
101    }
102
103    /// Check if the stream has completed.
104    pub fn is_done(&self) -> bool {
105        matches!(self.inner.borrow().state, StreamState::Done)
106    }
107
108    /// Check if the stream encountered an error.
109    pub fn is_error(&self) -> bool {
110        matches!(self.inner.borrow().state, StreamState::Error(_))
111    }
112
113    /// Get the error message if there was an error.
114    pub fn error(&self) -> Option<String> {
115        match &self.inner.borrow().state {
116            StreamState::Error(e) => Some(e.clone()),
117            _ => None,
118        }
119    }
120
121    /// Get the current stream state.
122    pub fn state(&self) -> StreamState {
123        self.inner.borrow().state.clone()
124    }
125}
126
127impl<T: Clone + Send + 'static> StreamHandle<T> {
128    /// Start the stream if not already started.
129    ///
130    /// The `stream_fn` should be a function that returns an iterator.
131    /// Each item from the iterator will be sent through the channel.
132    pub fn start<F, I>(&self, stream_fn: F)
133    where
134        F: FnOnce() -> I + Send + 'static,
135        I: Iterator<Item = T> + Send + 'static,
136    {
137        let mut inner = self.inner.borrow_mut();
138        if inner.started {
139            return;
140        }
141
142        inner.started = true;
143        inner.state = StreamState::Streaming;
144
145        // Create channel for stream items
146        let (tx, rx): (Sender<StreamItem<T>>, Receiver<StreamItem<T>>) = mpsc::channel();
147        inner.receiver = Some(rx);
148
149        // Spawn thread to run the stream
150        thread::spawn(move || {
151            let iter = stream_fn();
152            for item in iter {
153                if tx.send(StreamItem::Value(item)).is_err() {
154                    // Receiver dropped, stop streaming
155                    return;
156                }
157            }
158            let _ = tx.send(StreamItem::Done);
159        });
160    }
161
162    /// Start the stream with error handling.
163    pub fn start_with_result<F, I>(&self, stream_fn: F)
164    where
165        F: FnOnce() -> Result<I, String> + Send + 'static,
166        I: Iterator<Item = T> + Send + 'static,
167    {
168        let mut inner = self.inner.borrow_mut();
169        if inner.started {
170            return;
171        }
172
173        inner.started = true;
174        inner.state = StreamState::Streaming;
175
176        let (tx, rx): (Sender<StreamItem<T>>, Receiver<StreamItem<T>>) = mpsc::channel();
177        inner.receiver = Some(rx);
178
179        thread::spawn(move || match stream_fn() {
180            Ok(iter) => {
181                for item in iter {
182                    if tx.send(StreamItem::Value(item)).is_err() {
183                        return;
184                    }
185                }
186                let _ = tx.send(StreamItem::Done);
187            }
188            Err(e) => {
189                let _ = tx.send(StreamItem::Error(e));
190            }
191        });
192    }
193
194    /// Poll for new items and update accumulated value.
195    /// Returns true if there were updates.
196    pub fn poll(&self, accumulate: impl Fn(&mut T, T)) -> bool {
197        let mut inner = self.inner.borrow_mut();
198        let mut updated = false;
199
200        // Take receiver temporarily to avoid borrow conflicts
201        if let Some(receiver) = inner.receiver.take() {
202            // Drain all available items
203            let mut new_state = None;
204            loop {
205                match receiver.try_recv() {
206                    Ok(StreamItem::Value(item)) => {
207                        accumulate(&mut inner.accumulated, item);
208                        updated = true;
209                    }
210                    Ok(StreamItem::Done) => {
211                        new_state = Some(StreamState::Done);
212                        break;
213                    }
214                    Ok(StreamItem::Error(e)) => {
215                        new_state = Some(StreamState::Error(e));
216                        break;
217                    }
218                    Err(mpsc::TryRecvError::Empty) => {
219                        break;
220                    }
221                    Err(mpsc::TryRecvError::Disconnected) => {
222                        if !matches!(inner.state, StreamState::Done | StreamState::Error(_)) {
223                            new_state = Some(StreamState::Error(
224                                "Stream disconnected unexpectedly".to_string(),
225                            ));
226                        }
227                        break;
228                    }
229                }
230            }
231
232            // Put receiver back (unless stream is done)
233            if new_state.is_none() || matches!(new_state, Some(StreamState::Streaming)) {
234                inner.receiver = Some(receiver);
235            }
236
237            if let Some(state) = new_state {
238                inner.state = state;
239            }
240        }
241
242        updated
243    }
244
245    /// Reset the stream to allow restarting.
246    pub fn reset(&self)
247    where
248        T: Default,
249    {
250        let mut inner = self.inner.borrow_mut();
251        inner.accumulated = T::default();
252        inner.state = StreamState::Idle;
253        inner.started = false;
254        inner.receiver = None;
255    }
256
257    /// Reset the stream with a specific initial value.
258    pub fn reset_with(&self, initial: T) {
259        let mut inner = self.inner.borrow_mut();
260        inner.accumulated = initial;
261        inner.state = StreamState::Idle;
262        inner.started = false;
263        inner.receiver = None;
264    }
265}
266
267impl<T: Clone + Default + 'static> Default for StreamHandle<T> {
268    fn default() -> Self {
269        Self::new()
270    }
271}
272
273/// Convenience handle specifically for text streaming.
274/// Automatically accumulates string tokens.
275pub type TextStreamHandle = StreamHandle<String>;
276
277impl TextStreamHandle {
278    /// Poll and accumulate text by concatenation.
279    pub fn poll_text(&self) -> bool {
280        self.poll(|acc, item| acc.push_str(&item))
281    }
282}