cubecl_runtime/stream/
scheduler.rs

1use crate::{
2    config::streaming::StreamingLogLevel,
3    logging::ServerLogger,
4    server::Binding,
5    stream::{StreamFactory, StreamPool},
6};
7use alloc::sync::Arc;
8use cubecl_common::stream_id::StreamId;
9
10/// Defines a trait for a scheduler stream backend, specifying the types and behavior for task scheduling.
11pub trait SchedulerStreamBackend {
12    /// Type representing a task.
13    type Task: core::fmt::Debug;
14    /// Type representing a stream.
15    type Stream: core::fmt::Debug;
16    /// Type for the stream factory, which creates streams of type `Self::Stream`.
17    type Factory: StreamFactory<Stream = Self::Stream>;
18
19    /// Enqueues a task onto a given stream for execution.
20    fn enqueue(task: Self::Task, stream: &mut Self::Stream);
21    /// Flush the inner stream queue to ensure ordering between different streams.
22    fn flush(stream: &mut Self::Stream);
23    /// Returns a mutable reference to the stream factory.
24    fn factory(&mut self) -> &mut Self::Factory;
25}
26
27/// Represents a multi-stream scheduler that manages task execution across multiple streams.
28#[derive(Debug)]
29pub struct SchedulerMultiStream<B: SchedulerStreamBackend> {
30    /// Pool of streams managed by the scheduler.
31    pool: StreamPool<SchedulerPoolMarker<B>>,
32    /// Strategy for scheduling tasks (e.g., Interleave or Sequential).
33    strategy: SchedulerStrategy,
34    /// Maximum number of tasks allowed per stream before execution is triggered.
35    max_tasks: usize,
36    /// Server logger.
37    pub logger: Arc<ServerLogger>,
38}
39
40/// Defines the scheduling strategy for task execution.
41#[derive(Debug)]
42pub enum SchedulerStrategy {
43    /// Tasks from different streams are interleaved during execution.
44    Interleave,
45    /// Tasks from each stream are executed sequentially.
46    Sequential,
47}
48
49/// Represents a single stream that holds tasks and a backend stream.
50#[derive(Debug)]
51pub struct Stream<B: SchedulerStreamBackend> {
52    /// List of tasks queued for execution in this stream.
53    tasks: Vec<B::Task>,
54    /// The backend stream used for task execution.
55    stream: B::Stream,
56}
57
58impl<B: SchedulerStreamBackend> Stream<B> {
59    /// Flushes all tasks from the stream, returning them and clearing the internal task list.
60    fn flush(&mut self) -> Vec<B::Task> {
61        let mut returned = Vec::with_capacity(self.tasks.capacity());
62        core::mem::swap(&mut returned, &mut self.tasks);
63        returned
64    }
65}
66
67#[derive(Debug)]
68struct SchedulerPoolMarker<B: SchedulerStreamBackend> {
69    backend: B,
70}
71
72impl<B: SchedulerStreamBackend> StreamFactory for SchedulerPoolMarker<B> {
73    // The type of stream produced by this factory.
74    type Stream = Stream<B>;
75
76    // Creates a new stream with an empty task list and a backend stream.
77    fn create(&mut self) -> Self::Stream {
78        Stream {
79            tasks: Vec::new(),
80            // Uses the backend's factory to create a new stream.
81            stream: self.backend.factory().create(),
82        }
83    }
84}
85
86/// Options for configuring a `SchedulerMultiStream`.
87#[derive(Debug)]
88pub struct SchedulerMultiStreamOptions {
89    /// Maximum number of streams allowed in the pool.
90    pub max_streams: u8,
91    /// Maximum number of tasks per stream before execution is triggered.
92    pub max_tasks: usize,
93    /// The scheduling strategy to use.
94    pub strategy: SchedulerStrategy,
95}
96
97impl<B: SchedulerStreamBackend> SchedulerMultiStream<B> {
98    /// Creates a new `SchedulerMultiStream` with the given backend and options.
99    pub fn new(
100        logger: Arc<ServerLogger>,
101        backend: B,
102        options: SchedulerMultiStreamOptions,
103    ) -> Self {
104        Self {
105            pool: StreamPool::new(SchedulerPoolMarker { backend }, options.max_streams, 0),
106            max_tasks: options.max_tasks,
107            strategy: options.strategy,
108            logger,
109        }
110    }
111
112    /// Returns a mutable reference to the backend stream for a given stream ID.
113    pub fn stream(&mut self, stream_id: &StreamId) -> &mut B::Stream {
114        let stream = self.pool.get_mut(stream_id);
115        &mut stream.stream
116    }
117
118    /// Registers a task for execution on a specific stream, ensuring stream alignment.
119    pub fn register<'a>(
120        &mut self,
121        stream_id: StreamId,
122        task: B::Task,
123        bindings: impl Iterator<Item = &'a Binding>,
124    ) {
125        // Align streams to ensure dependencies are handled correctly.
126        self.align_streams(stream_id, bindings);
127
128        // Get the stream for the given stream ID and add the task to its queue.
129        let current = self.pool.get_mut(&stream_id);
130        current.tasks.push(task);
131
132        // If the task queue exceeds the maximum, execute the stream.
133        if current.tasks.len() >= self.max_tasks {
134            self.execute_streams(vec![stream_id]);
135        }
136    }
137
138    /// Aligns streams by flushing tasks from streams that conflict with the given bindings.
139    pub(crate) fn align_streams<'a>(
140        &mut self,
141        stream_id: StreamId,
142        bindings: impl Iterator<Item = &'a Binding>,
143    ) {
144        let mut to_flush = Vec::new();
145        // Get the index of the target stream.
146        let index = self.pool.stream_index(&stream_id);
147
148        // Identify streams that need to be flushed due to conflicting bindings.
149        for binding in bindings {
150            let index_stream = self.pool.stream_index(&binding.stream);
151            if index != index_stream {
152                to_flush.push(binding.stream);
153
154                self.logger.log_streaming(
155                    |level| matches!(level, StreamingLogLevel::Full),
156                    || format!("Binding on {} is shared on {}", binding.stream, stream_id),
157                );
158            }
159        }
160
161        // If no streams need flushing, return early.
162        if to_flush.is_empty() {
163            return;
164        }
165
166        self.logger.log_streaming(
167            |level| !matches!(level, StreamingLogLevel::Disabled),
168            || {
169                format!(
170                    "Flushing streams {:?} before registering more tasks on {stream_id}",
171                    to_flush
172                )
173            },
174        );
175        // Execute the streams that need to be flushed.
176        self.execute_streams(to_flush);
177    }
178
179    /// Executes tasks from the specified streams based on the scheduling strategy.
180    pub fn execute_streams(&mut self, stream_ids: Vec<StreamId>) {
181        let mut indices = Vec::with_capacity(stream_ids.len());
182        // Collect unique stream indices to avoid redundant processing.
183        for id in stream_ids {
184            let index = self.pool.stream_index(&id);
185            if !indices.contains(&index) {
186                indices.push(index);
187            }
188        }
189
190        // Create schedules for each stream to be executed.
191        let mut schedules = Vec::new();
192        for index in indices {
193            let stream = unsafe { self.pool.get_mut_index(index) }; // Note: `unsafe` usage assumes valid index.
194            let tasks = stream.flush();
195            let num_tasks = tasks.len();
196
197            schedules.push(Schedule {
198                tasks: tasks.into_iter(),
199                num_tasks,
200                stream_index: index,
201            });
202        }
203
204        // If no schedules were created, return early.
205        if schedules.is_empty() {
206            return;
207        }
208
209        // Execute schedules based on the configured strategy.
210        match self.strategy {
211            SchedulerStrategy::Interleave => self.execute_schedules_interleave(schedules),
212            SchedulerStrategy::Sequential => self.execute_schedules_sequence(schedules),
213        }
214    }
215
216    /// Executes schedules sequentially, processing each stream's tasks in order.
217    fn execute_schedules_sequence(&mut self, schedules: Vec<Schedule<B>>) {
218        for schedule in schedules {
219            let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) }; // Note: `unsafe` usage assumes valid index.
220            for task in schedule.tasks {
221                // Enqueue each task on the stream.
222                B::enqueue(task, &mut stream.stream);
223            }
224
225            // Makes sure the tasks are ordered on the compute queue.
226            B::flush(&mut stream.stream);
227        }
228    }
229
230    //// Executes schedules in an interleaved manner, alternating tasks from different streams.
231    ///
232    /// We chose the first stream as the one executing the tasks, ensuring proper ordering by
233    /// flushing all other streams first and flushing the execution stream at the end.
234    /// This way, we ensure that most tasks are actually interleaved on the real compute queue
235    /// shared across all streams.
236    fn execute_schedules_interleave(&mut self, mut schedules: Vec<Schedule<B>>) {
237        // Makes sure the tasks are ordered on the compute queue.
238        for schedule in schedules.iter_mut().skip(1) {
239            let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) };
240            B::flush(&mut stream.stream);
241        }
242
243        let execution_index = schedules.first().expect("At least one stream").stream_index;
244        let stream = unsafe { self.pool.get_mut_index(execution_index) };
245
246        // Find the maximum number of tasks across all schedules.
247        let num_tasks_max = schedules
248            .iter()
249            .map(|s| s.num_tasks)
250            .max()
251            .expect("At least one schedule");
252
253        // Iterate through tasks, interleaving them across streams.
254        for _ in 0..num_tasks_max {
255            for schedule in schedules.iter_mut() {
256                // If there are tasks remaining in the schedule, enqueue the next one.
257                if let Some(task) = schedule.tasks.next() {
258                    B::enqueue(task, &mut stream.stream);
259                }
260            }
261        }
262
263        // Making sure all tasks are registered to the queue.
264        B::flush(&mut stream.stream);
265    }
266}
267
268// Represents a schedule for executing tasks on a specific stream.
269struct Schedule<B: SchedulerStreamBackend> {
270    // Iterator over the tasks to be executed.
271    tasks: alloc::vec::IntoIter<B::Task>,
272    // Number of tasks in the schedule.
273    num_tasks: usize,
274    // Index of the stream in the pool.
275    stream_index: usize,
276}