cubecl_runtime/stream/
scheduler.rs1use crate::{
2 config::streaming::StreamingLogLevel,
3 logging::ServerLogger,
4 stream::{StreamFactory, StreamPool},
5};
6use alloc::{format, sync::Arc, vec, vec::Vec};
7use cubecl_common::stream_id::StreamId;
8
9pub trait SchedulerStreamBackend {
11 type Task: core::fmt::Debug;
13 type Stream: core::fmt::Debug;
15 type Factory: StreamFactory<Stream = Self::Stream>;
17
18 fn enqueue(task: Self::Task, stream: &mut Self::Stream);
20 fn flush(stream: &mut Self::Stream);
22 fn factory(&mut self) -> &mut Self::Factory;
24}
25
26#[derive(Debug)]
28pub struct SchedulerMultiStream<B: SchedulerStreamBackend> {
29 pool: StreamPool<SchedulerPoolMarker<B>>,
31 strategy: SchedulerStrategy,
33 max_tasks: usize,
35 pub logger: Arc<ServerLogger>,
37}
38
39#[derive(Debug)]
41pub enum SchedulerStrategy {
42 Interleave,
44 Sequential,
46}
47
48#[derive(Debug)]
50pub struct Stream<B: SchedulerStreamBackend> {
51 tasks: Vec<B::Task>,
53 stream: B::Stream,
55}
56
57impl<B: SchedulerStreamBackend> Stream<B> {
58 fn flush(&mut self) -> Vec<B::Task> {
60 let mut returned = Vec::with_capacity(self.tasks.capacity());
61 core::mem::swap(&mut returned, &mut self.tasks);
62 returned
63 }
64}
65
66#[derive(Debug)]
67struct SchedulerPoolMarker<B: SchedulerStreamBackend> {
68 backend: B,
69}
70
71impl<B: SchedulerStreamBackend> StreamFactory for SchedulerPoolMarker<B> {
72 type Stream = Stream<B>;
74
75 fn create(&mut self) -> Self::Stream {
77 Stream {
78 tasks: Vec::new(),
79 stream: self.backend.factory().create(),
81 }
82 }
83}
84
85#[derive(Debug)]
87pub struct SchedulerMultiStreamOptions {
88 pub max_streams: u8,
90 pub max_tasks: usize,
92 pub strategy: SchedulerStrategy,
94}
95
96impl<B: SchedulerStreamBackend> SchedulerMultiStream<B> {
97 pub fn new(
99 logger: Arc<ServerLogger>,
100 backend: B,
101 options: SchedulerMultiStreamOptions,
102 ) -> Self {
103 Self {
104 pool: StreamPool::new(SchedulerPoolMarker { backend }, options.max_streams, 0),
105 max_tasks: options.max_tasks,
106 strategy: options.strategy,
107 logger,
108 }
109 }
110
111 pub fn stream(&mut self, stream_id: &StreamId) -> &mut B::Stream {
113 let stream = self.pool.get_mut(stream_id);
114 &mut stream.stream
115 }
116
117 pub fn register(&mut self, stream_id: StreamId, task: B::Task, args_streams: &[StreamId]) {
119 self.align_streams(stream_id, args_streams);
121
122 let current = self.pool.get_mut(&stream_id);
124 current.tasks.push(task);
125
126 if current.tasks.len() >= self.max_tasks {
128 self.execute_streams(vec![stream_id]);
129 }
130 }
131
132 pub(crate) fn align_streams(&mut self, stream_id: StreamId, args_streams: &[StreamId]) {
134 let mut to_flush = Vec::new();
135 let index = self.pool.stream_index(&stream_id);
137
138 for arg_stream in args_streams {
140 let index_stream = self.pool.stream_index(arg_stream);
141 if index != index_stream {
142 to_flush.push(*arg_stream);
143
144 self.logger.log_streaming(
145 |level| matches!(level, StreamingLogLevel::Full),
146 || format!("Binding on {} is shared on {}", arg_stream, stream_id),
147 );
148 }
149 }
150
151 if to_flush.is_empty() {
153 return;
154 }
155
156 self.logger.log_streaming(
157 |level| !matches!(level, StreamingLogLevel::Disabled),
158 || {
159 format!(
160 "Flushing streams {to_flush:?} before registering more tasks on {stream_id}"
161 )
162 },
163 );
164 self.execute_streams(to_flush);
166 }
167
168 pub fn execute_streams(&mut self, stream_ids: Vec<StreamId>) {
170 let mut indices = Vec::with_capacity(stream_ids.len());
171
172 for id in stream_ids {
174 let index = self.pool.stream_index(&id);
175 if !indices.contains(&index) {
176 indices.push(index);
177 }
178 }
179
180 let mut schedules = Vec::new();
182 for index in indices {
183 let stream = unsafe { self.pool.get_mut_index(index) }; let tasks = stream.flush();
185 let num_tasks = tasks.len();
186
187 schedules.push(Schedule {
188 tasks: tasks.into_iter(),
189 num_tasks,
190 stream_index: index,
191 });
192 }
193
194 if schedules.is_empty() {
196 return;
197 }
198
199 match self.strategy {
201 SchedulerStrategy::Interleave => self.execute_schedules_interleave(schedules),
202 SchedulerStrategy::Sequential => self.execute_schedules_sequence(schedules),
203 }
204 }
205
206 fn execute_schedules_sequence(&mut self, schedules: Vec<Schedule<B>>) {
208 for schedule in schedules {
209 let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) }; for task in schedule.tasks {
211 B::enqueue(task, &mut stream.stream);
213 }
214
215 B::flush(&mut stream.stream);
217 }
218 }
219
220 fn execute_schedules_interleave(&mut self, mut schedules: Vec<Schedule<B>>) {
227 for schedule in schedules.iter_mut().skip(1) {
229 let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) };
230 B::flush(&mut stream.stream);
231 }
232
233 let execution_index = schedules.first().expect("At least one stream").stream_index;
234 let stream = unsafe { self.pool.get_mut_index(execution_index) };
235
236 let num_tasks_max = schedules
238 .iter()
239 .map(|s| s.num_tasks)
240 .max()
241 .expect("At least one schedule");
242
243 for _ in 0..num_tasks_max {
245 for schedule in schedules.iter_mut() {
246 if let Some(task) = schedule.tasks.next() {
248 B::enqueue(task, &mut stream.stream);
249 }
250 }
251 }
252
253 B::flush(&mut stream.stream);
255 }
256}
257
258struct Schedule<B: SchedulerStreamBackend> {
260 tasks: alloc::vec::IntoIter<B::Task>,
262 num_tasks: usize,
264 stream_index: usize,
266}