use crate::infer::GraphExt;
use crate::{Graph, NodeId, Op, Shape};
pub fn lower_depthwise_conv1d_causal(
g: &mut Graph,
input: NodeId,
weight: NodeId,
left_pad: NodeId,
kernel_size: usize,
out_shape: Shape,
) -> NodeId {
let in_shape = g.node(input).shape.clone();
let batch = in_shape.dim(0).unwrap_static();
let seq = in_shape.dim(1).unwrap_static();
let channels = in_shape.dim(2).unwrap_static();
let dtype = in_shape.dtype();
let k = kernel_size;
let padded_len = (k - 1) + seq;
let padded = g.concat(
vec![left_pad, input],
1,
Shape::new(&[batch, padded_len, channels], dtype),
);
let bcw = g.transpose_(padded, vec![0, 2, 1]);
let nchw = g.reshape_(
bcw,
vec![batch as i64, channels as i64, 1, padded_len as i64],
);
let conv = g.add_node(
Op::Conv {
kernel_size: vec![1, k],
stride: vec![1, 1],
padding: vec![0, 0],
dilation: vec![1, 1],
groups: channels,
},
vec![nchw, weight],
Shape::new(&[batch, channels, 1, seq], dtype),
);
let bcs = g.reshape_(conv, vec![batch as i64, channels as i64, seq as i64]);
let out = g.transpose_(bcs, vec![0, 2, 1]);
debug_assert_eq!(g.node(out).shape, out_shape);
out
}