1use crate::{
2 config::streaming::StreamingLogLevel,
3 logging::ServerLogger,
4 memory_management::SliceId,
5 server::Binding,
6 stream::{StreamFactory, StreamPool},
7};
8use core::any::Any;
9use cubecl_common::stream_id::StreamId;
10use hashbrown::HashMap;
11use std::sync::{Arc, mpsc::SyncSender};
12
13pub trait EventStreamBackend: 'static {
18 type Stream: core::fmt::Debug;
20 type Event: Send + 'static;
22
23 fn create_stream(&self) -> Self::Stream;
25 fn flush(stream: &mut Self::Stream) -> Self::Event;
28 fn wait_event(stream: &mut Self::Stream, event: Self::Event);
30 fn wait_event_sync(event: Self::Event);
32}
33
34#[derive(Debug)]
39pub struct MultiStream<B: EventStreamBackend> {
40 streams: StreamPool<EventStreamBackendWrapper<B>>,
42 pub logger: Arc<ServerLogger>,
44 max_streams: usize,
45 gc: GcThread<B>,
46}
47
48pub(crate) struct StreamWrapper<B: EventStreamBackend> {
53 stream: B::Stream,
55 cursor: u64,
57 last_synced: HashMap<usize, u64>,
59}
60
61pub struct ResolvedStreams<'a, B: EventStreamBackend> {
63 pub cursor: u64,
67 streams: &'a mut StreamPool<EventStreamBackendWrapper<B>>,
68 analysis: SharedBindingAnalysis,
69 gc: &'a GcThread<B>,
70 pub current: StreamId,
72}
73
74#[derive(Debug)]
75pub struct GcTask<B: EventStreamBackend> {
77 to_drop: Box<dyn Any + Send + 'static>,
78 event: B::Event,
80}
81
82impl<B: EventStreamBackend> GcTask<B> {
83 pub fn new<T: Send + 'static>(to_drop: T, event: B::Event) -> Self {
85 Self {
86 to_drop: Box::new(to_drop),
87 event,
88 }
89 }
90}
91
92#[derive(Debug)]
93struct EventStreamBackendWrapper<B: EventStreamBackend> {
94 backend: B,
95}
96
97impl<B: EventStreamBackend> StreamFactory for EventStreamBackendWrapper<B> {
98 type Stream = StreamWrapper<B>;
99
100 fn create(&mut self) -> Self::Stream {
101 StreamWrapper {
102 stream: self.backend.create_stream(),
103 cursor: 0,
104 last_synced: Default::default(),
105 }
106 }
107}
108
109#[derive(Debug)]
110struct GcThread<B: EventStreamBackend> {
111 sender: SyncSender<GcTask<B>>,
112}
113
114impl<B: EventStreamBackend> GcThread<B> {
115 fn new() -> GcThread<B> {
116 let (sender, recv) = std::sync::mpsc::sync_channel::<GcTask<B>>(8);
117
118 std::thread::spawn(move || {
119 while let Ok(event) = recv.recv() {
120 B::wait_event_sync(event.event);
121 core::mem::drop(event.to_drop);
122 }
123 });
124
125 GcThread { sender }
126 }
127 fn register(&self, task: GcTask<B>) {
128 self.sender.send(task).unwrap()
129 }
130}
131
132fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
133 stream_id.value as usize % max_streams
134}
135
136impl<'a, B: EventStreamBackend> ResolvedStreams<'a, B> {
137 pub fn get(&mut self, stream_id: &StreamId) -> &mut B::Stream {
139 let stream = self.streams.get_mut(stream_id);
140 &mut stream.stream
141 }
142
143 pub fn current(&mut self) -> &mut B::Stream {
145 let stream = self.streams.get_mut(&self.current);
146 &mut stream.stream
147 }
148
149 pub fn gc(&mut self, gc: GcTask<B>) {
151 self.gc.sender.send(gc).unwrap();
152 }
153}
154
155impl<'a, B: EventStreamBackend> Drop for ResolvedStreams<'a, B> {
156 fn drop(&mut self) {
157 if self.analysis.slices.is_empty() {
158 return;
159 }
160
161 let stream = self.streams.get_mut(&self.current);
162 let event_origin = B::flush(&mut stream.stream);
163
164 let stream_gc = &mut unsafe { self.streams.get_special(0) }.stream;
165 B::wait_event(stream_gc, event_origin);
166 let event = B::flush(stream_gc);
167
168 let mut ids = Vec::new();
169 self.analysis
170 .slices
171 .drain()
172 .for_each(|item| ids.extend(item.1));
173
174 self.gc.register(GcTask::new(ids, event));
175 }
176}
177
178impl<B: EventStreamBackend> MultiStream<B> {
179 pub fn new(logger: Arc<ServerLogger>, backend: B, max_streams: u8) -> Self {
181 let wrapper = EventStreamBackendWrapper { backend };
182 Self {
183 streams: StreamPool::new(wrapper, max_streams, 1),
184 logger,
185 max_streams: max_streams as usize,
186 gc: GcThread::new(),
187 }
188 }
189
190 pub fn gc(&mut self, gc: GcTask<B>) {
192 self.gc.sender.send(gc).unwrap();
193 }
194
195 pub fn resolve<'a>(
201 &mut self,
202 stream_id: StreamId,
203 bindings: impl Iterator<Item = &'a Binding>,
204 ) -> ResolvedStreams<'_, B> {
205 let analysis = self.align_streams(stream_id, bindings);
206
207 let stream = self.streams.get_mut(&stream_id);
208 stream.cursor += 1;
209
210 ResolvedStreams {
211 cursor: stream.cursor,
212 streams: &mut self.streams,
213 current: stream_id,
214 analysis,
215 gc: &self.gc,
216 }
217 }
218
219 fn align_streams<'a>(
224 &mut self,
225 stream_id: StreamId,
226 bindings: impl Iterator<Item = &'a Binding>,
227 ) -> SharedBindingAnalysis {
228 let analysis = self.update_shared_bindings(stream_id, bindings);
229
230 self.apply_analysis(stream_id, analysis)
231 }
232
233 pub(crate) fn update_shared_bindings<'a>(
238 &mut self,
239 stream_id: StreamId,
240 bindings: impl Iterator<Item = &'a Binding>,
241 ) -> SharedBindingAnalysis {
242 let mut analysis = SharedBindingAnalysis::default();
243 let current = self.streams.get_mut(&stream_id);
244
245 for binding in bindings {
246 if stream_id != binding.stream {
247 let index = stream_index(&binding.stream, self.max_streams);
248
249 if let Some(last_synced) = current.last_synced.get(&index) {
250 if *last_synced < binding.cursor {
251 self.logger.log_streaming(
252 |level| matches!(level, StreamingLogLevel::Full),
253 || {
254 format!(
255 "Binding on {} is shared on {} since it's not sync {} < {}",
256 binding.stream, stream_id, last_synced, binding.cursor
257 )
258 },
259 );
260 analysis.shared(binding, index);
261 }
262 } else {
263 self.logger.log_streaming(
264 |level| matches!(level, StreamingLogLevel::Full),
265 || {
266 format!(
267 "Binding on {} is shared on {} since it was never synced.",
268 binding.stream, stream_id,
269 )
270 },
271 );
272 analysis.shared(binding, index);
273 }
274 }
275 }
276
277 analysis
278 }
279
280 pub(crate) fn apply_analysis(
281 &mut self,
282 stream_id: StreamId,
283 analysis: SharedBindingAnalysis,
284 ) -> SharedBindingAnalysis {
285 if analysis.slices.is_empty() {
286 return analysis;
287 }
288
289 let mut events = Vec::with_capacity(analysis.slices.len());
290
291 unsafe {
292 for origin in analysis.slices.keys() {
293 let stream = self.streams.get_mut_index(*origin);
294 let event = B::flush(&mut stream.stream);
295
296 events.push(((origin, stream.cursor), event));
297 }
298 }
299
300 let stream = self.streams.get_mut(&stream_id);
301
302 for ((stream_origin, cursor_origin), event) in events {
303 stream.last_synced.insert(*stream_origin, cursor_origin);
304
305 self.logger.log_streaming(
306 |level| !matches!(level, StreamingLogLevel::Disabled),
307 || format!("Waiting on {stream_origin} from {stream_id}",),
308 );
309
310 B::wait_event(&mut stream.stream, event);
311 }
312
313 analysis
314 }
315}
316
317impl<B: EventStreamBackend> core::fmt::Debug for StreamWrapper<B> {
318 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
319 f.debug_struct("StreamWrapper")
320 .field("stream", &self.stream)
321 .field("cursor", &self.cursor)
322 .field("last_synced", &self.last_synced)
323 .finish()
324 }
325}
326
327#[derive(Default, Debug, PartialEq, Eq)]
328pub(crate) struct SharedBindingAnalysis {
329 slices: HashMap<usize, Vec<SliceId>>,
330}
331
332impl SharedBindingAnalysis {
333 fn shared(&mut self, binding: &Binding, index: usize) {
334 match self.slices.get_mut(&index) {
335 Some(bindings) => bindings.push(*binding.memory.id()),
336 None => {
337 self.slices.insert(index, vec![*binding.memory.id()]);
338 }
339 }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use crate::{memory_management::SliceHandle, server::Handle};
347
348 const MAX_STREAMS: u8 = 4;
349
350 #[test]
351 fn test_analysis_shared_bindings_1() {
352 let logger = Arc::new(ServerLogger::default());
353 let stream_1 = StreamId { value: 1 };
354 let stream_2 = StreamId { value: 2 };
355
356 let binding_1 = binding(stream_1);
357 let binding_2 = binding(stream_2);
358
359 let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
360 ms.resolve(stream_1, [].into_iter());
361 ms.resolve(stream_2, [].into_iter());
362
363 let analysis = ms.update_shared_bindings(stream_1, [&binding_1, &binding_2].into_iter());
364
365 let mut expected = SharedBindingAnalysis::default();
366 expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
367
368 assert_eq!(analysis, expected);
369 }
370
371 #[test]
372 fn test_analysis_shared_bindings_2() {
373 let logger = Arc::new(ServerLogger::default());
374 let stream_1 = StreamId { value: 1 };
375 let stream_2 = StreamId { value: 2 };
376
377 let binding_1 = binding(stream_1);
378 let binding_2 = binding(stream_2);
379 let binding_3 = binding(stream_1);
380
381 let mut ms = MultiStream::new(logger, TestBackend, 4);
382 ms.resolve(stream_1, [].into_iter());
383 ms.resolve(stream_2, [].into_iter());
384
385 let analysis =
386 ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
387
388 let mut expected = SharedBindingAnalysis::default();
389 expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
390
391 assert_eq!(analysis, expected);
392 }
393
394 #[test]
395 fn test_analysis_no_shared() {
396 let logger = Arc::new(ServerLogger::default());
397 let stream_1 = StreamId { value: 1 };
398 let stream_2 = StreamId { value: 2 };
399
400 let binding_1 = binding(stream_1);
401 let binding_2 = binding(stream_1);
402 let binding_3 = binding(stream_1);
403
404 let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
405 ms.resolve(stream_1, [].into_iter());
406 ms.resolve(stream_2, [].into_iter());
407
408 let analysis =
409 ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
410
411 let expected = SharedBindingAnalysis::default();
412
413 assert_eq!(analysis, expected);
414 }
415
416 #[test]
417 fn test_state() {
418 let logger = Arc::new(ServerLogger::default());
419 let stream_1 = StreamId { value: 1 };
420 let stream_2 = StreamId { value: 2 };
421
422 let binding_1 = binding(stream_1);
423 let binding_2 = binding(stream_2);
424 let binding_3 = binding(stream_1);
425
426 let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
427 ms.resolve(stream_1, [].into_iter());
428 ms.resolve(stream_2, [].into_iter());
429
430 ms.resolve(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
431
432 let stream1 = ms.streams.get_mut(&stream_1);
433 let index_2 = stream_index(&stream_2, MAX_STREAMS as usize);
434 assert_eq!(stream1.last_synced.get(&index_2), Some(&1));
435 assert_eq!(stream1.cursor, 2);
436
437 let stream2 = ms.streams.get_mut(&stream_2);
438 assert!(stream2.last_synced.is_empty());
439 assert_eq!(stream2.cursor, 1);
440 }
441
442 fn binding(stream: StreamId) -> Binding {
443 Handle::new(SliceHandle::new(), None, None, stream, 0, 10).binding()
444 }
445
446 struct TestBackend;
447
448 #[derive(Debug)]
449 struct TestStream {}
450
451 #[derive(Debug)]
452 struct TestEvent {}
453
454 impl EventStreamBackend for TestBackend {
455 type Stream = TestStream;
456 type Event = TestEvent;
457
458 fn create_stream(&self) -> Self::Stream {
459 TestStream {}
460 }
461
462 fn flush(_stream: &mut Self::Stream) -> Self::Event {
463 TestEvent {}
464 }
465
466 fn wait_event(_stream: &mut Self::Stream, _event: Self::Event) {}
467
468 fn wait_event_sync(_event: Self::Event) {}
469 }
470}