cubecl_runtime/stream/
scheduler.rs1use crate::{
2 config::streaming::StreamingLogLevel,
3 logging::ServerLogger,
4 server::Binding,
5 stream::{StreamFactory, StreamPool},
6};
7use alloc::sync::Arc;
8use cubecl_common::stream_id::StreamId;
9
10pub trait SchedulerStreamBackend {
12 type Task: core::fmt::Debug;
14 type Stream: core::fmt::Debug;
16 type Factory: StreamFactory<Stream = Self::Stream>;
18
19 fn enqueue(task: Self::Task, stream: &mut Self::Stream);
21 fn flush(stream: &mut Self::Stream);
23 fn factory(&mut self) -> &mut Self::Factory;
25}
26
27#[derive(Debug)]
29pub struct SchedulerMultiStream<B: SchedulerStreamBackend> {
30 pool: StreamPool<SchedulerPoolMarker<B>>,
32 strategy: SchedulerStrategy,
34 max_tasks: usize,
36 pub logger: Arc<ServerLogger>,
38}
39
40#[derive(Debug)]
42pub enum SchedulerStrategy {
43 Interleave,
45 Sequential,
47}
48
49#[derive(Debug)]
51pub struct Stream<B: SchedulerStreamBackend> {
52 tasks: Vec<B::Task>,
54 stream: B::Stream,
56}
57
58impl<B: SchedulerStreamBackend> Stream<B> {
59 fn flush(&mut self) -> Vec<B::Task> {
61 let mut returned = Vec::with_capacity(self.tasks.capacity());
62 core::mem::swap(&mut returned, &mut self.tasks);
63 returned
64 }
65}
66
67#[derive(Debug)]
68struct SchedulerPoolMarker<B: SchedulerStreamBackend> {
69 backend: B,
70}
71
72impl<B: SchedulerStreamBackend> StreamFactory for SchedulerPoolMarker<B> {
73 type Stream = Stream<B>;
75
76 fn create(&mut self) -> Self::Stream {
78 Stream {
79 tasks: Vec::new(),
80 stream: self.backend.factory().create(),
82 }
83 }
84}
85
86#[derive(Debug)]
88pub struct SchedulerMultiStreamOptions {
89 pub max_streams: u8,
91 pub max_tasks: usize,
93 pub strategy: SchedulerStrategy,
95}
96
97impl<B: SchedulerStreamBackend> SchedulerMultiStream<B> {
98 pub fn new(
100 logger: Arc<ServerLogger>,
101 backend: B,
102 options: SchedulerMultiStreamOptions,
103 ) -> Self {
104 Self {
105 pool: StreamPool::new(SchedulerPoolMarker { backend }, options.max_streams, 0),
106 max_tasks: options.max_tasks,
107 strategy: options.strategy,
108 logger,
109 }
110 }
111
112 pub fn stream(&mut self, stream_id: &StreamId) -> &mut B::Stream {
114 let stream = self.pool.get_mut(stream_id);
115 &mut stream.stream
116 }
117
118 pub fn register<'a>(
120 &mut self,
121 stream_id: StreamId,
122 task: B::Task,
123 bindings: impl Iterator<Item = &'a Binding>,
124 ) {
125 self.align_streams(stream_id, bindings);
127
128 let current = self.pool.get_mut(&stream_id);
130 current.tasks.push(task);
131
132 if current.tasks.len() >= self.max_tasks {
134 self.execute_streams(vec![stream_id]);
135 }
136 }
137
138 pub(crate) fn align_streams<'a>(
140 &mut self,
141 stream_id: StreamId,
142 bindings: impl Iterator<Item = &'a Binding>,
143 ) {
144 let mut to_flush = Vec::new();
145 let index = self.pool.stream_index(&stream_id);
147
148 for binding in bindings {
150 let index_stream = self.pool.stream_index(&binding.stream);
151 if index != index_stream {
152 to_flush.push(binding.stream);
153
154 self.logger.log_streaming(
155 |level| matches!(level, StreamingLogLevel::Full),
156 || format!("Binding on {} is shared on {}", binding.stream, stream_id),
157 );
158 }
159 }
160
161 if to_flush.is_empty() {
163 return;
164 }
165
166 self.logger.log_streaming(
167 |level| !matches!(level, StreamingLogLevel::Disabled),
168 || {
169 format!(
170 "Flushing streams {:?} before registering more tasks on {stream_id}",
171 to_flush
172 )
173 },
174 );
175 self.execute_streams(to_flush);
177 }
178
179 pub fn execute_streams(&mut self, stream_ids: Vec<StreamId>) {
181 let mut indices = Vec::with_capacity(stream_ids.len());
182 for id in stream_ids {
184 let index = self.pool.stream_index(&id);
185 if !indices.contains(&index) {
186 indices.push(index);
187 }
188 }
189
190 let mut schedules = Vec::new();
192 for index in indices {
193 let stream = unsafe { self.pool.get_mut_index(index) }; let tasks = stream.flush();
195 let num_tasks = tasks.len();
196
197 schedules.push(Schedule {
198 tasks: tasks.into_iter(),
199 num_tasks,
200 stream_index: index,
201 });
202 }
203
204 if schedules.is_empty() {
206 return;
207 }
208
209 match self.strategy {
211 SchedulerStrategy::Interleave => self.execute_schedules_interleave(schedules),
212 SchedulerStrategy::Sequential => self.execute_schedules_sequence(schedules),
213 }
214 }
215
216 fn execute_schedules_sequence(&mut self, schedules: Vec<Schedule<B>>) {
218 for schedule in schedules {
219 let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) }; for task in schedule.tasks {
221 B::enqueue(task, &mut stream.stream);
223 }
224
225 B::flush(&mut stream.stream);
227 }
228 }
229
230 fn execute_schedules_interleave(&mut self, mut schedules: Vec<Schedule<B>>) {
237 for schedule in schedules.iter_mut().skip(1) {
239 let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) };
240 B::flush(&mut stream.stream);
241 }
242
243 let execution_index = schedules.first().expect("At least one stream").stream_index;
244 let stream = unsafe { self.pool.get_mut_index(execution_index) };
245
246 let num_tasks_max = schedules
248 .iter()
249 .map(|s| s.num_tasks)
250 .max()
251 .expect("At least one schedule");
252
253 for _ in 0..num_tasks_max {
255 for schedule in schedules.iter_mut() {
256 if let Some(task) = schedule.tasks.next() {
258 B::enqueue(task, &mut stream.stream);
259 }
260 }
261 }
262
263 B::flush(&mut stream.stream);
265 }
266}
267
268struct Schedule<B: SchedulerStreamBackend> {
270 tasks: alloc::vec::IntoIter<B::Task>,
272 num_tasks: usize,
274 stream_index: usize,
276}