use crate::{
config::streaming::StreamingLogLevel,
logging::ServerLogger,
memory_management::SliceId,
server::{Binding, ExecutionError},
stream::{StreamFactory, StreamPool},
};
use core::any::Any;
use cubecl_common::stream_id::StreamId;
use hashbrown::HashMap;
use std::sync::{Arc, mpsc::SyncSender};
pub trait EventStreamBackend: 'static {
type Stream: core::fmt::Debug;
type Event: Send + 'static;
fn create_stream(&self) -> Self::Stream;
fn flush(stream: &mut Self::Stream) -> Self::Event;
fn wait_event(stream: &mut Self::Stream, event: Self::Event);
fn wait_event_sync(event: Self::Event) -> Result<(), ExecutionError>;
}
#[derive(Debug)]
pub struct MultiStream<B: EventStreamBackend> {
streams: StreamPool<EventStreamBackendWrapper<B>>,
pub logger: Arc<ServerLogger>,
max_streams: usize,
gc: GcThread<B>,
}
pub(crate) struct StreamWrapper<B: EventStreamBackend> {
stream: B::Stream,
cursor: u64,
last_synced: HashMap<usize, u64>,
}
pub struct ResolvedStreams<'a, B: EventStreamBackend> {
pub cursor: u64,
streams: &'a mut StreamPool<EventStreamBackendWrapper<B>>,
analysis: SharedBindingAnalysis,
gc: &'a GcThread<B>,
pub current: StreamId,
}
#[derive(Debug)]
pub struct GcTask<B: EventStreamBackend> {
to_drop: Box<dyn Any + Send + 'static>,
event: B::Event,
}
impl<B: EventStreamBackend> GcTask<B> {
pub fn new<T: Send + 'static>(to_drop: T, event: B::Event) -> Self {
Self {
to_drop: Box::new(to_drop),
event,
}
}
}
#[derive(Debug)]
struct EventStreamBackendWrapper<B: EventStreamBackend> {
backend: B,
}
impl<B: EventStreamBackend> StreamFactory for EventStreamBackendWrapper<B> {
type Stream = StreamWrapper<B>;
fn create(&mut self) -> Self::Stream {
StreamWrapper {
stream: self.backend.create_stream(),
cursor: 0,
last_synced: Default::default(),
}
}
}
#[derive(Debug)]
struct GcThread<B: EventStreamBackend> {
sender: SyncSender<GcTask<B>>,
}
impl<B: EventStreamBackend> GcThread<B> {
fn new() -> GcThread<B> {
let (sender, recv) = std::sync::mpsc::sync_channel::<GcTask<B>>(8);
std::thread::spawn(move || {
while let Ok(event) = recv.recv() {
B::wait_event_sync(event.event).unwrap();
core::mem::drop(event.to_drop);
}
});
GcThread { sender }
}
fn register(&self, task: GcTask<B>) {
self.sender.send(task).unwrap()
}
}
fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
stream_id.value as usize % max_streams
}
impl<'a, B: EventStreamBackend> ResolvedStreams<'a, B> {
pub fn get(&mut self, stream_id: &StreamId) -> &mut B::Stream {
let stream = self.streams.get_mut(stream_id);
&mut stream.stream
}
pub fn current(&mut self) -> &mut B::Stream {
let stream = self.streams.get_mut(&self.current);
&mut stream.stream
}
pub fn gc(&mut self, gc: GcTask<B>) {
self.gc.sender.send(gc).unwrap();
}
}
impl<'a, B: EventStreamBackend> Drop for ResolvedStreams<'a, B> {
fn drop(&mut self) {
if self.analysis.slices.is_empty() {
return;
}
let stream = self.streams.get_mut(&self.current);
let event_origin = B::flush(&mut stream.stream);
let stream_gc = &mut unsafe { self.streams.get_special(0) }.stream;
B::wait_event(stream_gc, event_origin);
let event = B::flush(stream_gc);
let mut ids = Vec::new();
self.analysis
.slices
.drain()
.for_each(|item| ids.extend(item.1));
self.gc.register(GcTask::new(ids, event));
}
}
impl<B: EventStreamBackend> MultiStream<B> {
pub fn new(logger: Arc<ServerLogger>, backend: B, max_streams: u8) -> Self {
let wrapper = EventStreamBackendWrapper { backend };
Self {
streams: StreamPool::new(wrapper, max_streams, 1),
logger,
max_streams: max_streams as usize,
gc: GcThread::new(),
}
}
pub fn gc(&mut self, gc: GcTask<B>) {
self.gc.sender.send(gc).unwrap();
}
pub fn resolve<'a>(
&mut self,
stream_id: StreamId,
bindings: impl Iterator<Item = &'a Binding>,
) -> ResolvedStreams<'_, B> {
let analysis = self.align_streams(stream_id, bindings);
let stream = self.streams.get_mut(&stream_id);
stream.cursor += 1;
ResolvedStreams {
cursor: stream.cursor,
streams: &mut self.streams,
current: stream_id,
analysis,
gc: &self.gc,
}
}
fn align_streams<'a>(
&mut self,
stream_id: StreamId,
bindings: impl Iterator<Item = &'a Binding>,
) -> SharedBindingAnalysis {
let analysis = self.update_shared_bindings(stream_id, bindings);
self.apply_analysis(stream_id, analysis)
}
pub(crate) fn update_shared_bindings<'a>(
&mut self,
stream_id: StreamId,
bindings: impl Iterator<Item = &'a Binding>,
) -> SharedBindingAnalysis {
let mut analysis = SharedBindingAnalysis::default();
let current = self.streams.get_mut(&stream_id);
for binding in bindings {
if stream_id != binding.stream {
let index = stream_index(&binding.stream, self.max_streams);
if let Some(last_synced) = current.last_synced.get(&index) {
if *last_synced < binding.cursor {
self.logger.log_streaming(
|level| matches!(level, StreamingLogLevel::Full),
|| {
format!(
"Binding on {} is shared on {} since it's not sync {} < {}",
binding.stream, stream_id, last_synced, binding.cursor
)
},
);
analysis.shared(binding, index);
}
} else {
self.logger.log_streaming(
|level| matches!(level, StreamingLogLevel::Full),
|| {
format!(
"Binding on {} is shared on {} since it was never synced.",
binding.stream, stream_id,
)
},
);
analysis.shared(binding, index);
}
}
}
analysis
}
pub(crate) fn apply_analysis(
&mut self,
stream_id: StreamId,
analysis: SharedBindingAnalysis,
) -> SharedBindingAnalysis {
if analysis.slices.is_empty() {
return analysis;
}
let mut events = Vec::with_capacity(analysis.slices.len());
unsafe {
for origin in analysis.slices.keys() {
let stream = self.streams.get_mut_index(*origin);
let event = B::flush(&mut stream.stream);
events.push(((origin, stream.cursor), event));
}
}
let stream = self.streams.get_mut(&stream_id);
for ((stream_origin, cursor_origin), event) in events {
stream.last_synced.insert(*stream_origin, cursor_origin);
self.logger.log_streaming(
|level| !matches!(level, StreamingLogLevel::Disabled),
|| format!("Waiting on {stream_origin} from {stream_id}",),
);
B::wait_event(&mut stream.stream, event);
}
analysis
}
}
impl<B: EventStreamBackend> core::fmt::Debug for StreamWrapper<B> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("StreamWrapper")
.field("stream", &self.stream)
.field("cursor", &self.cursor)
.field("last_synced", &self.last_synced)
.finish()
}
}
#[derive(Default, Debug, PartialEq, Eq)]
pub(crate) struct SharedBindingAnalysis {
slices: HashMap<usize, Vec<SliceId>>,
}
impl SharedBindingAnalysis {
fn shared(&mut self, binding: &Binding, index: usize) {
match self.slices.get_mut(&index) {
Some(bindings) => bindings.push(*binding.memory.id()),
None => {
self.slices.insert(index, vec![*binding.memory.id()]);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{memory_management::SliceHandle, server::Handle};
const MAX_STREAMS: u8 = 4;
#[test_log::test]
fn test_analysis_shared_bindings_1() {
let logger = Arc::new(ServerLogger::default());
let stream_1 = StreamId { value: 1 };
let stream_2 = StreamId { value: 2 };
let binding_1 = binding(stream_1);
let binding_2 = binding(stream_2);
let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
ms.resolve(stream_1, [].into_iter());
ms.resolve(stream_2, [].into_iter());
let analysis = ms.update_shared_bindings(stream_1, [&binding_1, &binding_2].into_iter());
let mut expected = SharedBindingAnalysis::default();
expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
assert_eq!(analysis, expected);
}
#[test_log::test]
fn test_analysis_shared_bindings_2() {
let logger = Arc::new(ServerLogger::default());
let stream_1 = StreamId { value: 1 };
let stream_2 = StreamId { value: 2 };
let binding_1 = binding(stream_1);
let binding_2 = binding(stream_2);
let binding_3 = binding(stream_1);
let mut ms = MultiStream::new(logger, TestBackend, 4);
ms.resolve(stream_1, [].into_iter());
ms.resolve(stream_2, [].into_iter());
let analysis =
ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
let mut expected = SharedBindingAnalysis::default();
expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
assert_eq!(analysis, expected);
}
#[test_log::test]
fn test_analysis_no_shared() {
let logger = Arc::new(ServerLogger::default());
let stream_1 = StreamId { value: 1 };
let stream_2 = StreamId { value: 2 };
let binding_1 = binding(stream_1);
let binding_2 = binding(stream_1);
let binding_3 = binding(stream_1);
let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
ms.resolve(stream_1, [].into_iter());
ms.resolve(stream_2, [].into_iter());
let analysis =
ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
let expected = SharedBindingAnalysis::default();
assert_eq!(analysis, expected);
}
#[test_log::test]
fn test_state() {
let logger = Arc::new(ServerLogger::default());
let stream_1 = StreamId { value: 1 };
let stream_2 = StreamId { value: 2 };
let binding_1 = binding(stream_1);
let binding_2 = binding(stream_2);
let binding_3 = binding(stream_1);
let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
ms.resolve(stream_1, [].into_iter());
ms.resolve(stream_2, [].into_iter());
ms.resolve(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
let stream1 = ms.streams.get_mut(&stream_1);
let index_2 = stream_index(&stream_2, MAX_STREAMS as usize);
assert_eq!(stream1.last_synced.get(&index_2), Some(&1));
assert_eq!(stream1.cursor, 2);
let stream2 = ms.streams.get_mut(&stream_2);
assert!(stream2.last_synced.is_empty());
assert_eq!(stream2.cursor, 1);
}
fn binding(stream: StreamId) -> Binding {
Handle::new(SliceHandle::new(), None, None, stream, 0, 10).binding()
}
struct TestBackend;
#[derive(Debug)]
struct TestStream {}
#[derive(Debug)]
struct TestEvent {}
impl EventStreamBackend for TestBackend {
type Stream = TestStream;
type Event = TestEvent;
fn create_stream(&self) -> Self::Stream {
TestStream {}
}
fn flush(_stream: &mut Self::Stream) -> Self::Event {
TestEvent {}
}
fn wait_event(_stream: &mut Self::Stream, _event: Self::Event) {}
fn wait_event_sync(_event: Self::Event) -> Result<(), ExecutionError> {
Ok(())
}
}
}