use crate::{Graph, NodeId, Op, Shape};
impl Graph {
pub fn input(&mut self, name: impl Into<String>, shape: Shape) -> NodeId {
let name: String = name.into();
self.push(Op::Input { name: name.clone() }, vec![], shape, Some(name))
}
pub fn param(&mut self, name: impl Into<String>, shape: Shape) -> NodeId {
let name: String = name.into();
self.push(Op::Param { name: name.clone() }, vec![], shape, Some(name))
}
pub fn add_node(&mut self, op: Op, inputs: Vec<NodeId>, shape: Shape) -> NodeId {
self.push(op, inputs, shape, None)
}
pub fn custom_op(
&mut self,
name: impl Into<String>,
attrs: Vec<u8>,
inputs: Vec<NodeId>,
) -> NodeId {
let name: String = name.into();
let ext = crate::lookup_op(&name)
.unwrap_or_else(|| panic!("custom_op: '{name}' is not registered in the op registry"));
assert_eq!(
ext.num_inputs(),
inputs.len(),
"custom_op '{name}': registered op expects {} inputs, got {}",
ext.num_inputs(),
inputs.len(),
);
let in_shapes: Vec<&Shape> = inputs.iter().map(|id| self.shape(*id)).collect();
let out_shape = ext.infer_shape(&in_shapes, &attrs);
let num_inputs = ext.num_inputs() as u32;
self.push(
Op::Custom {
name,
num_inputs,
attrs,
},
inputs,
out_shape,
None,
)
}
pub fn custom_op_packed(
&mut self,
name: impl Into<String>,
attrs: Vec<u8>,
inputs: Vec<NodeId>,
out_shape: Shape,
) -> NodeId {
let name: String = name.into();
let ext = crate::lookup_op(&name).unwrap_or_else(|| {
panic!("custom_op_packed: '{name}' is not registered in the op registry")
});
assert_eq!(
ext.num_inputs(),
inputs.len(),
"custom_op_packed '{name}': registered op expects {} inputs, got {}",
ext.num_inputs(),
inputs.len(),
);
let num_inputs = ext.num_inputs() as u32;
self.push(
Op::Custom {
name,
num_inputs,
attrs,
},
inputs,
out_shape,
None,
)
}
pub fn fft(&mut self, x: NodeId, inverse: bool) -> NodeId {
self.fft_norm(x, inverse, crate::fft::FftNorm::Backward)
}
pub fn fft_norm(&mut self, x: NodeId, inverse: bool, norm: crate::fft::FftNorm) -> NodeId {
let s = self.shape(x).clone();
crate::fft::fft_meta(&s);
self.push(Op::Fft { inverse, norm }, vec![x], s, None)
}
pub fn fft_axis(&mut self, x: NodeId, axis: usize, inverse: bool) -> NodeId {
use crate::infer::GraphExt as _;
let rank = self.shape(x).rank();
assert!(
axis < rank,
"fft_axis: axis {axis} out of range for rank-{rank} tensor"
);
let last = rank - 1;
if axis == last {
return self.fft(x, inverse);
}
let mut perm: Vec<usize> = (0..rank).collect();
perm.swap(axis, last);
let x_t = self.transpose_(x, perm.clone());
let y_t = self.fft(x_t, inverse);
self.transpose_(y_t, perm)
}
pub fn fftn(&mut self, x: NodeId, axes: &[usize], inverse: bool) -> NodeId {
let rank = self.shape(x).rank();
let axes = crate::fft::normalize_fftn_axes(rank, axes);
if axes.is_empty() {
return x;
}
if axes.len() > 1 && !self.shape(x).dtype().is_complex() {
panic!(
"fftn: multi-axis FFT on {:?} requires DType::C64; \
the F32/F64 2N real-block layout supports only one complex axis — \
call fft_axis for a single transform",
self.shape(x).dtype()
);
}
let mut y = x;
for axis in axes {
y = self.fft_axis(y, axis, inverse);
}
y
}
pub fn ifftn(&mut self, x: NodeId, axes: &[usize]) -> NodeId {
self.fftn(x, axes, true)
}
}