Skip to main content

tract_cuda/
transform.rs

1use DatumType::{F16, F32};
2use tract_core::dyn_clone::clone_box;
3use tract_core::internal::*;
4use tract_core::model::translator::Translate;
5use tract_core::ops::array::{MultiBroadcastTo, Slice, TypedConcat};
6use tract_core::ops::binary::TypedBinOp;
7use tract_core::ops::cast::Cast;
8use tract_core::ops::cnn::conv::rewrite_kernel_conv_in_oihw;
9use tract_core::ops::cnn::{Conv, rewrite_conv_with_n_axis};
10use tract_core::ops::einsum::prefix_matmul::{PrefixMatMul, rewrite_einsum_to_prefix_matmul};
11use tract_core::ops::element_wise::ElementWiseOp;
12use tract_core::ops::konst::Const;
13use tract_core::ops::logic::Comp;
14use tract_core::ops::nn::{LeakyRelu, Reduce, Softmax};
15use tract_core::tract_data::itertools::Itertools;
16use tract_core::tract_linalg::block_quant::Q4_0;
17use tract_core::transform::ModelTransform;
18use tract_gpu::fact::{DeviceFact, DeviceTypedFactExt};
19use tract_gpu::rewrite_rules::rewire_syncs::rewire_syncs;
20use tract_gpu::rewrite_rules::rms_norm::remove_rms_norm_cast;
21use tract_gpu::sync::{DeviceSync, DeviceSyncKind};
22use tract_gpu::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice};
23use tract_gpu::utils::as_quant_fact;
24use tract_pulse_opl::ops::{Delay, PulsePad};
25use tract_transformers::ops::apply_rope::{ApplyRope, RotateHalf};
26use tract_transformers::ops::dyn_kv_cache::DynKeyValueCache;
27use tract_transformers::ops::gelu_approximate::GeluApproximate;
28use tract_transformers::ops::rms_norm::RmsNorm;
29use tract_transformers::ops::scaled_masked_softmax::ScaledMaskedSoftmax;
30use tract_transformers::ops::sdpa::Sdpa;
31use tract_transformers::ops::silu::Silu;
32
33use crate::context::cuda_context;
34use crate::ops::{CudaDelay, CudaPulsePad};
35use crate::ops::{CudaLeakyRelu, wire_cuda_conv};
36use crate::{kernels, ops, rewrite_rules};
37
38#[derive(Debug, Default)]
39pub struct CudaTransform;
40
41impl ModelTransform for CudaTransform {
42    fn name(&self) -> StaticName {
43        "cuda-transform".into()
44    }
45
46    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
47        self.transform_up_to_phase(model, usize::MAX)
48    }
49}
50
51impl CudaTransform {
52    pub fn transform_up_to_phase(
53        &self,
54        model: &mut TypedModel,
55        stop_at_phase: usize,
56    ) -> TractResult<()> {
57        // Init CUDA Context if not done previously
58        cuda_context();
59
60        rewrite_einsum_to_prefix_matmul(model, false)?;
61        if stop_at_phase == 0 {
62            return Ok(());
63        }
64
65        Rewriter::default()
66            .with_rule_for("untranspose_matmul_output", rewrite_rules::untranspose_matmul_output)
67            .with_rule_for("add_broadcast_pre_matmul", rewrite_rules::add_broadcast_pre_matmul)
68            .with_rule_for("rewrite_kernel_conv_in_oihw", rewrite_kernel_conv_in_oihw)
69            .with_rule_for("rewrite_conv_with_n_axis", rewrite_conv_with_n_axis)
70            .rewrite(&(), model)?;
71
72        Rewriter::default()
73            .with_rule_for("remove_rms_norm_cast", remove_rms_norm_cast)
74            .rewrite(&(), model)?;
75
76        if stop_at_phase == 1 {
77            return Ok(());
78        }
79
80        *model = self.translate_model(model)?;
81
82        if stop_at_phase == 2 {
83            return Ok(());
84        }
85
86        Rewriter::default()
87            .with_rule_for("fuse_move_axis", rewrite_rules::fuse_move_axis)
88            .rewrite(&(), model)?;
89        Rewriter::default()
90            .with_rule_for("fuse_axis_op", rewrite_rules::fuse_axis_op)
91            .rewrite(&(), model)?;
92
93        rewire_syncs(model)?;
94
95        Rewriter::default()
96            .with_rule_for("pad_q40_weights", rewrite_rules::pad_q40_weights)
97            .rewrite(&(), model)?;
98        Ok(())
99    }
100
101    fn sync_inputs_if_required(
102        &self,
103        model: &mut TypedModel,
104        node: &TypedNode,
105        mapping: &HashMap<OutletId, OutletId>,
106        sync_kind: DeviceSyncKind,
107    ) -> TractResult<TVec<OutletId>> {
108        let mut mapped_inputs = tvec![];
109        for (i_idx, i) in node.inputs.iter().enumerate() {
110            let in_fact = model.outlet_fact_mut(mapping[i])?;
111            match sync_kind {
112                DeviceSyncKind::ToHost if in_fact.as_device_fact().is_some() => {
113                    mapped_inputs.push(
114                        model.wire_node(
115                            format!("{}.to-cpu-{i_idx}", node.name),
116                            DeviceSync::new(sync_kind),
117                            &[mapping[i]],
118                        )?[0],
119                    );
120                }
121                DeviceSyncKind::ToDevice if in_fact.as_device_fact().is_none() => {
122                    if let Some(ref konst) = in_fact.konst {
123                        if konst.as_device_tensor().is_none() {
124                            let device_konst =
125                                konst.as_ref().clone().into_device()?.into_opaque_tensor();
126                            let device_fact = DeviceFact::from_host(in_fact.clone())?;
127
128                            *in_fact = TypedFact::dt_scalar(DatumType::Opaque)
129                                .with_opaque_fact(device_fact);
130
131                            in_fact.konst = Some(Arc::new(device_konst));
132                            mapped_inputs.push(mapping[i]);
133                            continue;
134                        }
135                    }
136                    ensure!(
137                        in_fact.datum_type.is_copy(),
138                        "Only copy DatumType can be sync to Device: {:?}",
139                        in_fact.datum_type
140                    );
141
142                    mapped_inputs.push(
143                        model.wire_node(
144                            format!("{}.to-device-{i_idx}", node.name),
145                            DeviceSync::new(sync_kind),
146                            &[mapping[i]],
147                        )?[0],
148                    );
149                }
150                _ => mapped_inputs.push(mapping[i]),
151            }
152        }
153        Ok(mapped_inputs)
154    }
155
156    fn sync_model_outputs_if_required(
157        &self,
158        src: &TypedModel,
159        node: &TypedNode,
160        target: &mut TypedModel,
161        target_node_outlet_ids: TVec<OutletId>,
162    ) -> TractResult<TVec<OutletId>> {
163        let mut outputs = tvec![];
164        for (o_idx, o) in target_node_outlet_ids.into_iter().enumerate() {
165            // Add DeviceSync op for model output
166            let is_src_output = src.outputs.contains(&OutletId::new(node.id, o_idx));
167            if target.outlet_fact(o)?.as_device_fact().is_some() && is_src_output {
168                let sync_output = target.wire_node(
169                    format!("{}.to-host-{o_idx}-out", node.name),
170                    DeviceSync::new(DeviceSyncKind::ToHost),
171                    &[o],
172                )?[0];
173                outputs.push(sync_output);
174            } else {
175                outputs.push(o)
176            }
177        }
178        Ok(outputs)
179    }
180}
181
182fn can_translate_to_cuda_op(source: &TypedModel, node: &TypedNode) -> TractResult<bool> {
183    let input_facts = source.node_input_facts(node.id)?.iter().map(|f| (*f).clone()).collect_vec();
184    let input_dts = input_facts
185        .iter()
186        .map(|f| f.as_device_fact().map(|f| f.datum_type).unwrap_or(f.datum_type))
187        .collect_vec();
188
189    let in_dts_compatible =
190        input_facts.iter().all(|fact| DeviceTensor::is_supported_dt(fact.datum_type));
191
192    Ok(in_dts_compatible
193        && (node
194            .op_as::<Const>()
195            .is_some_and(|op| DeviceTensor::is_supported_dt(op.val().datum_type()))
196            || node
197                .op_as::<Silu>()
198                .is_some_and(|_| kernels::UnaryOps::is_supported_dt(input_dts[0]))
199            || node.op_as::<ElementWiseOp>().is_some_and(|op| op.0.is::<LeakyRelu>())
200            || node.op_as::<ElementWiseOp>().is_some_and(|op| {
201                kernels::UnaryOps::is_supported_dt(input_dts[0])
202                    && map_element_wise_ops_to_cuda(op).is_some()
203            })
204            || node.op_as::<TypedBinOp>().is_some_and(|op| {
205                map_binary_op_to_cuda(op).is_some_and(|op| op.0.is_supported_dt(input_dts[0]))
206            })
207            || node
208                .op_as::<Comp>()
209                .is_some_and(|op| convert_logic_op_to_cuda(op).0.is_supported_dt(input_dts[0]))
210            || node
211                .op_as::<Const>()
212                .is_some_and(|op| DeviceTensor::is_supported_dt(op.val().datum_type()))
213            || node.op_as::<Cast>().is_some_and(|op| {
214                ops::CudaCast::is_supported_dt(input_dts[0]) && ops::CudaCast::new(op.to).is_some()
215            })
216            || node.op_is::<MultiBroadcastTo>()
217            || node.op_is::<AxisOp>()
218            || node.op_is::<Slice>()
219            || node.op_is::<Delay>()
220            || node.op_is::<PulsePad>()
221            || node.op_is::<TypedConcat>()
222            || node.op_is::<DynKeyValueCache>()
223            || node.op_as::<Reduce>().is_some_and(|op| {
224                ops::CudaReduce::from_tract_core(op)
225                    .is_ok_and(|op| op.reducer.is_supported_dt(input_dts[0]))
226            })
227            || node.op_as::<Softmax>().is_some_and(|op| {
228                kernels::nn::Softmax::is_supported_dt(input_dts[0])
229                    && ops::CudaSoftmax::from_tract_core(op).is_ok()
230            })
231            || node
232                .op_as::<ScaledMaskedSoftmax>()
233                .is_some_and(|_| kernels::nn::ScaledMaskedSoftmax::is_supported_dt(input_dts[0]))
234            || node
235                .op_as::<RmsNorm>()
236                .is_some_and(|_| kernels::nn::RmsNorm::is_supported_dt(input_dts[0]))
237            || node
238                .op_as::<RotateHalf>()
239                .is_some_and(|_| kernels::array::RotateHalf::is_supported_dt(input_dts[0]))
240            || node
241                .op_as::<ApplyRope>()
242                .is_some_and(|_| kernels::nn::ApplyRope::is_supported_dt(input_dts[0]))
243            || node
244                .op_as::<GeluApproximate>()
245                .is_some_and(|_| kernels::nn::GeluApproximate::is_supported_dt(input_dts[0]))
246            || node.op_as::<Sdpa>().is_some()
247            || node.op_as::<PrefixMatMul>().is_some_and(|op| {
248                !op.transpose_c
249                    && op.quantize_output.is_none()
250                    && (can_convert_to_cuda_gemm(&input_facts)
251                        || can_convert_to_cuda_gemm(&[
252                            input_facts[1].clone(),
253                            input_facts[0].clone(),
254                        ]))
255            })
256            || (node.op_is::<Conv>() && input_facts[0].datum_type.is::<f32>())))
257}
258
259fn convert_const(op: &Const) -> TractResult<Const> {
260    let typed_fact: TypedFact = Arc::clone(op.val()).into();
261    let cuda_fact = if let Some(of) = op.opaque_fact() {
262        DeviceFact::from_host(typed_fact.with_opaque_fact(clone_box(of)))?
263    } else {
264        DeviceFact::from_host(typed_fact)?
265    };
266
267    let cuda_const = op.val().clone().into_device()?.into_opaque_tensor().into_arc_tensor();
268    Const::new_with_opaque_fact(cuda_const, Box::new(cuda_fact))
269}
270
271macro_rules! map_unary_ops {
272    ([$(($tract_unary_op:path, $cuda_unary_op:ident)),* $(,)?]) => {
273        |op: &tract_core::ops::element_wise::ElementWiseOp| {
274            $(if let Some(_op) = op.0.downcast_ref::<$tract_unary_op>() {
275                return Some($crate::ops::CudaUnaryOp(kernels::UnaryOps::$cuda_unary_op));
276            })*
277            return None;
278        }
279    };
280}
281
282fn map_element_wise_ops_to_cuda(op: &ElementWiseOp) -> Option<ops::CudaUnaryOp> {
283    map_unary_ops!([
284        (tract_core::ops::math::Abs, Abs),
285        (tract_core::ops::math::Exp, Exp),
286        (tract_core::ops::math::Ln, Ln),
287        (tract_core::ops::nn::Sigmoid, Sigmoid),
288        (tract_core::ops::math::Square, Sqr),
289        (tract_core::ops::math::Sqrt, Sqrt),
290        (tract_core::ops::math::Rsqrt, Rsqrt),
291        (tract_core::ops::math::Recip, Recip),
292        (tract_core::ops::math::Ceil, Ceil),
293        (tract_core::ops::math::Floor, Floor),
294        (tract_core::ops::math::Round, Round),
295        (tract_core::ops::math::RoundHalfToEven, RoundHalfToEven),
296        (tract_core::ops::math::Cos, Cos),
297        (tract_core::ops::math::Acos, Acos),
298        (tract_core::ops::math::Acosh, Acosh),
299        (tract_core::ops::math::Cosh, Cosh),
300        (tract_core::ops::math::Sin, Sin),
301        (tract_core::ops::math::Asin, Asin),
302        (tract_core::ops::math::Asinh, Asinh),
303        (tract_core::ops::math::Sinh, Sinh),
304        (tract_core::ops::math::Tan, Tan),
305        (tract_core::ops::math::Atan, Atan),
306        (tract_core::ops::math::Atanh, Atanh),
307        (tract_core::ops::math::Tanh, Tanh),
308        (tract_core::ops::math::Erf, Erf),
309        (tract_core::ops::math::Neg, Neg),
310    ])(op)
311}
312
313macro_rules! map_bin_ops {
314    ([$(($tract_bin_op:path, $cuda_bin_op:ident)),* $(,)?]) => {
315        |op: &TypedBinOp | {
316            $(if let Some(_op) = op.0.downcast_ref::<$tract_bin_op>() {
317                return Some($crate::ops::CudaBinOp(kernels::BinOps::$cuda_bin_op));
318            })*
319            return None;
320        }
321    };
322}
323
324#[allow(clippy::borrowed_box)]
325fn map_binary_op_to_cuda(op: &TypedBinOp) -> Option<ops::CudaBinOp> {
326    map_bin_ops!([
327        (tract_core::ops::math::Mul, Mul),
328        (tract_core::ops::math::Add, Add),
329        (tract_core::ops::math::Div, Div),
330        (tract_core::ops::math::Sub, Sub),
331        (tract_core::ops::math::Min, Min),
332        (tract_core::ops::math::Max, Max),
333        (tract_core::ops::math::Pow, Pow),
334        (tract_core::ops::logic::And, And),
335        (tract_core::ops::logic::Or, Or),
336    ])(op)
337}
338
339fn convert_logic_op_to_cuda(op: &Comp) -> ops::CudaBinOp {
340    match op {
341        Comp::Eq => ops::CudaBinOp(kernels::BinOps::Equals),
342        Comp::NE => ops::CudaBinOp(kernels::BinOps::NotEquals),
343        Comp::LT => ops::CudaBinOp(kernels::BinOps::Less),
344        Comp::LTE => ops::CudaBinOp(kernels::BinOps::LessEqual),
345        Comp::GT => ops::CudaBinOp(kernels::BinOps::Greater),
346        Comp::GTE => ops::CudaBinOp(kernels::BinOps::GreaterEqual),
347    }
348}
349
350fn can_convert_to_cuda_gemm(facts: &[TypedFact]) -> bool {
351    assert!(facts.len() == 2, "Ggml: Expected 2 inputs for Matmul");
352
353    let regular_types_support =
354        matches!((facts[0].datum_type, facts[1].datum_type), (F32, F32) | (F16, F16) | (F16, F32));
355
356    regular_types_support
357        || (as_quant_fact(&facts[1], &Q4_0).is_some() && matches!(facts[0].datum_type, F16 | F32))
358}
359
360fn convert_matmul_to_cuda(
361    model: &TypedModel,
362    node: &TypedNode,
363    target: &mut TypedModel,
364    inputs: &mut [OutletId],
365    op: &PrefixMatMul,
366) -> TractResult<TVec<OutletId>> {
367    let mut input_facts = model.node_input_facts(node.id)?;
368    // GGML kernel expects weights in second position and activations in first position
369    // This avoid output transposition due to GGML column-major data expectations
370
371    let mut swap_inputs = false;
372    if !can_convert_to_cuda_gemm(&[input_facts[0].clone(), input_facts[1].clone()])
373        && can_convert_to_cuda_gemm(&[input_facts[1].clone(), input_facts[0].clone()])
374    {
375        input_facts.swap(0, 1);
376        inputs.swap(0, 1);
377        swap_inputs = true;
378    }
379
380    let act_fact = input_facts[0];
381    let weight_fact = input_facts[1];
382    let outlets = inputs.split_at_mut(1);
383    let act_outlet = &mut outlets.0[0];
384    let weights_outlet = &mut outlets.1[0];
385
386    let transpose_act = if swap_inputs { !op.transpose_b } else { op.transpose_a };
387    let transpose_weight = if swap_inputs { !op.transpose_a } else { op.transpose_b };
388
389    if transpose_act {
390        let rank = act_fact.rank();
391        let perm_act_op = ops::CudaAxisOp::from_tract_core(AxisOp::Move(rank - 2, rank - 1));
392        let perm_act_name = node.name.clone() + ".perm_activs";
393        *act_outlet = target.wire_node(perm_act_name, perm_act_op, &[*act_outlet])?[0];
394    }
395
396    if act_fact.datum_type == DatumType::F16 && as_quant_fact(weight_fact, &Q4_0).is_some() {
397        let in_cast_op = ops::CudaCast::new(DatumType::F32).unwrap();
398        *act_outlet =
399            target.wire_node(node.name.clone() + ".in_cast", in_cast_op, &[*act_outlet])?[0];
400    } else if act_fact.datum_type == DatumType::F16 && weight_fact.datum_type == DatumType::F32 {
401        let in_cast_op = ops::CudaCast::new(DatumType::F16).unwrap();
402        *weights_outlet =
403            target.wire_node(node.name.clone() + ".in_cast", in_cast_op, &[*weights_outlet])?[0];
404    }
405
406    if !transpose_weight {
407        ensure!(as_quant_fact(weight_fact, &Q4_0).is_none(), "Cannot transpose Q40 tensor");
408
409        let rank = weight_fact.rank();
410        let perm_weights_op = ops::CudaAxisOp::from_tract_core(AxisOp::Move(rank - 2, rank - 1));
411        let perm_weights_name = node.name.clone() + ".perm_weights";
412        *weights_outlet =
413            target.wire_node(perm_weights_name, perm_weights_op, &[*weights_outlet])?[0];
414    }
415
416    if as_quant_fact(weight_fact, &Q4_0).is_some() {
417        let device_fact = target.outlet_fact(*act_outlet)?.to_device_fact()?;
418        let quant_op = ops::CudaGgmlQuantQ81::new(device_fact.shape.clone())?;
419        *act_outlet =
420            target.wire_node(node.name.clone() + ".quant_activs", quant_op, &[*act_outlet])?[0];
421    }
422    let mut matmul_output =
423        target.wire_node(node.name.clone(), *Box::new(ops::CudaGgmlGemm), inputs)?;
424
425    if swap_inputs {
426        let out_fact = target.outlet_fact(matmul_output[0])?;
427        let rank = &out_fact
428            .opaque_fact
429            .clone()
430            .map(|fact| fact.clarify_dt_shape().unwrap().1.len())
431            .unwrap();
432
433        let perm_out_op = ops::CudaAxisOp::from_tract_core(AxisOp::Move(rank - 2, rank - 1));
434        matmul_output =
435            target.wire_node(node.name.clone() + ".perm_out", perm_out_op, &matmul_output)?;
436    }
437
438    let out_fact = target.outlet_fact(matmul_output[0])?;
439    let out_dt = out_fact.as_device_fact().map(|f| f.datum_type).unwrap_or(out_fact.datum_type);
440
441    let expected_dt = model.node_output_facts(node.id)?[0].datum_type;
442    if out_dt != expected_dt {
443        ensure!(
444            ops::CudaCast::is_supported_dt(out_dt),
445            "Matmul output type cannot be casted to expected type"
446        );
447        let cast_op = ops::CudaCast::new(model.node_output_facts(node.id)?[0].datum_type).unwrap();
448        matmul_output =
449            target.wire_node(node.name.clone() + ".out_cast", cast_op, &matmul_output)?
450    }
451    Ok(matmul_output)
452}
453
454fn convert_sdpa_to_cuda_flash_attn(
455    model: &TypedModel,
456    node: &TypedNode,
457    target: &mut TypedModel,
458    inputs: &mut [OutletId],
459    op: &Sdpa,
460) -> TractResult<TVec<OutletId>> {
461    let facts = model.node_input_facts(node.id)?;
462
463    let [qf, kf, vf] = [facts[0], facts[1], facts[2]];
464    ensure!(kf.datum_type() == vf.datum_type(), "K/V dtypes must match");
465
466    let mask_fact = if facts.len() == 4 { Some(facts[3]) } else { None };
467
468    let (q, k, v, m_opt) = match &mut inputs[..] {
469        [q, k, v, m, ..] => (q, k, v, Some(m)),
470        [q, k, v] => (q, k, v, None),
471        _ => bail!("unexpected number of inputs"),
472    };
473
474    fn name(base: &str, suffix: &str) -> String {
475        format!("{base}{suffix}")
476    }
477
478    fn mut_cast(
479        target: &mut TypedModel,
480        node_name: &str,
481        dst: &mut OutletId,
482        have: DatumType,
483        want: DatumType,
484        suffix: &str,
485    ) -> TractResult<()> {
486        if have != want {
487            *dst = target.wire_node(
488                name(node_name, suffix),
489                ops::CudaCast::new(want).unwrap(),
490                &[*dst],
491            )?[0];
492        }
493        Ok(())
494    }
495
496    fn add_head_axis_if_rank3(
497        target: &mut TypedModel,
498        node_name: &str,
499        dst: &mut OutletId,
500        fact: &TypedFact,
501        suffix: &str,
502    ) -> TractResult<bool> {
503        if fact.rank() == 3 {
504            let ax = ops::CudaAxisOp::from_tract_core(AxisOp::Add(1));
505            *dst = target.wire_node(name(node_name, suffix), ax, &[*dst])?[0];
506            Ok(true)
507        } else {
508            ensure!(fact.rank() == 4, "Q/K/V must be rank 3 or 4");
509            Ok(false)
510        }
511    }
512
513    // ----- casts
514    let q_dt = qf.datum_type().unwrap();
515    let kv_dt = kf.datum_type().unwrap();
516    mut_cast(target, &node.name, k, kv_dt, DatumType::F16, ".cast_k")?;
517    mut_cast(target, &node.name, v, kv_dt, DatumType::F16, ".cast_v")?;
518    mut_cast(target, &node.name, q, q_dt, DatumType::F16, ".cast_q")?;
519
520    // ----- rank normalize
521    let mut added_head_axis = false;
522    added_head_axis |= add_head_axis_if_rank3(target, &node.name, q, qf, ".reshape_q")?;
523    added_head_axis |= add_head_axis_if_rank3(target, &node.name, k, kf, ".reshape_k")?;
524    added_head_axis |= add_head_axis_if_rank3(target, &node.name, v, vf, ".reshape_v")?;
525
526    let out_dim = kf.shape[kf.rank() - 1].to_i64()?;
527    ensure!(matches!(out_dim, 64 | 128), "Unsupported head dim (D): {out_dim}");
528    ensure!(kf.shape == vf.shape, "K and V shapes must be identical");
529
530    // ----- mask: cast & reshape
531    if let Some(mf) = mask_fact {
532        let m = m_opt.unwrap();
533        mut_cast(target, &node.name, m, mf.datum_type().unwrap(), DatumType::F16, ".cast_m")?;
534        if mf.rank() != 4 {
535            let ax = ops::CudaAxisOp::from_tract_core(AxisOp::Add(1));
536            *m = target.wire_node(name(&node.name, ".reshape_m"), ax, &[*m])?[0];
537        }
538    }
539
540    // ----- scale & op
541    let scale = op
542        .scale
543        .as_ref()
544        .map(|s| *s.to_scalar::<f32>().unwrap())
545        .unwrap_or(1.0 / (out_dim as f32).sqrt());
546    let sdpa = ops::CudaFlashAttention::new(scale, op.is_causal);
547
548    let mut out = target.wire_node(node.name.clone(), sdpa, inputs)?;
549
550    if added_head_axis {
551        out = target.wire_node(
552            name(&node.name, ".reshape_out"),
553            ops::CudaAxisOp::from_tract_core(AxisOp::Rm(1)),
554            &out,
555        )?;
556    }
557
558    if q_dt != DatumType::F16 {
559        out = target.wire_node(
560            name(&node.name, ".cast_out"),
561            ops::CudaCast::new(q_dt).unwrap(),
562            &out,
563        )?;
564    }
565
566    Ok(out)
567}
568
569impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for CudaTransform {
570    fn translate_node(
571        &self,
572        source: &TypedModel,
573        node: &TypedNode,
574        target: &mut TypedModel,
575        mapping: &HashMap<OutletId, OutletId>,
576    ) -> TractResult<TVec<OutletId>> {
577        let translatable = can_translate_to_cuda_op(source, node)?;
578
579        if translatable {
580            let mut device_inputs =
581                self.sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?;
582
583            let outlet_ids: TVec<OutletId> = if let Some(op) = node.op_as::<PrefixMatMul>() {
584                convert_matmul_to_cuda(source, node, target, &mut device_inputs, op)?
585            } else if let Some(op) = node.op_as::<Sdpa>() {
586                convert_sdpa_to_cuda_flash_attn(source, node, target, &mut device_inputs, op)?
587            } else if let Some(conv) = node.op_as::<Conv>() {
588                wire_cuda_conv(source, node, target, &device_inputs, conv)?
589            } else {
590                let op: Box<dyn TypedOp> = if let Some(op) = node.op_as::<Const>() {
591                    Box::new(convert_const(op)?)
592                } else if let Some(op) = node.op_as::<ElementWiseOp>() {
593                    if let Some(leaky) = op.0.downcast_ref::<LeakyRelu>() {
594                        Box::new(CudaLeakyRelu { alpha: leaky.alpha })
595                    } else {
596                        Box::new(map_element_wise_ops_to_cuda(op).unwrap())
597                    }
598                } else if let Some(op) = node.op_as::<TypedBinOp>() {
599                    Box::new(map_binary_op_to_cuda(op).unwrap())
600                } else if let Some(op) = node.op_as::<Comp>() {
601                    Box::new(convert_logic_op_to_cuda(op))
602                } else if let Some(_op) = node.op_as::<Silu>() {
603                    Box::new(ops::CudaUnaryOp(kernels::UnaryOps::Silu))
604                } else if let Some(op) = node.op_as::<MultiBroadcastTo>() {
605                    Box::new(ops::CudaMultiBroadcastTo::new(op.shape.clone()))
606                } else if let Some(op) = node.op_as::<Cast>() {
607                    Box::new(ops::CudaCast::new(op.to).unwrap())
608                } else if let Some(op) = node.op_as::<AxisOp>() {
609                    let in_fact = source.node_input_facts(node.id)?[0];
610                    Box::new(ops::CudaAxisOp::from_tract_core_with_fact(op.clone(), in_fact))
611                } else if let Some(op) = node.op_as::<Slice>() {
612                    Box::new(ops::CudaSlice::from_tract_core(op.clone()))
613                } else if let Some(op) = node.op_as::<TypedConcat>() {
614                    Box::new(ops::CudaConcat::from_tract_core(op))
615                } else if let Some(op) = node.op_as::<DynKeyValueCache>() {
616                    Box::new(ops::CudaDynKVCache::from_tract_transformers(op))
617                } else if let Some(op) = node.op_as::<Reduce>() {
618                    Box::new(ops::CudaReduce::from_tract_core(op)?)
619                } else if let Some(op) = node.op_as::<Softmax>() {
620                    Box::new(ops::CudaSoftmax::from_tract_core(op)?)
621                } else if let Some(op) = node.op_as::<ScaledMaskedSoftmax>() {
622                    Box::new(ops::CudaScaledMaskedSoftmax { scale: op.scale.clone() })
623                } else if let Some(_op) = node.op_as::<RotateHalf>() {
624                    Box::new(ops::CudaRotateHalf)
625                } else if let Some(_op) = node.op_as::<ApplyRope>() {
626                    Box::new(ops::CudaApplyRope)
627                } else if let Some(op) = node.op_as::<RmsNorm>() {
628                    Box::new(ops::CudaRmsNorm::new(op.axis, op.eps.clone()))
629                } else if let Some(op) = node.op_as::<GeluApproximate>() {
630                    Box::new(ops::CudaGeluApproximate { fast_impl: op.fast_impl })
631                } else if let Some(op) = node.op_as::<Delay>() {
632                    Box::new(CudaDelay::new(op.clone()))
633                } else if let Some(op) = node.op_as::<PulsePad>() {
634                    Box::new(CudaPulsePad::new(op)?)
635                } else {
636                    bail!("Failed to translate a supported CUDA Op")
637                };
638                target.wire_node(node.name.clone(), op, &device_inputs)?
639            };
640            self.sync_model_outputs_if_required(source, node, target, outlet_ids)
641        } else {
642            let cpu_inputs =
643                self.sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToHost)?;
644            target.wire_node(&node.name, node.op.clone(), &cpu_inputs)
645        }
646    }
647}
648
649#[cfg(test)]
650mod test {
651    use super::*;
652
653    #[test]
654    fn test_prefix_matmul_transform_f32_f16() -> TractResult<()> {
655        let mut model = TypedModel::default();
656        let (b, m, k, n) = (1, 16, 128, 32);
657
658        let a_fact = TypedFact::dt_shape(DatumType::F32, &[b, m, k]);
659        let b_fact = TypedFact::dt_shape(DatumType::F16, &[b, k, n]);
660
661        let source_a = model.add_source("a", a_fact)?;
662        let source_b = model.add_source("b", b_fact)?;
663
664        let op = PrefixMatMul {
665            transpose_a: false,
666            transpose_b: false,
667            transpose_c: false,
668            quantize_output: None,
669            operating_dt: Some(DatumType::F32),
670        };
671
672        let matmul_out = model.wire_node("matmul", op, &[source_a, source_b])?;
673        model.set_output_outlets(&matmul_out)?;
674
675        let tensor_a = Tensor::zero::<f32>(&[b, m, k])?;
676        let tensor_b = Tensor::zero::<f16>(&[b, k, n])?;
677        let inputs = tvec!(tensor_a.into(), tensor_b.into());
678
679        let transform = CudaTransform::default();
680        transform.transform(&mut model)?;
681
682        let cuda_runnable = model.into_runnable()?;
683        let _ = cuda_runnable.run(inputs)?;
684        Ok(())
685    }
686}