cubecl_runtime/stream/base.rs
1use alloc::vec::Vec;
2use cubecl_common::stream_id::StreamId;
3
4/// Trait for creating streams, used by the stream pool to generate streams as needed.
5pub trait StreamFactory {
6 /// The type of stream produced by this factory.
7 type Stream;
8 /// Creates a new stream instance.
9 fn create(&mut self) -> Self::Stream;
10}
11
12/// Represents a pool of streams, managing a collection of streams created by a factory.
13#[derive(Debug)]
14pub struct StreamPool<F: StreamFactory> {
15 /// Vector storing optional streams, where None indicates an uninitialized stream.
16 streams: Vec<Option<F::Stream>>,
17 /// The factory used to create new streams when needed.
18 factory: F,
19 /// Maximum number of regular streams (excludes special streams).
20 max_streams: usize,
21}
22
23impl<F: StreamFactory> StreamPool<F> {
24 /// Creates a new stream pool with the given backend factory and capacity constraints.
25 pub fn new(backend: F, max_streams: u8, num_special: u8) -> Self {
26 // Initialize a vector with capacity for regular and special streams.
27 let mut streams = Vec::with_capacity(max_streams as usize);
28 // Pre-populate the vector with None to reserve space for all streams.
29 for _ in 0..(max_streams + num_special) {
30 streams.push(None);
31 }
32
33 Self {
34 streams,
35 factory: backend,
36 max_streams: max_streams as usize,
37 }
38 }
39
40 /// Retrieves a mutable reference to a stream for a given stream ID.
41 pub fn get_mut(&mut self, stream_id: &StreamId) -> &mut F::Stream {
42 // Calculate the index for the stream ID.
43 let index = self.stream_index(stream_id);
44
45 // Use unsafe method to retrieve the stream, assuming the index is valid.
46 //
47 // # Safety
48 //
49 // * The `stream_index` function ensures the index is within bounds.
50 unsafe { self.get_mut_index(index) }
51 }
52
53 /// Retrieves a mutable reference to a stream at the specified index, initializing it if needed.
54 ///
55 /// # Safety
56 ///
57 /// * Caller must ensure the index is valid (less than `max_streams + num_special`).
58 /// * Lifetimes still follow the Rust rules.
59 pub unsafe fn get_mut_index(&mut self, index: usize) -> &mut F::Stream {
60 unsafe {
61 // Access the stream entry without bounds checking for performance.
62 let entry = self.streams.get_unchecked_mut(index);
63 match entry {
64 // If the stream exists, return it.
65 Some(val) => val,
66 // If the stream is None, create a new one using the factory.
67 None => {
68 let stream = self.factory.create();
69 // Store the new stream in the vector.
70 *entry = Some(stream);
71
72 // Re-access the entry, which is now guaranteed to be Some.
73 match entry {
74 Some(val) => val,
75 // Unreachable because we just set it to Some.
76 None => unreachable!(),
77 }
78 }
79 }
80 }
81 }
82
83 /// Retrieves a mutable reference to a special stream at the given index.
84 ///
85 /// # Safety
86 ///
87 /// * Caller must ensure the index corresponds to a valid special stream.
88 /// * Lifetimes still follow the Rust rules.
89 pub unsafe fn get_special(&mut self, index: u8) -> &mut F::Stream {
90 // Calculate the index for the special stream (offset by max_streams).
91 unsafe { self.get_mut_index(self.max_streams + index as usize) }
92 }
93
94 /// Calculates the index for a given stream ID, mapping it to the pool's capacity.
95 pub fn stream_index(&mut self, id: &StreamId) -> usize {
96 stream_index(id, self.max_streams)
97 }
98}
99
100/// Maps a stream ID to an index within the pool's capacity using modulo arithmetic.
101pub fn stream_index(stream_id: &StreamId, max_streams: usize) -> usize {
102 stream_id.value as usize % max_streams
103}