Skip to main content

cubecl_runtime/stream/
event.rs

1use crate::{
2    config::streaming::StreamingLogLevel,
3    logging::ServerLogger,
4    memory_management::ManagedMemoryId,
5    server::{Binding, ServerError},
6    stream::{StreamFactory, StreamPool},
7};
8use core::any::Any;
9use cubecl_common::{backtrace::BackTrace, stream_id::StreamId};
10use hashbrown::HashMap;
11use std::{
12    boxed::Box,
13    format,
14    sync::{Arc, mpsc::SyncSender},
15    vec::Vec,
16};
17
18/// Trait defining the backend operations for managing streams and events.
19///
20/// This trait provides the necessary methods for initializing streams, flushing them to create events,
21/// and waiting on events for synchronization purposes.
22pub trait EventStreamBackend: 'static {
23    /// The type representing a stream in this backend.
24    type Stream: core::fmt::Debug;
25    /// The type representing an event in this backend.
26    type Event: Send + 'static;
27
28    /// Initializes and returns a new stream associated with the given stream ID.
29    fn create_stream(&self) -> Self::Stream;
30    /// Returns the cursor of the given handle on the given stream.
31    fn handle_cursor(stream: &Self::Stream, handle: &Binding) -> u64;
32    /// Returns whether the stream can access new tasks.
33    fn is_healthy(stream: &Self::Stream) -> bool;
34
35    /// Flushes the given stream, ensuring all pending operations are submitted, and returns an event
36    /// that can be used for synchronization.
37    fn flush(stream: &mut Self::Stream) -> Self::Event;
38    /// Makes the stream wait for the specified event to complete before proceeding with further operations.
39    fn wait_event(stream: &mut Self::Stream, event: Self::Event);
40    /// Wait for the given event synching the CPU.
41    fn wait_event_sync(event: Self::Event) -> Result<(), ServerError>;
42}
43
44/// Manages multiple streams with synchronization logic based on shared bindings.
45///
46/// This struct handles the creation and alignment of streams to ensure proper synchronization
47/// when bindings (e.g., buffers) are shared across different streams.
48#[derive(Debug)]
49pub struct MultiStream<B: EventStreamBackend> {
50    /// The map of stream IDs to their corresponding stream wrappers.
51    streams: StreamPool<EventStreamBackendWrapper<B>>,
52    /// The logger used by the server.
53    pub logger: Arc<ServerLogger>,
54    max_streams: usize,
55    gc: GcThread<B>,
56    shared_bindings_pool: Vec<(ManagedMemoryId, StreamId, u64)>,
57}
58
59/// A wrapper around a backend stream that includes synchronization metadata.
60///
61/// This includes the stream itself, a map of last synchronized cursors from other streams,
62/// and the current cursor position for this stream.
63pub(crate) struct StreamWrapper<B: EventStreamBackend> {
64    /// The underlying backend stream.
65    stream: B::Stream,
66    /// The current cursor position, representing the logical progress or version of operations on this stream.
67    cursor: u64,
68    /// A map tracking the last synchronized cursor positions from other streams.
69    last_synced: HashMap<usize, u64>,
70}
71
72/// Streams that are synchronized correctly after a [`MultiStream::resolve`] is called.
73pub struct ResolvedStreams<'a, B: EventStreamBackend> {
74    /// The cursor on the current stream.
75    ///
76    /// This cursor should be use for new allocations happening on the current stream.
77    pub cursor: u64,
78    streams: &'a mut StreamPool<EventStreamBackendWrapper<B>>,
79    analysis: SharedBindingAnalysis,
80    gc: &'a GcThread<B>,
81    /// The current stream where new tasks can be sent safely.
82    pub current: StreamId,
83}
84
85#[derive(Debug)]
86/// A task to be enqueue on the gc stream that will be clearned after an event is reached.
87pub struct GcTask<B: EventStreamBackend> {
88    to_drop: Box<dyn Any + Send + 'static>,
89    /// The event to sync making sure the bindings in the batch are ready to be reused by other streams.
90    event: B::Event,
91}
92
93impl<B: EventStreamBackend> GcTask<B> {
94    /// Creates a new task that will be clearned when the event is reached.
95    pub fn new<T: Send + 'static>(to_drop: T, event: B::Event) -> Self {
96        Self {
97            to_drop: Box::new(to_drop),
98            event,
99        }
100    }
101}
102
103#[derive(Debug)]
104struct EventStreamBackendWrapper<B: EventStreamBackend> {
105    backend: B,
106}
107
108impl<B: EventStreamBackend> StreamFactory for EventStreamBackendWrapper<B> {
109    type Stream = StreamWrapper<B>;
110
111    fn create(&mut self) -> Self::Stream {
112        StreamWrapper {
113            stream: self.backend.create_stream(),
114            cursor: 0,
115            last_synced: Default::default(),
116        }
117    }
118}
119
120#[derive(Debug)]
121struct GcThread<B: EventStreamBackend> {
122    sender: SyncSender<GcTask<B>>,
123}
124
125impl<B: EventStreamBackend> GcThread<B> {
126    fn new() -> GcThread<B> {
127        let (sender, recv) = std::sync::mpsc::sync_channel::<GcTask<B>>(8);
128
129        std::thread::spawn(move || {
130            while let Ok(event) = recv.recv() {
131                B::wait_event_sync(event.event).unwrap();
132                core::mem::drop(event.to_drop);
133            }
134        });
135
136        GcThread { sender }
137    }
138    fn register(&self, task: GcTask<B>) {
139        self.sender.send(task).unwrap()
140    }
141}
142
143fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
144    stream_id.value as usize % max_streams
145}
146
147impl<'a, B: EventStreamBackend> ResolvedStreams<'a, B> {
148    /// Get the stream associated to the given [`stream_id`](StreamId).
149    pub fn get(&mut self, stream_id: &StreamId) -> &mut B::Stream {
150        let stream = self.streams.get_mut(stream_id);
151        &mut stream.stream
152    }
153
154    /// Get the stream associated to the [current `stream_id`](StreamId).
155    pub fn current(&mut self) -> &mut B::Stream {
156        let stream = self.streams.get_mut(&self.current);
157        &mut stream.stream
158    }
159
160    /// Enqueue a task to be cleaned.
161    pub fn gc(&mut self, gc: GcTask<B>) {
162        self.gc.sender.send(gc).unwrap();
163    }
164}
165
166impl<'a, B: EventStreamBackend> Drop for ResolvedStreams<'a, B> {
167    fn drop(&mut self) {
168        if self.analysis.slices.is_empty() {
169            return;
170        }
171
172        let stream = self.streams.get_mut(&self.current);
173        let event_origin = B::flush(&mut stream.stream);
174
175        let stream_gc = &mut unsafe { self.streams.get_special(0) }.stream;
176        B::wait_event(stream_gc, event_origin);
177        let event = B::flush(stream_gc);
178
179        let mut ids = Vec::new();
180        self.analysis
181            .slices
182            .drain()
183            .for_each(|item| ids.extend(item.1));
184
185        self.gc.register(GcTask::new(ids, event));
186    }
187}
188
189impl<B: EventStreamBackend> MultiStream<B> {
190    /// Creates an empty multi-stream.
191    pub fn new(logger: Arc<ServerLogger>, backend: B, max_streams: u8) -> Self {
192        let wrapper = EventStreamBackendWrapper { backend };
193        Self {
194            streams: StreamPool::new(wrapper, max_streams, 1),
195            logger,
196            max_streams: max_streams as usize,
197            gc: GcThread::new(),
198            shared_bindings_pool: Vec::new(),
199        }
200    }
201
202    /// Enqueue a task to be cleaned.
203    pub fn gc(&mut self, gc: GcTask<B>) {
204        self.gc.sender.send(gc).unwrap();
205    }
206
207    /// Resolves and returns a mutable reference to the stream for the given ID, performing any necessary
208    /// alignment based on the provided bindings.
209    ///
210    /// This method ensures that the stream is synchronized with any shared bindings from other streams
211    /// before returning the stream reference.
212    pub fn resolve<'a>(
213        &mut self,
214        stream_id: StreamId,
215        handles: impl Iterator<Item = &'a Binding>,
216        enforce_healthy: bool,
217    ) -> Result<ResolvedStreams<'_, B>, ServerError> {
218        let analysis = self.align_streams(stream_id, handles);
219
220        let stream = self.streams.get_mut(&stream_id);
221        stream.cursor += 1;
222
223        if enforce_healthy && !B::is_healthy(&stream.stream) {
224            return Err(ServerError::Generic {
225                reason: "Can't resolve the stream since it is currently in an error state".into(),
226                backtrace: BackTrace::capture(),
227            });
228        }
229
230        Ok(ResolvedStreams {
231            cursor: stream.cursor,
232            streams: &mut self.streams,
233            current: stream_id,
234            analysis,
235            gc: &self.gc,
236        })
237    }
238
239    /// Aligns the target stream with other streams based on shared bindings.
240    ///
241    /// This initializes the stream if it doesn't exist, analyzes which originating streams need flushing
242    /// for synchronization, flushes them, and waits on the events in the target stream.
243    fn align_streams<'a>(
244        &mut self,
245        stream_id: StreamId,
246        handles: impl Iterator<Item = &'a Binding>,
247    ) -> SharedBindingAnalysis {
248        let analysis = self.update_shared_bindings(stream_id, handles);
249
250        self.apply_analysis(stream_id, analysis)
251    }
252
253    /// Update and analyzes the bindings to determine which streams need alignment (flushing and waiting).
254    ///
255    /// This checks for shared bindings from other streams and determines if synchronization is needed
256    /// based on cursor positions.
257    pub(crate) fn update_shared_bindings<'a>(
258        &mut self,
259        stream_id: StreamId,
260        handles: impl Iterator<Item = &'a Binding>,
261    ) -> SharedBindingAnalysis {
262        // We reset the memory pool for the info.
263        self.shared_bindings_pool.clear();
264
265        for handle in handles {
266            let index = stream_index(&handle.stream, self.max_streams);
267            let stream = unsafe { self.streams.get_mut_index(index) };
268            let cursor_handle = B::handle_cursor(&stream.stream, handle);
269
270            // We only add the info to be consider if the handle stream is different from the current
271            // stream.
272            if handle.stream != stream_id {
273                self.shared_bindings_pool.push((
274                    handle.memory.descriptor().id,
275                    handle.stream,
276                    cursor_handle,
277                ));
278            }
279        }
280
281        let mut analysis = SharedBindingAnalysis::default();
282        let current = self.streams.get_mut(&stream_id);
283
284        for (handle_id, stream, cursor) in self.shared_bindings_pool.iter() {
285            let index = stream_index(stream, self.max_streams);
286
287            if let Some(last_synced) = current.last_synced.get(&index) {
288                if last_synced < cursor {
289                    self.logger.log_streaming(
290                        |level| matches!(level, StreamingLogLevel::Full),
291                        || {
292                            format!(
293                                "Binding on {} is shared on {} since it's not sync {} < {}",
294                                stream, stream_id, last_synced, cursor
295                            )
296                        },
297                    );
298                    analysis.shared(*handle_id, index);
299                }
300            } else {
301                self.logger.log_streaming(
302                    |level| matches!(level, StreamingLogLevel::Full),
303                    || {
304                        format!(
305                            "Binding on {} is shared on {} since it was never synced.",
306                            stream, stream_id,
307                        )
308                    },
309                );
310                analysis.shared(*handle_id, index);
311            }
312        }
313
314        analysis
315    }
316
317    pub(crate) fn apply_analysis(
318        &mut self,
319        stream_id: StreamId,
320        analysis: SharedBindingAnalysis,
321    ) -> SharedBindingAnalysis {
322        if analysis.slices.is_empty() {
323            return analysis;
324        }
325
326        let mut events = Vec::with_capacity(analysis.slices.len());
327
328        unsafe {
329            for origin in analysis.slices.keys() {
330                let stream = self.streams.get_mut_index(*origin);
331                let event = B::flush(&mut stream.stream);
332
333                events.push(((origin, stream.cursor), event));
334            }
335        }
336
337        let stream = self.streams.get_mut(&stream_id);
338
339        for ((stream_origin, cursor_origin), event) in events {
340            stream.last_synced.insert(*stream_origin, cursor_origin);
341
342            self.logger.log_streaming(
343                |level| !matches!(level, StreamingLogLevel::Disabled),
344                || format!("Waiting on {stream_origin} from {stream_id}",),
345            );
346
347            B::wait_event(&mut stream.stream, event);
348        }
349
350        analysis
351    }
352}
353
354impl<B: EventStreamBackend> core::fmt::Debug for StreamWrapper<B> {
355    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
356        f.debug_struct("StreamWrapper")
357            .field("stream", &self.stream)
358            .field("cursor", &self.cursor)
359            .field("last_synced", &self.last_synced)
360            .finish()
361    }
362}
363
364#[derive(Default, Debug, PartialEq, Eq)]
365pub(crate) struct SharedBindingAnalysis {
366    slices: HashMap<usize, Vec<ManagedMemoryId>>,
367}
368
369impl SharedBindingAnalysis {
370    fn shared(&mut self, id: ManagedMemoryId, index: usize) {
371        match self.slices.get_mut(&index) {
372            Some(bindings) => bindings.push(id),
373            None => {
374                self.slices.insert(index, alloc::vec![id]);
375            }
376        }
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use crate::server::Handle;
383
384    use super::*;
385
386    const MAX_STREAMS: u8 = 4;
387
388    #[test_log::test]
389    fn test_analysis_shared_bindings_1() {
390        let logger = Arc::new(ServerLogger::default());
391        let stream_1 = StreamId { value: 1 };
392        let stream_2 = StreamId { value: 2 };
393
394        let binding_1 = handle(stream_1);
395        let binding_2 = handle(stream_2);
396
397        let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
398        ms.resolve(stream_1, [].into_iter(), false).unwrap();
399        ms.resolve(stream_2, [].into_iter(), false).unwrap();
400
401        let analysis = ms.update_shared_bindings(stream_1, [&binding_1, &binding_2].into_iter());
402
403        let mut expected = SharedBindingAnalysis::default();
404        expected.shared(
405            binding_2.memory.descriptor().id,
406            ms.streams.stream_index(&binding_2.stream),
407        );
408
409        assert_eq!(analysis, expected);
410    }
411
412    #[test_log::test]
413    fn test_analysis_shared_bindings_2() {
414        let logger = Arc::new(ServerLogger::default());
415        let stream_1 = StreamId { value: 1 };
416        let stream_2 = StreamId { value: 2 };
417
418        let binding_1 = handle(stream_1);
419        let binding_2 = handle(stream_2);
420        let binding_3 = handle(stream_1);
421
422        let mut ms = MultiStream::new(logger, TestBackend, 4);
423        ms.resolve(stream_1, [].into_iter(), false).unwrap();
424        ms.resolve(stream_2, [].into_iter(), false).unwrap();
425
426        let analysis =
427            ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
428
429        let mut expected = SharedBindingAnalysis::default();
430        expected.shared(
431            binding_2.memory.descriptor().id,
432            ms.streams.stream_index(&binding_2.stream),
433        );
434
435        assert_eq!(analysis, expected);
436    }
437
438    #[test_log::test]
439    fn test_analysis_no_shared() {
440        let logger = Arc::new(ServerLogger::default());
441        let stream_1 = StreamId { value: 1 };
442        let stream_2 = StreamId { value: 2 };
443
444        let binding_1 = handle(stream_1);
445        let binding_2 = handle(stream_1);
446        let binding_3 = handle(stream_1);
447
448        let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
449        ms.resolve(stream_1, [].into_iter(), false).unwrap();
450        ms.resolve(stream_2, [].into_iter(), false).unwrap();
451
452        let analysis =
453            ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
454
455        let expected = SharedBindingAnalysis::default();
456
457        assert_eq!(analysis, expected);
458    }
459
460    #[test_log::test]
461    fn test_state() {
462        let logger = Arc::new(ServerLogger::default());
463        let stream_1 = StreamId { value: 1 };
464        let stream_2 = StreamId { value: 2 };
465
466        let binding_1 = handle(stream_1);
467        let binding_2 = handle(stream_2);
468        let binding_3 = handle(stream_1);
469
470        let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
471        ms.resolve(stream_1, [].into_iter(), false).unwrap();
472        ms.resolve(stream_2, [].into_iter(), false).unwrap();
473
474        ms.resolve(
475            stream_1,
476            [&binding_1, &binding_2, &binding_3].into_iter(),
477            false,
478        )
479        .unwrap();
480
481        let stream1 = ms.streams.get_mut(&stream_1);
482        let index_2 = stream_index(&stream_2, MAX_STREAMS as usize);
483        assert_eq!(stream1.last_synced.get(&index_2), Some(&1));
484        assert_eq!(stream1.cursor, 2);
485
486        let stream2 = ms.streams.get_mut(&stream_2);
487        assert!(stream2.last_synced.is_empty());
488        assert_eq!(stream2.cursor, 1);
489    }
490
491    fn handle(stream: StreamId) -> Binding {
492        Handle::new(stream, 10).binding()
493    }
494
495    struct TestBackend;
496
497    #[derive(Debug)]
498    struct TestStream {}
499
500    #[derive(Debug)]
501    struct TestEvent {}
502
503    impl EventStreamBackend for TestBackend {
504        type Stream = TestStream;
505        type Event = TestEvent;
506
507        fn create_stream(&self) -> Self::Stream {
508            TestStream {}
509        }
510
511        fn flush(_stream: &mut Self::Stream) -> Self::Event {
512            TestEvent {}
513        }
514
515        fn wait_event(_stream: &mut Self::Stream, _event: Self::Event) {}
516
517        fn wait_event_sync(_event: Self::Event) -> Result<(), ServerError> {
518            Ok(())
519        }
520
521        fn handle_cursor(_stream: &Self::Stream, _handle: &Binding) -> u64 {
522            0
523        }
524
525        fn is_healthy(_stream: &Self::Stream) -> bool {
526            true
527        }
528    }
529}