1use std::num::NonZero;
2
3use super::Compiler;
4use crate::{
5 ir::{
6 Binding, CubeDim, Elem, Id, Item, KernelDefinition, Location, ReadingStrategy, Scope,
7 UIntKind, Variable, VariableKind, Vectorization, Visibility,
8 },
9 Runtime,
10};
11
12#[derive(Clone)]
15pub struct KernelIntegrator {
16 expansion: KernelExpansion,
17 input_bindings: Vec<Binding>,
18 output_bindings: Vec<Binding>,
19 named_bindings: Vec<(String, Binding)>,
20}
21
22#[derive(Clone)]
24pub struct KernelExpansion {
25 pub inputs: Vec<InputInfo>,
26 pub outputs: Vec<OutputInfo>,
27 pub scope: Scope,
28 pub kernel_name: String,
29}
30
31#[derive(new, Default, Clone, Debug, Hash, PartialEq, Eq)]
33pub struct InplaceMapping {
34 pub pos_input: usize,
36 pub pos_output: usize,
38}
39
40#[derive(Clone, Debug, Hash, PartialEq, Eq)]
41enum VectorizationPartial {
42 Input {
43 pos: usize,
44 vectorization: Vectorization,
45 },
46 Output {
47 pos: usize,
48 vectorization: Vectorization,
49 },
50}
51
52#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
53pub struct KernelSettings {
54 pub mappings: Vec<InplaceMapping>,
55 vectorization_partial: Vec<VectorizationPartial>,
56 pub cube_dim: CubeDim,
57 pub reading_strategy: Vec<(Id, ReadingStrategy)>,
58 pub kernel_name: String,
59}
60
61impl core::fmt::Display for KernelSettings {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 f.write_str("m")?;
84 for mapping in self.mappings.iter() {
85 f.write_fmt(format_args!(
86 "i{}o{}",
87 mapping.pos_input, mapping.pos_output
88 ))?;
89 }
90
91 f.write_str("r")?;
92
93 for (input, strategy) in self.reading_strategy.iter() {
94 match strategy {
95 ReadingStrategy::OutputLayout => f.write_fmt(format_args!("i{}o", input)),
96 ReadingStrategy::Plain => f.write_fmt(format_args!("i{}p", input)),
97 }?;
98 }
99
100 for vectorization in self.vectorization_partial.iter() {
101 match vectorization {
102 VectorizationPartial::Input { pos, vectorization } => f.write_fmt(format_args!(
103 "v{}i{pos}",
104 vectorization.map(NonZero::get).unwrap_or(1)
105 ))?,
106 VectorizationPartial::Output { pos, vectorization } => f.write_fmt(
107 format_args!("v{}o{pos}", vectorization.map(NonZero::get).unwrap_or(1)),
108 )?,
109 };
110 }
111
112 f.write_fmt(format_args!(
113 "x{}y{}z{}",
114 self.cube_dim.x, self.cube_dim.y, self.cube_dim.x
115 ))
116 }
117}
118
119impl KernelSettings {
120 #[allow(dead_code)]
122 pub fn vectorize_input(mut self, position: usize, vectorization: Vectorization) -> Self {
123 if vectorization.is_none() {
126 return self;
127 }
128
129 self.vectorization_partial
130 .push(VectorizationPartial::Input {
131 pos: position,
132 vectorization,
133 });
134 self
135 }
136
137 #[allow(dead_code)]
139 pub fn vectorize_output(mut self, position: usize, vectorization: Vectorization) -> Self {
140 if vectorization.is_none() {
143 return self;
144 }
145
146 self.vectorization_partial
147 .push(VectorizationPartial::Output {
148 pos: position,
149 vectorization,
150 });
151 self
152 }
153
154 pub fn vectorization_input(&self, position: usize) -> Vectorization {
156 for partial in self.vectorization_partial.iter() {
157 if let VectorizationPartial::Input { pos, vectorization } = partial {
158 if *pos == position {
159 return *vectorization;
160 }
161 }
162 }
163
164 None
165 }
166
167 pub fn vectorization_output(&self, position: usize) -> Vectorization {
169 for partial in self.vectorization_partial.iter() {
170 if let VectorizationPartial::Output { pos, vectorization } = partial {
171 if *pos == position {
172 return *vectorization;
173 }
174 }
175 }
176
177 None
178 }
179
180 pub fn inplace(mut self, mappings: Vec<InplaceMapping>) -> Self {
187 self.mappings = mappings;
188 self
189 }
190
191 #[allow(dead_code)]
193 pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
194 self.cube_dim = cube_dim;
195 self
196 }
197
198 #[allow(dead_code)]
200 pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
201 self.kernel_name = name.as_ref().to_string();
202 self
203 }
204}
205
206#[allow(dead_code)]
207fn is_contiguous(strides: &[usize]) -> bool {
208 let mut current = 0;
209
210 for stride in strides.iter().rev() {
211 if current > *stride {
212 return false;
213 }
214 current = *stride;
215 }
216
217 true
218}
219
220#[derive(Clone, Debug)]
222pub enum InputInfo {
223 Array {
224 item: Item,
225 visibility: Visibility,
226 has_extended_meta: bool,
228 },
229 Scalar {
230 elem: Elem,
231 size: usize,
232 },
233}
234
235impl InputInfo {
236 #[allow(dead_code)]
238 pub fn item(&self) -> Item {
239 match self {
240 InputInfo::Array { item, .. } => *item,
241 InputInfo::Scalar { elem, size: _ } => Item::new(*elem),
242 }
243 }
244}
245
246impl OutputInfo {
247 #[allow(dead_code)]
249 pub fn item(&self) -> Item {
250 match self {
251 OutputInfo::ArrayWrite { item, .. } => *item,
252 OutputInfo::InputArrayWrite { item, .. } => *item,
253 OutputInfo::Array { item, .. } => *item,
254 }
255 }
256}
257
258#[derive(Clone, Debug)]
260pub enum OutputInfo {
261 ArrayWrite {
265 item: Item,
266 local: Id,
267 position: Variable,
268 has_extended_meta: bool,
270 },
271 InputArrayWrite {
273 item: Item,
274 input: Id,
275 local: Id,
276 position: Variable,
277 },
278 Array {
282 item: Item,
283 has_extended_meta: bool,
285 },
286}
287
288impl OutputInfo {
289 #[allow(dead_code)]
290 pub fn elem_size<R: Runtime>(&self) -> usize {
291 let elem = match self {
292 OutputInfo::ArrayWrite { item, .. } => bool_elem(item.elem()),
293 OutputInfo::InputArrayWrite { item, .. } => bool_elem(item.elem()),
294 OutputInfo::Array { item, .. } => bool_elem(item.elem()),
295 };
296 <R::Compiler as Compiler>::elem_size(elem)
297 }
298}
299
300impl KernelIntegrator {
301 pub fn new(info: KernelExpansion) -> Self {
303 Self {
304 expansion: info,
305 input_bindings: Default::default(),
306 output_bindings: Default::default(),
307 named_bindings: Default::default(),
308 }
309 }
310
311 pub fn integrate(mut self, mut settings: KernelSettings) -> KernelDefinition {
313 self.register_inputs(&settings);
314 self.register_outputs(&mut settings);
315
316 let inputs = self.input_bindings;
317 let outputs = self.output_bindings;
318 let mut named = Vec::with_capacity(2);
319
320 named.push((
321 "info".to_string(),
322 Binding {
323 item: Item::new(Elem::UInt(UIntKind::U32)),
324 visibility: Visibility::Read,
325 location: Location::Storage,
326 has_extended_meta: false,
327 size: None, },
330 ));
331
332 for (name, binding) in self.named_bindings.into_iter() {
333 named.push((name, binding));
334 }
335
336 KernelDefinition {
337 inputs,
338 outputs,
339 named,
340 cube_dim: settings.cube_dim,
341 body: self.expansion.scope,
342 kernel_name: self.expansion.kernel_name,
343 }
344 }
345
346 fn register_inputs(&mut self, settings: &KernelSettings) {
347 for (id, strategy) in settings.reading_strategy.iter() {
348 self.expansion.scope.update_read(*id, *strategy);
349 }
350
351 for input in self.expansion.inputs.drain(..) {
352 match input {
353 InputInfo::Array {
354 item,
355 visibility,
356 has_extended_meta,
357 } => {
358 self.input_bindings.push(Binding {
359 item: bool_item(item),
360 visibility,
361 location: Location::Storage,
362 has_extended_meta,
363 size: None,
364 });
365 }
366 InputInfo::Scalar { elem, size } => {
367 let elem = bool_elem(elem);
368
369 self.named_bindings.push((
370 format!("scalars_{}", elem),
371 Binding {
372 item: Item::new(elem),
373 visibility: Visibility::Read,
374 location: Location::Storage,
375 has_extended_meta: false,
376 size: Some(size),
377 },
378 ));
379 }
380 }
381 }
382 }
383
384 fn register_outputs(&mut self, settings: &mut KernelSettings) {
385 let mut index = 0;
386
387 if !settings.mappings.is_empty() {
388 let mut mappings = Vec::new();
389 core::mem::swap(&mut settings.mappings, &mut mappings);
390
391 for mapping in mappings {
392 self.register_inplace_mapping(mapping);
393 }
394 }
395
396 for array in self.expansion.outputs.drain(..) {
397 match array {
398 OutputInfo::ArrayWrite {
399 item,
400 local,
401 position,
402 has_extended_meta,
403 } => {
404 let item_adapted = bool_item(item);
405
406 self.output_bindings.push(Binding {
407 item: item_adapted,
408 visibility: Visibility::ReadWrite,
409 location: Location::Storage,
410 has_extended_meta,
411 size: None,
412 });
413 self.expansion.scope.write_global(
414 Variable::new(VariableKind::LocalMut { id: local }, item),
415 Variable::new(VariableKind::GlobalOutputArray(index), item_adapted),
416 position,
417 );
418 index += 1;
419 }
420 OutputInfo::InputArrayWrite {
421 item,
422 input,
423 local,
424 position,
425 } => {
426 self.expansion.scope.write_global(
427 Variable::new(VariableKind::LocalMut { id: local }, item),
428 Variable::new(VariableKind::GlobalInputArray(input), bool_item(item)),
429 position,
430 );
431 }
432 OutputInfo::Array {
433 item,
434 has_extended_meta,
435 } => {
436 let elem_adapted = bool_item(item);
437
438 self.output_bindings.push(Binding {
439 item: elem_adapted,
440 visibility: Visibility::ReadWrite,
441 location: Location::Storage,
442 has_extended_meta,
443 size: None,
444 });
445
446 index += 1;
447 }
448 }
449 }
450 }
451
452 fn register_inplace_mapping(&mut self, mapping: InplaceMapping) {
453 let output = match self.expansion.outputs.get_mut(mapping.pos_output) {
454 Some(output) => output,
455 None => {
456 if let Some(binding) = self.input_bindings.get_mut(mapping.pos_input) {
457 binding.visibility = Visibility::ReadWrite;
459 }
460
461 return;
463 }
464 };
465
466 let (item, local, position) = match output {
467 OutputInfo::ArrayWrite { item, local, position, .. } => (item, local, position),
468 OutputInfo::InputArrayWrite {
469 item: _,
470 input,
471 local: _,
472 position: _,
473 } => {
474 assert_eq!(
475 *input, mapping.pos_input as Id,
476 "Can't use different inputs for the same output."
477 );
478 return;
479 }
480 OutputInfo::Array { .. } => panic!("Can't register an inplace operation for an array that isn't using a defined writing strategy."),
481 };
482
483 let item = match self.input_bindings.get_mut(mapping.pos_input) {
484 Some(binding) => {
485 binding.visibility = Visibility::ReadWrite;
487 self.expansion
489 .scope
490 .update_read(mapping.pos_input as Id, ReadingStrategy::Plain);
491
492 binding.item
496 }
497 None => *item,
498 };
499
500 *output = OutputInfo::InputArrayWrite {
502 item,
503 input: mapping.pos_input as Id,
504 local: *local,
505 position: *position,
506 };
507 }
508}
509
510fn bool_item(ty: Item) -> Item {
511 Item {
512 elem: bool_elem(ty.elem),
513 vectorization: ty.vectorization,
514 }
515}
516
517pub fn bool_elem(elem: Elem) -> Elem {
518 match elem {
519 Elem::Bool => Elem::UInt(UIntKind::U32),
521 _ => elem,
522 }
523}