use crate::op::Activation;
use crate::quant::QuantScheme;
use crate::{Graph, NodeId, Op, Shape};
impl Graph {
pub fn matmul(&mut self, lhs: NodeId, rhs: NodeId, out_shape: Shape) -> NodeId {
self.push(Op::MatMul, vec![lhs, rhs], out_shape, None)
}
pub fn dense_solve(&mut self, a: NodeId, b: NodeId, out_shape: Shape) -> NodeId {
self.push(Op::DenseSolve, vec![a, b], out_shape, None)
}
pub fn batched_dense_solve(&mut self, a: NodeId, b: NodeId, out_shape: Shape) -> NodeId {
self.push(Op::BatchedDenseSolve, vec![a, b], out_shape, None)
}
pub fn lora_matmul(
&mut self,
x: NodeId,
w: NodeId,
a: NodeId,
b: NodeId,
scale: f32,
shape: Shape,
) -> NodeId {
self.push(Op::LoraMatMul { scale }, vec![x, w, a, b], shape, None)
}
pub fn dequant_matmul(
&mut self,
x: NodeId,
w_q: NodeId,
scale: NodeId,
zp: NodeId,
scheme: QuantScheme,
shape: Shape,
) -> NodeId {
self.push(
Op::DequantMatMul { scheme },
vec![x, w_q, scale, zp],
shape,
None,
)
}
pub fn dequant_matmul_packed(
&mut self,
x: NodeId,
packed_w: NodeId,
scheme: QuantScheme,
shape: Shape,
) -> NodeId {
debug_assert!(
scheme.is_gguf(),
"dequant_matmul_packed requires a GGUF QuantScheme"
);
self.push(Op::DequantMatMul { scheme }, vec![x, packed_w], shape, None)
}
pub fn dequant_matmul_nvfp4(
&mut self,
x: NodeId,
w_q: NodeId,
block_scales: NodeId,
global_scale: NodeId,
shape: Shape,
) -> NodeId {
self.dequant_matmul(
x,
w_q,
block_scales,
global_scale,
QuantScheme::Nvfp4Block,
shape,
)
}
pub fn fused_matmul_bias_act(
&mut self,
input: NodeId,
weight: NodeId,
bias: NodeId,
activation: Option<Activation>,
shape: Shape,
) -> NodeId {
self.push(
Op::FusedMatMulBiasAct { activation },
vec![input, weight, bias],
shape,
None,
)
}
pub fn q_matmul(
&mut self,
x: NodeId,
w: NodeId,
bias: NodeId,
x_zp: i32,
w_zp: i32,
out_zp: i32,
mult: f32,
out_shape: Shape,
) -> NodeId {
debug_assert_eq!(
out_shape.dtype(),
crate::DType::I8,
"q_matmul output dtype must be I8"
);
self.push(
Op::QMatMul {
x_zp,
w_zp,
out_zp,
mult,
},
vec![x, w, bias],
out_shape,
None,
)
}
#[allow(clippy::too_many_arguments)]
pub fn q_conv2d(
&mut self,
x: NodeId,
w: NodeId,
bias: NodeId,
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
dilation: Vec<usize>,
groups: usize,
x_zp: i32,
w_zp: i32,
out_zp: i32,
mult: f32,
out_shape: Shape,
) -> NodeId {
debug_assert_eq!(
out_shape.dtype(),
crate::DType::I8,
"q_conv2d output dtype must be I8"
);
self.push(
Op::QConv2d {
kernel_size,
stride,
padding,
dilation,
groups,
x_zp,
w_zp,
out_zp,
mult,
},
vec![x, w, bias],
out_shape,
None,
)
}
}