1use tract_core::internal::tract_smallvec::ToSmallVec;
2use tract_core::internal::*;
3use tract_core::model::translator::Translate;
4use tract_core::ops::array::{MultiBroadcastTo, Slice, TypedConcat};
5use tract_core::ops::binary::TypedBinOp;
6use tract_core::ops::cast::Cast;
7use tract_core::ops::einsum::prefix_matmul::{PrefixMatMul, rewrite_einsum_to_prefix_matmul};
8use tract_core::ops::element_wise::ElementWiseOp;
9use tract_core::ops::konst::Const;
10use tract_core::ops::logic::Comp;
11use tract_core::ops::nn::{Reduce, Softmax};
12use tract_core::tract_data::itertools::Itertools;
13use tract_core::tract_linalg::block_quant::{BlockQuant, BlockQuantFact, BlockQuantValue, Q4_0};
14use tract_core::transform::ModelTransform;
15use tract_gpu::fact::{DeviceFact, DeviceTypedFactExt};
16use tract_gpu::rewrite_rules::rewire_syncs::rewire_syncs;
17use tract_gpu::sync::{DeviceSync, DeviceSyncKind};
18use tract_gpu::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice};
19use tract_gpu::utils::{as_q40_fact, as_q40_tensor};
20use tract_transformers::ops::apply_rope::{ApplyRope, RotateHalf};
21use tract_transformers::ops::dyn_kv_cache::DynKeyValueCache;
22use tract_transformers::ops::gelu_approximate::GeluApproximate;
23use tract_transformers::ops::rms_norm::RmsNorm;
24use tract_transformers::ops::scaled_masked_softmax::ScaledMaskedSoftmax;
25use tract_transformers::ops::silu::Silu;
26
27use crate::context::cuda_context;
28use crate::kernels::matmul::{GemmKernel, GgmlGemm};
29use crate::{Q40_ROW_PADDING, kernels, ops, rewrite_rules};
30
31#[derive(Debug, Default)]
32pub struct CudaTransform;
33
34impl ModelTransform for CudaTransform {
35 fn name(&self) -> StaticName {
36 "cuda-transform".into()
37 }
38
39 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
40 self.transform_up_to_phase(model, usize::MAX)
41 }
42}
43
44impl CudaTransform {
45 pub fn transform_up_to_phase(
46 &self,
47 model: &mut TypedModel,
48 stop_at_phase: usize,
49 ) -> TractResult<()> {
50 cuda_context();
52
53 rewrite_einsum_to_prefix_matmul(model)?;
54 if stop_at_phase == 0 {
55 return Ok(());
56 }
57
58 Rewriter::default()
59 .with_rule_for("untranspose_matmul_output", rewrite_rules::untranspose_matmul_output)
60 .with_rule_for("add_broadcast_pre_matmul", rewrite_rules::add_broadcast_pre_matmul)
61 .rewrite(&(), model)?;
62
63 if stop_at_phase == 1 {
64 return Ok(());
65 }
66
67 *model = self.translate_model(model)?;
68
69 if stop_at_phase == 2 {
70 return Ok(());
71 }
72
73 Rewriter::default()
74 .with_rule_for("fuse_move_axis", rewrite_rules::fuse_move_axis)
75 .rewrite(&(), model)?;
76 Rewriter::default()
77 .with_rule_for("fuse_axis_op", rewrite_rules::fuse_axis_op)
78 .rewrite(&(), model)?;
79
80 rewire_syncs(model)?;
81 Ok(())
82 }
83
84 fn sync_inputs_if_required(
85 &self,
86 model: &mut TypedModel,
87 node: &TypedNode,
88 mapping: &HashMap<OutletId, OutletId>,
89 sync_kind: DeviceSyncKind,
90 ) -> TractResult<TVec<OutletId>> {
91 let mut mapped_inputs = tvec![];
92 for (i_idx, i) in node.inputs.iter().enumerate() {
93 let in_fact = model.outlet_fact_mut(mapping[i])?;
94 match sync_kind {
95 DeviceSyncKind::ToHost if in_fact.as_device_fact().is_some() => {
96 mapped_inputs.push(
97 model.wire_node(
98 format!("{}.to-cpu-{i_idx}", node.name),
99 DeviceSync::new(sync_kind),
100 &[mapping[i]],
101 )?[0],
102 );
103 }
104 DeviceSyncKind::ToDevice if in_fact.as_device_fact().is_none() => {
105 if let Some(ref konst) = in_fact.konst {
106 if konst.as_device_tensor().is_none() {
107 let device_konst =
108 konst.as_ref().clone().into_device()?.into_opaque_tensor();
109 let device_fact = DeviceFact::from_host(in_fact.clone())?;
110
111 *in_fact = TypedFact::dt_scalar(DatumType::Opaque)
112 .with_opaque_fact(device_fact);
113
114 in_fact.konst = Some(Arc::new(device_konst));
115 mapped_inputs.push(mapping[i]);
116 continue;
117 }
118 }
119 ensure!(
120 in_fact.datum_type.is_copy(),
121 "Only copy DatumType can be sync to Device: {:?}",
122 in_fact.datum_type
123 );
124
125 mapped_inputs.push(
126 model.wire_node(
127 format!("{}.to-device-{i_idx}", node.name),
128 DeviceSync::new(sync_kind),
129 &[mapping[i]],
130 )?[0],
131 );
132 }
133 _ => mapped_inputs.push(mapping[i]),
134 }
135 }
136 Ok(mapped_inputs)
137 }
138
139 fn sync_model_outputs_if_required(
140 &self,
141 src: &TypedModel,
142 node: &TypedNode,
143 target: &mut TypedModel,
144 target_node_outlet_ids: TVec<OutletId>,
145 ) -> TractResult<TVec<OutletId>> {
146 let mut outputs = tvec![];
147 for (o_idx, o) in target_node_outlet_ids.into_iter().enumerate() {
148 let is_src_output = src.outputs.contains(&OutletId::new(node.id, o_idx));
150 if target.outlet_fact(o)?.as_device_fact().is_some() && is_src_output {
151 let sync_output = target.wire_node(
152 format!("{}.to-host-{o_idx}-out", node.name),
153 DeviceSync::new(DeviceSyncKind::ToHost),
154 &[o],
155 )?[0];
156 outputs.push(sync_output);
157 } else {
158 outputs.push(o)
159 }
160 }
161 Ok(outputs)
162 }
163}
164
165fn can_translate_to_cuda_op(source: &TypedModel, node: &TypedNode) -> TractResult<bool> {
166 let input_facts = source.node_input_facts(node.id)?.iter().map(|f| (*f).clone()).collect_vec();
167 let input_dts = input_facts
168 .iter()
169 .map(|f| f.as_device_fact().map(|f| f.datum_type).unwrap_or(f.datum_type))
170 .collect_vec();
171
172 let in_dts_compatible =
173 input_facts.iter().all(|fact| DeviceTensor::is_supported_dt(fact.datum_type));
174
175 Ok(in_dts_compatible
176 && (node
177 .op_as::<Const>()
178 .is_some_and(|op| DeviceTensor::is_supported_dt(op.val().datum_type()))
179 || node
180 .op_as::<Silu>()
181 .is_some_and(|_| kernels::UnaryOps::is_supported_dt(input_dts[0]))
182 || node.op_as::<ElementWiseOp>().is_some_and(|op| {
183 kernels::UnaryOps::is_supported_dt(input_dts[0])
184 && map_element_wise_ops_to_cuda(op).is_some()
185 })
186 || node.op_as::<TypedBinOp>().is_some_and(|op| {
187 map_binary_op_to_cuda(op).is_some_and(|op| op.0.is_supported_dt(input_dts[0]))
188 })
189 || node
190 .op_as::<Comp>()
191 .is_some_and(|op| convert_logic_op_to_cuda(op).0.is_supported_dt(input_dts[0]))
192 || node
193 .op_as::<Const>()
194 .is_some_and(|op| DeviceTensor::is_supported_dt(op.val().datum_type()))
195 || node.op_as::<Cast>().is_some_and(|op| {
196 ops::CudaCast::is_supported_dt(input_dts[0]) && ops::CudaCast::new(op.to).is_some()
197 })
198 || node.op_is::<MultiBroadcastTo>()
199 || node.op_is::<AxisOp>()
200 || node.op_is::<Slice>()
201 || node.op_is::<TypedConcat>()
202 || node.op_is::<DynKeyValueCache>()
203 || node.op_as::<Reduce>().is_some_and(|op| {
204 kernels::nn::Reducer::is_supported_dt(input_dts[0])
205 && ops::CudaReduce::from_tract_core(op).is_ok()
206 })
207 || node.op_as::<Softmax>().is_some_and(|op| {
208 kernels::nn::Softmax::is_supported_dt(input_dts[0])
209 && ops::CudaSoftmax::from_tract_core(op).is_ok()
210 })
211 || node
212 .op_as::<ScaledMaskedSoftmax>()
213 .is_some_and(|_| kernels::nn::ScaledMaskedSoftmax::is_supported_dt(input_dts[0]))
214 || node
215 .op_as::<RmsNorm>()
216 .is_some_and(|_| kernels::nn::RmsNorm::is_supported_dt(input_dts[0]))
217 || node
218 .op_as::<RotateHalf>()
219 .is_some_and(|_| kernels::array::RotateHalf::is_supported_dt(input_dts[0]))
220 || node
221 .op_as::<ApplyRope>()
222 .is_some_and(|_| kernels::nn::ApplyRope::is_supported_dt(input_dts[0]))
223 || node
224 .op_as::<GeluApproximate>()
225 .is_some_and(|_| kernels::nn::GeluApproximate::is_supported_dt(input_dts[0])))
226 || node.op_as::<PrefixMatMul>().is_some_and(|op| {
227 !op.transpose_c
228 && op.quantize_output.is_none()
229 && (GgmlGemm.is_supported_dts(&input_facts)
230 || GgmlGemm.is_supported_dts(&[input_facts[1].clone(), input_facts[0].clone()]))
231 }))
232}
233
234pub fn pad_q40(q40_bqv: &BlockQuantValue) -> TractResult<BlockQuantValue> {
235 let shape = q40_bqv.fact.shape();
236 ensure!(shape.len() >= 2);
237
238 let k = *shape.last().unwrap();
239 ensure!(k % 32 == 0);
240
241 let to_pad = k.next_multiple_of(Q40_ROW_PADDING) - k;
242 if to_pad == 0 {
243 return Ok(q40_bqv.clone()); }
245
246 let outer_rows: usize = shape[..shape.len() - 1].iter().product();
247 let row_bytes = k * Q4_0.block_bytes() / Q4_0.block_len();
248
249 let pad_quant = Q4_0.quant_f32(&vec![0f32; to_pad])?;
250 let pad_bytes = pad_quant.len();
251
252 let mut new_data = Vec::with_capacity(outer_rows * (row_bytes + pad_bytes));
253 let old_bytes = q40_bqv.value.as_bytes();
254
255 for row in 0..outer_rows {
256 let start = row * row_bytes;
257 new_data.extend_from_slice(&old_bytes[start..start + row_bytes]);
258 new_data.extend_from_slice(&pad_quant);
259 }
260
261 let mut new_shape = shape.to_smallvec();
262 *new_shape.last_mut().unwrap() += to_pad;
263
264 Ok(BlockQuantValue {
265 fact: BlockQuantFact::new(q40_bqv.fact.format.clone(), new_shape),
266 value: Arc::new(Blob::from_bytes(&new_data)?),
267 })
268}
269
270fn convert_const(op: &Const) -> TractResult<Const> {
271 let typed_fact: TypedFact = Arc::clone(op.val()).into();
272 let cuda_const = op.val().clone();
273
274 let to_device_opaque = |fact: TypedFact, tensor: Arc<Tensor>| -> TractResult<_> {
275 Ok((
276 DeviceFact::from_host(fact)?,
277 tensor.into_device()?.into_opaque_tensor().into_arc_tensor(),
278 ))
279 };
280
281 let (cuda_fact, cuda_tensor) = match op.opaque_fact() {
282 Some(_) => {
283 ensure!(as_q40_fact(&typed_fact).is_some(), "Only support Q40 block quantization");
284
285 let tensor = cuda_const.into_tensor();
286 let bqv = as_q40_tensor(&tensor).unwrap();
287
288 let padded_bqv = pad_q40(bqv)?;
289 let padded_fact = typed_fact.with_opaque_fact(padded_bqv.fact.clone());
290 let padded_tensor = tensor0(Opaque(Arc::new(padded_bqv)))
291 .broadcast_into_rank(op.val().rank())?
292 .into_arc_tensor();
293
294 to_device_opaque(padded_fact, padded_tensor)?
295 }
296 None => to_device_opaque(typed_fact, cuda_const)?,
297 };
298
299 Const::new_with_opaque_fact(cuda_tensor, Box::new(cuda_fact))
300}
301
302macro_rules! map_unary_ops {
303 ([$(($tract_unary_op:path, $cuda_unary_op:ident)),* $(,)?]) => {
304 |op: &tract_core::ops::element_wise::ElementWiseOp| {
305 $(if let Some(_op) = op.0.downcast_ref::<$tract_unary_op>() {
306 return Some($crate::ops::CudaUnaryOp(kernels::UnaryOps::$cuda_unary_op));
307 })*
308 return None;
309 }
310 };
311}
312
313fn map_element_wise_ops_to_cuda(op: &ElementWiseOp) -> Option<ops::CudaUnaryOp> {
314 map_unary_ops!([
315 (tract_core::ops::math::Abs, Abs),
316 (tract_core::ops::math::Exp, Exp),
317 (tract_core::ops::math::Ln, Ln),
318 (tract_core::ops::nn::Sigmoid, Sigmoid),
319 (tract_core::ops::math::Square, Sqr),
320 (tract_core::ops::math::Sqrt, Sqrt),
321 (tract_core::ops::math::Rsqrt, Rsqrt),
322 (tract_core::ops::math::Recip, Recip),
323 (tract_core::ops::math::Ceil, Ceil),
324 (tract_core::ops::math::Floor, Floor),
325 (tract_core::ops::math::Round, Round),
326 (tract_core::ops::math::RoundHalfToEven, RoundHalfToEven),
327 (tract_core::ops::math::Cos, Cos),
328 (tract_core::ops::math::Acos, Acos),
329 (tract_core::ops::math::Acosh, Acosh),
330 (tract_core::ops::math::Cosh, Cosh),
331 (tract_core::ops::math::Sin, Sin),
332 (tract_core::ops::math::Asin, Asin),
333 (tract_core::ops::math::Asinh, Asinh),
334 (tract_core::ops::math::Sinh, Sinh),
335 (tract_core::ops::math::Tan, Tan),
336 (tract_core::ops::math::Atan, Atan),
337 (tract_core::ops::math::Atanh, Atanh),
338 (tract_core::ops::math::Tanh, Tanh),
339 (tract_core::ops::math::Erf, Erf),
340 (tract_core::ops::math::Neg, Neg),
341 ])(op)
342}
343
344macro_rules! map_bin_ops {
345 ([$(($tract_bin_op:path, $cuda_bin_op:ident)),* $(,)?]) => {
346 |op: &TypedBinOp | {
347 $(if let Some(_op) = op.0.downcast_ref::<$tract_bin_op>() {
348 return Some($crate::ops::CudaBinOp(kernels::BinOps::$cuda_bin_op));
349 })*
350 return None;
351 }
352 };
353}
354
355#[allow(clippy::borrowed_box)]
356fn map_binary_op_to_cuda(op: &TypedBinOp) -> Option<ops::CudaBinOp> {
357 map_bin_ops!([
358 (tract_core::ops::math::Mul, Mul),
359 (tract_core::ops::math::Add, Add),
360 (tract_core::ops::math::Div, Div),
361 (tract_core::ops::math::Sub, Sub),
362 (tract_core::ops::math::Pow, Pow),
363 (tract_core::ops::logic::And, And),
364 (tract_core::ops::logic::Or, Or),
365 ])(op)
366}
367
368fn convert_logic_op_to_cuda(op: &Comp) -> ops::CudaBinOp {
369 match op {
370 Comp::Eq => ops::CudaBinOp(kernels::BinOps::Equals),
371 Comp::NE => ops::CudaBinOp(kernels::BinOps::NotEquals),
372 Comp::LT => ops::CudaBinOp(kernels::BinOps::Less),
373 Comp::LTE => ops::CudaBinOp(kernels::BinOps::LessEqual),
374 Comp::GT => ops::CudaBinOp(kernels::BinOps::Greater),
375 Comp::GTE => ops::CudaBinOp(kernels::BinOps::GreaterEqual),
376 }
377}
378
379fn convert_matmul_to_cuda(
380 model: &TypedModel,
381 node: &TypedNode,
382 target: &mut TypedModel,
383 inputs: &mut [OutletId],
384 op: &PrefixMatMul,
385) -> TractResult<TVec<OutletId>> {
386 let mut input_facts = model.node_input_facts(node.id)?;
387
388 let mut swap_inputs = false;
389 if !GgmlGemm.is_supported_dts(&[input_facts[0].clone(), input_facts[1].clone()])
390 && GgmlGemm.is_supported_dts(&[input_facts[1].clone(), input_facts[0].clone()])
391 {
392 input_facts.swap(0, 1);
393 inputs.swap(0, 1);
394 swap_inputs = true;
395 }
396
397 let a_pos = swap_inputs as usize;
398 let b_pos = 1 - swap_inputs as usize;
399 if op.transpose_a {
400 ensure!(as_q40_fact(input_facts[a_pos]).is_none(), "Cannot transpose Q40 tensor");
401
402 let rank = input_facts[a_pos].rank();
403 let perm_a_op = ops::CudaAxisOp::from_tract_core(AxisOp::Move(rank - 2, rank - 1));
404 let perm_a_name = node.name.clone() + ".perm_a";
405 inputs[a_pos] = target.wire_node(perm_a_name, perm_a_op, &[inputs[a_pos]])?[0];
406 }
407
408 if input_facts[0].datum_type == DatumType::F16 && as_q40_fact(input_facts[1]).is_some() {
409 let in_cast_op = ops::CudaCast::new(DatumType::F32).unwrap();
410 inputs[0] = target.wire_node(node.name.clone() + ".in_cast", in_cast_op, &[inputs[0]])?[0];
411 }
412
413 if !op.transpose_b {
414 ensure!(as_q40_fact(input_facts[b_pos]).is_none(), "Cannot transpose Q40 tensor");
415
416 let rank = input_facts[b_pos].rank();
417 let perm_b_op = ops::CudaAxisOp::from_tract_core(AxisOp::Move(rank - 2, rank - 1));
418 let perm_b_name = node.name.clone() + ".perm_b";
419 inputs[b_pos] = target.wire_node(perm_b_name, perm_b_op, &[inputs[b_pos]])?[0];
420 }
421
422 let op = ops::CudaGemm::<GgmlGemm>::new(false, true);
423 let mut matmul_output = target.wire_node(node.name.clone(), op, inputs)?;
424
425 if swap_inputs {
426 let out_fact = target.outlet_fact(matmul_output[0])?;
427 let rank = &out_fact
428 .opaque_fact
429 .clone()
430 .map(|fact| fact.clarify_dt_shape().unwrap().1.len())
431 .unwrap();
432
433 let perm_out_op = ops::CudaAxisOp::from_tract_core(AxisOp::Move(rank - 2, rank - 1));
434 matmul_output =
435 target.wire_node(node.name.clone() + ".perm_out", perm_out_op, &matmul_output)?;
436 }
437
438 let out_fact = target.outlet_fact(matmul_output[0])?;
439 let out_dt = out_fact.to_device_fact().map(|f| f.datum_type).unwrap_or(out_fact.datum_type);
440
441 let expected_dt = model.node_output_facts(node.id)?[0].datum_type;
442
443 if out_dt != expected_dt {
444 ensure!(
445 ops::CudaCast::is_supported_dt(out_dt),
446 "Matmul output type cannot be casted to expected type"
447 );
448 let cast_op = ops::CudaCast::new(model.node_output_facts(node.id)?[0].datum_type).unwrap();
449 matmul_output =
450 target.wire_node(node.name.clone() + ".out_cast", cast_op, &matmul_output)?
451 }
452 Ok(matmul_output)
453}
454
455impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for CudaTransform {
456 fn translate_node(
457 &self,
458 source: &TypedModel,
459 node: &TypedNode,
460 target: &mut TypedModel,
461 mapping: &HashMap<OutletId, OutletId>,
462 ) -> TractResult<TVec<OutletId>> {
463 let translatable = can_translate_to_cuda_op(source, node)?;
464
465 if translatable {
466 let mut device_inputs =
467 self.sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?;
468
469 let outlet_ids: TVec<OutletId> = if let Some(op) = node.op_as::<PrefixMatMul>() {
470 convert_matmul_to_cuda(source, node, target, &mut device_inputs, op)?
471 } else {
472 let op: Box<dyn TypedOp> = if let Some(op) = node.op_as::<Const>() {
473 Box::new(convert_const(op)?)
474 } else if let Some(op) = node.op_as::<ElementWiseOp>() {
475 Box::new(map_element_wise_ops_to_cuda(op).unwrap())
476 } else if let Some(op) = node.op_as::<TypedBinOp>() {
477 Box::new(map_binary_op_to_cuda(op).unwrap())
478 } else if let Some(op) = node.op_as::<Comp>() {
479 Box::new(convert_logic_op_to_cuda(op))
480 } else if let Some(_op) = node.op_as::<Silu>() {
481 Box::new(ops::CudaUnaryOp(kernels::UnaryOps::Silu))
482 } else if let Some(op) = node.op_as::<MultiBroadcastTo>() {
483 Box::new(ops::CudaMultiBroadcastTo::new(op.shape.clone()))
484 } else if let Some(op) = node.op_as::<Cast>() {
485 Box::new(ops::CudaCast::new(op.to).unwrap())
486 } else if let Some(op) = node.op_as::<AxisOp>() {
487 let in_fact = source.node_input_facts(node.id)?[0];
488 Box::new(ops::CudaAxisOp::from_tract_core_with_fact(op.clone(), in_fact))
489 } else if let Some(op) = node.op_as::<Slice>() {
490 Box::new(ops::CudaSlice::from_tract_core(op.clone()))
491 } else if let Some(op) = node.op_as::<TypedConcat>() {
492 Box::new(ops::CudaConcat::from_tract_core(op))
493 } else if let Some(op) = node.op_as::<DynKeyValueCache>() {
494 Box::new(ops::CudaDynKVCache::from_tract_transformers(op))
495 } else if let Some(op) = node.op_as::<Reduce>() {
496 Box::new(ops::CudaReduce::from_tract_core(op)?)
497 } else if let Some(op) = node.op_as::<Softmax>() {
498 Box::new(ops::CudaSoftmax::from_tract_core(op)?)
499 } else if let Some(op) = node.op_as::<ScaledMaskedSoftmax>() {
500 Box::new(ops::CudaScaledMaskedSoftmax { scale: op.scale.clone() })
501 } else if let Some(_op) = node.op_as::<RotateHalf>() {
502 Box::new(ops::CudaRotateHalf)
503 } else if let Some(_op) = node.op_as::<ApplyRope>() {
504 Box::new(ops::CudaApplyRope)
505 } else if let Some(op) = node.op_as::<RmsNorm>() {
506 Box::new(ops::CudaRmsNorm::new(op.axis, op.eps.clone()))
507 } else if let Some(op) = node.op_as::<GeluApproximate>() {
508 Box::new(ops::CudaGeluApproximate { fast_impl: op.fast_impl })
509 } else {
510 bail!("Failed to translate a supported CUDA Op")
511 };
512 target.wire_node(node.name.clone(), op, &device_inputs)?
513 };
514 self.sync_model_outputs_if_required(source, node, target, outlet_ids)
515 } else {
516 let cpu_inputs =
517 self.sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToHost)?;
518 target.wire_node(&node.name, node.op.clone(), &cpu_inputs)
519 }
520 }
521}