cubecl_core/frontend/operation/
fma.rs

1use crate::{prelude::*, unexpanded};
2
3/// Fused multiply-add `A*B+C`.
4#[allow(unused_variables)]
5pub fn fma<C: CubePrimitive>(a: C, b: C, c: C) -> C {
6    unexpanded!()
7}
8
9/// Expand method of [fma].
10pub mod fma {
11    use super::*;
12    use cubecl_ir::{Arithmetic, FmaOperator, Instruction, Scope};
13
14    pub fn expand<C: CubePrimitive>(
15        scope: &mut Scope,
16        a: ExpandElementTyped<C>,
17        b: ExpandElementTyped<C>,
18        c: ExpandElementTyped<C>,
19    ) -> ExpandElementTyped<C> {
20        let output = scope.create_local(a.expand.ty);
21        let out = *output;
22        let a = *a.expand;
23        let b = *b.expand;
24        let c = *c.expand;
25
26        scope.register(Instruction::new(
27            Arithmetic::Fma(FmaOperator { a, b, c }),
28            out,
29        ));
30
31        output.into()
32    }
33}