use crate::op::ReduceOp;
use crate::{Graph, NodeId, Op, Shape};
impl Graph {
pub fn reduce(
&mut self,
input: NodeId,
op: ReduceOp,
axes: Vec<usize>,
keep_dim: bool,
shape: Shape,
) -> NodeId {
self.push(Op::Reduce { op, axes, keep_dim }, vec![input], shape, None)
}
pub fn softmax(&mut self, input: NodeId, axis: i32, shape: Shape) -> NodeId {
self.push(Op::Softmax { axis }, vec![input], shape, None)
}
pub fn cumsum(&mut self, input: NodeId, axis: i32, exclusive: bool, shape: Shape) -> NodeId {
self.push(Op::Cumsum { axis, exclusive }, vec![input], shape, None)
}
pub fn sample(
&mut self,
logits: NodeId,
top_k: usize,
top_p: f32,
temperature: f32,
seed: u64,
output_shape: Shape,
) -> NodeId {
self.push(
Op::Sample {
top_k,
top_p,
temperature,
seed,
},
vec![logits],
output_shape,
None,
)
}
}