use candle_core::{Result, Tensor, Device};
use num_traits::Float;
pub fn trunc_normal_init<F: Float>(std: F, a: F, b: F) -> impl Fn() -> F {
move || {
F::zero()
}
}
pub fn count_parameters(tensor: &Tensor) -> usize {
tensor.dims().iter().product()
}
pub fn create_causal_mask(seq_len: usize, device: &Device) -> Result<Tensor> {
todo!("Implement causal mask")
}