Skip to main content

cubecl_runtime/stream/
scheduler.rs

1use crate::{
2    config::streaming::StreamingLogLevel,
3    logging::ServerLogger,
4    stream::{StreamFactory, StreamPool},
5};
6use alloc::{format, sync::Arc, vec, vec::Vec};
7use cubecl_common::stream_id::StreamId;
8
9/// Defines a trait for a scheduler stream backend, specifying the types and behavior for task scheduling.
10pub trait SchedulerStreamBackend {
11    /// Type representing a task.
12    type Task: core::fmt::Debug;
13    /// Type representing a stream.
14    type Stream: core::fmt::Debug;
15    /// Type for the stream factory, which creates streams of type `Self::Stream`.
16    type Factory: StreamFactory<Stream = Self::Stream>;
17
18    /// Enqueues a task onto a given stream for execution.
19    fn enqueue(task: Self::Task, stream: &mut Self::Stream);
20    /// Flush the inner stream queue to ensure ordering between different streams.
21    fn flush(stream: &mut Self::Stream);
22    /// Returns a mutable reference to the stream factory.
23    fn factory(&mut self) -> &mut Self::Factory;
24}
25
26/// Represents a multi-stream scheduler that manages task execution across multiple streams.
27#[derive(Debug)]
28pub struct SchedulerMultiStream<B: SchedulerStreamBackend> {
29    /// Pool of streams managed by the scheduler.
30    pool: StreamPool<SchedulerPoolMarker<B>>,
31    /// Strategy for scheduling tasks (e.g., Interleave or Sequential).
32    strategy: SchedulerStrategy,
33    /// Maximum number of tasks allowed per stream before execution is triggered.
34    max_tasks: usize,
35    /// Server logger.
36    pub logger: Arc<ServerLogger>,
37}
38
39/// Defines the scheduling strategy for task execution.
40#[derive(Debug)]
41pub enum SchedulerStrategy {
42    /// Tasks from different streams are interleaved during execution.
43    Interleave,
44    /// Tasks from each stream are executed sequentially.
45    Sequential,
46}
47
48/// Represents a single stream that holds tasks and a backend stream.
49#[derive(Debug)]
50pub struct Stream<B: SchedulerStreamBackend> {
51    /// List of tasks queued for execution in this stream.
52    tasks: Vec<B::Task>,
53    /// The backend stream used for task execution.
54    stream: B::Stream,
55}
56
57impl<B: SchedulerStreamBackend> Stream<B> {
58    /// Flushes all tasks from the stream, returning them and clearing the internal task list.
59    fn flush(&mut self) -> Vec<B::Task> {
60        let mut returned = Vec::with_capacity(self.tasks.capacity());
61        core::mem::swap(&mut returned, &mut self.tasks);
62        returned
63    }
64}
65
66#[derive(Debug)]
67struct SchedulerPoolMarker<B: SchedulerStreamBackend> {
68    backend: B,
69}
70
71impl<B: SchedulerStreamBackend> StreamFactory for SchedulerPoolMarker<B> {
72    // The type of stream produced by this factory.
73    type Stream = Stream<B>;
74
75    // Creates a new stream with an empty task list and a backend stream.
76    fn create(&mut self) -> Self::Stream {
77        Stream {
78            tasks: Vec::new(),
79            // Uses the backend's factory to create a new stream.
80            stream: self.backend.factory().create(),
81        }
82    }
83}
84
85/// Options for configuring a `SchedulerMultiStream`.
86#[derive(Debug)]
87pub struct SchedulerMultiStreamOptions {
88    /// Maximum number of streams allowed in the pool.
89    pub max_streams: u8,
90    /// Maximum number of tasks per stream before execution is triggered.
91    pub max_tasks: usize,
92    /// The scheduling strategy to use.
93    pub strategy: SchedulerStrategy,
94}
95
96impl<B: SchedulerStreamBackend> SchedulerMultiStream<B> {
97    /// Creates a new `SchedulerMultiStream` with the given backend and options.
98    pub fn new(
99        logger: Arc<ServerLogger>,
100        backend: B,
101        options: SchedulerMultiStreamOptions,
102    ) -> Self {
103        Self {
104            pool: StreamPool::new(SchedulerPoolMarker { backend }, options.max_streams, 0),
105            max_tasks: options.max_tasks,
106            strategy: options.strategy,
107            logger,
108        }
109    }
110
111    /// Returns a mutable reference to the backend stream for a given stream ID.
112    pub fn stream(&mut self, stream_id: &StreamId) -> &mut B::Stream {
113        let stream = self.pool.get_mut(stream_id);
114        &mut stream.stream
115    }
116
117    /// Registers a task for execution on a specific stream, ensuring stream alignment.
118    pub fn register(&mut self, stream_id: StreamId, task: B::Task, args_streams: &[StreamId]) {
119        // Align streams to ensure dependencies are handled correctly.
120        self.align_streams(stream_id, args_streams);
121
122        // Get the stream for the given stream ID and add the task to its queue.
123        let current = self.pool.get_mut(&stream_id);
124        current.tasks.push(task);
125
126        // If the task queue exceeds the maximum, execute the stream.
127        if current.tasks.len() >= self.max_tasks {
128            self.execute_streams(vec![stream_id]);
129        }
130    }
131
132    /// Aligns streams by flushing tasks from streams that conflict with the given bindings.
133    pub(crate) fn align_streams(&mut self, stream_id: StreamId, args_streams: &[StreamId]) {
134        let mut to_flush = Vec::new();
135        // Get the index of the target stream.
136        let index = self.pool.stream_index(&stream_id);
137
138        // Identify streams that need to be flushed due to conflicting bindings.
139        for arg_stream in args_streams {
140            let index_stream = self.pool.stream_index(arg_stream);
141            if index != index_stream {
142                to_flush.push(*arg_stream);
143
144                self.logger.log_streaming(
145                    |level| matches!(level, StreamingLogLevel::Full),
146                    || format!("Binding on {} is shared on {}", arg_stream, stream_id),
147                );
148            }
149        }
150
151        // If no streams need flushing, return early.
152        if to_flush.is_empty() {
153            return;
154        }
155
156        self.logger.log_streaming(
157            |level| !matches!(level, StreamingLogLevel::Disabled),
158            || {
159                format!(
160                    "Flushing streams {to_flush:?} before registering more tasks on {stream_id}"
161                )
162            },
163        );
164        // Execute the streams that need to be flushed.
165        self.execute_streams(to_flush);
166    }
167
168    /// Executes tasks from the specified streams based on the scheduling strategy.
169    pub fn execute_streams(&mut self, stream_ids: Vec<StreamId>) {
170        let mut indices = Vec::with_capacity(stream_ids.len());
171
172        // Collect unique stream indices to avoid redundant processing.
173        for id in stream_ids {
174            let index = self.pool.stream_index(&id);
175            if !indices.contains(&index) {
176                indices.push(index);
177            }
178        }
179
180        // Create schedules for each stream to be executed.
181        let mut schedules = Vec::new();
182        for index in indices {
183            let stream = unsafe { self.pool.get_mut_index(index) }; // Note: `unsafe` usage assumes valid index.
184            let tasks = stream.flush();
185            let num_tasks = tasks.len();
186
187            schedules.push(Schedule {
188                tasks: tasks.into_iter(),
189                num_tasks,
190                stream_index: index,
191            });
192        }
193
194        // If no schedules were created, return early.
195        if schedules.is_empty() {
196            return;
197        }
198
199        // Execute schedules based on the configured strategy.
200        match self.strategy {
201            SchedulerStrategy::Interleave => self.execute_schedules_interleave(schedules),
202            SchedulerStrategy::Sequential => self.execute_schedules_sequence(schedules),
203        }
204    }
205
206    /// Executes schedules sequentially, processing each stream's tasks in order.
207    fn execute_schedules_sequence(&mut self, schedules: Vec<Schedule<B>>) {
208        for schedule in schedules {
209            let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) }; // Note: `unsafe` usage assumes valid index.
210            for task in schedule.tasks {
211                // Enqueue each task on the stream.
212                B::enqueue(task, &mut stream.stream);
213            }
214
215            // Makes sure the tasks are ordered on the compute queue.
216            B::flush(&mut stream.stream);
217        }
218    }
219
220    //// Executes schedules in an interleaved manner, alternating tasks from different streams.
221    ///
222    /// We chose the first stream as the one executing the tasks, ensuring proper ordering by
223    /// flushing all other streams first and flushing the execution stream at the end.
224    /// This way, we ensure that most tasks are actually interleaved on the real compute queue
225    /// shared across all streams.
226    fn execute_schedules_interleave(&mut self, mut schedules: Vec<Schedule<B>>) {
227        // Makes sure the tasks are ordered on the compute queue.
228        for schedule in schedules.iter_mut().skip(1) {
229            let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) };
230            B::flush(&mut stream.stream);
231        }
232
233        let execution_index = schedules.first().expect("At least one stream").stream_index;
234        let stream = unsafe { self.pool.get_mut_index(execution_index) };
235
236        // Find the maximum number of tasks across all schedules.
237        let num_tasks_max = schedules
238            .iter()
239            .map(|s| s.num_tasks)
240            .max()
241            .expect("At least one schedule");
242
243        // Iterate through tasks, interleaving them across streams.
244        for _ in 0..num_tasks_max {
245            for schedule in schedules.iter_mut() {
246                // If there are tasks remaining in the schedule, enqueue the next one.
247                if let Some(task) = schedule.tasks.next() {
248                    B::enqueue(task, &mut stream.stream);
249                }
250            }
251        }
252
253        // Making sure all tasks are registered to the queue.
254        B::flush(&mut stream.stream);
255    }
256}
257
258// Represents a schedule for executing tasks on a specific stream.
259struct Schedule<B: SchedulerStreamBackend> {
260    // Iterator over the tasks to be executed.
261    tasks: alloc::vec::IntoIter<B::Task>,
262    // Number of tasks in the schedule.
263    num_tasks: usize,
264    // Index of the stream in the pool.
265    stream_index: usize,
266}