use crate::op::{Activation, BinaryOp};
use crate::{Graph, NodeId, Op, Shape};
impl Graph {
pub fn binary(&mut self, op: BinaryOp, lhs: NodeId, rhs: NodeId, out_shape: Shape) -> NodeId {
self.push(Op::Binary(op), vec![lhs, rhs], out_shape, None)
}
pub fn activation(&mut self, act: Activation, input: NodeId, shape: Shape) -> NodeId {
self.push(Op::Activation(act), vec![input], shape, None)
}
pub fn quantize(&mut self, x: NodeId, scale: f32, zero_point: i32) -> NodeId {
let shape = self.shape(x).clone().with_dtype(crate::DType::I8);
self.push(
Op::Quantize {
axis: None,
scales: vec![scale],
zero_points: vec![zero_point],
},
vec![x],
shape,
None,
)
}
pub fn quantize_per_channel(
&mut self,
x: NodeId,
axis: usize,
scales: Vec<f32>,
zero_points: Vec<i32>,
) -> NodeId {
debug_assert_eq!(scales.len(), zero_points.len());
let shape = self.shape(x).clone().with_dtype(crate::DType::I8);
debug_assert_eq!(
shape.dim(axis),
crate::shape::Dim::Static(scales.len()),
"quantize_per_channel: scales.len() must match input.dim(axis)"
);
self.push(
Op::Quantize {
axis: Some(axis),
scales,
zero_points,
},
vec![x],
shape,
None,
)
}
pub fn dequantize(&mut self, x: NodeId, scale: f32, zero_point: i32) -> NodeId {
let shape = self.shape(x).clone().with_dtype(crate::DType::F32);
self.push(
Op::Dequantize {
axis: None,
scales: vec![scale],
zero_points: vec![zero_point],
},
vec![x],
shape,
None,
)
}
pub fn dequantize_per_channel(
&mut self,
x: NodeId,
axis: usize,
scales: Vec<f32>,
zero_points: Vec<i32>,
) -> NodeId {
debug_assert_eq!(scales.len(), zero_points.len());
let shape = self.shape(x).clone().with_dtype(crate::DType::F32);
debug_assert_eq!(
shape.dim(axis),
crate::shape::Dim::Static(scales.len()),
"dequantize_per_channel: scales.len() must match input.dim(axis)"
);
self.push(
Op::Dequantize {
axis: Some(axis),
scales,
zero_points,
},
vec![x],
shape,
None,
)
}
}