1use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40use std::time::Instant;
41
42use futures_core::Stream;
43
44use crate::bridge::{PooledRayonTask, PooledTaskCompletion, TaskState};
45use crate::context::current_runtime;
46use crate::mab::{Arm, ComputeHintProvider, DecisionId, FunctionKey, MabScheduler};
47use crate::pool::TypedPool;
48use crate::runtime::LoomRuntimeInner;
49
50pub trait ComputeStreamExt: Stream {
75 fn compute_map<F, U>(self, f: F) -> ComputeMap<Self, F, U>
112 where
113 Self: Sized,
114 F: Fn(Self::Item) -> U + Send + Sync + 'static,
115 Self::Item: Send + 'static,
116 U: Send + 'static;
117
118 fn adaptive_map<F, U>(self, f: F) -> AdaptiveMap<Self, F, U>
156 where
157 Self: Sized,
158 F: Fn(Self::Item) -> U + Send + Sync + 'static,
159 Self::Item: ComputeHintProvider + Send + 'static,
160 U: Send + 'static;
161}
162
163impl<S: Stream> ComputeStreamExt for S {
164 fn compute_map<F, U>(self, f: F) -> ComputeMap<Self, F, U>
165 where
166 Self: Sized,
167 F: Fn(Self::Item) -> U + Send + Sync + 'static,
168 Self::Item: Send + 'static,
169 U: Send + 'static,
170 {
171 ComputeMap::new(self, f)
172 }
173
174 fn adaptive_map<F, U>(self, f: F) -> AdaptiveMap<Self, F, U>
175 where
176 Self: Sized,
177 F: Fn(Self::Item) -> U + Send + Sync + 'static,
178 Self::Item: ComputeHintProvider + Send + 'static,
179 U: Send + 'static,
180 {
181 AdaptiveMap::new(self, f)
182 }
183}
184
185#[must_use = "streams do nothing unless polled"]
190pub struct ComputeMap<S, F, U>
191where
192 U: Send + 'static,
193{
194 stream: S,
195 f: Arc<F>,
196 state: Option<ComputeMapState<U>>,
199}
200
201impl<S: Unpin, F, U: Send + 'static> Unpin for ComputeMap<S, F, U> {}
203
204struct ComputeMapState<U: Send + 'static> {
208 runtime: Arc<LoomRuntimeInner>,
209 pool: Arc<TypedPool<U>>,
210 task_state: Arc<TaskState<U>>,
212 pending: Option<PooledRayonTask<U>>,
214}
215
216impl<U: Send + 'static> Drop for ComputeMapState<U> {
217 fn drop(&mut self) {
218 if self.pending.is_none() && !std::thread::panicking() {
222 self.task_state.reset();
223 let task_state = Arc::clone(&self.task_state);
225 self.pool.push(task_state);
226 }
227 }
232}
233
234impl<S, F, U> ComputeMap<S, F, U>
235where
236 U: Send + 'static,
237{
238 fn new(stream: S, f: F) -> Self {
239 Self {
240 stream,
241 f: Arc::new(f),
242 state: None,
243 }
244 }
245}
246
247impl<S, F, U> Stream for ComputeMap<S, F, U>
248where
249 S: Stream + Unpin,
250 S::Item: Send + 'static,
251 F: Fn(S::Item) -> U + Send + Sync + 'static,
252 U: Send + 'static,
253{
254 type Item = U;
255
256 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
257 let this = &mut *self;
258
259 let state = this.state.get_or_insert_with(|| {
261 let runtime = current_runtime().expect("compute_map used outside loom runtime");
262 let pool = runtime.inner.pools.get_or_create::<U>();
263 let task_state = pool.pop().unwrap_or_else(|| Arc::new(TaskState::new()));
264
265 ComputeMapState {
266 runtime: runtime.inner,
267 pool,
268 task_state,
269 pending: None,
270 }
271 });
272
273 if let Some(ref mut pending) = state.pending {
275 match Pin::new(pending).poll(cx) {
276 Poll::Ready(result) => {
277 state.pending = None;
279 state.task_state.reset();
281 return Poll::Ready(Some(result));
282 }
283 Poll::Pending => {
284 return Poll::Pending;
285 }
286 }
287 }
288
289 match Pin::new(&mut this.stream).poll_next(cx) {
291 Poll::Ready(Some(item)) => {
292 let f = Arc::clone(&this.f);
294 let task_state = Arc::clone(&state.task_state);
295
296 let (task, completion): (PooledRayonTask<U>, PooledTaskCompletion<U>) = {
298 let (task, completion, _state_for_return) = PooledRayonTask::new(task_state);
300 (task, completion)
301 };
302
303 state.runtime.rayon_pool.spawn(move || {
305 let result = f(item);
306 completion.complete(result);
307 });
308
309 state.pending = Some(task);
311
312 if let Some(ref mut pending) = state.pending {
314 match Pin::new(pending).poll(cx) {
315 Poll::Ready(result) => {
316 state.pending = None;
317 state.task_state.reset();
318 Poll::Ready(Some(result))
319 }
320 Poll::Pending => Poll::Pending,
321 }
322 } else {
323 Poll::Pending
324 }
325 }
326 Poll::Ready(None) => {
327 Poll::Ready(None)
329 }
330 Poll::Pending => Poll::Pending,
331 }
332 }
333
334 fn size_hint(&self) -> (usize, Option<usize>) {
335 let (lower, upper) = self.stream.size_hint();
338 if self.state.as_ref().is_some_and(|s| s.pending.is_some()) {
339 (lower.saturating_add(1), upper.map(|u| u.saturating_add(1)))
341 } else {
342 (lower, upper)
343 }
344 }
345}
346
347#[must_use = "streams do nothing unless polled"]
356pub struct AdaptiveMap<S, F, U>
357where
358 U: Send + 'static,
359{
360 stream: S,
361 f: Arc<F>,
362 function_key: FunctionKey,
363 state: Option<AdaptiveMapState<U>>,
364}
365
366impl<S: Unpin, F, U: Send + 'static> Unpin for AdaptiveMap<S, F, U> {}
368
369struct AdaptivePending<U: Send + 'static> {
371 decision_id: DecisionId,
372 start_time: Instant,
373 task: PooledRayonTask<U>,
374}
375
376struct AdaptiveMapState<U: Send + 'static> {
378 runtime: Arc<LoomRuntimeInner>,
379 pool: Arc<TypedPool<U>>,
380 task_state: Arc<TaskState<U>>,
382 scheduler: MabScheduler,
384 pending: Option<AdaptivePending<U>>,
386}
387
388impl<U: Send + 'static> Drop for AdaptiveMapState<U> {
389 fn drop(&mut self) {
390 if self.pending.is_none() && !std::thread::panicking() {
394 self.task_state.reset();
395 let task_state = Arc::clone(&self.task_state);
396 self.pool.push(task_state);
397 }
398 }
403}
404
405impl<S, F: 'static, U> AdaptiveMap<S, F, U>
406where
407 U: Send + 'static,
408{
409 fn new(stream: S, f: F) -> Self {
410 let function_key = FunctionKey::from_type::<F>();
413
414 Self {
415 stream,
416 f: Arc::new(f),
417 function_key,
418 state: None,
419 }
420 }
421}
422
423impl<S, F, U> Stream for AdaptiveMap<S, F, U>
424where
425 S: Stream + Unpin,
426 S::Item: ComputeHintProvider + Send + 'static,
427 F: Fn(S::Item) -> U + Send + Sync + 'static,
428 U: Send + 'static,
429{
430 type Item = U;
431
432 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
433 let this = &mut *self;
434
435 let state = this.state.get_or_insert_with(|| {
437 let runtime = current_runtime().expect("adaptive_map used outside loom runtime");
438 let pool = runtime.inner.pools.get_or_create::<U>();
439 let task_state = pool.pop().unwrap_or_else(|| Arc::new(TaskState::new()));
440
441 let scheduler = MabScheduler::with_metrics(
443 runtime.inner.mab_knobs.clone(),
444 runtime.inner.prometheus_metrics.clone(),
445 );
446
447 AdaptiveMapState {
448 runtime: runtime.inner,
449 pool,
450 task_state,
451 scheduler,
452 pending: None,
453 }
454 });
455
456 if let Some(mut pending) = state.pending.take() {
458 match Pin::new(&mut pending.task).poll(cx) {
460 Poll::Ready(result) => {
461 let elapsed_us = pending.start_time.elapsed().as_secs_f64() * 1_000_000.0;
462 state
463 .scheduler
464 .finish(pending.decision_id, elapsed_us, None);
465 state.task_state.reset();
466 return Poll::Ready(Some(result));
467 }
468 Poll::Pending => {
469 state.pending = Some(pending);
471 return Poll::Pending;
472 }
473 }
474 }
475
476 match Pin::new(&mut this.stream).poll_next(cx) {
478 Poll::Ready(Some(item)) => {
479 let hint = item.compute_hint();
481
482 let ctx = state.runtime.prometheus_metrics.collect_context(
484 state.runtime.tokio_threads as u32,
485 state.runtime.rayon_threads as u32,
486 );
487
488 let (decision_id, arm) =
490 state
491 .scheduler
492 .choose_with_hint(this.function_key, &ctx, hint);
493
494 let f = Arc::clone(&this.f);
495
496 match arm {
497 Arm::InlineTokio => {
498 let start = Instant::now();
500 let result = f(item);
501 let elapsed_us = start.elapsed().as_secs_f64() * 1_000_000.0;
502
503 state
505 .scheduler
506 .finish(decision_id, elapsed_us, Some(elapsed_us));
507 Poll::Ready(Some(result))
508 }
509 Arm::OffloadRayon => {
510 let task_state = Arc::clone(&state.task_state);
512 let (task, completion): (PooledRayonTask<U>, PooledTaskCompletion<U>) = {
513 let (task, completion, _state_for_return) =
514 PooledRayonTask::new(task_state);
515 (task, completion)
516 };
517
518 let start_time = Instant::now();
519
520 state.runtime.rayon_pool.spawn(move || {
522 let result = f(item);
523 completion.complete(result);
524 });
525
526 let mut pending = AdaptivePending {
528 decision_id,
529 start_time,
530 task,
531 };
532
533 match Pin::new(&mut pending.task).poll(cx) {
535 Poll::Ready(result) => {
536 let elapsed_us =
537 pending.start_time.elapsed().as_secs_f64() * 1_000_000.0;
538 state
539 .scheduler
540 .finish(pending.decision_id, elapsed_us, None);
541 state.task_state.reset();
542 Poll::Ready(Some(result))
543 }
544 Poll::Pending => {
545 state.pending = Some(pending);
546 Poll::Pending
547 }
548 }
549 }
550 }
551 }
552 Poll::Ready(None) => Poll::Ready(None),
553 Poll::Pending => Poll::Pending,
554 }
555 }
556
557 fn size_hint(&self) -> (usize, Option<usize>) {
558 let (lower, upper) = self.stream.size_hint();
559 if self.state.as_ref().is_some_and(|s| s.pending.is_some()) {
560 (lower.saturating_add(1), upper.map(|u| u.saturating_add(1)))
561 } else {
562 (lower, upper)
563 }
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use crate::config::LoomConfig;
571 use crate::pool::DEFAULT_POOL_SIZE;
572 use crate::runtime::LoomRuntime;
573 use futures::stream::{self, StreamExt};
574
575 fn test_config() -> LoomConfig {
576 LoomConfig {
577 prefix: "stream-test".to_string(),
578 cpuset: None,
579 tokio_threads: Some(1),
580 rayon_threads: Some(2),
581 compute_pool_size: DEFAULT_POOL_SIZE,
582 #[cfg(feature = "cuda")]
583 cuda_device: None,
584 mab_knobs: None,
585 calibration: None,
586 prometheus_registry: None,
587 }
588 }
589
590 fn test_runtime() -> LoomRuntime {
591 LoomRuntime::from_config(test_config()).unwrap()
592 }
593
594 #[test]
595 fn test_compute_map_basic() {
596 let runtime = test_runtime();
597 runtime.block_on(async {
598 let results: Vec<_> = stream::iter(0..10).compute_map(|n| n * 2).collect().await;
599 assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
600 });
601 }
602
603 #[test]
604 fn test_compute_map_preserves_order() {
605 let runtime = test_runtime();
606 runtime.block_on(async {
607 let results: Vec<_> = stream::iter(vec![5, 1, 3, 2, 4])
609 .compute_map(|n| {
610 std::thread::sleep(std::time::Duration::from_micros(n as u64 * 10));
612 n * 10
613 })
614 .collect()
615 .await;
616 assert_eq!(results, vec![50, 10, 30, 20, 40]);
618 });
619 }
620
621 #[test]
622 fn test_compute_map_empty_stream() {
623 let runtime = test_runtime();
624 runtime.block_on(async {
625 let results: Vec<i32> = stream::iter(std::iter::empty::<i32>())
626 .compute_map(|n| n * 2)
627 .collect()
628 .await;
629 assert!(results.is_empty());
630 });
631 }
632
633 #[test]
634 fn test_compute_map_single_item() {
635 let runtime = test_runtime();
636 runtime.block_on(async {
637 let results: Vec<_> = stream::iter(vec![42])
638 .compute_map(|n| n + 1)
639 .collect()
640 .await;
641 assert_eq!(results, vec![43]);
642 });
643 }
644
645 #[test]
646 fn test_compute_map_with_strings() {
647 let runtime = test_runtime();
648 runtime.block_on(async {
649 let results: Vec<_> = stream::iter(vec!["hello", "world"])
650 .compute_map(|s| s.to_uppercase())
651 .collect()
652 .await;
653 assert_eq!(results, vec!["HELLO", "WORLD"]);
654 });
655 }
656
657 #[test]
658 fn test_compute_map_type_conversion() {
659 let runtime = test_runtime();
660 runtime.block_on(async {
661 let results: Vec<_> = stream::iter(1..=5)
662 .compute_map(|n| format!("item-{}", n))
663 .collect()
664 .await;
665 assert_eq!(
666 results,
667 vec!["item-1", "item-2", "item-3", "item-4", "item-5"]
668 );
669 });
670 }
671
672 #[test]
673 fn test_compute_map_cpu_intensive() {
674 let runtime = test_runtime();
675 runtime.block_on(async {
676 let results: Vec<_> = stream::iter(0..5)
678 .compute_map(|n| (0..1000).map(|i| i * n).sum::<i64>())
679 .collect()
680 .await;
681
682 let expected: Vec<i64> = (0..5).map(|n| (0..1000).map(|i| i * n).sum()).collect();
683 assert_eq!(results, expected);
684 });
685 }
686
687 #[test]
688 fn test_compute_map_size_hint() {
689 let runtime = test_runtime();
690 runtime.block_on(async {
691 let stream = stream::iter(0..10).compute_map(|n| n * 2);
692 assert_eq!(stream.size_hint(), (10, Some(10)));
693 });
694 }
695
696 #[test]
697 fn test_compute_map_chained() {
698 let runtime = test_runtime();
699 runtime.block_on(async {
700 let results: Vec<_> = stream::iter(0..10)
702 .compute_map(|n| n * 2)
703 .filter(|n| futures::future::ready(*n > 10))
704 .collect()
705 .await;
706 assert_eq!(results, vec![12, 14, 16, 18]);
707 });
708 }
709
710 #[test]
715 fn test_adaptive_map_basic() {
716 let runtime = test_runtime();
717 runtime.block_on(async {
718 let results: Vec<_> = stream::iter(0..10).adaptive_map(|n| n * 2).collect().await;
719 assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
720 });
721 }
722
723 #[test]
724 fn test_adaptive_map_preserves_order() {
725 let runtime = test_runtime();
726 runtime.block_on(async {
727 let results: Vec<_> = stream::iter(vec![5, 1, 3, 2, 4])
728 .adaptive_map(|n| {
729 std::thread::sleep(std::time::Duration::from_micros(n as u64 * 10));
731 n * 10
732 })
733 .collect()
734 .await;
735 assert_eq!(results, vec![50, 10, 30, 20, 40]);
736 });
737 }
738
739 #[test]
740 fn test_adaptive_map_empty_stream() {
741 let runtime = test_runtime();
742 runtime.block_on(async {
743 let results: Vec<i32> = stream::iter(std::iter::empty::<i32>())
744 .adaptive_map(|n| n * 2)
745 .collect()
746 .await;
747 assert!(results.is_empty());
748 });
749 }
750
751 #[test]
752 fn test_adaptive_map_with_hint() {
753 use crate::mab::{ComputeHint, ComputeHintProvider};
754
755 struct HintedItem {
756 value: i32,
757 hint: ComputeHint,
758 }
759
760 impl ComputeHintProvider for HintedItem {
761 fn compute_hint(&self) -> ComputeHint {
762 self.hint
763 }
764 }
765
766 let runtime = test_runtime();
767 runtime.block_on(async {
768 let items = vec![
769 HintedItem {
770 value: 1,
771 hint: ComputeHint::Low,
772 },
773 HintedItem {
774 value: 2,
775 hint: ComputeHint::High,
776 },
777 HintedItem {
778 value: 3,
779 hint: ComputeHint::Medium,
780 },
781 ];
782
783 let results: Vec<_> = stream::iter(items)
784 .adaptive_map(|item| item.value * 2)
785 .collect()
786 .await;
787
788 assert_eq!(results, vec![2, 4, 6]);
789 });
790 }
791
792 #[test]
793 fn test_adaptive_map_learns_from_fast_work() {
794 let runtime = test_runtime();
795 runtime.block_on(async {
796 let results: Vec<_> = stream::iter(0..100)
798 .adaptive_map(|n| {
799 n + 1
801 })
802 .collect()
803 .await;
804
805 assert_eq!(results.len(), 100);
806 assert_eq!(results[0], 1);
807 assert_eq!(results[99], 100);
808 });
809 }
810
811 #[test]
812 fn test_adaptive_map_size_hint() {
813 let runtime = test_runtime();
814 runtime.block_on(async {
815 let stream = stream::iter(0..10).adaptive_map(|n| n * 2);
816 assert_eq!(stream.size_hint(), (10, Some(10)));
817 });
818 }
819}