cubecl_runtime/stream/base.rs
1use cubecl_common::stream_id::StreamId;
2
3/// Trait for creating streams, used by the stream pool to generate streams as needed.
4pub trait StreamFactory {
5    /// The type of stream produced by this factory.
6    type Stream;
7    /// Creates a new stream instance.
8    fn create(&mut self) -> Self::Stream;
9}
10
11/// Represents a pool of streams, managing a collection of streams created by a factory.
12#[derive(Debug)]
13pub struct StreamPool<F: StreamFactory> {
14    /// Vector storing optional streams, where None indicates an uninitialized stream.
15    streams: Vec<Option<F::Stream>>,
16    /// The factory used to create new streams when needed.
17    factory: F,
18    /// Maximum number of regular streams (excludes special streams).
19    max_streams: usize,
20}
21
22impl<F: StreamFactory> StreamPool<F> {
23    /// Creates a new stream pool with the given backend factory and capacity constraints.
24    pub fn new(backend: F, max_streams: u8, num_special: u8) -> Self {
25        // Initialize a vector with capacity for regular and special streams.
26        let mut streams = Vec::with_capacity(max_streams as usize);
27        // Pre-populate the vector with None to reserve space for all streams.
28        for _ in 0..(max_streams + num_special) {
29            streams.push(None);
30        }
31
32        Self {
33            streams,
34            factory: backend,
35            max_streams: max_streams as usize,
36        }
37    }
38
39    /// Retrieves a mutable reference to a stream for a given stream ID.
40    pub fn get_mut(&mut self, stream_id: &StreamId) -> &mut F::Stream {
41        // Calculate the index for the stream ID.
42        let index = self.stream_index(stream_id);
43
44        // Use unsafe method to retrieve the stream, assuming the index is valid.
45        //
46        // # Safety
47        //
48        // * The `stream_index` function ensures the index is within bounds.
49        unsafe { self.get_mut_index(index) }
50    }
51
52    /// Retrieves a mutable reference to a stream at the specified index, initializing it if needed.
53    ///
54    /// # Safety
55    ///
56    /// * Caller must ensure the index is valid (less than `max_streams + num_special`).
57    /// * Lifetimes still follow the Rust rules.
58    pub unsafe fn get_mut_index(&mut self, index: usize) -> &mut F::Stream {
59        unsafe {
60            // Access the stream entry without bounds checking for performance.
61            let entry = self.streams.get_unchecked_mut(index);
62            match entry {
63                // If the stream exists, return it.
64                Some(val) => val,
65                // If the stream is None, create a new one using the factory.
66                None => {
67                    let stream = self.factory.create();
68                    // Store the new stream in the vector.
69                    *entry = Some(stream);
70
71                    // Re-access the entry, which is now guaranteed to be Some.
72                    match entry {
73                        Some(val) => val,
74                        // Unreachable because we just set it to Some.
75                        None => unreachable!(),
76                    }
77                }
78            }
79        }
80    }
81
82    /// Retrieves a mutable reference to a special stream at the given index.
83    ///
84    /// # Safety
85    ///
86    /// * Caller must ensure the index corresponds to a valid special stream.
87    /// * Lifetimes still follow the Rust rules.
88    pub unsafe fn get_special(&mut self, index: u8) -> &mut F::Stream {
89        // Calculate the index for the special stream (offset by max_streams).
90        unsafe { self.get_mut_index(self.max_streams + index as usize) }
91    }
92
93    /// Calculates the index for a given stream ID, mapping it to the pool's capacity.
94    pub fn stream_index(&mut self, id: &StreamId) -> usize {
95        stream_index(id, self.max_streams)
96    }
97}
98
99/// Maps a stream ID to an index within the pool's capacity using modulo arithmetic.
100pub fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
101    stream_id.value as usize % max_streams
102}