Skip to main content

loom_rs/
stream.rs

1//! Stream combinators for processing items via rayon compute threads.
2//!
3//! This module provides the [`ComputeStreamExt`] trait which extends any [`Stream`]
4//! with methods for processing items through the rayon thread pool.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use loom_rs::{LoomBuilder, ComputeStreamExt};
10//! use futures::stream::{self, StreamExt};
11//!
12//! let runtime = LoomBuilder::new().build()?;
13//!
14//! runtime.block_on(async {
15//!     let results: Vec<_> = stream::iter(0..100)
16//!         .compute_map(|n| {
17//!             // CPU-intensive work runs on rayon
18//!             (0..n).map(|i| i * i).sum::<i64>()
19//!         })
20//!         .collect()
21//!         .await;
22//! });
23//! ```
24//!
25//! # Performance
26//!
27//! The key optimization is reusing the same `TaskState` for all items in the stream,
28//! rather than getting/returning from the pool for each item:
29//!
30//! | Operation | Overhead | Allocations |
31//! |-----------|----------|-------------|
32//! | Stream creation | ~1us | TaskState (from pool or new) |
33//! | Each item | ~100-500ns | 0 bytes (reuses TaskState) |
34//! | Stream drop | ~10ns | Returns TaskState to pool |
35
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40
41use futures_core::Stream;
42
43use crate::bridge::{PooledRayonTask, PooledTaskCompletion, TaskState};
44use crate::context::current_runtime;
45use crate::pool::TypedPool;
46use crate::runtime::LoomRuntimeInner;
47
48/// Extension trait for streams that adds compute-based processing methods.
49///
50/// This trait is automatically implemented for all types that implement [`Stream`].
51///
52/// # Example
53///
54/// ```ignore
55/// use loom_rs::{LoomBuilder, ComputeStreamExt};
56/// use futures::stream::{self, StreamExt};
57///
58/// let runtime = LoomBuilder::new().build()?;
59///
60/// runtime.block_on(async {
61///     let numbers = stream::iter(0..10);
62///
63///     // Each item is processed on rayon, results stream back in order
64///     let results: Vec<_> = numbers
65///         .compute_map(|n| n * 2)
66///         .collect()
67///         .await;
68///
69///     assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
70/// });
71/// ```
72pub trait ComputeStreamExt: Stream {
73    /// Map each stream item through a compute-heavy closure on rayon.
74    ///
75    /// Items are processed sequentially (one at a time) to preserve
76    /// stream ordering and provide natural backpressure.
77    ///
78    /// # Performance
79    ///
80    /// Unlike calling `spawn_compute` in a loop, `compute_map` reuses the same
81    /// internal `TaskState` for every item, avoiding per-item pool operations:
82    ///
83    /// - First poll: Gets `TaskState` from pool (or allocates new)
84    /// - Each item: ~100-500ns overhead, 0 allocations
85    /// - Stream drop: Returns `TaskState` to pool
86    ///
87    /// # Panics
88    ///
89    /// Panics if called outside a loom runtime context (i.e., not within `block_on`,
90    /// a tokio worker thread, or a rayon worker thread managed by the runtime).
91    ///
92    /// # Example
93    ///
94    /// ```ignore
95    /// use loom_rs::{LoomBuilder, ComputeStreamExt};
96    /// use futures::stream::{self, StreamExt};
97    ///
98    /// let runtime = LoomBuilder::new().build()?;
99    ///
100    /// runtime.block_on(async {
101    ///     let results: Vec<_> = stream::iter(vec!["hello", "world"])
102    ///         .compute_map(|s| s.to_uppercase())
103    ///         .collect()
104    ///         .await;
105    ///
106    ///     assert_eq!(results, vec!["HELLO", "WORLD"]);
107    /// });
108    /// ```
109    fn compute_map<F, U>(self, f: F) -> ComputeMap<Self, F, U>
110    where
111        Self: Sized,
112        F: Fn(Self::Item) -> U + Send + Sync + 'static,
113        Self::Item: Send + 'static,
114        U: Send + 'static;
115}
116
117impl<S: Stream> ComputeStreamExt for S {
118    fn compute_map<F, U>(self, f: F) -> ComputeMap<Self, F, U>
119    where
120        Self: Sized,
121        F: Fn(Self::Item) -> U + Send + Sync + 'static,
122        Self::Item: Send + 'static,
123        U: Send + 'static,
124    {
125        ComputeMap::new(self, f)
126    }
127}
128
129/// A stream adapter that maps items through rayon compute threads.
130///
131/// Created by the [`compute_map`](ComputeStreamExt::compute_map) method on streams.
132/// Items are processed sequentially to preserve ordering.
133#[must_use = "streams do nothing unless polled"]
134pub struct ComputeMap<S, F, U>
135where
136    U: Send + 'static,
137{
138    stream: S,
139    f: Arc<F>,
140    // Lazily initialized on first poll. Drop impl on ComputeMapState
141    // handles returning TaskState to pool.
142    state: Option<ComputeMapState<U>>,
143}
144
145// Manual Unpin implementation - ComputeMap is Unpin if S is Unpin
146impl<S: Unpin, F, U: Send + 'static> Unpin for ComputeMap<S, F, U> {}
147
148/// Internal state for ComputeMap, initialized on first poll.
149///
150/// The Drop impl returns the TaskState to the pool when the stream is dropped.
151struct ComputeMapState<U: Send + 'static> {
152    runtime: Arc<LoomRuntimeInner>,
153    pool: Arc<TypedPool<U>>,
154    /// Reused TaskState - no per-item pool operations!
155    task_state: Arc<TaskState<U>>,
156    /// Currently pending compute task, if any
157    pending: Option<PooledRayonTask<U>>,
158}
159
160impl<U: Send + 'static> Drop for ComputeMapState<U> {
161    fn drop(&mut self) {
162        // Return TaskState to pool if there's no pending task
163        // (which would still be using it)
164        if self.pending.is_none() {
165            self.task_state.reset();
166            // Clone the Arc before pushing - we need ownership
167            let task_state = Arc::clone(&self.task_state);
168            self.pool.push(task_state);
169        }
170        // If there's a pending task, the TaskState will be dropped with it
171        // This is a rare edge case (stream dropped while compute in flight)
172    }
173}
174
175impl<S, F, U> ComputeMap<S, F, U>
176where
177    U: Send + 'static,
178{
179    fn new(stream: S, f: F) -> Self {
180        Self {
181            stream,
182            f: Arc::new(f),
183            state: None,
184        }
185    }
186}
187
188impl<S, F, U> Stream for ComputeMap<S, F, U>
189where
190    S: Stream + Unpin,
191    S::Item: Send + 'static,
192    F: Fn(S::Item) -> U + Send + Sync + 'static,
193    U: Send + 'static,
194{
195    type Item = U;
196
197    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198        let this = &mut *self;
199
200        // Initialize state on first poll
201        let state = this.state.get_or_insert_with(|| {
202            let runtime = current_runtime().expect("compute_map used outside loom runtime");
203            let pool = runtime.pools.get_or_create::<U>();
204            let task_state = pool.pop().unwrap_or_else(|| Arc::new(TaskState::new()));
205
206            ComputeMapState {
207                runtime,
208                pool,
209                task_state,
210                pending: None,
211            }
212        });
213
214        // If we have a pending task, poll it first
215        if let Some(ref mut pending) = state.pending {
216            match Pin::new(pending).poll(cx) {
217                Poll::Ready(result) => {
218                    // Task complete, clear pending
219                    state.pending = None;
220                    // Reset state for reuse
221                    state.task_state.reset();
222                    return Poll::Ready(Some(result));
223                }
224                Poll::Pending => {
225                    return Poll::Pending;
226                }
227            }
228        }
229
230        // No pending task, poll the inner stream for the next item
231        match Pin::new(&mut this.stream).poll_next(cx) {
232            Poll::Ready(Some(item)) => {
233                // Got an item, spawn compute task
234                let f = Arc::clone(&this.f);
235                let task_state = Arc::clone(&state.task_state);
236
237                // Create the pooled task components
238                let (task, completion): (PooledRayonTask<U>, PooledTaskCompletion<U>) = {
239                    // Reuse the same TaskState (already have it)
240                    let (task, completion, _state_for_return) = PooledRayonTask::new(task_state);
241                    (task, completion)
242                };
243
244                // Spawn on rayon
245                state.runtime.rayon_pool.spawn(move || {
246                    let result = f(item);
247                    completion.complete(result);
248                });
249
250                // Store pending task and poll it immediately
251                state.pending = Some(task);
252
253                // Poll the pending task - it might already be ready
254                if let Some(ref mut pending) = state.pending {
255                    match Pin::new(pending).poll(cx) {
256                        Poll::Ready(result) => {
257                            state.pending = None;
258                            state.task_state.reset();
259                            Poll::Ready(Some(result))
260                        }
261                        Poll::Pending => Poll::Pending,
262                    }
263                } else {
264                    Poll::Pending
265                }
266            }
267            Poll::Ready(None) => {
268                // Stream exhausted
269                Poll::Ready(None)
270            }
271            Poll::Pending => Poll::Pending,
272        }
273    }
274
275    fn size_hint(&self) -> (usize, Option<usize>) {
276        // We produce the same number of items as the inner stream
277        // Adjust for pending task if any
278        let (lower, upper) = self.stream.size_hint();
279        if self.state.as_ref().is_some_and(|s| s.pending.is_some()) {
280            // We have a pending item that will be produced
281            (lower.saturating_add(1), upper.map(|u| u.saturating_add(1)))
282        } else {
283            (lower, upper)
284        }
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use crate::config::LoomConfig;
292    use crate::pool::DEFAULT_POOL_SIZE;
293    use crate::runtime::LoomRuntime;
294    use futures::stream::{self, StreamExt};
295
296    fn test_config() -> LoomConfig {
297        LoomConfig {
298            prefix: "stream-test".to_string(),
299            cpuset: None,
300            tokio_threads: Some(1),
301            rayon_threads: Some(2),
302            compute_pool_size: DEFAULT_POOL_SIZE,
303            #[cfg(feature = "cuda")]
304            cuda_device: None,
305        }
306    }
307
308    fn test_runtime() -> LoomRuntime {
309        LoomRuntime::from_config(test_config(), DEFAULT_POOL_SIZE).unwrap()
310    }
311
312    #[test]
313    fn test_compute_map_basic() {
314        let runtime = test_runtime();
315        runtime.block_on(async {
316            let results: Vec<_> = stream::iter(0..10).compute_map(|n| n * 2).collect().await;
317            assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
318        });
319    }
320
321    #[test]
322    fn test_compute_map_preserves_order() {
323        let runtime = test_runtime();
324        runtime.block_on(async {
325            // Use items that might have varying compute times
326            let results: Vec<_> = stream::iter(vec![5, 1, 3, 2, 4])
327                .compute_map(|n| {
328                    // Small delay proportional to value (simulates varying compute times)
329                    std::thread::sleep(std::time::Duration::from_micros(n as u64 * 10));
330                    n * 10
331                })
332                .collect()
333                .await;
334            // Order should be preserved
335            assert_eq!(results, vec![50, 10, 30, 20, 40]);
336        });
337    }
338
339    #[test]
340    fn test_compute_map_empty_stream() {
341        let runtime = test_runtime();
342        runtime.block_on(async {
343            let results: Vec<i32> = stream::iter(std::iter::empty::<i32>())
344                .compute_map(|n| n * 2)
345                .collect()
346                .await;
347            assert!(results.is_empty());
348        });
349    }
350
351    #[test]
352    fn test_compute_map_single_item() {
353        let runtime = test_runtime();
354        runtime.block_on(async {
355            let results: Vec<_> = stream::iter(vec![42])
356                .compute_map(|n| n + 1)
357                .collect()
358                .await;
359            assert_eq!(results, vec![43]);
360        });
361    }
362
363    #[test]
364    fn test_compute_map_with_strings() {
365        let runtime = test_runtime();
366        runtime.block_on(async {
367            let results: Vec<_> = stream::iter(vec!["hello", "world"])
368                .compute_map(|s| s.to_uppercase())
369                .collect()
370                .await;
371            assert_eq!(results, vec!["HELLO", "WORLD"]);
372        });
373    }
374
375    #[test]
376    fn test_compute_map_type_conversion() {
377        let runtime = test_runtime();
378        runtime.block_on(async {
379            let results: Vec<_> = stream::iter(1..=5)
380                .compute_map(|n| format!("item-{}", n))
381                .collect()
382                .await;
383            assert_eq!(
384                results,
385                vec!["item-1", "item-2", "item-3", "item-4", "item-5"]
386            );
387        });
388    }
389
390    #[test]
391    fn test_compute_map_cpu_intensive() {
392        let runtime = test_runtime();
393        runtime.block_on(async {
394            // Simulate CPU-intensive work
395            let results: Vec<_> = stream::iter(0..5)
396                .compute_map(|n| (0..1000).map(|i| i * n).sum::<i64>())
397                .collect()
398                .await;
399
400            let expected: Vec<i64> = (0..5).map(|n| (0..1000).map(|i| i * n).sum()).collect();
401            assert_eq!(results, expected);
402        });
403    }
404
405    #[test]
406    fn test_compute_map_size_hint() {
407        let runtime = test_runtime();
408        runtime.block_on(async {
409            let stream = stream::iter(0..10).compute_map(|n| n * 2);
410            assert_eq!(stream.size_hint(), (10, Some(10)));
411        });
412    }
413
414    #[test]
415    fn test_compute_map_chained() {
416        let runtime = test_runtime();
417        runtime.block_on(async {
418            // Chain compute_map with other stream combinators
419            let results: Vec<_> = stream::iter(0..10)
420                .compute_map(|n| n * 2)
421                .filter(|n| futures::future::ready(*n > 10))
422                .collect()
423                .await;
424            assert_eq!(results, vec![12, 14, 16, 18]);
425        });
426    }
427}