1use crate::{
2 config::streaming::StreamingLogLevel,
3 logging::ServerLogger,
4 memory_management::SliceId,
5 server::Binding,
6 stream::{StreamFactory, StreamPool},
7};
8use core::any::Any;
9use cubecl_common::stream_id::StreamId;
10use hashbrown::HashMap;
11use std::sync::{Arc, mpsc::SyncSender};
12
13pub trait EventStreamBackend: 'static {
18 type Stream: core::fmt::Debug;
20 type Event: Send + 'static;
22
23 fn create_stream(&self) -> Self::Stream;
25 fn flush(stream: &mut Self::Stream) -> Self::Event;
28 fn wait_event(stream: &mut Self::Stream, event: Self::Event);
30 fn wait_event_sync(event: Self::Event);
32}
33
34#[derive(Debug)]
39pub struct MultiStream<B: EventStreamBackend> {
40 streams: StreamPool<EventStreamBackendWrapper<B>>,
42 pub logger: Arc<ServerLogger>,
44 max_streams: usize,
45 gc: GcThread<B>,
46}
47
48pub(crate) struct StreamWrapper<B: EventStreamBackend> {
53 stream: B::Stream,
55 cursor: u64,
57 last_synced: HashMap<usize, u64>,
59}
60
61pub struct ResolvedStreams<'a, B: EventStreamBackend> {
63 pub cursor: u64,
67 streams: &'a mut StreamPool<EventStreamBackendWrapper<B>>,
68 analysis: SharedBindingAnalysis,
69 gc: &'a GcThread<B>,
70 pub current: StreamId,
72}
73
74#[derive(Debug)]
75pub struct GcTask<B: EventStreamBackend> {
77 to_drop: Box<dyn Any + Send + 'static>,
78 event: B::Event,
80}
81
82impl<B: EventStreamBackend> GcTask<B> {
83 pub fn new<T: Send + 'static>(to_drop: T, event: B::Event) -> Self {
85 Self {
86 to_drop: Box::new(to_drop),
87 event,
88 }
89 }
90}
91
92#[derive(Debug)]
93struct EventStreamBackendWrapper<B: EventStreamBackend> {
94 backend: B,
95}
96
97impl<B: EventStreamBackend> StreamFactory for EventStreamBackendWrapper<B> {
98 type Stream = StreamWrapper<B>;
99
100 fn create(&mut self) -> Self::Stream {
101 StreamWrapper {
102 stream: self.backend.create_stream(),
103 cursor: 0,
104 last_synced: Default::default(),
105 }
106 }
107}
108
109#[derive(Debug)]
110struct GcThread<B: EventStreamBackend> {
111 sender: SyncSender<GcTask<B>>,
112}
113
114impl<B: EventStreamBackend> GcThread<B> {
115 fn new() -> GcThread<B> {
116 let (sender, recv) = std::sync::mpsc::sync_channel::<GcTask<B>>(8);
117
118 std::thread::spawn(move || {
119 while let Ok(event) = recv.recv() {
120 B::wait_event_sync(event.event);
121 core::mem::drop(event.to_drop);
122 }
123 });
124
125 GcThread { sender }
126 }
127 fn register(&self, task: GcTask<B>) {
128 self.sender.send(task).unwrap()
129 }
130}
131
132fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
133 stream_id.value as usize % max_streams
134}
135
136impl<'a, B: EventStreamBackend> ResolvedStreams<'a, B> {
137 pub fn get(&mut self, stream_id: &StreamId) -> &mut B::Stream {
139 let stream = self.streams.get_mut(stream_id);
140 &mut stream.stream
141 }
142
143 pub fn current(&mut self) -> &mut B::Stream {
145 let stream = self.streams.get_mut(&self.current);
146 &mut stream.stream
147 }
148}
149
150impl<'a, B: EventStreamBackend> Drop for ResolvedStreams<'a, B> {
151 fn drop(&mut self) {
152 if self.analysis.slices.is_empty() {
153 return;
154 }
155
156 let stream = self.streams.get_mut(&self.current);
157 let event_origin = B::flush(&mut stream.stream);
158
159 let stream_gc = &mut unsafe { self.streams.get_special(0) }.stream;
160 B::wait_event(stream_gc, event_origin);
161 let event = B::flush(stream_gc);
162
163 let mut ids = Vec::new();
164 self.analysis
165 .slices
166 .drain()
167 .for_each(|item| ids.extend(item.1));
168
169 self.gc.register(GcTask::new(ids, event));
170 }
171}
172
173impl<B: EventStreamBackend> MultiStream<B> {
174 pub fn new(logger: Arc<ServerLogger>, backend: B, max_streams: u8) -> Self {
176 let wrapper = EventStreamBackendWrapper { backend };
177 Self {
178 streams: StreamPool::new(wrapper, max_streams, 1),
179 logger,
180 max_streams: max_streams as usize,
181 gc: GcThread::new(),
182 }
183 }
184
185 pub fn gc(&mut self, gc: GcTask<B>) {
187 self.gc.sender.send(gc).unwrap();
188 }
189
190 pub fn resolve<'a>(
196 &mut self,
197 stream_id: StreamId,
198 bindings: impl Iterator<Item = &'a Binding>,
199 ) -> ResolvedStreams<'_, B> {
200 let analysis = self.align_streams(stream_id, bindings);
201
202 let stream = self.streams.get_mut(&stream_id);
203 stream.cursor += 1;
204
205 ResolvedStreams {
206 cursor: stream.cursor,
207 streams: &mut self.streams,
208 current: stream_id,
209 analysis,
210 gc: &self.gc,
211 }
212 }
213
214 fn align_streams<'a>(
219 &mut self,
220 stream_id: StreamId,
221 bindings: impl Iterator<Item = &'a Binding>,
222 ) -> SharedBindingAnalysis {
223 let analysis = self.update_shared_bindings(stream_id, bindings);
224
225 self.apply_analysis(stream_id, analysis)
226 }
227
228 pub(crate) fn update_shared_bindings<'a>(
233 &mut self,
234 stream_id: StreamId,
235 bindings: impl Iterator<Item = &'a Binding>,
236 ) -> SharedBindingAnalysis {
237 let mut analysis = SharedBindingAnalysis::default();
238 let current = self.streams.get_mut(&stream_id);
239
240 for binding in bindings {
241 if stream_id != binding.stream {
242 let index = stream_index(&binding.stream, self.max_streams);
243
244 if let Some(last_synced) = current.last_synced.get(&index) {
245 if *last_synced < binding.cursor {
246 self.logger.log_streaming(
247 |level| matches!(level, StreamingLogLevel::Full),
248 || {
249 format!(
250 "Binding on {} is shared on {} since it's not sync {} < {}",
251 binding.stream, stream_id, last_synced, binding.cursor
252 )
253 },
254 );
255 analysis.shared(binding, index);
256 }
257 } else {
258 self.logger.log_streaming(
259 |level| matches!(level, StreamingLogLevel::Full),
260 || {
261 format!(
262 "Binding on {} is shared on {} since it was never synced.",
263 binding.stream, stream_id,
264 )
265 },
266 );
267 analysis.shared(binding, index);
268 }
269 }
270 }
271
272 analysis
273 }
274
275 pub(crate) fn apply_analysis(
276 &mut self,
277 stream_id: StreamId,
278 analysis: SharedBindingAnalysis,
279 ) -> SharedBindingAnalysis {
280 if analysis.slices.is_empty() {
281 return analysis;
282 }
283
284 let mut events = Vec::with_capacity(analysis.slices.len());
285
286 unsafe {
287 for origin in analysis.slices.keys() {
288 let stream = self.streams.get_mut_index(*origin);
289 let event = B::flush(&mut stream.stream);
290
291 events.push(((origin, stream.cursor), event));
292 }
293 }
294
295 let stream = self.streams.get_mut(&stream_id);
296
297 for ((stream_origin, cursor_origin), event) in events {
298 stream.last_synced.insert(*stream_origin, cursor_origin);
299
300 self.logger.log_streaming(
301 |level| !matches!(level, StreamingLogLevel::Disabled),
302 || format!("Waiting on {stream_origin} from {stream_id}",),
303 );
304
305 B::wait_event(&mut stream.stream, event);
306 }
307
308 analysis
309 }
310}
311
312impl<B: EventStreamBackend> core::fmt::Debug for StreamWrapper<B> {
313 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
314 f.debug_struct("StreamWrapper")
315 .field("stream", &self.stream)
316 .field("cursor", &self.cursor)
317 .field("last_synced", &self.last_synced)
318 .finish()
319 }
320}
321
322#[derive(Default, Debug, PartialEq, Eq)]
323pub(crate) struct SharedBindingAnalysis {
324 slices: HashMap<usize, Vec<SliceId>>,
325}
326
327impl SharedBindingAnalysis {
328 fn shared(&mut self, binding: &Binding, index: usize) {
329 match self.slices.get_mut(&index) {
330 Some(bindings) => bindings.push(*binding.memory.id()),
331 None => {
332 self.slices.insert(index, vec![*binding.memory.id()]);
333 }
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::{memory_management::SliceHandle, server::Handle};
342
343 const MAX_STREAMS: u8 = 4;
344
345 #[test]
346 fn test_analysis_shared_bindings_1() {
347 let logger = Arc::new(ServerLogger::default());
348 let stream_1 = StreamId { value: 1 };
349 let stream_2 = StreamId { value: 2 };
350
351 let binding_1 = binding(stream_1);
352 let binding_2 = binding(stream_2);
353
354 let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
355 ms.resolve(stream_1, [].into_iter());
356 ms.resolve(stream_2, [].into_iter());
357
358 let analysis = ms.update_shared_bindings(stream_1, [&binding_1, &binding_2].into_iter());
359
360 let mut expected = SharedBindingAnalysis::default();
361 expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
362
363 assert_eq!(analysis, expected);
364 }
365
366 #[test]
367 fn test_analysis_shared_bindings_2() {
368 let logger = Arc::new(ServerLogger::default());
369 let stream_1 = StreamId { value: 1 };
370 let stream_2 = StreamId { value: 2 };
371
372 let binding_1 = binding(stream_1);
373 let binding_2 = binding(stream_2);
374 let binding_3 = binding(stream_1);
375
376 let mut ms = MultiStream::new(logger, TestBackend, 4);
377 ms.resolve(stream_1, [].into_iter());
378 ms.resolve(stream_2, [].into_iter());
379
380 let analysis =
381 ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
382
383 let mut expected = SharedBindingAnalysis::default();
384 expected.shared(&binding_2, ms.streams.stream_index(&binding_2.stream));
385
386 assert_eq!(analysis, expected);
387 }
388
389 #[test]
390 fn test_analysis_no_shared() {
391 let logger = Arc::new(ServerLogger::default());
392 let stream_1 = StreamId { value: 1 };
393 let stream_2 = StreamId { value: 2 };
394
395 let binding_1 = binding(stream_1);
396 let binding_2 = binding(stream_1);
397 let binding_3 = binding(stream_1);
398
399 let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
400 ms.resolve(stream_1, [].into_iter());
401 ms.resolve(stream_2, [].into_iter());
402
403 let analysis =
404 ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
405
406 let expected = SharedBindingAnalysis::default();
407
408 assert_eq!(analysis, expected);
409 }
410
411 #[test]
412 fn test_state() {
413 let logger = Arc::new(ServerLogger::default());
414 let stream_1 = StreamId { value: 1 };
415 let stream_2 = StreamId { value: 2 };
416
417 let binding_1 = binding(stream_1);
418 let binding_2 = binding(stream_2);
419 let binding_3 = binding(stream_1);
420
421 let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
422 ms.resolve(stream_1, [].into_iter());
423 ms.resolve(stream_2, [].into_iter());
424
425 ms.resolve(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
426
427 let stream1 = ms.streams.get_mut(&stream_1);
428 let index_2 = stream_index(&stream_2, MAX_STREAMS as usize);
429 assert_eq!(stream1.last_synced.get(&index_2), Some(&1));
430 assert_eq!(stream1.cursor, 2);
431
432 let stream2 = ms.streams.get_mut(&stream_2);
433 assert!(stream2.last_synced.is_empty());
434 assert_eq!(stream2.cursor, 1);
435 }
436
437 fn binding(stream: StreamId) -> Binding {
438 Handle::new(SliceHandle::new(), None, None, stream, 0, 10).binding()
439 }
440
441 struct TestBackend;
442
443 #[derive(Debug)]
444 struct TestStream {}
445
446 #[derive(Debug)]
447 struct TestEvent {}
448
449 impl EventStreamBackend for TestBackend {
450 type Stream = TestStream;
451 type Event = TestEvent;
452
453 fn create_stream(&self) -> Self::Stream {
454 TestStream {}
455 }
456
457 fn flush(_stream: &mut Self::Stream) -> Self::Event {
458 TestEvent {}
459 }
460
461 fn wait_event(_stream: &mut Self::Stream, _event: Self::Event) {}
462
463 fn wait_event_sync(_event: Self::Event) {}
464 }
465}