cubecl_runtime/stream/
event.rs

1use crate::{
2    config::streaming::StreamingLogLevel,
3    logging::ServerLogger,
4    memory_management::SliceId,
5    server::Binding,
6    stream::{StreamFactory, StreamPool},
7};
8use core::any::Any;
9use cubecl_common::stream_id::StreamId;
10use hashbrown::HashMap;
11use std::sync::{Arc, mpsc::SyncSender};
12
13/// Trait defining the backend operations for managing streams and events.
14///
15/// This trait provides the necessary methods for initializing streams, flushing them to create events,
16/// and waiting on events for synchronization purposes.
17pub trait EventStreamBackend: 'static {
18    /// The type representing a stream in this backend.
19    type Stream: core::fmt::Debug;
20    /// The type representing an event in this backend.
21    type Event: Send + 'static;
22
23    /// Initializes and returns a new stream associated with the given stream ID.
24    fn create_stream(&self) -> Self::Stream;
25    /// Flushes the given stream, ensuring all pending operations are submitted, and returns an event
26    /// that can be used for synchronization.
27    fn flush(stream: &mut Self::Stream) -> Self::Event;
28    /// Makes the stream wait for the specified event to complete before proceeding with further operations.
29    fn wait_event(stream: &mut Self::Stream, event: Self::Event);
30    /// Wait for the given event synching the CPU.
31    fn wait_event_sync(event: Self::Event);
32}
33
34/// Manages multiple streams with synchronization logic based on shared bindings.
35///
36/// This struct handles the creation and alignment of streams to ensure proper synchronization
37/// when bindings (e.g., buffers) are shared across different streams.
38#[derive(Debug)]
39pub struct MultiStream<B: EventStreamBackend> {
40    /// The map of stream IDs to their corresponding stream wrappers.
41    streams: StreamPool<EventStreamBackendWrapper<B>>,
42    /// The logger used by the server.
43    pub logger: Arc<ServerLogger>,
44    max_streams: usize,
45    gc: GcThread<B>,
46}
47
48/// A wrapper around a backend stream that includes synchronization metadata.
49///
50/// This includes the stream itself, a map of last synchronized cursors from other streams,
51/// and the current cursor position for this stream.
52pub(crate) struct StreamWrapper<B: EventStreamBackend> {
53    /// The underlying backend stream.
54    stream: B::Stream,
55    /// The current cursor position, representing the logical progress or version of operations on this stream.
56    cursor: u64,
57    /// A map tracking the last synchronized cursor positions from other streams.
58    last_synced: HashMap<usize, u64>,
59}
60
61/// Streams that are synchronized correctly after a [MultiStream::resolve] is called.
62pub struct ResolvedStreams<'a, B: EventStreamBackend> {
63    /// The cursor on the current stream.
64    ///
65    /// This cursor should be use for new allocations happening on the current stream.
66    pub cursor: u64,
67    streams: &'a mut StreamPool<EventStreamBackendWrapper<B>>,
68    analysis: SharedBindingAnalysis,
69    gc: &'a GcThread<B>,
70    /// The current stream where new tasks can be sent safely.
71    pub current: StreamId,
72}
73
74#[derive(Debug)]
75/// A task to be enqueue on the gc stream that will be clearned after an event is reached.
76pub struct GcTask<B: EventStreamBackend> {
77    to_drop: Box<dyn Any + Send + 'static>,
78    /// The event to sync making sure the bindings in the batch are ready to be reused by other streams.
79    event: B::Event,
80}
81
82impl<B: EventStreamBackend> GcTask<B> {
83    /// Creates a new task that will be clearned when the event is reached.
84    pub fn new<T: Send + 'static>(to_drop: T, event: B::Event) -> Self {
85        Self {
86            to_drop: Box::new(to_drop),
87            event,
88        }
89    }
90}
91
92#[derive(Debug)]
93struct EventStreamBackendWrapper<B: EventStreamBackend> {
94    backend: B,
95}
96
97impl<B: EventStreamBackend> StreamFactory for EventStreamBackendWrapper<B> {
98    type Stream = StreamWrapper<B>;
99
100    fn create(&mut self) -> Self::Stream {
101        StreamWrapper {
102            stream: self.backend.create_stream(),
103            cursor: 0,
104            last_synced: Default::default(),
105        }
106    }
107}
108
109#[derive(Debug)]
110struct GcThread<B: EventStreamBackend> {
111    sender: SyncSender<GcTask<B>>,
112}
113
114impl<B: EventStreamBackend> GcThread<B> {
115    fn new() -> GcThread<B> {
116        let (sender, recv) = std::sync::mpsc::sync_channel::<GcTask<B>>(8);
117
118        std::thread::spawn(move || {
119            while let Ok(event) = recv.recv() {
120                B::wait_event_sync(event.event);
121                core::mem::drop(event.to_drop);
122            }
123        });
124
125        GcThread { sender }
126    }
127    fn register(&self, task: GcTask<B>) {
128        self.sender.send(task).unwrap()
129    }
130}
131
132fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
133    stream_id.value as usize % max_streams
134}
135
136impl<'a, B: EventStreamBackend> ResolvedStreams<'a, B> {
137    /// Get the stream associated to the given [stream_id](StreamId).
138    pub fn get(&mut self, stream_id: &StreamId) -> &mut B::Stream {
139        let stream = self.streams.get_mut(stream_id);
140        &mut stream.stream
141    }
142
143    /// Get the stream associated to the [current stream_id](StreamId).
144    pub fn current(&mut self) -> &mut B::Stream {
145        let stream = self.streams.get_mut(&self.current);
146        &mut stream.stream
147    }
148
149    /// Enqueue a task to be cleaned.
150    pub fn gc(&mut self, gc: GcTask<B>) {
151        self.gc.sender.send(gc).unwrap();
152    }
153}
154
155impl<'a, B: EventStreamBackend> Drop for ResolvedStreams<'a, B> {
156    fn drop(&mut self) {
157        if self.analysis.slices.is_empty() {
158            return;
159        }
160
161        let stream = self.streams.get_mut(&self.current);
162        let event_origin = B::flush(&mut stream.stream);
163
164        let stream_gc = &mut unsafe { self.streams.get_special(0) }.stream;
165        B::wait_event(stream_gc, event_origin);
166        let event = B::flush(stream_gc);
167
168        let mut ids = Vec::new();
169        self.analysis
170            .slices
171            .drain()
172            .for_each(|item| ids.extend(item.1));
173
174        self.gc.register(GcTask::new(ids, event));
175    }
176}
177
178impl<B: EventStreamBackend> MultiStream<B> {
179    /// Creates an empty multi-stream.
180    pub fn new(logger: Arc<ServerLogger>, backend: B, max_streams: u8) -> Self {
181        let wrapper = EventStreamBackendWrapper { backend };
182        Self {
183            streams: StreamPool::new(wrapper, max_streams, 1),
184            logger,
185            max_streams: max_streams as usize,
186            gc: GcThread::new(),
187        }
188    }
189
190    /// Enqueue a task to be cleaned.
191    pub fn gc(&mut self, gc: GcTask<B>) {
192        self.gc.sender.send(gc).unwrap();
193    }
194
195    /// Resolves and returns a mutable reference to the stream for the given ID, performing any necessary
196    /// alignment based on the provided bindings.
197    ///
198    /// This method ensures that the stream is synchronized with any shared bindings from other streams
199    /// before returning the stream reference.
200    pub fn resolve<'a>(
201        &mut self,
202        stream_id: StreamId,
203        bindings: impl Iterator<Item = &'a Binding>,
204    ) -> ResolvedStreams<'_, B> {
205        let analysis = self.align_streams(stream_id, bindings);
206
207        let stream = self.streams.get_mut(&stream_id);
208        stream.cursor += 1;
209
210        ResolvedStreams {
211            cursor: stream.cursor,
212            streams: &mut self.streams,
213            current: stream_id,
214            analysis,
215            gc: &self.gc,
216        }
217    }
218
219    /// Aligns the target stream with other streams based on shared bindings.
220    ///
221    /// This initializes the stream if it doesn't exist, analyzes which originating streams need flushing
222    /// for synchronization, flushes them, and waits on the events in the target stream.
223    fn align_streams<'a>(
224        &mut self,
225        stream_id: StreamId,
226        bindings: impl Iterator<Item = &'a Binding>,
227    ) -> SharedBindingAnalysis {
228        let analysis = self.update_shared_bindings(stream_id, bindings);
229
230        self.apply_analysis(stream_id, analysis)
231    }
232
233    /// Update and analyzes the bindings to determine which streams need alignment (flushing and waiting).
234    ///
235    /// This checks for shared bindings from other streams and determines if synchronization is needed
236    /// based on cursor positions.
237    pub(crate) fn update_shared_bindings<'a>(
238        &mut self,
239        stream_id: StreamId,
240        bindings: impl Iterator<Item = &'a Binding>,
241    ) -> SharedBindingAnalysis {
242        let mut analysis = SharedBindingAnalysis::default();
243        let current = self.streams.get_mut(&stream_id);
244
245        for binding in bindings {
246            if stream_id != binding.stream {
247                let index = stream_index(&binding.stream, self.max_streams);
248
249                if let Some(last_synced) = current.last_synced.get(&index) {
250                    if *last_synced < binding.cursor {
251                        self.logger.log_streaming(
252                            |level| matches!(level, StreamingLogLevel::Full),
253                            || {
254                                format!(
255                                    "Binding on {} is shared on {} since it's not sync {} < {}",
256                                    binding.stream, stream_id, last_synced, binding.cursor
257                                )
258                            },
259                        );
260                        analysis.shared(binding, index);
261                    }
262                } else {
263                    self.logger.log_streaming(
264                        |level| matches!(level, StreamingLogLevel::Full),
265                        || {
266                            format!(
267                                "Binding on {} is shared on {} since it was never synced.",
268                                binding.stream, stream_id,
269                            )
270                        },
271                    );
272                    analysis.shared(binding, index);
273                }
274            }
275        }
276
277        analysis
278    }
279
280    pub(crate) fn apply_analysis(
281        &mut self,
282        stream_id: StreamId,
283        analysis: SharedBindingAnalysis,
284    ) -> SharedBindingAnalysis {
285        if analysis.slices.is_empty() {
286            return analysis;
287        }
288
289        let mut events = Vec::with_capacity(analysis.slices.len());
290
291        unsafe {
292            for origin in analysis.slices.keys() {
293                let stream = self.streams.get_mut_index(*origin);
294                let event = B::flush(&mut stream.stream);
295
296                events.push(((origin, stream.cursor), event));
297            }
298        }
299
300        let stream = self.streams.get_mut(&stream_id);
301
302        for ((stream_origin, cursor_origin), event) in events {
303            stream.last_synced.insert(*stream_origin, cursor_origin);
304
305            self.logger.log_streaming(
306                |level| !matches!(level, StreamingLogLevel::Disabled),
307                || format!("Waiting on {stream_origin} from {stream_id}",),
308            );
309
310            B::wait_event(&mut stream.stream, event);
311        }
312
313        analysis
314    }
315}
316
317impl<B: EventStreamBackend> core::fmt::Debug for StreamWrapper<B> {
318    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
319        f.debug_struct("StreamWrapper")
320            .field("stream", &self.stream)
321            .field("cursor", &self.cursor)
322            .field("last_synced", &self.last_synced)
323            .finish()
324    }
325}
326
327#[derive(Default, Debug, PartialEq, Eq)]
328pub(crate) struct SharedBindingAnalysis {
329    slices: HashMap<usize, Vec<SliceId>>,
330}
331
332impl SharedBindingAnalysis {
333    fn shared(&mut self, binding: &Binding, index: usize) {
334        match self.slices.get_mut(&index) {
335            Some(bindings) => bindings.push(*binding.memory.id()),
336            None => {
337                self.slices.insert(index, vec![*binding.memory.id()]);
338            }
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use crate::{memory_management::SliceHandle, server::Handle};
347
348    const MAX_STREAMS: u8 = 4;
349
350    #[test]
351    fn test_analysis_shared_bindings_1() {
352        let logger = Arc::new(ServerLogger::default());
353        let stream_1 = StreamId { value: 1 };
354        let stream_2 = StreamId { value: 2 };
355
356        let binding_1 = binding(stream_1);
357        let binding_2 = binding(stream_2);
358
359        let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
360        ms.resolve(stream_1, [].into_iter());
361        ms.resolve(stream_2, [].into_iter());
362
363        let analysis = ms.update_shared_bindings(stream_1, [&binding_1, &binding_2].into_iter());
364
365        let mut expected = SharedBindingAnalysis::default();
366        expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
367
368        assert_eq!(analysis, expected);
369    }
370
371    #[test]
372    fn test_analysis_shared_bindings_2() {
373        let logger = Arc::new(ServerLogger::default());
374        let stream_1 = StreamId { value: 1 };
375        let stream_2 = StreamId { value: 2 };
376
377        let binding_1 = binding(stream_1);
378        let binding_2 = binding(stream_2);
379        let binding_3 = binding(stream_1);
380
381        let mut ms = MultiStream::new(logger, TestBackend, 4);
382        ms.resolve(stream_1, [].into_iter());
383        ms.resolve(stream_2, [].into_iter());
384
385        let analysis =
386            ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
387
388        let mut expected = SharedBindingAnalysis::default();
389        expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
390
391        assert_eq!(analysis, expected);
392    }
393
394    #[test]
395    fn test_analysis_no_shared() {
396        let logger = Arc::new(ServerLogger::default());
397        let stream_1 = StreamId { value: 1 };
398        let stream_2 = StreamId { value: 2 };
399
400        let binding_1 = binding(stream_1);
401        let binding_2 = binding(stream_1);
402        let binding_3 = binding(stream_1);
403
404        let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
405        ms.resolve(stream_1, [].into_iter());
406        ms.resolve(stream_2, [].into_iter());
407
408        let analysis =
409            ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
410
411        let expected = SharedBindingAnalysis::default();
412
413        assert_eq!(analysis, expected);
414    }
415
416    #[test]
417    fn test_state() {
418        let logger = Arc::new(ServerLogger::default());
419        let stream_1 = StreamId { value: 1 };
420        let stream_2 = StreamId { value: 2 };
421
422        let binding_1 = binding(stream_1);
423        let binding_2 = binding(stream_2);
424        let binding_3 = binding(stream_1);
425
426        let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
427        ms.resolve(stream_1, [].into_iter());
428        ms.resolve(stream_2, [].into_iter());
429
430        ms.resolve(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
431
432        let stream1 = ms.streams.get_mut(&stream_1);
433        let index_2 = stream_index(&stream_2, MAX_STREAMS as usize);
434        assert_eq!(stream1.last_synced.get(&index_2), Some(&1));
435        assert_eq!(stream1.cursor, 2);
436
437        let stream2 = ms.streams.get_mut(&stream_2);
438        assert!(stream2.last_synced.is_empty());
439        assert_eq!(stream2.cursor, 1);
440    }
441
442    fn binding(stream: StreamId) -> Binding {
443        Handle::new(SliceHandle::new(), None, None, stream, 0, 10).binding()
444    }
445
446    struct TestBackend;
447
448    #[derive(Debug)]
449    struct TestStream {}
450
451    #[derive(Debug)]
452    struct TestEvent {}
453
454    impl EventStreamBackend for TestBackend {
455        type Stream = TestStream;
456        type Event = TestEvent;
457
458        fn create_stream(&self) -> Self::Stream {
459            TestStream {}
460        }
461
462        fn flush(_stream: &mut Self::Stream) -> Self::Event {
463            TestEvent {}
464        }
465
466        fn wait_event(_stream: &mut Self::Stream, _event: Self::Event) {}
467
468        fn wait_event_sync(_event: Self::Event) {}
469    }
470}