Skip to main content

cubecl_runtime/stream/
event.rs

1use crate::{
2    config::streaming::StreamingLogLevel,
3    logging::ServerLogger,
4    memory_management::SliceId,
5    server::{Binding, ExecutionError},
6    stream::{StreamFactory, StreamPool},
7};
8use core::any::Any;
9use cubecl_common::stream_id::StreamId;
10use hashbrown::HashMap;
11use std::{
12    boxed::Box,
13    format,
14    sync::{Arc, mpsc::SyncSender},
15    vec,
16    vec::Vec,
17};
18
19/// Trait defining the backend operations for managing streams and events.
20///
21/// This trait provides the necessary methods for initializing streams, flushing them to create events,
22/// and waiting on events for synchronization purposes.
23pub trait EventStreamBackend: 'static {
24    /// The type representing a stream in this backend.
25    type Stream: core::fmt::Debug;
26    /// The type representing an event in this backend.
27    type Event: Send + 'static;
28
29    /// Initializes and returns a new stream associated with the given stream ID.
30    fn create_stream(&self) -> Self::Stream;
31    /// Flushes the given stream, ensuring all pending operations are submitted, and returns an event
32    /// that can be used for synchronization.
33    fn flush(stream: &mut Self::Stream) -> Self::Event;
34    /// Makes the stream wait for the specified event to complete before proceeding with further operations.
35    fn wait_event(stream: &mut Self::Stream, event: Self::Event);
36    /// Wait for the given event synching the CPU.
37    fn wait_event_sync(event: Self::Event) -> Result<(), ExecutionError>;
38}
39
40/// Manages multiple streams with synchronization logic based on shared bindings.
41///
42/// This struct handles the creation and alignment of streams to ensure proper synchronization
43/// when bindings (e.g., buffers) are shared across different streams.
44#[derive(Debug)]
45pub struct MultiStream<B: EventStreamBackend> {
46    /// The map of stream IDs to their corresponding stream wrappers.
47    streams: StreamPool<EventStreamBackendWrapper<B>>,
48    /// The logger used by the server.
49    pub logger: Arc<ServerLogger>,
50    max_streams: usize,
51    gc: GcThread<B>,
52}
53
54/// A wrapper around a backend stream that includes synchronization metadata.
55///
56/// This includes the stream itself, a map of last synchronized cursors from other streams,
57/// and the current cursor position for this stream.
58pub(crate) struct StreamWrapper<B: EventStreamBackend> {
59    /// The underlying backend stream.
60    stream: B::Stream,
61    /// The current cursor position, representing the logical progress or version of operations on this stream.
62    cursor: u64,
63    /// A map tracking the last synchronized cursor positions from other streams.
64    last_synced: HashMap<usize, u64>,
65}
66
67/// Streams that are synchronized correctly after a [`MultiStream::resolve`] is called.
68pub struct ResolvedStreams<'a, B: EventStreamBackend> {
69    /// The cursor on the current stream.
70    ///
71    /// This cursor should be use for new allocations happening on the current stream.
72    pub cursor: u64,
73    streams: &'a mut StreamPool<EventStreamBackendWrapper<B>>,
74    analysis: SharedBindingAnalysis,
75    gc: &'a GcThread<B>,
76    /// The current stream where new tasks can be sent safely.
77    pub current: StreamId,
78}
79
80#[derive(Debug)]
81/// A task to be enqueue on the gc stream that will be clearned after an event is reached.
82pub struct GcTask<B: EventStreamBackend> {
83    to_drop: Box<dyn Any + Send + 'static>,
84    /// The event to sync making sure the bindings in the batch are ready to be reused by other streams.
85    event: B::Event,
86}
87
88impl<B: EventStreamBackend> GcTask<B> {
89    /// Creates a new task that will be clearned when the event is reached.
90    pub fn new<T: Send + 'static>(to_drop: T, event: B::Event) -> Self {
91        Self {
92            to_drop: Box::new(to_drop),
93            event,
94        }
95    }
96}
97
98#[derive(Debug)]
99struct EventStreamBackendWrapper<B: EventStreamBackend> {
100    backend: B,
101}
102
103impl<B: EventStreamBackend> StreamFactory for EventStreamBackendWrapper<B> {
104    type Stream = StreamWrapper<B>;
105
106    fn create(&mut self) -> Self::Stream {
107        StreamWrapper {
108            stream: self.backend.create_stream(),
109            cursor: 0,
110            last_synced: Default::default(),
111        }
112    }
113}
114
115#[derive(Debug)]
116struct GcThread<B: EventStreamBackend> {
117    sender: SyncSender<GcTask<B>>,
118}
119
120impl<B: EventStreamBackend> GcThread<B> {
121    fn new() -> GcThread<B> {
122        let (sender, recv) = std::sync::mpsc::sync_channel::<GcTask<B>>(8);
123
124        std::thread::spawn(move || {
125            while let Ok(event) = recv.recv() {
126                B::wait_event_sync(event.event).unwrap();
127                core::mem::drop(event.to_drop);
128            }
129        });
130
131        GcThread { sender }
132    }
133    fn register(&self, task: GcTask<B>) {
134        self.sender.send(task).unwrap()
135    }
136}
137
138fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
139    stream_id.value as usize % max_streams
140}
141
142impl<'a, B: EventStreamBackend> ResolvedStreams<'a, B> {
143    /// Get the stream associated to the given [`stream_id`](StreamId).
144    pub fn get(&mut self, stream_id: &StreamId) -> &mut B::Stream {
145        let stream = self.streams.get_mut(stream_id);
146        &mut stream.stream
147    }
148
149    /// Get the stream associated to the [current `stream_id`](StreamId).
150    pub fn current(&mut self) -> &mut B::Stream {
151        let stream = self.streams.get_mut(&self.current);
152        &mut stream.stream
153    }
154
155    /// Enqueue a task to be cleaned.
156    pub fn gc(&mut self, gc: GcTask<B>) {
157        self.gc.sender.send(gc).unwrap();
158    }
159}
160
161impl<'a, B: EventStreamBackend> Drop for ResolvedStreams<'a, B> {
162    fn drop(&mut self) {
163        if self.analysis.slices.is_empty() {
164            return;
165        }
166
167        let stream = self.streams.get_mut(&self.current);
168        let event_origin = B::flush(&mut stream.stream);
169
170        let stream_gc = &mut unsafe { self.streams.get_special(0) }.stream;
171        B::wait_event(stream_gc, event_origin);
172        let event = B::flush(stream_gc);
173
174        let mut ids = Vec::new();
175        self.analysis
176            .slices
177            .drain()
178            .for_each(|item| ids.extend(item.1));
179
180        self.gc.register(GcTask::new(ids, event));
181    }
182}
183
184impl<B: EventStreamBackend> MultiStream<B> {
185    /// Creates an empty multi-stream.
186    pub fn new(logger: Arc<ServerLogger>, backend: B, max_streams: u8) -> Self {
187        let wrapper = EventStreamBackendWrapper { backend };
188        Self {
189            streams: StreamPool::new(wrapper, max_streams, 1),
190            logger,
191            max_streams: max_streams as usize,
192            gc: GcThread::new(),
193        }
194    }
195
196    /// Enqueue a task to be cleaned.
197    pub fn gc(&mut self, gc: GcTask<B>) {
198        self.gc.sender.send(gc).unwrap();
199    }
200
201    /// Resolves and returns a mutable reference to the stream for the given ID, performing any necessary
202    /// alignment based on the provided bindings.
203    ///
204    /// This method ensures that the stream is synchronized with any shared bindings from other streams
205    /// before returning the stream reference.
206    pub fn resolve<'a>(
207        &mut self,
208        stream_id: StreamId,
209        bindings: impl Iterator<Item = &'a Binding>,
210    ) -> ResolvedStreams<'_, B> {
211        let analysis = self.align_streams(stream_id, bindings);
212
213        let stream = self.streams.get_mut(&stream_id);
214        stream.cursor += 1;
215
216        ResolvedStreams {
217            cursor: stream.cursor,
218            streams: &mut self.streams,
219            current: stream_id,
220            analysis,
221            gc: &self.gc,
222        }
223    }
224
225    /// Aligns the target stream with other streams based on shared bindings.
226    ///
227    /// This initializes the stream if it doesn't exist, analyzes which originating streams need flushing
228    /// for synchronization, flushes them, and waits on the events in the target stream.
229    fn align_streams<'a>(
230        &mut self,
231        stream_id: StreamId,
232        bindings: impl Iterator<Item = &'a Binding>,
233    ) -> SharedBindingAnalysis {
234        let analysis = self.update_shared_bindings(stream_id, bindings);
235
236        self.apply_analysis(stream_id, analysis)
237    }
238
239    /// Update and analyzes the bindings to determine which streams need alignment (flushing and waiting).
240    ///
241    /// This checks for shared bindings from other streams and determines if synchronization is needed
242    /// based on cursor positions.
243    pub(crate) fn update_shared_bindings<'a>(
244        &mut self,
245        stream_id: StreamId,
246        bindings: impl Iterator<Item = &'a Binding>,
247    ) -> SharedBindingAnalysis {
248        let mut analysis = SharedBindingAnalysis::default();
249        let current = self.streams.get_mut(&stream_id);
250
251        for binding in bindings {
252            if stream_id != binding.stream {
253                let index = stream_index(&binding.stream, self.max_streams);
254
255                if let Some(last_synced) = current.last_synced.get(&index) {
256                    if *last_synced < binding.cursor {
257                        self.logger.log_streaming(
258                            |level| matches!(level, StreamingLogLevel::Full),
259                            || {
260                                format!(
261                                    "Binding on {} is shared on {} since it's not sync {} < {}",
262                                    binding.stream, stream_id, last_synced, binding.cursor
263                                )
264                            },
265                        );
266                        analysis.shared(binding, index);
267                    }
268                } else {
269                    self.logger.log_streaming(
270                        |level| matches!(level, StreamingLogLevel::Full),
271                        || {
272                            format!(
273                                "Binding on {} is shared on {} since it was never synced.",
274                                binding.stream, stream_id,
275                            )
276                        },
277                    );
278                    analysis.shared(binding, index);
279                }
280            }
281        }
282
283        analysis
284    }
285
286    pub(crate) fn apply_analysis(
287        &mut self,
288        stream_id: StreamId,
289        analysis: SharedBindingAnalysis,
290    ) -> SharedBindingAnalysis {
291        if analysis.slices.is_empty() {
292            return analysis;
293        }
294
295        let mut events = Vec::with_capacity(analysis.slices.len());
296
297        unsafe {
298            for origin in analysis.slices.keys() {
299                let stream = self.streams.get_mut_index(*origin);
300                let event = B::flush(&mut stream.stream);
301
302                events.push(((origin, stream.cursor), event));
303            }
304        }
305
306        let stream = self.streams.get_mut(&stream_id);
307
308        for ((stream_origin, cursor_origin), event) in events {
309            stream.last_synced.insert(*stream_origin, cursor_origin);
310
311            self.logger.log_streaming(
312                |level| !matches!(level, StreamingLogLevel::Disabled),
313                || format!("Waiting on {stream_origin} from {stream_id}",),
314            );
315
316            B::wait_event(&mut stream.stream, event);
317        }
318
319        analysis
320    }
321}
322
323impl<B: EventStreamBackend> core::fmt::Debug for StreamWrapper<B> {
324    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
325        f.debug_struct("StreamWrapper")
326            .field("stream", &self.stream)
327            .field("cursor", &self.cursor)
328            .field("last_synced", &self.last_synced)
329            .finish()
330    }
331}
332
333#[derive(Default, Debug, PartialEq, Eq)]
334pub(crate) struct SharedBindingAnalysis {
335    slices: HashMap<usize, Vec<SliceId>>,
336}
337
338impl SharedBindingAnalysis {
339    fn shared(&mut self, binding: &Binding, index: usize) {
340        match self.slices.get_mut(&index) {
341            Some(bindings) => bindings.push(*binding.memory.id()),
342            None => {
343                self.slices.insert(index, vec![*binding.memory.id()]);
344            }
345        }
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use crate::{memory_management::SliceHandle, server::Handle};
353
354    const MAX_STREAMS: u8 = 4;
355
356    #[test_log::test]
357    fn test_analysis_shared_bindings_1() {
358        let logger = Arc::new(ServerLogger::default());
359        let stream_1 = StreamId { value: 1 };
360        let stream_2 = StreamId { value: 2 };
361
362        let binding_1 = binding(stream_1);
363        let binding_2 = binding(stream_2);
364
365        let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
366        ms.resolve(stream_1, [].into_iter());
367        ms.resolve(stream_2, [].into_iter());
368
369        let analysis = ms.update_shared_bindings(stream_1, [&binding_1, &binding_2].into_iter());
370
371        let mut expected = SharedBindingAnalysis::default();
372        expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
373
374        assert_eq!(analysis, expected);
375    }
376
377    #[test_log::test]
378    fn test_analysis_shared_bindings_2() {
379        let logger = Arc::new(ServerLogger::default());
380        let stream_1 = StreamId { value: 1 };
381        let stream_2 = StreamId { value: 2 };
382
383        let binding_1 = binding(stream_1);
384        let binding_2 = binding(stream_2);
385        let binding_3 = binding(stream_1);
386
387        let mut ms = MultiStream::new(logger, TestBackend, 4);
388        ms.resolve(stream_1, [].into_iter());
389        ms.resolve(stream_2, [].into_iter());
390
391        let analysis =
392            ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
393
394        let mut expected = SharedBindingAnalysis::default();
395        expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
396
397        assert_eq!(analysis, expected);
398    }
399
400    #[test_log::test]
401    fn test_analysis_no_shared() {
402        let logger = Arc::new(ServerLogger::default());
403        let stream_1 = StreamId { value: 1 };
404        let stream_2 = StreamId { value: 2 };
405
406        let binding_1 = binding(stream_1);
407        let binding_2 = binding(stream_1);
408        let binding_3 = binding(stream_1);
409
410        let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
411        ms.resolve(stream_1, [].into_iter());
412        ms.resolve(stream_2, [].into_iter());
413
414        let analysis =
415            ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
416
417        let expected = SharedBindingAnalysis::default();
418
419        assert_eq!(analysis, expected);
420    }
421
422    #[test_log::test]
423    fn test_state() {
424        let logger = Arc::new(ServerLogger::default());
425        let stream_1 = StreamId { value: 1 };
426        let stream_2 = StreamId { value: 2 };
427
428        let binding_1 = binding(stream_1);
429        let binding_2 = binding(stream_2);
430        let binding_3 = binding(stream_1);
431
432        let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
433        ms.resolve(stream_1, [].into_iter());
434        ms.resolve(stream_2, [].into_iter());
435
436        ms.resolve(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
437
438        let stream1 = ms.streams.get_mut(&stream_1);
439        let index_2 = stream_index(&stream_2, MAX_STREAMS as usize);
440        assert_eq!(stream1.last_synced.get(&index_2), Some(&1));
441        assert_eq!(stream1.cursor, 2);
442
443        let stream2 = ms.streams.get_mut(&stream_2);
444        assert!(stream2.last_synced.is_empty());
445        assert_eq!(stream2.cursor, 1);
446    }
447
448    fn binding(stream: StreamId) -> Binding {
449        Handle::new(SliceHandle::new(), None, None, stream, 0, 10).binding()
450    }
451
452    struct TestBackend;
453
454    #[derive(Debug)]
455    struct TestStream {}
456
457    #[derive(Debug)]
458    struct TestEvent {}
459
460    impl EventStreamBackend for TestBackend {
461        type Stream = TestStream;
462        type Event = TestEvent;
463
464        fn create_stream(&self) -> Self::Stream {
465            TestStream {}
466        }
467
468        fn flush(_stream: &mut Self::Stream) -> Self::Event {
469            TestEvent {}
470        }
471
472        fn wait_event(_stream: &mut Self::Stream, _event: Self::Event) {}
473
474        fn wait_event_sync(_event: Self::Event) -> Result<(), ExecutionError> {
475            Ok(())
476        }
477    }
478}