1use crate::{
2 config::streaming::StreamingLogLevel,
3 logging::ServerLogger,
4 memory_management::SliceId,
5 server::{Binding, ExecutionError},
6 stream::{StreamFactory, StreamPool},
7};
8use core::any::Any;
9use cubecl_common::stream_id::StreamId;
10use hashbrown::HashMap;
11use std::{
12 boxed::Box,
13 format,
14 sync::{Arc, mpsc::SyncSender},
15 vec,
16 vec::Vec,
17};
18
19pub trait EventStreamBackend: 'static {
24 type Stream: core::fmt::Debug;
26 type Event: Send + 'static;
28
29 fn create_stream(&self) -> Self::Stream;
31 fn flush(stream: &mut Self::Stream) -> Self::Event;
34 fn wait_event(stream: &mut Self::Stream, event: Self::Event);
36 fn wait_event_sync(event: Self::Event) -> Result<(), ExecutionError>;
38}
39
40#[derive(Debug)]
45pub struct MultiStream<B: EventStreamBackend> {
46 streams: StreamPool<EventStreamBackendWrapper<B>>,
48 pub logger: Arc<ServerLogger>,
50 max_streams: usize,
51 gc: GcThread<B>,
52}
53
54pub(crate) struct StreamWrapper<B: EventStreamBackend> {
59 stream: B::Stream,
61 cursor: u64,
63 last_synced: HashMap<usize, u64>,
65}
66
67pub struct ResolvedStreams<'a, B: EventStreamBackend> {
69 pub cursor: u64,
73 streams: &'a mut StreamPool<EventStreamBackendWrapper<B>>,
74 analysis: SharedBindingAnalysis,
75 gc: &'a GcThread<B>,
76 pub current: StreamId,
78}
79
80#[derive(Debug)]
81pub struct GcTask<B: EventStreamBackend> {
83 to_drop: Box<dyn Any + Send + 'static>,
84 event: B::Event,
86}
87
88impl<B: EventStreamBackend> GcTask<B> {
89 pub fn new<T: Send + 'static>(to_drop: T, event: B::Event) -> Self {
91 Self {
92 to_drop: Box::new(to_drop),
93 event,
94 }
95 }
96}
97
98#[derive(Debug)]
99struct EventStreamBackendWrapper<B: EventStreamBackend> {
100 backend: B,
101}
102
103impl<B: EventStreamBackend> StreamFactory for EventStreamBackendWrapper<B> {
104 type Stream = StreamWrapper<B>;
105
106 fn create(&mut self) -> Self::Stream {
107 StreamWrapper {
108 stream: self.backend.create_stream(),
109 cursor: 0,
110 last_synced: Default::default(),
111 }
112 }
113}
114
115#[derive(Debug)]
116struct GcThread<B: EventStreamBackend> {
117 sender: SyncSender<GcTask<B>>,
118}
119
120impl<B: EventStreamBackend> GcThread<B> {
121 fn new() -> GcThread<B> {
122 let (sender, recv) = std::sync::mpsc::sync_channel::<GcTask<B>>(8);
123
124 std::thread::spawn(move || {
125 while let Ok(event) = recv.recv() {
126 B::wait_event_sync(event.event).unwrap();
127 core::mem::drop(event.to_drop);
128 }
129 });
130
131 GcThread { sender }
132 }
133 fn register(&self, task: GcTask<B>) {
134 self.sender.send(task).unwrap()
135 }
136}
137
138fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
139 stream_id.value as usize % max_streams
140}
141
142impl<'a, B: EventStreamBackend> ResolvedStreams<'a, B> {
143 pub fn get(&mut self, stream_id: &StreamId) -> &mut B::Stream {
145 let stream = self.streams.get_mut(stream_id);
146 &mut stream.stream
147 }
148
149 pub fn current(&mut self) -> &mut B::Stream {
151 let stream = self.streams.get_mut(&self.current);
152 &mut stream.stream
153 }
154
155 pub fn gc(&mut self, gc: GcTask<B>) {
157 self.gc.sender.send(gc).unwrap();
158 }
159}
160
161impl<'a, B: EventStreamBackend> Drop for ResolvedStreams<'a, B> {
162 fn drop(&mut self) {
163 if self.analysis.slices.is_empty() {
164 return;
165 }
166
167 let stream = self.streams.get_mut(&self.current);
168 let event_origin = B::flush(&mut stream.stream);
169
170 let stream_gc = &mut unsafe { self.streams.get_special(0) }.stream;
171 B::wait_event(stream_gc, event_origin);
172 let event = B::flush(stream_gc);
173
174 let mut ids = Vec::new();
175 self.analysis
176 .slices
177 .drain()
178 .for_each(|item| ids.extend(item.1));
179
180 self.gc.register(GcTask::new(ids, event));
181 }
182}
183
184impl<B: EventStreamBackend> MultiStream<B> {
185 pub fn new(logger: Arc<ServerLogger>, backend: B, max_streams: u8) -> Self {
187 let wrapper = EventStreamBackendWrapper { backend };
188 Self {
189 streams: StreamPool::new(wrapper, max_streams, 1),
190 logger,
191 max_streams: max_streams as usize,
192 gc: GcThread::new(),
193 }
194 }
195
196 pub fn gc(&mut self, gc: GcTask<B>) {
198 self.gc.sender.send(gc).unwrap();
199 }
200
201 pub fn resolve<'a>(
207 &mut self,
208 stream_id: StreamId,
209 bindings: impl Iterator<Item = &'a Binding>,
210 ) -> ResolvedStreams<'_, B> {
211 let analysis = self.align_streams(stream_id, bindings);
212
213 let stream = self.streams.get_mut(&stream_id);
214 stream.cursor += 1;
215
216 ResolvedStreams {
217 cursor: stream.cursor,
218 streams: &mut self.streams,
219 current: stream_id,
220 analysis,
221 gc: &self.gc,
222 }
223 }
224
225 fn align_streams<'a>(
230 &mut self,
231 stream_id: StreamId,
232 bindings: impl Iterator<Item = &'a Binding>,
233 ) -> SharedBindingAnalysis {
234 let analysis = self.update_shared_bindings(stream_id, bindings);
235
236 self.apply_analysis(stream_id, analysis)
237 }
238
239 pub(crate) fn update_shared_bindings<'a>(
244 &mut self,
245 stream_id: StreamId,
246 bindings: impl Iterator<Item = &'a Binding>,
247 ) -> SharedBindingAnalysis {
248 let mut analysis = SharedBindingAnalysis::default();
249 let current = self.streams.get_mut(&stream_id);
250
251 for binding in bindings {
252 if stream_id != binding.stream {
253 let index = stream_index(&binding.stream, self.max_streams);
254
255 if let Some(last_synced) = current.last_synced.get(&index) {
256 if *last_synced < binding.cursor {
257 self.logger.log_streaming(
258 |level| matches!(level, StreamingLogLevel::Full),
259 || {
260 format!(
261 "Binding on {} is shared on {} since it's not sync {} < {}",
262 binding.stream, stream_id, last_synced, binding.cursor
263 )
264 },
265 );
266 analysis.shared(binding, index);
267 }
268 } else {
269 self.logger.log_streaming(
270 |level| matches!(level, StreamingLogLevel::Full),
271 || {
272 format!(
273 "Binding on {} is shared on {} since it was never synced.",
274 binding.stream, stream_id,
275 )
276 },
277 );
278 analysis.shared(binding, index);
279 }
280 }
281 }
282
283 analysis
284 }
285
286 pub(crate) fn apply_analysis(
287 &mut self,
288 stream_id: StreamId,
289 analysis: SharedBindingAnalysis,
290 ) -> SharedBindingAnalysis {
291 if analysis.slices.is_empty() {
292 return analysis;
293 }
294
295 let mut events = Vec::with_capacity(analysis.slices.len());
296
297 unsafe {
298 for origin in analysis.slices.keys() {
299 let stream = self.streams.get_mut_index(*origin);
300 let event = B::flush(&mut stream.stream);
301
302 events.push(((origin, stream.cursor), event));
303 }
304 }
305
306 let stream = self.streams.get_mut(&stream_id);
307
308 for ((stream_origin, cursor_origin), event) in events {
309 stream.last_synced.insert(*stream_origin, cursor_origin);
310
311 self.logger.log_streaming(
312 |level| !matches!(level, StreamingLogLevel::Disabled),
313 || format!("Waiting on {stream_origin} from {stream_id}",),
314 );
315
316 B::wait_event(&mut stream.stream, event);
317 }
318
319 analysis
320 }
321}
322
323impl<B: EventStreamBackend> core::fmt::Debug for StreamWrapper<B> {
324 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
325 f.debug_struct("StreamWrapper")
326 .field("stream", &self.stream)
327 .field("cursor", &self.cursor)
328 .field("last_synced", &self.last_synced)
329 .finish()
330 }
331}
332
333#[derive(Default, Debug, PartialEq, Eq)]
334pub(crate) struct SharedBindingAnalysis {
335 slices: HashMap<usize, Vec<SliceId>>,
336}
337
338impl SharedBindingAnalysis {
339 fn shared(&mut self, binding: &Binding, index: usize) {
340 match self.slices.get_mut(&index) {
341 Some(bindings) => bindings.push(*binding.memory.id()),
342 None => {
343 self.slices.insert(index, vec![*binding.memory.id()]);
344 }
345 }
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use crate::{memory_management::SliceHandle, server::Handle};
353
354 const MAX_STREAMS: u8 = 4;
355
356 #[test_log::test]
357 fn test_analysis_shared_bindings_1() {
358 let logger = Arc::new(ServerLogger::default());
359 let stream_1 = StreamId { value: 1 };
360 let stream_2 = StreamId { value: 2 };
361
362 let binding_1 = binding(stream_1);
363 let binding_2 = binding(stream_2);
364
365 let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
366 ms.resolve(stream_1, [].into_iter());
367 ms.resolve(stream_2, [].into_iter());
368
369 let analysis = ms.update_shared_bindings(stream_1, [&binding_1, &binding_2].into_iter());
370
371 let mut expected = SharedBindingAnalysis::default();
372 expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
373
374 assert_eq!(analysis, expected);
375 }
376
377 #[test_log::test]
378 fn test_analysis_shared_bindings_2() {
379 let logger = Arc::new(ServerLogger::default());
380 let stream_1 = StreamId { value: 1 };
381 let stream_2 = StreamId { value: 2 };
382
383 let binding_1 = binding(stream_1);
384 let binding_2 = binding(stream_2);
385 let binding_3 = binding(stream_1);
386
387 let mut ms = MultiStream::new(logger, TestBackend, 4);
388 ms.resolve(stream_1, [].into_iter());
389 ms.resolve(stream_2, [].into_iter());
390
391 let analysis =
392 ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
393
394 let mut expected = SharedBindingAnalysis::default();
395 expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
396
397 assert_eq!(analysis, expected);
398 }
399
400 #[test_log::test]
401 fn test_analysis_no_shared() {
402 let logger = Arc::new(ServerLogger::default());
403 let stream_1 = StreamId { value: 1 };
404 let stream_2 = StreamId { value: 2 };
405
406 let binding_1 = binding(stream_1);
407 let binding_2 = binding(stream_1);
408 let binding_3 = binding(stream_1);
409
410 let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
411 ms.resolve(stream_1, [].into_iter());
412 ms.resolve(stream_2, [].into_iter());
413
414 let analysis =
415 ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
416
417 let expected = SharedBindingAnalysis::default();
418
419 assert_eq!(analysis, expected);
420 }
421
422 #[test_log::test]
423 fn test_state() {
424 let logger = Arc::new(ServerLogger::default());
425 let stream_1 = StreamId { value: 1 };
426 let stream_2 = StreamId { value: 2 };
427
428 let binding_1 = binding(stream_1);
429 let binding_2 = binding(stream_2);
430 let binding_3 = binding(stream_1);
431
432 let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
433 ms.resolve(stream_1, [].into_iter());
434 ms.resolve(stream_2, [].into_iter());
435
436 ms.resolve(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
437
438 let stream1 = ms.streams.get_mut(&stream_1);
439 let index_2 = stream_index(&stream_2, MAX_STREAMS as usize);
440 assert_eq!(stream1.last_synced.get(&index_2), Some(&1));
441 assert_eq!(stream1.cursor, 2);
442
443 let stream2 = ms.streams.get_mut(&stream_2);
444 assert!(stream2.last_synced.is_empty());
445 assert_eq!(stream2.cursor, 1);
446 }
447
448 fn binding(stream: StreamId) -> Binding {
449 Handle::new(SliceHandle::new(), None, None, stream, 0, 10).binding()
450 }
451
452 struct TestBackend;
453
454 #[derive(Debug)]
455 struct TestStream {}
456
457 #[derive(Debug)]
458 struct TestEvent {}
459
460 impl EventStreamBackend for TestBackend {
461 type Stream = TestStream;
462 type Event = TestEvent;
463
464 fn create_stream(&self) -> Self::Stream {
465 TestStream {}
466 }
467
468 fn flush(_stream: &mut Self::Stream) -> Self::Event {
469 TestEvent {}
470 }
471
472 fn wait_event(_stream: &mut Self::Stream, _event: Self::Event) {}
473
474 fn wait_event_sync(_event: Self::Event) -> Result<(), ExecutionError> {
475 Ok(())
476 }
477 }
478}