1use std::cell::RefCell;
6use std::rc::Rc;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::mpsc::{self, Receiver, Sender};
9use std::sync::Arc;
10use std::thread;
11
12#[derive(Clone, Debug)]
14pub enum StreamState {
15 Idle,
17 Streaming,
19 Done,
21 Error(String),
23}
24
25pub struct StreamHandle<T> {
27 inner: Rc<RefCell<StreamInner<T>>>,
28}
29
30struct StreamInner<T> {
31 accumulated: T,
33 state: StreamState,
35 started: bool,
37 receiver: Option<Receiver<StreamItem<T>>>,
39 wake_flag: Option<Arc<AtomicBool>>,
41}
42
43enum StreamItem<T> {
45 Value(T),
47 Done,
49 Error(String),
51}
52
53impl<T> Clone for StreamHandle<T>
54where
55 T: Clone,
56{
57 fn clone(&self) -> Self {
58 Self {
59 inner: Rc::clone(&self.inner),
60 }
61 }
62}
63
64impl<T: Clone + Default + 'static> StreamHandle<T> {
65 pub fn new() -> Self {
67 Self {
68 inner: Rc::new(RefCell::new(StreamInner {
69 accumulated: T::default(),
70 state: StreamState::Idle,
71 started: false,
72 receiver: None,
73 wake_flag: None,
74 })),
75 }
76 }
77
78 pub fn with_wake_flag(wake_flag: Arc<AtomicBool>) -> Self {
80 Self {
81 inner: Rc::new(RefCell::new(StreamInner {
82 accumulated: T::default(),
83 state: StreamState::Idle,
84 started: false,
85 receiver: None,
86 wake_flag: Some(wake_flag),
87 })),
88 }
89 }
90
91 pub fn with_initial(initial: T) -> Self {
93 Self {
94 inner: Rc::new(RefCell::new(StreamInner {
95 accumulated: initial,
96 state: StreamState::Idle,
97 started: false,
98 receiver: None,
99 wake_flag: None,
100 })),
101 }
102 }
103
104 pub fn get(&self) -> T {
106 self.inner.borrow().accumulated.clone()
107 }
108
109 pub fn is_loading(&self) -> bool {
111 matches!(
112 self.inner.borrow().state,
113 StreamState::Idle | StreamState::Streaming
114 )
115 }
116
117 pub fn is_streaming(&self) -> bool {
119 matches!(self.inner.borrow().state, StreamState::Streaming)
120 }
121
122 pub fn is_done(&self) -> bool {
124 matches!(self.inner.borrow().state, StreamState::Done)
125 }
126
127 pub fn is_error(&self) -> bool {
129 matches!(self.inner.borrow().state, StreamState::Error(_))
130 }
131
132 pub fn error(&self) -> Option<String> {
134 match &self.inner.borrow().state {
135 StreamState::Error(e) => Some(e.clone()),
136 _ => None,
137 }
138 }
139
140 pub fn state(&self) -> StreamState {
142 self.inner.borrow().state.clone()
143 }
144}
145
146impl<T: Clone + Send + 'static> StreamHandle<T> {
147 pub fn start<F, I>(&self, stream_fn: F)
152 where
153 F: FnOnce() -> I + Send + 'static,
154 I: Iterator<Item = T> + Send + 'static,
155 {
156 let mut inner = self.inner.borrow_mut();
157 if inner.started {
158 return;
159 }
160
161 inner.started = true;
162 inner.state = StreamState::Streaming;
163
164 let (tx, rx): (Sender<StreamItem<T>>, Receiver<StreamItem<T>>) = mpsc::channel();
166 inner.receiver = Some(rx);
167 let wake_flag = inner.wake_flag.clone();
168
169 thread::spawn(move || {
171 let iter = stream_fn();
172 for item in iter {
173 if tx.send(StreamItem::Value(item)).is_err() {
174 return;
176 }
177 if let Some(ref flag) = wake_flag {
178 flag.store(true, Ordering::Release);
179 }
180 }
181 let _ = tx.send(StreamItem::Done);
182 if let Some(ref flag) = wake_flag {
183 flag.store(true, Ordering::Release);
184 }
185 });
186 }
187
188 pub fn start_with_result<F, I>(&self, stream_fn: F)
190 where
191 F: FnOnce() -> Result<I, String> + Send + 'static,
192 I: Iterator<Item = T> + Send + 'static,
193 {
194 let mut inner = self.inner.borrow_mut();
195 if inner.started {
196 return;
197 }
198
199 inner.started = true;
200 inner.state = StreamState::Streaming;
201
202 let (tx, rx): (Sender<StreamItem<T>>, Receiver<StreamItem<T>>) = mpsc::channel();
203 inner.receiver = Some(rx);
204 let wake_flag = inner.wake_flag.clone();
205
206 thread::spawn(move || match stream_fn() {
207 Ok(iter) => {
208 for item in iter {
209 if tx.send(StreamItem::Value(item)).is_err() {
210 return;
211 }
212 if let Some(ref flag) = wake_flag {
213 flag.store(true, Ordering::Release);
214 }
215 }
216 let _ = tx.send(StreamItem::Done);
217 if let Some(ref flag) = wake_flag {
218 flag.store(true, Ordering::Release);
219 }
220 }
221 Err(e) => {
222 let _ = tx.send(StreamItem::Error(e));
223 if let Some(ref flag) = wake_flag {
224 flag.store(true, Ordering::Release);
225 }
226 }
227 });
228 }
229
230 pub fn poll(&self, accumulate: impl Fn(&mut T, T)) -> bool {
233 let mut inner = self.inner.borrow_mut();
234 let mut updated = false;
235
236 if let Some(receiver) = inner.receiver.take() {
238 let mut new_state = None;
240 loop {
241 match receiver.try_recv() {
242 Ok(StreamItem::Value(item)) => {
243 accumulate(&mut inner.accumulated, item);
244 updated = true;
245 }
246 Ok(StreamItem::Done) => {
247 new_state = Some(StreamState::Done);
248 break;
249 }
250 Ok(StreamItem::Error(e)) => {
251 new_state = Some(StreamState::Error(e));
252 break;
253 }
254 Err(mpsc::TryRecvError::Empty) => {
255 break;
256 }
257 Err(mpsc::TryRecvError::Disconnected) => {
258 if !matches!(inner.state, StreamState::Done | StreamState::Error(_)) {
259 new_state = Some(StreamState::Error(
260 "Stream disconnected unexpectedly".to_string(),
261 ));
262 }
263 break;
264 }
265 }
266 }
267
268 if new_state.is_none() || matches!(new_state, Some(StreamState::Streaming)) {
270 inner.receiver = Some(receiver);
271 }
272
273 if let Some(state) = new_state {
274 inner.state = state;
275 }
276 }
277
278 updated
279 }
280
281 pub fn reset(&self)
283 where
284 T: Default,
285 {
286 let mut inner = self.inner.borrow_mut();
287 inner.accumulated = T::default();
288 inner.state = StreamState::Idle;
289 inner.started = false;
290 inner.receiver = None;
291 }
293
294 pub fn reset_with(&self, initial: T) {
296 let mut inner = self.inner.borrow_mut();
297 inner.accumulated = initial;
298 inner.state = StreamState::Idle;
299 inner.started = false;
300 inner.receiver = None;
301 }
303}
304
305impl<T: Clone + Default + 'static> Default for StreamHandle<T> {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311pub type TextStreamHandle = StreamHandle<String>;
314
315impl TextStreamHandle {
316 pub fn poll_text(&self) -> bool {
318 self.poll(|acc, item| acc.push_str(&item))
319 }
320}