cubecl_runtime/stream/base.rs
1use cubecl_common::stream_id::StreamId;
2
3/// Trait for creating streams, used by the stream pool to generate streams as needed.
4pub trait StreamFactory {
5 /// The type of stream produced by this factory.
6 type Stream;
7 /// Creates a new stream instance.
8 fn create(&mut self) -> Self::Stream;
9}
10
11/// Represents a pool of streams, managing a collection of streams created by a factory.
12#[derive(Debug)]
13pub struct StreamPool<F: StreamFactory> {
14 /// Vector storing optional streams, where None indicates an uninitialized stream.
15 streams: Vec<Option<F::Stream>>,
16 /// The factory used to create new streams when needed.
17 factory: F,
18 /// Maximum number of regular streams (excludes special streams).
19 max_streams: usize,
20}
21
22impl<F: StreamFactory> StreamPool<F> {
23 /// Creates a new stream pool with the given backend factory and capacity constraints.
24 pub fn new(backend: F, max_streams: u8, num_special: u8) -> Self {
25 // Initialize a vector with capacity for regular and special streams.
26 let mut streams = Vec::with_capacity(max_streams as usize);
27 // Pre-populate the vector with None to reserve space for all streams.
28 for _ in 0..(max_streams + num_special) {
29 streams.push(None);
30 }
31
32 Self {
33 streams,
34 factory: backend,
35 max_streams: max_streams as usize,
36 }
37 }
38
39 /// Retrieves a mutable reference to a stream for a given stream ID.
40 pub fn get_mut(&mut self, stream_id: &StreamId) -> &mut F::Stream {
41 // Calculate the index for the stream ID.
42 let index = self.stream_index(stream_id);
43
44 // Use unsafe method to retrieve the stream, assuming the index is valid.
45 //
46 // # Safety
47 //
48 // * The `stream_index` function ensures the index is within bounds.
49 unsafe { self.get_mut_index(index) }
50 }
51
52 /// Retrieves a mutable reference to a stream at the specified index, initializing it if needed.
53 ///
54 /// # Safety
55 ///
56 /// * Caller must ensure the index is valid (less than `max_streams + num_special`).
57 /// * Lifetimes still follow the Rust rules.
58 pub unsafe fn get_mut_index(&mut self, index: usize) -> &mut F::Stream {
59 unsafe {
60 // Access the stream entry without bounds checking for performance.
61 let entry = self.streams.get_unchecked_mut(index);
62 match entry {
63 // If the stream exists, return it.
64 Some(val) => val,
65 // If the stream is None, create a new one using the factory.
66 None => {
67 let stream = self.factory.create();
68 // Store the new stream in the vector.
69 *entry = Some(stream);
70
71 // Re-access the entry, which is now guaranteed to be Some.
72 match entry {
73 Some(val) => val,
74 // Unreachable because we just set it to Some.
75 None => unreachable!(),
76 }
77 }
78 }
79 }
80 }
81
82 /// Retrieves a mutable reference to a special stream at the given index.
83 ///
84 /// # Safety
85 ///
86 /// * Caller must ensure the index corresponds to a valid special stream.
87 /// * Lifetimes still follow the Rust rules.
88 pub unsafe fn get_special(&mut self, index: u8) -> &mut F::Stream {
89 // Calculate the index for the special stream (offset by max_streams).
90 unsafe { self.get_mut_index(self.max_streams + index as usize) }
91 }
92
93 /// Calculates the index for a given stream ID, mapping it to the pool's capacity.
94 pub fn stream_index(&mut self, id: &StreamId) -> usize {
95 stream_index(id, self.max_streams)
96 }
97}
98
99/// Maps a stream ID to an index within the pool's capacity using modulo arithmetic.
100pub fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
101 stream_id.value as usize % max_streams
102}