cubecl_core/frontend/
pipeline.rs

1//! This module exposes pipelining utilities for multi-stage asynchronous data copies
2//! with latency hiding.
3//! We call producers all threads that call producer_acquire and producer_commit,
4//! and consumers threads that call consumer_wait and consumer_release.
5//!
6//! # Example
7//! In this example, threads play both the role of producer and consumer
8//!
9//! ```rust, ignore
10//! #[cube(launch)]
11//! /// Calculate the sum of an array, using pipelining
12//! fn pipelined_sum<F: Float>(
13//!     input: &Array<Line<F>>,
14//!     output: &mut Array<Line<F>>,
15//!     #[comptime] batch_len: u32,
16//! ) {
17//!     let smem_size = 2 * batch_len;
18//!     let num_batches = input.len() / batch_len;
19//!     let mut shared_memory = SharedMemory::<F>::new_lined(smem_size, input.line_size());
20//!     let pipeline = Pipeline::new();
21//!
22//!     let mut sum = Line::<F>::empty(input.line_size()).fill(F::new(0.));
23//!
24//!     // Copy the first batch to shared memory
25//!     pipeline.producer_acquire();
26//!     pipeline.memcpy_async(
27//!         input.slice(0, batch_len),
28//!         shared_memory.slice_mut(0, batch_len),
29//!     );
30//!     pipeline.producer_commit();
31//!
32//!     for input_batch in 1..num_batches {
33//!         // Copy and compute index always alternate
34//!         let copy_index = input_batch % 2;
35//!         let compute_index = (input_batch + 1) % 2;
36//!
37//!         // Copy the next batch to shared memory
38//!         pipeline.producer_acquire();
39//!         pipeline.memcpy_async(
40//!             input.slice(batch_len * input_batch, batch_len * (input_batch + 1)),
41//!             shared_memory.slice_mut(batch_len * copy_index, batch_len * (copy_index + 1)),
42//!         );
43//!         pipeline.producer_commit();
44//!
45//!         // Compute the batch that is ready
46//!         pipeline.consumer_wait();
47//!         let compute_slice =
48//!             shared_memory.slice(batch_len * compute_index, batch_len * (compute_index + 1));
49//!         for i in 0..batch_len {
50//!             sum += compute_slice[i];
51//!         }
52//!         pipeline.consumer_release();
53//!     }
54//!
55//!     // Compute the last batch
56//!     pipeline.consumer_wait();
57//!     let compute_slice = shared_memory.slice(
58//!         batch_len * ((num_batches + 1) % 2),
59//!         batch_len * ((num_batches + 1) % 2 + 1),
60//!     );
61//!     for i in 0..batch_len {
62//!         sum += compute_slice[i];
63//!     }
64//!     pipeline.consumer_release();
65//!
66//!     output[0] = sum;
67//! }
68//! ```
69
70use std::marker::PhantomData;
71
72use cubecl_ir::ExpandElement;
73
74use crate::{
75    ir::{Item, PipelineOps, Scope},
76    unexpanded,
77};
78
79use super::{CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, Init, Line, Slice, SliceMut};
80
81/// A mechanism for managing a sequence of `memcpy_async`
82/// For now, it only works at the Cube scope
83#[derive(Clone, Copy)]
84pub struct Pipeline<C: CubePrimitive> {
85    _c: PhantomData<C>,
86}
87
88impl<C: CubePrimitive> CubeType for Pipeline<C> {
89    type ExpandType = PipelineExpand<C>;
90}
91
92impl<C: CubePrimitive> Init for PipelineExpand<C> {
93    fn init(self, _scope: &mut Scope) -> Self {
94        self
95    }
96}
97
98impl<C: CubePrimitive> CubeDebug for PipelineExpand<C> {
99    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
100        scope.update_variable_name(*self.elem, name);
101    }
102}
103
104#[derive(Clone)]
105/// Expand type of [Pipeline]
106pub struct PipelineExpand<C: CubePrimitive> {
107    elem: ExpandElement,
108    _c: PhantomData<C>,
109}
110
111impl<C: CubePrimitive> Default for Pipeline<C> {
112    fn default() -> Self {
113        Self::new(1)
114    }
115}
116
117impl<C: CubePrimitive> Pipeline<C> {
118    /// Create a pipeline instance
119    pub fn new(_num_stages: u8) -> Self {
120        Self { _c: PhantomData }
121    }
122
123    /// Copy the source slice to destination
124    ///
125    /// # Safety
126    ///
127    /// This will try to copy the whole source slice, so
128    /// make sure source length <= destination length
129    pub fn memcpy_async(&self, _source: &Slice<Line<C>>, _destination: &mut SliceMut<Line<C>>) {
130        unexpanded!()
131    }
132
133    /// Reserves a specific stage for the producer to work on.
134    pub fn producer_acquire(&self) {
135        unexpanded!()
136    }
137
138    /// Signals that the producer is done and the stage is ready for the consumer.
139    pub fn producer_commit(&self) {
140        unexpanded!()
141    }
142
143    /// Waits until the producer has finished with the stage.
144    pub fn consumer_wait(&self) {
145        unexpanded!()
146    }
147
148    /// Frees the stage after the consumer is done using it.
149    pub fn consumer_release(&self) {
150        unexpanded!()
151    }
152
153    pub fn __expand_new(scope: &mut Scope, num_stages: u8) -> PipelineExpand<C> {
154        let elem = C::as_elem(scope);
155        let variable = scope.create_pipeline(Item::new(elem), num_stages);
156        PipelineExpand {
157            elem: variable,
158            _c: PhantomData,
159        }
160    }
161
162    pub fn __expand_memcpy_async(
163        scope: &mut Scope,
164        expand: PipelineExpand<C>,
165        source: ExpandElementTyped<Slice<Line<C>>>,
166        destination: ExpandElementTyped<SliceMut<Line<C>>>,
167    ) {
168        expand.__expand_memcpy_async_method(scope, source, destination);
169    }
170
171    pub fn __expand_producer_acquire(scope: &mut Scope, expand: PipelineExpand<C>) {
172        expand.__expand_producer_acquire_method(scope);
173    }
174
175    pub fn __expand_producer_commit(scope: &mut Scope, expand: PipelineExpand<C>) {
176        expand.__expand_producer_commit_method(scope);
177    }
178
179    pub fn __expand_consumer_wait(scope: &mut Scope, expand: PipelineExpand<C>) {
180        expand.__expand_consumer_wait_method(scope);
181    }
182
183    pub fn __expand_consumer_release(scope: &mut Scope, expand: PipelineExpand<C>) {
184        expand.__expand_consumer_release_method(scope);
185    }
186}
187
188impl<C: CubePrimitive> PipelineExpand<C> {
189    pub fn __expand_memcpy_async_method(
190        &self,
191        scope: &mut Scope,
192        source: ExpandElementTyped<Slice<Line<C>>>,
193        destination: ExpandElementTyped<SliceMut<Line<C>>>,
194    ) {
195        let pipeline = *self.elem;
196        let source = *source.expand;
197        let destination = *destination.expand;
198
199        let mem_copy = PipelineOps::MemCopyAsync {
200            pipeline,
201            source,
202            destination,
203        };
204
205        scope.register(mem_copy);
206    }
207
208    pub fn __expand_producer_acquire_method(&self, scope: &mut Scope) {
209        let pipeline = *self.elem;
210        scope.register(PipelineOps::ProducerAcquire { pipeline });
211    }
212    pub fn __expand_producer_commit_method(&self, scope: &mut Scope) {
213        let pipeline = *self.elem;
214        scope.register(PipelineOps::ProducerCommit { pipeline });
215    }
216    pub fn __expand_consumer_wait_method(&self, scope: &mut Scope) {
217        let pipeline = *self.elem;
218        scope.register(PipelineOps::ConsumerWait { pipeline });
219    }
220    pub fn __expand_consumer_release_method(&self, scope: &mut Scope) {
221        let pipeline = *self.elem;
222        scope.register(PipelineOps::ConsumerRelease { pipeline });
223    }
224}