cubecl_runtime/stream/
event.rs

1use crate::{
2    config::streaming::StreamingLogLevel,
3    logging::ServerLogger,
4    memory_management::SliceId,
5    server::Binding,
6    stream::{StreamFactory, StreamPool},
7};
8use core::any::Any;
9use cubecl_common::stream_id::StreamId;
10use hashbrown::HashMap;
11use std::sync::{Arc, mpsc::SyncSender};
12
13/// Trait defining the backend operations for managing streams and events.
14///
15/// This trait provides the necessary methods for initializing streams, flushing them to create events,
16/// and waiting on events for synchronization purposes.
17pub trait EventStreamBackend: 'static {
18    /// The type representing a stream in this backend.
19    type Stream: core::fmt::Debug;
20    /// The type representing an event in this backend.
21    type Event: Send + 'static;
22
23    /// Initializes and returns a new stream associated with the given stream ID.
24    fn create_stream(&self) -> Self::Stream;
25    /// Flushes the given stream, ensuring all pending operations are submitted, and returns an event
26    /// that can be used for synchronization.
27    fn flush(stream: &mut Self::Stream) -> Self::Event;
28    /// Makes the stream wait for the specified event to complete before proceeding with further operations.
29    fn wait_event(stream: &mut Self::Stream, event: Self::Event);
30    /// Wait for the given event synching the CPU.
31    fn wait_event_sync(event: Self::Event);
32}
33
34/// Manages multiple streams with synchronization logic based on shared bindings.
35///
36/// This struct handles the creation and alignment of streams to ensure proper synchronization
37/// when bindings (e.g., buffers) are shared across different streams.
38#[derive(Debug)]
39pub struct MultiStream<B: EventStreamBackend> {
40    /// The map of stream IDs to their corresponding stream wrappers.
41    streams: StreamPool<EventStreamBackendWrapper<B>>,
42    /// The logger used by the server.
43    pub logger: Arc<ServerLogger>,
44    max_streams: usize,
45    gc: GcThread<B>,
46}
47
48/// A wrapper around a backend stream that includes synchronization metadata.
49///
50/// This includes the stream itself, a map of last synchronized cursors from other streams,
51/// and the current cursor position for this stream.
52pub(crate) struct StreamWrapper<B: EventStreamBackend> {
53    /// The underlying backend stream.
54    stream: B::Stream,
55    /// The current cursor position, representing the logical progress or version of operations on this stream.
56    cursor: u64,
57    /// A map tracking the last synchronized cursor positions from other streams.
58    last_synced: HashMap<usize, u64>,
59}
60
61/// Streams that are synchronized correctly after a [MultiStream::resolve] is called.
62pub struct ResolvedStreams<'a, B: EventStreamBackend> {
63    /// The cursor on the current stream.
64    ///
65    /// This cursor should be use for new allocations happening on the current stream.
66    pub cursor: u64,
67    streams: &'a mut StreamPool<EventStreamBackendWrapper<B>>,
68    analysis: SharedBindingAnalysis,
69    gc: &'a GcThread<B>,
70    /// The current stream where new tasks can be sent safely.
71    pub current: StreamId,
72}
73
74#[derive(Debug)]
75/// A task to be enqueue on the gc stream that will be clearned after an event is reached.
76pub struct GcTask<B: EventStreamBackend> {
77    to_drop: Box<dyn Any + Send + 'static>,
78    /// The event to sync making sure the bindings in the batch are ready to be reused by other streams.
79    event: B::Event,
80}
81
82impl<B: EventStreamBackend> GcTask<B> {
83    /// Creates a new task that will be clearned when the event is reached.
84    pub fn new<T: Send + 'static>(to_drop: T, event: B::Event) -> Self {
85        Self {
86            to_drop: Box::new(to_drop),
87            event,
88        }
89    }
90}
91
92#[derive(Debug)]
93struct EventStreamBackendWrapper<B: EventStreamBackend> {
94    backend: B,
95}
96
97impl<B: EventStreamBackend> StreamFactory for EventStreamBackendWrapper<B> {
98    type Stream = StreamWrapper<B>;
99
100    fn create(&mut self) -> Self::Stream {
101        StreamWrapper {
102            stream: self.backend.create_stream(),
103            cursor: 0,
104            last_synced: Default::default(),
105        }
106    }
107}
108
109#[derive(Debug)]
110struct GcThread<B: EventStreamBackend> {
111    sender: SyncSender<GcTask<B>>,
112}
113
114impl<B: EventStreamBackend> GcThread<B> {
115    fn new() -> GcThread<B> {
116        let (sender, recv) = std::sync::mpsc::sync_channel::<GcTask<B>>(8);
117
118        std::thread::spawn(move || {
119            while let Ok(event) = recv.recv() {
120                B::wait_event_sync(event.event);
121                core::mem::drop(event.to_drop);
122            }
123        });
124
125        GcThread { sender }
126    }
127    fn register(&self, task: GcTask<B>) {
128        self.sender.send(task).unwrap()
129    }
130}
131
132fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
133    stream_id.value as usize % max_streams
134}
135
136impl<'a, B: EventStreamBackend> ResolvedStreams<'a, B> {
137    /// Get the stream associated to the given [stream_id](StreamId).
138    pub fn get(&mut self, stream_id: &StreamId) -> &mut B::Stream {
139        let stream = self.streams.get_mut(stream_id);
140        &mut stream.stream
141    }
142
143    /// Get the stream associated to the [current stream_id](StreamId).
144    pub fn current(&mut self) -> &mut B::Stream {
145        let stream = self.streams.get_mut(&self.current);
146        &mut stream.stream
147    }
148}
149
150impl<'a, B: EventStreamBackend> Drop for ResolvedStreams<'a, B> {
151    fn drop(&mut self) {
152        if self.analysis.slices.is_empty() {
153            return;
154        }
155
156        let stream = self.streams.get_mut(&self.current);
157        let event_origin = B::flush(&mut stream.stream);
158
159        let stream_gc = &mut unsafe { self.streams.get_special(0) }.stream;
160        B::wait_event(stream_gc, event_origin);
161        let event = B::flush(stream_gc);
162
163        let mut ids = Vec::new();
164        self.analysis
165            .slices
166            .drain()
167            .for_each(|item| ids.extend(item.1));
168
169        self.gc.register(GcTask::new(ids, event));
170    }
171}
172
173impl<B: EventStreamBackend> MultiStream<B> {
174    /// Creates an empty multi-stream.
175    pub fn new(logger: Arc<ServerLogger>, backend: B, max_streams: u8) -> Self {
176        let wrapper = EventStreamBackendWrapper { backend };
177        Self {
178            streams: StreamPool::new(wrapper, max_streams, 1),
179            logger,
180            max_streams: max_streams as usize,
181            gc: GcThread::new(),
182        }
183    }
184
185    /// Enqueue a task to be cleaned.
186    pub fn gc(&mut self, gc: GcTask<B>) {
187        self.gc.sender.send(gc).unwrap();
188    }
189
190    /// Resolves and returns a mutable reference to the stream for the given ID, performing any necessary
191    /// alignment based on the provided bindings.
192    ///
193    /// This method ensures that the stream is synchronized with any shared bindings from other streams
194    /// before returning the stream reference.
195    pub fn resolve<'a>(
196        &mut self,
197        stream_id: StreamId,
198        bindings: impl Iterator<Item = &'a Binding>,
199    ) -> ResolvedStreams<'_, B> {
200        let analysis = self.align_streams(stream_id, bindings);
201
202        let stream = self.streams.get_mut(&stream_id);
203        stream.cursor += 1;
204
205        ResolvedStreams {
206            cursor: stream.cursor,
207            streams: &mut self.streams,
208            current: stream_id,
209            analysis,
210            gc: &self.gc,
211        }
212    }
213
214    /// Aligns the target stream with other streams based on shared bindings.
215    ///
216    /// This initializes the stream if it doesn't exist, analyzes which originating streams need flushing
217    /// for synchronization, flushes them, and waits on the events in the target stream.
218    fn align_streams<'a>(
219        &mut self,
220        stream_id: StreamId,
221        bindings: impl Iterator<Item = &'a Binding>,
222    ) -> SharedBindingAnalysis {
223        let analysis = self.update_shared_bindings(stream_id, bindings);
224
225        self.apply_analysis(stream_id, analysis)
226    }
227
228    /// Update and analyzes the bindings to determine which streams need alignment (flushing and waiting).
229    ///
230    /// This checks for shared bindings from other streams and determines if synchronization is needed
231    /// based on cursor positions.
232    pub(crate) fn update_shared_bindings<'a>(
233        &mut self,
234        stream_id: StreamId,
235        bindings: impl Iterator<Item = &'a Binding>,
236    ) -> SharedBindingAnalysis {
237        let mut analysis = SharedBindingAnalysis::default();
238        let current = self.streams.get_mut(&stream_id);
239
240        for binding in bindings {
241            if stream_id != binding.stream {
242                let index = stream_index(&binding.stream, self.max_streams);
243
244                if let Some(last_synced) = current.last_synced.get(&index) {
245                    if *last_synced < binding.cursor {
246                        self.logger.log_streaming(
247                            |level| matches!(level, StreamingLogLevel::Full),
248                            || {
249                                format!(
250                                    "Binding on {} is shared on {} since it's not sync {} < {}",
251                                    binding.stream, stream_id, last_synced, binding.cursor
252                                )
253                            },
254                        );
255                        analysis.shared(binding, index);
256                    }
257                } else {
258                    self.logger.log_streaming(
259                        |level| matches!(level, StreamingLogLevel::Full),
260                        || {
261                            format!(
262                                "Binding on {} is shared on {} since it was never synced.",
263                                binding.stream, stream_id,
264                            )
265                        },
266                    );
267                    analysis.shared(binding, index);
268                }
269            }
270        }
271
272        analysis
273    }
274
275    pub(crate) fn apply_analysis(
276        &mut self,
277        stream_id: StreamId,
278        analysis: SharedBindingAnalysis,
279    ) -> SharedBindingAnalysis {
280        if analysis.slices.is_empty() {
281            return analysis;
282        }
283
284        let mut events = Vec::with_capacity(analysis.slices.len());
285
286        unsafe {
287            for origin in analysis.slices.keys() {
288                let stream = self.streams.get_mut_index(*origin);
289                let event = B::flush(&mut stream.stream);
290
291                events.push(((origin, stream.cursor), event));
292            }
293        }
294
295        let stream = self.streams.get_mut(&stream_id);
296
297        for ((stream_origin, cursor_origin), event) in events {
298            stream.last_synced.insert(*stream_origin, cursor_origin);
299
300            self.logger.log_streaming(
301                |level| !matches!(level, StreamingLogLevel::Disabled),
302                || format!("Waiting on {stream_origin} from {stream_id}",),
303            );
304
305            B::wait_event(&mut stream.stream, event);
306        }
307
308        analysis
309    }
310}
311
312impl<B: EventStreamBackend> core::fmt::Debug for StreamWrapper<B> {
313    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
314        f.debug_struct("StreamWrapper")
315            .field("stream", &self.stream)
316            .field("cursor", &self.cursor)
317            .field("last_synced", &self.last_synced)
318            .finish()
319    }
320}
321
322#[derive(Default, Debug, PartialEq, Eq)]
323pub(crate) struct SharedBindingAnalysis {
324    slices: HashMap<usize, Vec<SliceId>>,
325}
326
327impl SharedBindingAnalysis {
328    fn shared(&mut self, binding: &Binding, index: usize) {
329        match self.slices.get_mut(&index) {
330            Some(bindings) => bindings.push(*binding.memory.id()),
331            None => {
332                self.slices.insert(index, vec![*binding.memory.id()]);
333            }
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::{memory_management::SliceHandle, server::Handle};
342
343    const MAX_STREAMS: u8 = 4;
344
345    #[test]
346    fn test_analysis_shared_bindings_1() {
347        let logger = Arc::new(ServerLogger::default());
348        let stream_1 = StreamId { value: 1 };
349        let stream_2 = StreamId { value: 2 };
350
351        let binding_1 = binding(stream_1);
352        let binding_2 = binding(stream_2);
353
354        let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
355        ms.resolve(stream_1, [].into_iter());
356        ms.resolve(stream_2, [].into_iter());
357
358        let analysis = ms.update_shared_bindings(stream_1, [&binding_1, &binding_2].into_iter());
359
360        let mut expected = SharedBindingAnalysis::default();
361        expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
362
363        assert_eq!(analysis, expected);
364    }
365
366    #[test]
367    fn test_analysis_shared_bindings_2() {
368        let logger = Arc::new(ServerLogger::default());
369        let stream_1 = StreamId { value: 1 };
370        let stream_2 = StreamId { value: 2 };
371
372        let binding_1 = binding(stream_1);
373        let binding_2 = binding(stream_2);
374        let binding_3 = binding(stream_1);
375
376        let mut ms = MultiStream::new(logger, TestBackend, 4);
377        ms.resolve(stream_1, [].into_iter());
378        ms.resolve(stream_2, [].into_iter());
379
380        let analysis =
381            ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
382
383        let mut expected = SharedBindingAnalysis::default();
384        expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
385
386        assert_eq!(analysis, expected);
387    }
388
389    #[test]
390    fn test_analysis_no_shared() {
391        let logger = Arc::new(ServerLogger::default());
392        let stream_1 = StreamId { value: 1 };
393        let stream_2 = StreamId { value: 2 };
394
395        let binding_1 = binding(stream_1);
396        let binding_2 = binding(stream_1);
397        let binding_3 = binding(stream_1);
398
399        let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
400        ms.resolve(stream_1, [].into_iter());
401        ms.resolve(stream_2, [].into_iter());
402
403        let analysis =
404            ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
405
406        let expected = SharedBindingAnalysis::default();
407
408        assert_eq!(analysis, expected);
409    }
410
411    #[test]
412    fn test_state() {
413        let logger = Arc::new(ServerLogger::default());
414        let stream_1 = StreamId { value: 1 };
415        let stream_2 = StreamId { value: 2 };
416
417        let binding_1 = binding(stream_1);
418        let binding_2 = binding(stream_2);
419        let binding_3 = binding(stream_1);
420
421        let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
422        ms.resolve(stream_1, [].into_iter());
423        ms.resolve(stream_2, [].into_iter());
424
425        ms.resolve(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
426
427        let stream1 = ms.streams.get_mut(&stream_1);
428        let index_2 = stream_index(&stream_2, MAX_STREAMS as usize);
429        assert_eq!(stream1.last_synced.get(&index_2), Some(&1));
430        assert_eq!(stream1.cursor, 2);
431
432        let stream2 = ms.streams.get_mut(&stream_2);
433        assert!(stream2.last_synced.is_empty());
434        assert_eq!(stream2.cursor, 1);
435    }
436
437    fn binding(stream: StreamId) -> Binding {
438        Handle::new(SliceHandle::new(), None, None, stream, 0, 10).binding()
439    }
440
441    struct TestBackend;
442
443    #[derive(Debug)]
444    struct TestStream {}
445
446    #[derive(Debug)]
447    struct TestEvent {}
448
449    impl EventStreamBackend for TestBackend {
450        type Stream = TestStream;
451        type Event = TestEvent;
452
453        fn create_stream(&self) -> Self::Stream {
454            TestStream {}
455        }
456
457        fn flush(_stream: &mut Self::Stream) -> Self::Event {
458            TestEvent {}
459        }
460
461        fn wait_event(_stream: &mut Self::Stream, _event: Self::Event) {}
462
463        fn wait_event_sync(_event: Self::Event) {}
464    }
465}