1use crate::{
2 config::streaming::StreamingLogLevel,
3 logging::ServerLogger,
4 memory_management::ManagedMemoryId,
5 server::{Binding, ServerError},
6 stream::{StreamFactory, StreamPool},
7};
8use core::any::Any;
9use cubecl_common::{backtrace::BackTrace, stream_id::StreamId};
10use hashbrown::HashMap;
11use std::{
12 boxed::Box,
13 format,
14 sync::{Arc, mpsc::SyncSender},
15 vec::Vec,
16};
17
18pub trait EventStreamBackend: 'static {
23 type Stream: core::fmt::Debug;
25 type Event: Send + 'static;
27
28 fn create_stream(&self) -> Self::Stream;
30 fn handle_cursor(stream: &Self::Stream, handle: &Binding) -> u64;
32 fn is_healthy(stream: &Self::Stream) -> bool;
34
35 fn flush(stream: &mut Self::Stream) -> Self::Event;
38 fn wait_event(stream: &mut Self::Stream, event: Self::Event);
40 fn wait_event_sync(event: Self::Event) -> Result<(), ServerError>;
42}
43
44#[derive(Debug)]
49pub struct MultiStream<B: EventStreamBackend> {
50 streams: StreamPool<EventStreamBackendWrapper<B>>,
52 pub logger: Arc<ServerLogger>,
54 max_streams: usize,
55 gc: GcThread<B>,
56 shared_bindings_pool: Vec<(ManagedMemoryId, StreamId, u64)>,
57}
58
59pub(crate) struct StreamWrapper<B: EventStreamBackend> {
64 stream: B::Stream,
66 cursor: u64,
68 last_synced: HashMap<usize, u64>,
70}
71
72pub struct ResolvedStreams<'a, B: EventStreamBackend> {
74 pub cursor: u64,
78 streams: &'a mut StreamPool<EventStreamBackendWrapper<B>>,
79 analysis: SharedBindingAnalysis,
80 gc: &'a GcThread<B>,
81 pub current: StreamId,
83}
84
85#[derive(Debug)]
86pub struct GcTask<B: EventStreamBackend> {
88 to_drop: Box<dyn Any + Send + 'static>,
89 event: B::Event,
91}
92
93impl<B: EventStreamBackend> GcTask<B> {
94 pub fn new<T: Send + 'static>(to_drop: T, event: B::Event) -> Self {
96 Self {
97 to_drop: Box::new(to_drop),
98 event,
99 }
100 }
101}
102
103#[derive(Debug)]
104struct EventStreamBackendWrapper<B: EventStreamBackend> {
105 backend: B,
106}
107
108impl<B: EventStreamBackend> StreamFactory for EventStreamBackendWrapper<B> {
109 type Stream = StreamWrapper<B>;
110
111 fn create(&mut self) -> Self::Stream {
112 StreamWrapper {
113 stream: self.backend.create_stream(),
114 cursor: 0,
115 last_synced: Default::default(),
116 }
117 }
118}
119
120#[derive(Debug)]
121struct GcThread<B: EventStreamBackend> {
122 sender: SyncSender<GcTask<B>>,
123}
124
125impl<B: EventStreamBackend> GcThread<B> {
126 fn new() -> GcThread<B> {
127 let (sender, recv) = std::sync::mpsc::sync_channel::<GcTask<B>>(8);
128
129 std::thread::spawn(move || {
130 while let Ok(event) = recv.recv() {
131 B::wait_event_sync(event.event).unwrap();
132 core::mem::drop(event.to_drop);
133 }
134 });
135
136 GcThread { sender }
137 }
138 fn register(&self, task: GcTask<B>) {
139 self.sender.send(task).unwrap()
140 }
141}
142
143fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
144 stream_id.value as usize % max_streams
145}
146
147impl<'a, B: EventStreamBackend> ResolvedStreams<'a, B> {
148 pub fn get(&mut self, stream_id: &StreamId) -> &mut B::Stream {
150 let stream = self.streams.get_mut(stream_id);
151 &mut stream.stream
152 }
153
154 pub fn current(&mut self) -> &mut B::Stream {
156 let stream = self.streams.get_mut(&self.current);
157 &mut stream.stream
158 }
159
160 pub fn gc(&mut self, gc: GcTask<B>) {
162 self.gc.sender.send(gc).unwrap();
163 }
164}
165
166impl<'a, B: EventStreamBackend> Drop for ResolvedStreams<'a, B> {
167 fn drop(&mut self) {
168 if self.analysis.slices.is_empty() {
169 return;
170 }
171
172 let stream = self.streams.get_mut(&self.current);
173 let event_origin = B::flush(&mut stream.stream);
174
175 let stream_gc = &mut unsafe { self.streams.get_special(0) }.stream;
176 B::wait_event(stream_gc, event_origin);
177 let event = B::flush(stream_gc);
178
179 let mut ids = Vec::new();
180 self.analysis
181 .slices
182 .drain()
183 .for_each(|item| ids.extend(item.1));
184
185 self.gc.register(GcTask::new(ids, event));
186 }
187}
188
189impl<B: EventStreamBackend> MultiStream<B> {
190 pub fn new(logger: Arc<ServerLogger>, backend: B, max_streams: u8) -> Self {
192 let wrapper = EventStreamBackendWrapper { backend };
193 Self {
194 streams: StreamPool::new(wrapper, max_streams, 1),
195 logger,
196 max_streams: max_streams as usize,
197 gc: GcThread::new(),
198 shared_bindings_pool: Vec::new(),
199 }
200 }
201
202 pub fn gc(&mut self, gc: GcTask<B>) {
204 self.gc.sender.send(gc).unwrap();
205 }
206
207 pub fn resolve<'a>(
213 &mut self,
214 stream_id: StreamId,
215 handles: impl Iterator<Item = &'a Binding>,
216 enforce_healthy: bool,
217 ) -> Result<ResolvedStreams<'_, B>, ServerError> {
218 let analysis = self.align_streams(stream_id, handles);
219
220 let stream = self.streams.get_mut(&stream_id);
221 stream.cursor += 1;
222
223 if enforce_healthy && !B::is_healthy(&stream.stream) {
224 return Err(ServerError::Generic {
225 reason: "Can't resolve the stream since it is currently in an error state".into(),
226 backtrace: BackTrace::capture(),
227 });
228 }
229
230 Ok(ResolvedStreams {
231 cursor: stream.cursor,
232 streams: &mut self.streams,
233 current: stream_id,
234 analysis,
235 gc: &self.gc,
236 })
237 }
238
239 fn align_streams<'a>(
244 &mut self,
245 stream_id: StreamId,
246 handles: impl Iterator<Item = &'a Binding>,
247 ) -> SharedBindingAnalysis {
248 let analysis = self.update_shared_bindings(stream_id, handles);
249
250 self.apply_analysis(stream_id, analysis)
251 }
252
253 pub(crate) fn update_shared_bindings<'a>(
258 &mut self,
259 stream_id: StreamId,
260 handles: impl Iterator<Item = &'a Binding>,
261 ) -> SharedBindingAnalysis {
262 self.shared_bindings_pool.clear();
264
265 for handle in handles {
266 let index = stream_index(&handle.stream, self.max_streams);
267 let stream = unsafe { self.streams.get_mut_index(index) };
268 let cursor_handle = B::handle_cursor(&stream.stream, handle);
269
270 if handle.stream != stream_id {
273 self.shared_bindings_pool.push((
274 handle.memory.descriptor().id,
275 handle.stream,
276 cursor_handle,
277 ));
278 }
279 }
280
281 let mut analysis = SharedBindingAnalysis::default();
282 let current = self.streams.get_mut(&stream_id);
283
284 for (handle_id, stream, cursor) in self.shared_bindings_pool.iter() {
285 let index = stream_index(stream, self.max_streams);
286
287 if let Some(last_synced) = current.last_synced.get(&index) {
288 if last_synced < cursor {
289 self.logger.log_streaming(
290 |level| matches!(level, StreamingLogLevel::Full),
291 || {
292 format!(
293 "Binding on {} is shared on {} since it's not sync {} < {}",
294 stream, stream_id, last_synced, cursor
295 )
296 },
297 );
298 analysis.shared(*handle_id, index);
299 }
300 } else {
301 self.logger.log_streaming(
302 |level| matches!(level, StreamingLogLevel::Full),
303 || {
304 format!(
305 "Binding on {} is shared on {} since it was never synced.",
306 stream, stream_id,
307 )
308 },
309 );
310 analysis.shared(*handle_id, index);
311 }
312 }
313
314 analysis
315 }
316
317 pub(crate) fn apply_analysis(
318 &mut self,
319 stream_id: StreamId,
320 analysis: SharedBindingAnalysis,
321 ) -> SharedBindingAnalysis {
322 if analysis.slices.is_empty() {
323 return analysis;
324 }
325
326 let mut events = Vec::with_capacity(analysis.slices.len());
327
328 unsafe {
329 for origin in analysis.slices.keys() {
330 let stream = self.streams.get_mut_index(*origin);
331 let event = B::flush(&mut stream.stream);
332
333 events.push(((origin, stream.cursor), event));
334 }
335 }
336
337 let stream = self.streams.get_mut(&stream_id);
338
339 for ((stream_origin, cursor_origin), event) in events {
340 stream.last_synced.insert(*stream_origin, cursor_origin);
341
342 self.logger.log_streaming(
343 |level| !matches!(level, StreamingLogLevel::Disabled),
344 || format!("Waiting on {stream_origin} from {stream_id}",),
345 );
346
347 B::wait_event(&mut stream.stream, event);
348 }
349
350 analysis
351 }
352}
353
354impl<B: EventStreamBackend> core::fmt::Debug for StreamWrapper<B> {
355 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
356 f.debug_struct("StreamWrapper")
357 .field("stream", &self.stream)
358 .field("cursor", &self.cursor)
359 .field("last_synced", &self.last_synced)
360 .finish()
361 }
362}
363
364#[derive(Default, Debug, PartialEq, Eq)]
365pub(crate) struct SharedBindingAnalysis {
366 slices: HashMap<usize, Vec<ManagedMemoryId>>,
367}
368
369impl SharedBindingAnalysis {
370 fn shared(&mut self, id: ManagedMemoryId, index: usize) {
371 match self.slices.get_mut(&index) {
372 Some(bindings) => bindings.push(id),
373 None => {
374 self.slices.insert(index, alloc::vec![id]);
375 }
376 }
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use crate::server::Handle;
383
384 use super::*;
385
386 const MAX_STREAMS: u8 = 4;
387
388 #[test_log::test]
389 fn test_analysis_shared_bindings_1() {
390 let logger = Arc::new(ServerLogger::default());
391 let stream_1 = StreamId { value: 1 };
392 let stream_2 = StreamId { value: 2 };
393
394 let binding_1 = handle(stream_1);
395 let binding_2 = handle(stream_2);
396
397 let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
398 ms.resolve(stream_1, [].into_iter(), false).unwrap();
399 ms.resolve(stream_2, [].into_iter(), false).unwrap();
400
401 let analysis = ms.update_shared_bindings(stream_1, [&binding_1, &binding_2].into_iter());
402
403 let mut expected = SharedBindingAnalysis::default();
404 expected.shared(
405 binding_2.memory.descriptor().id,
406 ms.streams.stream_index(&binding_2.stream),
407 );
408
409 assert_eq!(analysis, expected);
410 }
411
412 #[test_log::test]
413 fn test_analysis_shared_bindings_2() {
414 let logger = Arc::new(ServerLogger::default());
415 let stream_1 = StreamId { value: 1 };
416 let stream_2 = StreamId { value: 2 };
417
418 let binding_1 = handle(stream_1);
419 let binding_2 = handle(stream_2);
420 let binding_3 = handle(stream_1);
421
422 let mut ms = MultiStream::new(logger, TestBackend, 4);
423 ms.resolve(stream_1, [].into_iter(), false).unwrap();
424 ms.resolve(stream_2, [].into_iter(), false).unwrap();
425
426 let analysis =
427 ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
428
429 let mut expected = SharedBindingAnalysis::default();
430 expected.shared(
431 binding_2.memory.descriptor().id,
432 ms.streams.stream_index(&binding_2.stream),
433 );
434
435 assert_eq!(analysis, expected);
436 }
437
438 #[test_log::test]
439 fn test_analysis_no_shared() {
440 let logger = Arc::new(ServerLogger::default());
441 let stream_1 = StreamId { value: 1 };
442 let stream_2 = StreamId { value: 2 };
443
444 let binding_1 = handle(stream_1);
445 let binding_2 = handle(stream_1);
446 let binding_3 = handle(stream_1);
447
448 let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
449 ms.resolve(stream_1, [].into_iter(), false).unwrap();
450 ms.resolve(stream_2, [].into_iter(), false).unwrap();
451
452 let analysis =
453 ms.update_shared_bindings(stream_1, [&binding_1, &binding_2, &binding_3].into_iter());
454
455 let expected = SharedBindingAnalysis::default();
456
457 assert_eq!(analysis, expected);
458 }
459
460 #[test_log::test]
461 fn test_state() {
462 let logger = Arc::new(ServerLogger::default());
463 let stream_1 = StreamId { value: 1 };
464 let stream_2 = StreamId { value: 2 };
465
466 let binding_1 = handle(stream_1);
467 let binding_2 = handle(stream_2);
468 let binding_3 = handle(stream_1);
469
470 let mut ms = MultiStream::new(logger, TestBackend, MAX_STREAMS);
471 ms.resolve(stream_1, [].into_iter(), false).unwrap();
472 ms.resolve(stream_2, [].into_iter(), false).unwrap();
473
474 ms.resolve(
475 stream_1,
476 [&binding_1, &binding_2, &binding_3].into_iter(),
477 false,
478 )
479 .unwrap();
480
481 let stream1 = ms.streams.get_mut(&stream_1);
482 let index_2 = stream_index(&stream_2, MAX_STREAMS as usize);
483 assert_eq!(stream1.last_synced.get(&index_2), Some(&1));
484 assert_eq!(stream1.cursor, 2);
485
486 let stream2 = ms.streams.get_mut(&stream_2);
487 assert!(stream2.last_synced.is_empty());
488 assert_eq!(stream2.cursor, 1);
489 }
490
491 fn handle(stream: StreamId) -> Binding {
492 Handle::new(stream, 10).binding()
493 }
494
495 struct TestBackend;
496
497 #[derive(Debug)]
498 struct TestStream {}
499
500 #[derive(Debug)]
501 struct TestEvent {}
502
503 impl EventStreamBackend for TestBackend {
504 type Stream = TestStream;
505 type Event = TestEvent;
506
507 fn create_stream(&self) -> Self::Stream {
508 TestStream {}
509 }
510
511 fn flush(_stream: &mut Self::Stream) -> Self::Event {
512 TestEvent {}
513 }
514
515 fn wait_event(_stream: &mut Self::Stream, _event: Self::Event) {}
516
517 fn wait_event_sync(_event: Self::Event) -> Result<(), ServerError> {
518 Ok(())
519 }
520
521 fn handle_cursor(_stream: &Self::Stream, _handle: &Binding) -> u64 {
522 0
523 }
524
525 fn is_healthy(_stream: &Self::Stream) -> bool {
526 true
527 }
528 }
529}