cubecl_runtime/stream/
base.rs

1use alloc::vec::Vec;
2use cubecl_common::stream_id::StreamId;
3
4/// Trait for creating streams, used by the stream pool to generate streams as needed.
5pub trait StreamFactory {
6    /// The type of stream produced by this factory.
7    type Stream;
8    /// Creates a new stream instance.
9    fn create(&mut self) -> Self::Stream;
10}
11
12/// Represents a pool of streams, managing a collection of streams created by a factory.
13#[derive(Debug)]
14pub struct StreamPool<F: StreamFactory> {
15    /// Vector storing optional streams, where None indicates an uninitialized stream.
16    streams: Vec<Option<F::Stream>>,
17    /// The factory used to create new streams when needed.
18    factory: F,
19    /// Maximum number of regular streams (excludes special streams).
20    max_streams: usize,
21}
22
23impl<F: StreamFactory> StreamPool<F> {
24    /// Creates a new stream pool with the given backend factory and capacity constraints.
25    pub fn new(backend: F, max_streams: u8, num_special: u8) -> Self {
26        // Initialize a vector with capacity for regular and special streams.
27        let mut streams = Vec::with_capacity(max_streams as usize);
28        // Pre-populate the vector with None to reserve space for all streams.
29        for _ in 0..(max_streams + num_special) {
30            streams.push(None);
31        }
32
33        Self {
34            streams,
35            factory: backend,
36            max_streams: max_streams as usize,
37        }
38    }
39
40    /// Retrieves a mutable reference to a stream for a given stream ID.
41    pub fn get_mut(&mut self, stream_id: &StreamId) -> &mut F::Stream {
42        // Calculate the index for the stream ID.
43        let index = self.stream_index(stream_id);
44
45        // Use unsafe method to retrieve the stream, assuming the index is valid.
46        //
47        // # Safety
48        //
49        // * The `stream_index` function ensures the index is within bounds.
50        unsafe { self.get_mut_index(index) }
51    }
52
53    /// Retrieves a mutable reference to a stream at the specified index, initializing it if needed.
54    ///
55    /// # Safety
56    ///
57    /// * Caller must ensure the index is valid (less than `max_streams + num_special`).
58    /// * Lifetimes still follow the Rust rules.
59    pub unsafe fn get_mut_index(&mut self, index: usize) -> &mut F::Stream {
60        unsafe {
61            // Access the stream entry without bounds checking for performance.
62            let entry = self.streams.get_unchecked_mut(index);
63            match entry {
64                // If the stream exists, return it.
65                Some(val) => val,
66                // If the stream is None, create a new one using the factory.
67                None => {
68                    let stream = self.factory.create();
69                    // Store the new stream in the vector.
70                    *entry = Some(stream);
71
72                    // Re-access the entry, which is now guaranteed to be Some.
73                    match entry {
74                        Some(val) => val,
75                        // Unreachable because we just set it to Some.
76                        None => unreachable!(),
77                    }
78                }
79            }
80        }
81    }
82
83    /// Retrieves a mutable reference to a special stream at the given index.
84    ///
85    /// # Safety
86    ///
87    /// * Caller must ensure the index corresponds to a valid special stream.
88    /// * Lifetimes still follow the Rust rules.
89    pub unsafe fn get_special(&mut self, index: u8) -> &mut F::Stream {
90        // Calculate the index for the special stream (offset by max_streams).
91        unsafe { self.get_mut_index(self.max_streams + index as usize) }
92    }
93
94    /// Calculates the index for a given stream ID, mapping it to the pool's capacity.
95    pub fn stream_index(&mut self, id: &StreamId) -> usize {
96        stream_index(id, self.max_streams)
97    }
98}
99
100/// Maps a stream ID to an index within the pool's capacity using modulo arithmetic.
101pub fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
102    stream_id.value as usize % max_streams
103}