#[crate::module(tile_rust_crate = true)]
pub mod creation {
use crate::core::*;
#[crate::entry()]
pub fn full<T: ElementType, const S: [i32; 1]>(value: T, tensor: &mut Tensor<T, S>) {
let value_tile: Tile<T, S> = value.broadcast(tensor.shape());
tensor.store(value_tile);
}
#[crate::entry()]
pub fn arange<T: ElementType, const S: [i32; 1]>(tensor: &mut Tensor<T, S>) {
let pid: (i32, i32, i32) = get_tile_block_id();
let offset = pid.0 * S[0];
let offset_tile = broadcast_scalar(offset, tensor.shape());
let range = iota::<i32, S>(tensor.shape());
let offset_range = offset_tile + range;
let offset_range: Tile<T, S> = convert_tile(offset_range);
tensor.store(offset_range);
}
#[crate::entry()]
pub fn linspace<const S: [i32; 1]>(tensor: &mut Tensor<f32, S>, start: f32, step: f32) {
let pid: (i32, i32, i32) = get_tile_block_id();
let offset: i32 = pid.0 * S[0];
let indices: Tile<i32, S> = iota(tensor.shape());
let offset_indices: Tile<i32, S> = indices + broadcast_scalar(offset, tensor.shape());
let float_indices: Tile<f32, S> = convert_tile(offset_indices);
let start_tile: Tile<f32, S> = broadcast_scalar(start, tensor.shape());
let step_tile: Tile<f32, S> = broadcast_scalar(step, tensor.shape());
tensor.store(start_tile + float_indices * step_tile);
}
#[crate::entry()]
pub fn eye<const BR: i32, const BC: i32>(tensor: &mut Tensor<f32, { [BR, BC] }>) {
let pid: (i32, i32, i32) = get_tile_block_id();
let row_offset: i32 = pid.0 * BR;
let col_offset: i32 = pid.1 * BC;
let row_iota: Tile<i32, { [BR] }> = iota(const_shape![BR]);
let row_base: Tile<i32, { [BR] }> = broadcast_scalar(row_offset, const_shape![BR]);
let rows: Tile<i32, { [BR] }> = row_iota + row_base;
let rows_2d: Tile<i32, { [BR, 1] }> = rows.reshape(const_shape![BR, 1]);
let rows_bc: Tile<i32, { [BR, BC] }> = rows_2d.broadcast(const_shape![BR, BC]);
let col_iota: Tile<i32, { [BC] }> = iota(const_shape![BC]);
let col_base: Tile<i32, { [BC] }> = broadcast_scalar(col_offset, const_shape![BC]);
let cols: Tile<i32, { [BC] }> = col_iota + col_base;
let cols_2d: Tile<i32, { [1, BC] }> = cols.reshape(const_shape![1, BC]);
let cols_bc: Tile<i32, { [BR, BC] }> = cols_2d.broadcast(const_shape![BR, BC]);
let is_diag: Tile<bool, { [BR, BC] }> = eq_tile(rows_bc, cols_bc);
let ones: Tile<f32, { [BR, BC] }> = constant(1.0f32, const_shape![BR, BC]);
let zeros: Tile<f32, { [BR, BC] }> = constant(0.0f32, const_shape![BR, BC]);
tensor.store(select(is_diag, ones, zeros));
}
}