loom_rs/stream.rs
1//! Stream combinators for processing items via rayon compute threads.
2//!
3//! This module provides the [`ComputeStreamExt`] trait which extends any [`Stream`]
4//! with methods for processing items through the rayon thread pool.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use loom_rs::{LoomBuilder, ComputeStreamExt};
10//! use futures::stream::{self, StreamExt};
11//!
12//! let runtime = LoomBuilder::new().build()?;
13//!
14//! runtime.block_on(async {
15//! let results: Vec<_> = stream::iter(0..100)
16//! .compute_map(|n| {
17//! // CPU-intensive work runs on rayon
18//! (0..n).map(|i| i * i).sum::<i64>()
19//! })
20//! .collect()
21//! .await;
22//! });
23//! ```
24//!
25//! # Performance
26//!
27//! The key optimization is reusing the same `TaskState` for all items in the stream,
28//! rather than getting/returning from the pool for each item:
29//!
30//! | Operation | Overhead | Allocations |
31//! |-----------|----------|-------------|
32//! | Stream creation | ~1us | TaskState (from pool or new) |
33//! | Each item | ~100-500ns | 0 bytes (reuses TaskState) |
34//! | Stream drop | ~10ns | Returns TaskState to pool |
35
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40
41use futures_core::Stream;
42
43use crate::bridge::{PooledRayonTask, PooledTaskCompletion, TaskState};
44use crate::context::current_runtime;
45use crate::pool::TypedPool;
46use crate::runtime::LoomRuntimeInner;
47
48/// Extension trait for streams that adds compute-based processing methods.
49///
50/// This trait is automatically implemented for all types that implement [`Stream`].
51///
52/// # Example
53///
54/// ```ignore
55/// use loom_rs::{LoomBuilder, ComputeStreamExt};
56/// use futures::stream::{self, StreamExt};
57///
58/// let runtime = LoomBuilder::new().build()?;
59///
60/// runtime.block_on(async {
61/// let numbers = stream::iter(0..10);
62///
63/// // Each item is processed on rayon, results stream back in order
64/// let results: Vec<_> = numbers
65/// .compute_map(|n| n * 2)
66/// .collect()
67/// .await;
68///
69/// assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
70/// });
71/// ```
72pub trait ComputeStreamExt: Stream {
73 /// Map each stream item through a compute-heavy closure on rayon.
74 ///
75 /// Items are processed sequentially (one at a time) to preserve
76 /// stream ordering and provide natural backpressure.
77 ///
78 /// # Performance
79 ///
80 /// Unlike calling `spawn_compute` in a loop, `compute_map` reuses the same
81 /// internal `TaskState` for every item, avoiding per-item pool operations:
82 ///
83 /// - First poll: Gets `TaskState` from pool (or allocates new)
84 /// - Each item: ~100-500ns overhead, 0 allocations
85 /// - Stream drop: Returns `TaskState` to pool
86 ///
87 /// # Panics
88 ///
89 /// Panics if called outside a loom runtime context (i.e., not within `block_on`,
90 /// a tokio worker thread, or a rayon worker thread managed by the runtime).
91 ///
92 /// # Example
93 ///
94 /// ```ignore
95 /// use loom_rs::{LoomBuilder, ComputeStreamExt};
96 /// use futures::stream::{self, StreamExt};
97 ///
98 /// let runtime = LoomBuilder::new().build()?;
99 ///
100 /// runtime.block_on(async {
101 /// let results: Vec<_> = stream::iter(vec!["hello", "world"])
102 /// .compute_map(|s| s.to_uppercase())
103 /// .collect()
104 /// .await;
105 ///
106 /// assert_eq!(results, vec!["HELLO", "WORLD"]);
107 /// });
108 /// ```
109 fn compute_map<F, U>(self, f: F) -> ComputeMap<Self, F, U>
110 where
111 Self: Sized,
112 F: Fn(Self::Item) -> U + Send + Sync + 'static,
113 Self::Item: Send + 'static,
114 U: Send + 'static;
115}
116
117impl<S: Stream> ComputeStreamExt for S {
118 fn compute_map<F, U>(self, f: F) -> ComputeMap<Self, F, U>
119 where
120 Self: Sized,
121 F: Fn(Self::Item) -> U + Send + Sync + 'static,
122 Self::Item: Send + 'static,
123 U: Send + 'static,
124 {
125 ComputeMap::new(self, f)
126 }
127}
128
129/// A stream adapter that maps items through rayon compute threads.
130///
131/// Created by the [`compute_map`](ComputeStreamExt::compute_map) method on streams.
132/// Items are processed sequentially to preserve ordering.
133#[must_use = "streams do nothing unless polled"]
134pub struct ComputeMap<S, F, U>
135where
136 U: Send + 'static,
137{
138 stream: S,
139 f: Arc<F>,
140 // Lazily initialized on first poll. Drop impl on ComputeMapState
141 // handles returning TaskState to pool.
142 state: Option<ComputeMapState<U>>,
143}
144
145// Manual Unpin implementation - ComputeMap is Unpin if S is Unpin
146impl<S: Unpin, F, U: Send + 'static> Unpin for ComputeMap<S, F, U> {}
147
148/// Internal state for ComputeMap, initialized on first poll.
149///
150/// The Drop impl returns the TaskState to the pool when the stream is dropped.
151struct ComputeMapState<U: Send + 'static> {
152 runtime: Arc<LoomRuntimeInner>,
153 pool: Arc<TypedPool<U>>,
154 /// Reused TaskState - no per-item pool operations!
155 task_state: Arc<TaskState<U>>,
156 /// Currently pending compute task, if any
157 pending: Option<PooledRayonTask<U>>,
158}
159
160impl<U: Send + 'static> Drop for ComputeMapState<U> {
161 fn drop(&mut self) {
162 // Return TaskState to pool if there's no pending task
163 // (which would still be using it)
164 if self.pending.is_none() {
165 self.task_state.reset();
166 // Clone the Arc before pushing - we need ownership
167 let task_state = Arc::clone(&self.task_state);
168 self.pool.push(task_state);
169 }
170 // If there's a pending task, the TaskState will be dropped with it
171 // This is a rare edge case (stream dropped while compute in flight)
172 }
173}
174
175impl<S, F, U> ComputeMap<S, F, U>
176where
177 U: Send + 'static,
178{
179 fn new(stream: S, f: F) -> Self {
180 Self {
181 stream,
182 f: Arc::new(f),
183 state: None,
184 }
185 }
186}
187
188impl<S, F, U> Stream for ComputeMap<S, F, U>
189where
190 S: Stream + Unpin,
191 S::Item: Send + 'static,
192 F: Fn(S::Item) -> U + Send + Sync + 'static,
193 U: Send + 'static,
194{
195 type Item = U;
196
197 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198 let this = &mut *self;
199
200 // Initialize state on first poll
201 let state = this.state.get_or_insert_with(|| {
202 let runtime = current_runtime().expect("compute_map used outside loom runtime");
203 let pool = runtime.pools.get_or_create::<U>();
204 let task_state = pool.pop().unwrap_or_else(|| Arc::new(TaskState::new()));
205
206 ComputeMapState {
207 runtime,
208 pool,
209 task_state,
210 pending: None,
211 }
212 });
213
214 // If we have a pending task, poll it first
215 if let Some(ref mut pending) = state.pending {
216 match Pin::new(pending).poll(cx) {
217 Poll::Ready(result) => {
218 // Task complete, clear pending
219 state.pending = None;
220 // Reset state for reuse
221 state.task_state.reset();
222 return Poll::Ready(Some(result));
223 }
224 Poll::Pending => {
225 return Poll::Pending;
226 }
227 }
228 }
229
230 // No pending task, poll the inner stream for the next item
231 match Pin::new(&mut this.stream).poll_next(cx) {
232 Poll::Ready(Some(item)) => {
233 // Got an item, spawn compute task
234 let f = Arc::clone(&this.f);
235 let task_state = Arc::clone(&state.task_state);
236
237 // Create the pooled task components
238 let (task, completion): (PooledRayonTask<U>, PooledTaskCompletion<U>) = {
239 // Reuse the same TaskState (already have it)
240 let (task, completion, _state_for_return) = PooledRayonTask::new(task_state);
241 (task, completion)
242 };
243
244 // Spawn on rayon
245 state.runtime.rayon_pool.spawn(move || {
246 let result = f(item);
247 completion.complete(result);
248 });
249
250 // Store pending task and poll it immediately
251 state.pending = Some(task);
252
253 // Poll the pending task - it might already be ready
254 if let Some(ref mut pending) = state.pending {
255 match Pin::new(pending).poll(cx) {
256 Poll::Ready(result) => {
257 state.pending = None;
258 state.task_state.reset();
259 Poll::Ready(Some(result))
260 }
261 Poll::Pending => Poll::Pending,
262 }
263 } else {
264 Poll::Pending
265 }
266 }
267 Poll::Ready(None) => {
268 // Stream exhausted
269 Poll::Ready(None)
270 }
271 Poll::Pending => Poll::Pending,
272 }
273 }
274
275 fn size_hint(&self) -> (usize, Option<usize>) {
276 // We produce the same number of items as the inner stream
277 // Adjust for pending task if any
278 let (lower, upper) = self.stream.size_hint();
279 if self.state.as_ref().is_some_and(|s| s.pending.is_some()) {
280 // We have a pending item that will be produced
281 (lower.saturating_add(1), upper.map(|u| u.saturating_add(1)))
282 } else {
283 (lower, upper)
284 }
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use crate::config::LoomConfig;
292 use crate::pool::DEFAULT_POOL_SIZE;
293 use crate::runtime::LoomRuntime;
294 use futures::stream::{self, StreamExt};
295
296 fn test_config() -> LoomConfig {
297 LoomConfig {
298 prefix: "stream-test".to_string(),
299 cpuset: None,
300 tokio_threads: Some(1),
301 rayon_threads: Some(2),
302 compute_pool_size: DEFAULT_POOL_SIZE,
303 #[cfg(feature = "cuda")]
304 cuda_device: None,
305 }
306 }
307
308 fn test_runtime() -> LoomRuntime {
309 LoomRuntime::from_config(test_config(), DEFAULT_POOL_SIZE).unwrap()
310 }
311
312 #[test]
313 fn test_compute_map_basic() {
314 let runtime = test_runtime();
315 runtime.block_on(async {
316 let results: Vec<_> = stream::iter(0..10).compute_map(|n| n * 2).collect().await;
317 assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
318 });
319 }
320
321 #[test]
322 fn test_compute_map_preserves_order() {
323 let runtime = test_runtime();
324 runtime.block_on(async {
325 // Use items that might have varying compute times
326 let results: Vec<_> = stream::iter(vec![5, 1, 3, 2, 4])
327 .compute_map(|n| {
328 // Small delay proportional to value (simulates varying compute times)
329 std::thread::sleep(std::time::Duration::from_micros(n as u64 * 10));
330 n * 10
331 })
332 .collect()
333 .await;
334 // Order should be preserved
335 assert_eq!(results, vec![50, 10, 30, 20, 40]);
336 });
337 }
338
339 #[test]
340 fn test_compute_map_empty_stream() {
341 let runtime = test_runtime();
342 runtime.block_on(async {
343 let results: Vec<i32> = stream::iter(std::iter::empty::<i32>())
344 .compute_map(|n| n * 2)
345 .collect()
346 .await;
347 assert!(results.is_empty());
348 });
349 }
350
351 #[test]
352 fn test_compute_map_single_item() {
353 let runtime = test_runtime();
354 runtime.block_on(async {
355 let results: Vec<_> = stream::iter(vec![42])
356 .compute_map(|n| n + 1)
357 .collect()
358 .await;
359 assert_eq!(results, vec![43]);
360 });
361 }
362
363 #[test]
364 fn test_compute_map_with_strings() {
365 let runtime = test_runtime();
366 runtime.block_on(async {
367 let results: Vec<_> = stream::iter(vec!["hello", "world"])
368 .compute_map(|s| s.to_uppercase())
369 .collect()
370 .await;
371 assert_eq!(results, vec!["HELLO", "WORLD"]);
372 });
373 }
374
375 #[test]
376 fn test_compute_map_type_conversion() {
377 let runtime = test_runtime();
378 runtime.block_on(async {
379 let results: Vec<_> = stream::iter(1..=5)
380 .compute_map(|n| format!("item-{}", n))
381 .collect()
382 .await;
383 assert_eq!(
384 results,
385 vec!["item-1", "item-2", "item-3", "item-4", "item-5"]
386 );
387 });
388 }
389
390 #[test]
391 fn test_compute_map_cpu_intensive() {
392 let runtime = test_runtime();
393 runtime.block_on(async {
394 // Simulate CPU-intensive work
395 let results: Vec<_> = stream::iter(0..5)
396 .compute_map(|n| (0..1000).map(|i| i * n).sum::<i64>())
397 .collect()
398 .await;
399
400 let expected: Vec<i64> = (0..5).map(|n| (0..1000).map(|i| i * n).sum()).collect();
401 assert_eq!(results, expected);
402 });
403 }
404
405 #[test]
406 fn test_compute_map_size_hint() {
407 let runtime = test_runtime();
408 runtime.block_on(async {
409 let stream = stream::iter(0..10).compute_map(|n| n * 2);
410 assert_eq!(stream.size_hint(), (10, Some(10)));
411 });
412 }
413
414 #[test]
415 fn test_compute_map_chained() {
416 let runtime = test_runtime();
417 runtime.block_on(async {
418 // Chain compute_map with other stream combinators
419 let results: Vec<_> = stream::iter(0..10)
420 .compute_map(|n| n * 2)
421 .filter(|n| futures::future::ready(*n > 10))
422 .collect()
423 .await;
424 assert_eq!(results, vec![12, 14, 16, 18]);
425 });
426 }
427}