cubecl_core/codegen/
integrator.rs

1use std::num::NonZero;
2
3use super::Compiler;
4use crate::{
5    ir::{
6        Binding, CubeDim, Elem, Id, Item, KernelDefinition, Location, ReadingStrategy, Scope,
7        UIntKind, Variable, VariableKind, Vectorization, Visibility,
8    },
9    Runtime,
10};
11
12/// The kernel integrator allows you to create a [kernel definition](KernelDefinition) based on
13/// [kernel expansion](KernelExpansion) and [kernel settings](KernelSettings).
14#[derive(Clone)]
15pub struct KernelIntegrator {
16    expansion: KernelExpansion,
17    input_bindings: Vec<Binding>,
18    output_bindings: Vec<Binding>,
19    named_bindings: Vec<(String, Binding)>,
20}
21
22/// The information necessary to compile a [kernel definition](KernelDefinition).
23#[derive(Clone)]
24pub struct KernelExpansion {
25    pub inputs: Vec<InputInfo>,
26    pub outputs: Vec<OutputInfo>,
27    pub scope: Scope,
28    pub kernel_name: String,
29}
30
31/// Simply indicate the output that can be replaced by the input.
32#[derive(new, Default, Clone, Debug, Hash, PartialEq, Eq)]
33pub struct InplaceMapping {
34    /// Input position.
35    pub pos_input: usize,
36    /// Output position.
37    pub pos_output: usize,
38}
39
40#[derive(Clone, Debug, Hash, PartialEq, Eq)]
41enum VectorizationPartial {
42    Input {
43        pos: usize,
44        vectorization: Vectorization,
45    },
46    Output {
47        pos: usize,
48        vectorization: Vectorization,
49    },
50}
51
52#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
53pub struct KernelSettings {
54    pub mappings: Vec<InplaceMapping>,
55    vectorization_partial: Vec<VectorizationPartial>,
56    pub cube_dim: CubeDim,
57    pub reading_strategy: Vec<(Id, ReadingStrategy)>,
58    pub kernel_name: String,
59}
60
61impl core::fmt::Display for KernelSettings {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        // The goal of this implementation is to generate the shortest representation
64        // that won't clash with any other compilation settings. This is crucial since we rely on
65        // this representation to know when to compile a new version of a kernel.
66        //
67        // Each main section starts with a letter that can't be used by other main sections:
68        //
69        // * Mapping:          m
70        //   * Input:  i
71        //   * Output: o
72        //
73        // * Reading Strategy: r
74        //   * Output layout: o
75        //   * Plain:         p
76        //
77        // * Vectorization Global:    vg{factor}
78        // * Vectorization Partial Input:    v{factor}i{pos}
79        // * Vectorization Partial Output:    vo
80        // * Cube Dim X: x
81        // * Cube Dim Y: y
82        // * Cube Dim Z: z
83        f.write_str("m")?;
84        for mapping in self.mappings.iter() {
85            f.write_fmt(format_args!(
86                "i{}o{}",
87                mapping.pos_input, mapping.pos_output
88            ))?;
89        }
90
91        f.write_str("r")?;
92
93        for (input, strategy) in self.reading_strategy.iter() {
94            match strategy {
95                ReadingStrategy::OutputLayout => f.write_fmt(format_args!("i{}o", input)),
96                ReadingStrategy::Plain => f.write_fmt(format_args!("i{}p", input)),
97            }?;
98        }
99
100        for vectorization in self.vectorization_partial.iter() {
101            match vectorization {
102                VectorizationPartial::Input { pos, vectorization } => f.write_fmt(format_args!(
103                    "v{}i{pos}",
104                    vectorization.map(NonZero::get).unwrap_or(1)
105                ))?,
106                VectorizationPartial::Output { pos, vectorization } => f.write_fmt(
107                    format_args!("v{}o{pos}", vectorization.map(NonZero::get).unwrap_or(1)),
108                )?,
109            };
110        }
111
112        f.write_fmt(format_args!(
113            "x{}y{}z{}",
114            self.cube_dim.x, self.cube_dim.y, self.cube_dim.x
115        ))
116    }
117}
118
119impl KernelSettings {
120    /// Compile the shader with vectorization enabled for an input.
121    #[allow(dead_code)]
122    pub fn vectorize_input(mut self, position: usize, vectorization: Vectorization) -> Self {
123        // Not setting the vectorization factor when it's the default value reduces the kernel id
124        // size.
125        if vectorization.is_none() {
126            return self;
127        }
128
129        self.vectorization_partial
130            .push(VectorizationPartial::Input {
131                pos: position,
132                vectorization,
133            });
134        self
135    }
136
137    /// Compile the shader with vectorization enabled for an output.
138    #[allow(dead_code)]
139    pub fn vectorize_output(mut self, position: usize, vectorization: Vectorization) -> Self {
140        // Not setting the vectorization factor when it's the default value reduces the kernel id
141        // size.
142        if vectorization.is_none() {
143            return self;
144        }
145
146        self.vectorization_partial
147            .push(VectorizationPartial::Output {
148                pos: position,
149                vectorization,
150            });
151        self
152    }
153
154    /// Fetch the vectorization for the provided input position.
155    pub fn vectorization_input(&self, position: usize) -> Vectorization {
156        for partial in self.vectorization_partial.iter() {
157            if let VectorizationPartial::Input { pos, vectorization } = partial {
158                if *pos == position {
159                    return *vectorization;
160                }
161            }
162        }
163
164        None
165    }
166
167    /// Fetch the vectorization for the provided output position.
168    pub fn vectorization_output(&self, position: usize) -> Vectorization {
169        for partial in self.vectorization_partial.iter() {
170            if let VectorizationPartial::Output { pos, vectorization } = partial {
171                if *pos == position {
172                    return *vectorization;
173                }
174            }
175        }
176
177        None
178    }
179
180    /// Compile the shader with inplace enabled by the given [mapping](InplaceMapping).
181    ///
182    /// Notes:
183    ///
184    /// You should favor using `dynamic_settings` when using fusion, since the mapping is going to
185    /// be created from the runtime information.
186    pub fn inplace(mut self, mappings: Vec<InplaceMapping>) -> Self {
187        self.mappings = mappings;
188        self
189    }
190
191    /// Set cube dimension.
192    #[allow(dead_code)]
193    pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
194        self.cube_dim = cube_dim;
195        self
196    }
197
198    /// Set kernel name.
199    #[allow(dead_code)]
200    pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
201        self.kernel_name = name.as_ref().to_string();
202        self
203    }
204}
205
206#[allow(dead_code)]
207fn is_contiguous(strides: &[usize]) -> bool {
208    let mut current = 0;
209
210    for stride in strides.iter().rev() {
211        if current > *stride {
212            return false;
213        }
214        current = *stride;
215    }
216
217    true
218}
219
220/// Information related to an input.
221#[derive(Clone, Debug)]
222pub enum InputInfo {
223    Array {
224        item: Item,
225        visibility: Visibility,
226        /// Whether this input has extended metadata (rank, shape, strides)
227        has_extended_meta: bool,
228    },
229    Scalar {
230        elem: Elem,
231        size: usize,
232    },
233}
234
235impl InputInfo {
236    /// The item type of the input.
237    #[allow(dead_code)]
238    pub fn item(&self) -> Item {
239        match self {
240            InputInfo::Array { item, .. } => *item,
241            InputInfo::Scalar { elem, size: _ } => Item::new(*elem),
242        }
243    }
244}
245
246impl OutputInfo {
247    /// The item type of the input.
248    #[allow(dead_code)]
249    pub fn item(&self) -> Item {
250        match self {
251            OutputInfo::ArrayWrite { item, .. } => *item,
252            OutputInfo::InputArrayWrite { item, .. } => *item,
253            OutputInfo::Array { item, .. } => *item,
254        }
255    }
256}
257
258/// Information related to an output.
259#[derive(Clone, Debug)]
260pub enum OutputInfo {
261    /// Write the local variable to a new array.
262    ///
263    /// This will create a new binding in the [kernel definition](KernelDefinition).
264    ArrayWrite {
265        item: Item,
266        local: Id,
267        position: Variable,
268        /// Whether this output has extended metadata (rank, shape, strides)
269        has_extended_meta: bool,
270    },
271    /// Write the local variable to an existing input binding.
272    InputArrayWrite {
273        item: Item,
274        input: Id,
275        local: Id,
276        position: Variable,
277    },
278    /// Simply register the output, but don't automatically add a write to it.
279    ///
280    /// Useful when a procedure writes to the output using operations.
281    Array {
282        item: Item,
283        /// Whether this output has extended metadata (rank, shape, strides)
284        has_extended_meta: bool,
285    },
286}
287
288impl OutputInfo {
289    #[allow(dead_code)]
290    pub fn elem_size<R: Runtime>(&self) -> usize {
291        let elem = match self {
292            OutputInfo::ArrayWrite { item, .. } => bool_elem(item.elem()),
293            OutputInfo::InputArrayWrite { item, .. } => bool_elem(item.elem()),
294            OutputInfo::Array { item, .. } => bool_elem(item.elem()),
295        };
296        <R::Compiler as Compiler>::elem_size(elem)
297    }
298}
299
300impl KernelIntegrator {
301    /// Starts a new compilation.
302    pub fn new(info: KernelExpansion) -> Self {
303        Self {
304            expansion: info,
305            input_bindings: Default::default(),
306            output_bindings: Default::default(),
307            named_bindings: Default::default(),
308        }
309    }
310
311    /// Performs the compilation with the provided [settings](KernelSettings).
312    pub fn integrate(mut self, mut settings: KernelSettings) -> KernelDefinition {
313        self.register_inputs(&settings);
314        self.register_outputs(&mut settings);
315
316        let inputs = self.input_bindings;
317        let outputs = self.output_bindings;
318        let mut named = Vec::with_capacity(2);
319
320        named.push((
321            "info".to_string(),
322            Binding {
323                item: Item::new(Elem::UInt(UIntKind::U32)),
324                visibility: Visibility::Read,
325                location: Location::Storage,
326                has_extended_meta: false,
327                size: None, // We avoid putting the length here since it will force a new kernel
328                            // for each tensor rank.
329            },
330        ));
331
332        for (name, binding) in self.named_bindings.into_iter() {
333            named.push((name, binding));
334        }
335
336        KernelDefinition {
337            inputs,
338            outputs,
339            named,
340            cube_dim: settings.cube_dim,
341            body: self.expansion.scope,
342            kernel_name: self.expansion.kernel_name,
343        }
344    }
345
346    fn register_inputs(&mut self, settings: &KernelSettings) {
347        for (id, strategy) in settings.reading_strategy.iter() {
348            self.expansion.scope.update_read(*id, *strategy);
349        }
350
351        for input in self.expansion.inputs.drain(..) {
352            match input {
353                InputInfo::Array {
354                    item,
355                    visibility,
356                    has_extended_meta,
357                } => {
358                    self.input_bindings.push(Binding {
359                        item: bool_item(item),
360                        visibility,
361                        location: Location::Storage,
362                        has_extended_meta,
363                        size: None,
364                    });
365                }
366                InputInfo::Scalar { elem, size } => {
367                    let elem = bool_elem(elem);
368
369                    self.named_bindings.push((
370                        format!("scalars_{}", elem),
371                        Binding {
372                            item: Item::new(elem),
373                            visibility: Visibility::Read,
374                            location: Location::Storage,
375                            has_extended_meta: false,
376                            size: Some(size),
377                        },
378                    ));
379                }
380            }
381        }
382    }
383
384    fn register_outputs(&mut self, settings: &mut KernelSettings) {
385        let mut index = 0;
386
387        if !settings.mappings.is_empty() {
388            let mut mappings = Vec::new();
389            core::mem::swap(&mut settings.mappings, &mut mappings);
390
391            for mapping in mappings {
392                self.register_inplace_mapping(mapping);
393            }
394        }
395
396        for array in self.expansion.outputs.drain(..) {
397            match array {
398                OutputInfo::ArrayWrite {
399                    item,
400                    local,
401                    position,
402                    has_extended_meta,
403                } => {
404                    let item_adapted = bool_item(item);
405
406                    self.output_bindings.push(Binding {
407                        item: item_adapted,
408                        visibility: Visibility::ReadWrite,
409                        location: Location::Storage,
410                        has_extended_meta,
411                        size: None,
412                    });
413                    self.expansion.scope.write_global(
414                        Variable::new(VariableKind::LocalMut { id: local }, item),
415                        Variable::new(VariableKind::GlobalOutputArray(index), item_adapted),
416                        position,
417                    );
418                    index += 1;
419                }
420                OutputInfo::InputArrayWrite {
421                    item,
422                    input,
423                    local,
424                    position,
425                } => {
426                    self.expansion.scope.write_global(
427                        Variable::new(VariableKind::LocalMut { id: local }, item),
428                        Variable::new(VariableKind::GlobalInputArray(input), bool_item(item)),
429                        position,
430                    );
431                }
432                OutputInfo::Array {
433                    item,
434                    has_extended_meta,
435                } => {
436                    let elem_adapted = bool_item(item);
437
438                    self.output_bindings.push(Binding {
439                        item: elem_adapted,
440                        visibility: Visibility::ReadWrite,
441                        location: Location::Storage,
442                        has_extended_meta,
443                        size: None,
444                    });
445
446                    index += 1;
447                }
448            }
449        }
450    }
451
452    fn register_inplace_mapping(&mut self, mapping: InplaceMapping) {
453        let output = match self.expansion.outputs.get_mut(mapping.pos_output) {
454            Some(output) => output,
455            None => {
456                if let Some(binding) = self.input_bindings.get_mut(mapping.pos_input) {
457                    // Update input visibility.
458                    binding.visibility = Visibility::ReadWrite;
459                }
460
461                // The mapping is handled differently, normally by cube itself.
462                return;
463            }
464        };
465
466        let (item, local, position) = match output {
467            OutputInfo::ArrayWrite { item, local, position, .. } => (item, local, position),
468            OutputInfo::InputArrayWrite {
469                item: _,
470                input,
471                local: _,
472                position: _,
473            } => {
474                assert_eq!(
475                    *input, mapping.pos_input as Id,
476                    "Can't use different inputs for the same output."
477                );
478                return;
479            }
480            OutputInfo::Array { .. } => panic!("Can't register an inplace operation for an array that isn't using a defined writing strategy."),
481        };
482
483        let item = match self.input_bindings.get_mut(mapping.pos_input) {
484            Some(binding) => {
485                // Update input visibility.
486                binding.visibility = Visibility::ReadWrite;
487                // Inputs modified inplace should be read without any specified layout.
488                self.expansion
489                    .scope
490                    .update_read(mapping.pos_input as Id, ReadingStrategy::Plain);
491
492                // Use the same item as the input.
493                //
494                // The output can be different (i.e inplace boolean operations on float bindings).
495                binding.item
496            }
497            None => *item,
498        };
499
500        // Update the output.
501        *output = OutputInfo::InputArrayWrite {
502            item,
503            input: mapping.pos_input as Id,
504            local: *local,
505            position: *position,
506        };
507    }
508}
509
510fn bool_item(ty: Item) -> Item {
511    Item {
512        elem: bool_elem(ty.elem),
513        vectorization: ty.vectorization,
514    }
515}
516
517pub fn bool_elem(elem: Elem) -> Elem {
518    match elem {
519        // U32 are used for bool tensors
520        Elem::Bool => Elem::UInt(UIntKind::U32),
521        _ => elem,
522    }
523}