use std::sync::Arc;
use morok_dtype::DType;
use morok_dtype::DeviceSpec;
use morok_dtype::ext::HasDType;
use crate::IntoUOp;
use crate::op::Op;
use crate::types::{ConstValue, ConstValueHash};
use crate::uop::core::UOp;
use crate::uop::hash_consing::next_unique_id;
impl UOp {
pub fn const_(dtype: DType, value: ConstValue) -> Arc<Self> {
let normalized = value.cast(&dtype).unwrap_or(value);
Self::new(Op::Const(ConstValueHash(normalized)), dtype)
}
pub fn native_const<T: HasDType + IntoUOp>(value: T) -> Arc<Self> {
value.into_uop(T::DTYPE)
}
pub fn index_const(value: i64) -> Arc<Self> {
Self::const_(DType::Index, ConstValue::Int(value))
}
pub fn const_like<T: crate::IntoUOp>(self: &Arc<Self>, value: T) -> Arc<Self> {
value.into_uop(self.dtype())
}
pub fn vconst(values: Vec<ConstValue>, scalar_dtype: DType) -> Arc<Self> {
let vec_dtype = scalar_dtype.vec(values.len());
Self::new(Op::VConst { values }, vec_dtype)
}
pub fn buffer_id(num: Option<usize>) -> Arc<Self> {
let id = num.unwrap_or_else(next_unique_id);
Self::new(Op::Unique(id), DType::Void)
}
pub fn new_buffer(device: DeviceSpec, size: usize, dtype: DType) -> Arc<Self> {
let unique = Self::buffer_id(None);
let dev = Self::device(device);
Self::new(Op::Buffer { unique, device: dev, size }, dtype)
}
pub fn param(slot: usize, size: usize, dtype: DType, device: Option<Arc<Self>>) -> Arc<Self> {
Self::new(Op::Param { slot, size, device }, dtype)
}
pub fn view(self: &Arc<Self>, size: usize, offset: usize) -> Arc<Self> {
let dtype = self.dtype.clone();
Self::new(Op::BufferView { buffer: self.clone(), size, offset }, dtype)
}
pub fn device(device: DeviceSpec) -> Arc<Self> {
Self::new(Op::Device(device), DType::Void)
}
pub fn noop() -> Arc<Self> {
Self::new(Op::Noop, DType::Void)
}
pub fn cast(self: &Arc<Self>, dtype: DType) -> Arc<Self> {
let src_vcount = self.dtype().vcount();
let dst_vcount = dtype.vcount();
let dtype = if dst_vcount == 1 && src_vcount > 1 { dtype.vec(src_vcount) } else { dtype };
if self.dtype() == dtype {
return self.clone();
}
Self::new(Op::Cast { src: self.clone(), dtype: dtype.clone() }, dtype)
}
pub fn bitcast(self: &Arc<Self>, dtype: DType) -> Arc<Self> {
Self::new(Op::BitCast { src: self.clone(), dtype: dtype.clone() }, dtype)
}
}