use crate::{
config::streaming::StreamingLogLevel,
logging::ServerLogger,
stream::{StreamFactory, StreamPool},
};
use alloc::{format, sync::Arc, vec, vec::Vec};
use cubecl_common::stream_id::StreamId;
pub trait SchedulerStreamBackend {
type Task: core::fmt::Debug;
type Stream: core::fmt::Debug;
type Factory: StreamFactory<Stream = Self::Stream>;
fn enqueue(task: Self::Task, stream: &mut Self::Stream);
fn flush(stream: &mut Self::Stream);
fn factory(&mut self) -> &mut Self::Factory;
}
#[derive(Debug)]
pub struct SchedulerMultiStream<B: SchedulerStreamBackend> {
pool: StreamPool<SchedulerPoolMarker<B>>,
strategy: SchedulerStrategy,
max_tasks: usize,
pub logger: Arc<ServerLogger>,
}
#[derive(Debug)]
pub enum SchedulerStrategy {
Interleave,
Sequential,
}
#[derive(Debug)]
pub struct Stream<B: SchedulerStreamBackend> {
tasks: Vec<B::Task>,
stream: B::Stream,
}
impl<B: SchedulerStreamBackend> Stream<B> {
fn flush(&mut self) -> Vec<B::Task> {
let mut returned = Vec::with_capacity(self.tasks.capacity());
core::mem::swap(&mut returned, &mut self.tasks);
returned
}
}
#[derive(Debug)]
struct SchedulerPoolMarker<B: SchedulerStreamBackend> {
backend: B,
}
impl<B: SchedulerStreamBackend> StreamFactory for SchedulerPoolMarker<B> {
type Stream = Stream<B>;
fn create(&mut self) -> Self::Stream {
Stream {
tasks: Vec::new(),
stream: self.backend.factory().create(),
}
}
}
#[derive(Debug)]
pub struct SchedulerMultiStreamOptions {
pub max_streams: u8,
pub max_tasks: usize,
pub strategy: SchedulerStrategy,
}
impl<B: SchedulerStreamBackend> SchedulerMultiStream<B> {
pub fn new(
logger: Arc<ServerLogger>,
backend: B,
options: SchedulerMultiStreamOptions,
) -> Self {
Self {
pool: StreamPool::new(SchedulerPoolMarker { backend }, options.max_streams, 0),
max_tasks: options.max_tasks,
strategy: options.strategy,
logger,
}
}
pub fn stream(&mut self, stream_id: &StreamId) -> &mut B::Stream {
let stream = self.pool.get_mut(stream_id);
&mut stream.stream
}
pub fn register(&mut self, stream_id: StreamId, task: B::Task, args_streams: &[StreamId]) {
self.align_streams(stream_id, args_streams);
let current = self.pool.get_mut(&stream_id);
current.tasks.push(task);
if current.tasks.len() >= self.max_tasks {
self.execute_streams(vec![stream_id]);
}
}
pub(crate) fn align_streams(&mut self, stream_id: StreamId, args_streams: &[StreamId]) {
let mut to_flush = Vec::new();
let index = self.pool.stream_index(&stream_id);
for arg_stream in args_streams {
let index_stream = self.pool.stream_index(arg_stream);
if index != index_stream {
to_flush.push(*arg_stream);
self.logger.log_streaming(
|level| matches!(level, StreamingLogLevel::Full),
|| format!("Binding on {} is shared on {}", arg_stream, stream_id),
);
}
}
if to_flush.is_empty() {
return;
}
self.logger.log_streaming(
|level| !matches!(level, StreamingLogLevel::Disabled),
|| {
format!(
"Flushing streams {to_flush:?} before registering more tasks on {stream_id}"
)
},
);
self.execute_streams(to_flush);
}
pub fn execute_streams(&mut self, stream_ids: Vec<StreamId>) {
let mut indices = Vec::with_capacity(stream_ids.len());
for id in stream_ids {
let index = self.pool.stream_index(&id);
if !indices.contains(&index) {
indices.push(index);
}
}
let mut schedules = Vec::new();
for index in indices {
let stream = unsafe { self.pool.get_mut_index(index) }; let tasks = stream.flush();
let num_tasks = tasks.len();
schedules.push(Schedule {
tasks: tasks.into_iter(),
num_tasks,
stream_index: index,
});
}
if schedules.is_empty() {
return;
}
match self.strategy {
SchedulerStrategy::Interleave => self.execute_schedules_interleave(schedules),
SchedulerStrategy::Sequential => self.execute_schedules_sequence(schedules),
}
}
fn execute_schedules_sequence(&mut self, schedules: Vec<Schedule<B>>) {
for schedule in schedules {
let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) }; for task in schedule.tasks {
B::enqueue(task, &mut stream.stream);
}
B::flush(&mut stream.stream);
}
}
fn execute_schedules_interleave(&mut self, mut schedules: Vec<Schedule<B>>) {
for schedule in schedules.iter_mut().skip(1) {
let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) };
B::flush(&mut stream.stream);
}
let execution_index = schedules.first().expect("At least one stream").stream_index;
let stream = unsafe { self.pool.get_mut_index(execution_index) };
let num_tasks_max = schedules
.iter()
.map(|s| s.num_tasks)
.max()
.expect("At least one schedule");
for _ in 0..num_tasks_max {
for schedule in schedules.iter_mut() {
if let Some(task) = schedule.tasks.next() {
B::enqueue(task, &mut stream.stream);
}
}
}
B::flush(&mut stream.stream);
}
}
struct Schedule<B: SchedulerStreamBackend> {
tasks: alloc::vec::IntoIter<B::Task>,
num_tasks: usize,
stream_index: usize,
}