use super::{Device, DeviceId, DeviceInfo, DeviceProgramId, Event, Kernel, MemoryPool, PoolBufferId, PoolId};
use crate::{
error::{BackendError, ErrorStatus},
kernel::{Op, UOp},
shape::Dim,
slab::Slab,
};
use nanoserde::DeJson;
use std::{
fs,
fs::File,
io::{BufRead, BufReader, BufWriter, Write},
os::unix::io::AsRawFd,
path::PathBuf,
process::{Child, ChildStdin, ChildStdout, Command},
ptr,
sync::atomic::{AtomicU8, Ordering},
};
const TENSTORRENT_IOCTL_MAGIC: u8 = 0xFA;
const fn ioctl_code(nr: u32) -> u64 {
const BASE: u64 = TENSTORRENT_IOCTL_MAGIC as u64;
(BASE << 8) | (nr as u64)
}
const TENSTORRENT_IOCTL_GET_DEVICE_INFO: u64 = ioctl_code(0);
const TENSTORRENT_IOCTL_ALLOCATE_DMA_BUF: u64 = ioctl_code(3);
const PAGE_SIZE: u32 = 4096;
#[derive(Debug, Clone, Copy, Default)]
#[repr(C)]
struct TTGetDeviceInfo {
output_size_bytes: u32, out_output_size_bytes: u32, vendor_id: u16,
device_id: u16,
subsystem_vendor_id: u16,
subsystem_id: u16,
bus_dev_fn: u16,
max_dma_buf_size_log2: u16,
pci_domain: u16,
reserved: u16,
}
#[derive(Debug, Clone, Copy, Default)]
#[repr(C)]
struct TTAllocateDmaBufIn {
requested_size: u32,
buf_index: u8,
flags: u8,
reserved0: [u8; 2],
reserved1: [u64; 2],
}
#[derive(Debug, Clone, Copy, Default)]
#[repr(C)]
struct TTAllocateDmaBufOut {
physical_address: u64,
mapping_offset: u64,
size: u32,
reserved0: u32,
noc_address: u64,
reserved1: u64,
}
#[derive(Debug, Clone, Copy, Default)]
#[repr(C)]
struct TTAllocateDmaBuf {
inn: TTAllocateDmaBufIn,
out: TTAllocateDmaBufOut,
}
unsafe fn ioctl_ptr<T>(fd: i32, request: u64, arg: *mut T) -> i32 {
unsafe { libc::ioctl(fd, request as libc::c_ulong, arg as *mut libc::c_void) }
}
const DRAM_SIZE_TABLE: &[(u16, &str, u64)] = &[
(0x0036, "p100", 28u64 * 1024 * 1024 * 1024),
(0x0040, "p150a", 32u64 * 1024 * 1024 * 1024),
(0x0041, "p150b", 32u64 * 1024 * 1024 * 1024),
(0x0042, "p150c", 32u64 * 1024 * 1024 * 1024),
(0x0043, "p100a", 28u64 * 1024 * 1024 * 1024),
(0x0044, "p300b", 64u64 * 1024 * 1024 * 1024),
(0x0045, "p300a", 64u64 * 1024 * 1024 * 1024),
(0x0046, "p300c", 64u64 * 1024 * 1024 * 1024),
];
fn dram_size_for_subsystem_id(subsystem_id: u16) -> Result<Dim, BackendError> {
for &(id, _name, size) in DRAM_SIZE_TABLE {
if id == subsystem_id {
return Ok(size as Dim);
}
}
Err(BackendError {
status: ErrorStatus::Initialization,
context: format!("unknown Tenstorrent board (subsystem_id=0x{subsystem_id:04x}, card_type=?), please report this to zyx with `lspci -nn | grep 1e52` output").into(),
})
}
#[derive(Default, Debug, DeJson)]
#[nserde(default)]
pub struct TTConfig {
device_ids: Option<Vec<i32>>,
}
struct TTBuffer {
file: File,
mmap_ptr: *mut u8,
size: u32,
noc_address: u64,
buf_index: u8,
}
impl std::fmt::Debug for TTBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TTBuffer")
.field("file", &self.file)
.field("mmap_ptr", &self.mmap_ptr)
.field("size", &self.size)
.field("noc_address", &self.noc_address)
.field("buf_index", &self.buf_index)
.finish()
}
}
impl Drop for TTBuffer {
fn drop(&mut self) {
if !self.mmap_ptr.is_null() {
unsafe {
libc::munmap(self.mmap_ptr as *mut libc::c_void, self.size as usize);
}
}
}
}
unsafe impl Send for TTBuffer {}
unsafe impl Sync for TTBuffer {}
#[derive(Debug)]
pub struct TTMemoryPool {
device_file: Option<File>,
#[allow(unused)]
total_bytes: Dim,
free_bytes: Dim,
next_buf_index: AtomicU8,
buffers: Slab<PoolBufferId, TTBuffer>,
}
#[derive(Debug, Clone)]
pub struct TTEvent;
pub(super) fn initialize_device(
config: &TTConfig,
memory_pools: &mut Slab<PoolId, MemoryPool>,
devices: &mut Slab<DeviceId, Device>,
debug_dev: bool,
) -> Result<(), BackendError> {
if let Some(device_ids) = &config.device_ids
&& device_ids.is_empty()
{
if debug_dev {
println!("[tenstorrent] configured out");
}
return Ok(());
}
let device_file = File::options()
.read(true)
.write(true)
.open("/dev/tenstorrent/0")
.map_err(|e| BackendError {
status: ErrorStatus::Initialization,
context: format!("open /dev/tenstorrent/0: {e}").into(),
})?;
let fd = device_file.as_raw_fd();
let out_size = size_of::<TTGetDeviceInfo>() as u32 - 4;
let mut info = TTGetDeviceInfo { output_size_bytes: out_size, ..Default::default() };
unsafe {
let ret = ioctl_ptr(fd, TENSTORRENT_IOCTL_GET_DEVICE_INFO, &mut info);
if ret != 0 {
return Err(BackendError {
status: ErrorStatus::Initialization,
context: format!("TENSTORRENT_IOCTL_GET_DEVICE_INFO: {ret}").into(),
});
}
}
let total_bytes = dram_size_for_subsystem_id(info.subsystem_id)?;
if debug_dev {
let card_name = DRAM_SIZE_TABLE
.iter()
.find(|&&(id, _, _)| id == info.subsystem_id)
.map(|&(_, name, _)| name)
.unwrap_or("?");
println!(
"[tenstorrent] vendor=0x{:04x} device=0x{:04x} subsys=0x{:04x} card={card_name} (subven=0x{:04x})",
info.vendor_id, info.device_id, info.subsystem_id, info.subsystem_vendor_id
);
println!("[tenstorrent] total_dram={} MB", total_bytes / (1024 * 1024));
println!("[tenstorrent] max_dma_buf_size_log2={}", info.max_dma_buf_size_log2);
}
let pool_id = memory_pools.len();
let pool = MemoryPool::TT(TTMemoryPool {
device_file: Some(device_file),
total_bytes,
free_bytes: total_bytes,
next_buf_index: AtomicU8::new(0),
buffers: Slab::new(),
});
memory_pools.push(pool);
let config_base = std::env::var_os("XDG_CONFIG_HOME")
.and_then(|p| {
let p = PathBuf::from(p);
if p.is_absolute() { Some(p) } else { None }
})
.or_else(|| std::env::home_dir().map(|h| h.join(".config")))
.unwrap_or_else(|| PathBuf::from("/tmp"));
let cache_dir = config_base.join("zyx/cache/tt");
let runtime_path = config_base.join("zyx/zyx-tt-runtime");
if !runtime_path.exists() {
return Err(BackendError {
status: ErrorStatus::Initialization,
context: format!(
"runtime not found at {}. Rebuild with TT_METAL_ROOT set.",
runtime_path.display()
)
.into(),
});
}
let kernel_dir = PathBuf::from(env!("ZYX_TT_KERNEL_DIR"));
let _device_id = devices.len();
devices.push(Device::TT(TTDevice {
device_info: DeviceInfo {
compute: 200_000_000_000_000, max_global_work_dims: vec![Dim::from(u32::MAX); 3],
max_local_threads: 1024,
max_local_work_dims: vec![1, 1024, 1],
preferred_vector_size: 16,
local_mem_size: 1_500_000, max_register_bytes: 128,
tensor_cores: true,
warp_size: 1, supported_dtypes: u32::MAX, has_native_exp2: false,
},
memory_pool_id: pool_id,
runtime: None,
kernel_dir,
cache_dir,
runtime_path,
programs: Slab::new(),
}));
Ok(())
}
impl TTMemoryPool {
pub fn deinitialize(&mut self) {
}
pub fn free_bytes(&self) -> Dim {
self.free_bytes
}
pub fn allocate(&mut self, bytes: Dim) -> Result<(PoolBufferId, Event), BackendError> {
let bytes32: u32 = u32::try_from(bytes).map_err(|_| BackendError {
status: ErrorStatus::MemoryAllocation,
context: "allocation size exceeds 4 GiB".into(),
})?;
if self.device_file.is_none() {
return Err(BackendError { status: ErrorStatus::MemoryAllocation, context: "device not opened".into() });
}
let page_aligned = bytes32.next_multiple_of(PAGE_SIZE);
if self.free_bytes < page_aligned as Dim {
return Err(BackendError { status: ErrorStatus::MemoryAllocation, context: "OOM on tenstorrent device".into() });
}
let buf_index = self.next_buf_index.fetch_add(1, Ordering::Relaxed);
let buf_file = File::options()
.read(true)
.write(true)
.open("/dev/tenstorrent/0")
.map_err(|e| BackendError {
status: ErrorStatus::MemoryAllocation,
context: format!("open /dev/tenstorrent/0 for buffer {buf_index}: {e}").into(),
})?;
let buf_fd = buf_file.as_raw_fd();
let mut alloc = TTAllocateDmaBuf {
inn: TTAllocateDmaBufIn {
requested_size: page_aligned,
buf_index,
flags: 1, reserved0: [0; 2],
reserved1: [0; 2],
},
out: TTAllocateDmaBufOut::default(),
};
unsafe {
let ret = ioctl_ptr(buf_fd, TENSTORRENT_IOCTL_ALLOCATE_DMA_BUF, &mut alloc);
if ret != 0 {
return Err(BackendError {
status: ErrorStatus::MemoryAllocation,
context: format!("TENSTORRENT_IOCTL_ALLOCATE_DMA_BUF: {ret}").into(),
});
}
}
let actual_size = alloc.out.size;
let mmap_ptr = unsafe {
let ptr = libc::mmap(
ptr::null_mut(),
actual_size as usize,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
buf_fd,
alloc.out.mapping_offset as i64,
);
if ptr == libc::MAP_FAILED {
return Err(BackendError {
status: ErrorStatus::MemoryAllocation,
context: format!(
"mmap DMA buffer (size={actual_size}, offset=0x{:x})",
alloc.out.mapping_offset
)
.into(),
});
}
ptr as *mut u8
};
self.free_bytes -= actual_size as Dim;
let buf = TTBuffer { file: buf_file, mmap_ptr, size: actual_size, noc_address: alloc.out.noc_address, buf_index };
let id = self.buffers.push(buf);
Ok((id, Event::TT(TTEvent)))
}
pub fn deallocate(&mut self, buffer_id: PoolBufferId, event_wait_list: Vec<Event>) {
let _ = event_wait_list;
if self.buffers.contains_key(buffer_id) {
let buf = unsafe { self.buffers.remove_and_return(buffer_id) };
self.free_bytes += buf.size as Dim;
}
}
pub fn host_to_pool(&mut self, src: &[u8], dst: PoolBufferId, event_wait_list: Vec<Event>) -> Result<Event, BackendError> {
let _ = event_wait_list;
let buf = self
.buffers
.get_mut(dst)
.ok_or_else(|| BackendError { status: ErrorStatus::MemoryCopyH2P, context: "invalid buffer id".into() })?;
let len = src.len().min(buf.size as usize);
unsafe {
ptr::copy_nonoverlapping(src.as_ptr(), buf.mmap_ptr, len);
}
Ok(Event::TT(TTEvent))
}
pub fn pool_to_host(&mut self, src: PoolBufferId, dst: &mut [u8], event_wait_list: Vec<Event>) -> Result<(), BackendError> {
let _ = event_wait_list;
let buf = self
.buffers
.get_mut(src)
.ok_or_else(|| BackendError { status: ErrorStatus::MemoryCopyP2H, context: "invalid buffer id".into() })?;
let len = dst.len().min(buf.size as usize);
unsafe {
ptr::copy_nonoverlapping(buf.mmap_ptr, dst.as_mut_ptr(), len);
}
Ok(())
}
pub fn sync_events(&mut self, events: Vec<Event>) -> Result<(), BackendError> {
let _ = self;
let _ = events;
Ok(())
}
pub fn release_events(&mut self, events: Vec<Event>) {
let _ = self;
let _ = events;
}
pub fn noc_address(&self, buffer_id: PoolBufferId) -> Result<u64, BackendError> {
if self.buffers.contains_key(buffer_id) {
Ok(self.buffers[buffer_id].noc_address)
} else {
Err(BackendError { status: ErrorStatus::MemoryAllocation, context: "invalid buffer id".into() })
}
}
pub fn buffer_size(&self, buffer_id: PoolBufferId) -> Result<u64, BackendError> {
if self.buffers.contains_key(buffer_id) {
Ok(self.buffers[buffer_id].size as u64)
} else {
Err(BackendError { status: ErrorStatus::MemoryAllocation, context: "invalid buffer id".into() })
}
}
}
#[derive(Debug)]
struct RuntimeProcess {
stdin: BufWriter<ChildStdin>,
stdout: BufReader<ChildStdout>,
child: Child,
timeout_ms: u64,
}
impl RuntimeProcess {
fn new(runtime_path: &str, kernel_dir: &str, cache_dir: &str) -> Result<Self, BackendError> {
let mut child = Command::new(runtime_path)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.spawn()
.map_err(|e| BackendError {
status: ErrorStatus::Initialization,
context: format!("spawn tt-runtime {runtime_path}: {e}").into(),
})?;
let stdin = child
.stdin
.take()
.ok_or_else(|| BackendError { status: ErrorStatus::Initialization, context: "tt-runtime: no stdin".into() })?;
let stdout = child
.stdout
.take()
.ok_or_else(|| BackendError { status: ErrorStatus::Initialization, context: "tt-runtime: no stdout".into() })?;
let mut rt = RuntimeProcess {
stdin: BufWriter::new(stdin),
stdout: BufReader::new(stdout),
child,
timeout_ms: 30000, };
let init_json = format!(r#"{{"cmd":"init","kernel_dir":"{kernel_dir}","cache_dir":"{cache_dir}"}}"#);
rt.send(&init_json)?;
let resp = rt.recv_with_timeout(rt.timeout_ms)?;
if resp.contains("\"error\"") {
let msg = extract_json_str(&resp, "msg").unwrap_or_else(|| "unknown".into());
return Err(BackendError {
status: ErrorStatus::Initialization,
context: format!("tt-runtime init error: {msg}").into(),
});
}
Ok(rt)
}
fn send(&mut self, json: &str) -> Result<(), BackendError> {
self.stdin
.write_all(json.as_bytes())
.map_err(|e| BackendError { status: ErrorStatus::KernelLaunch, context: format!("tt-runtime write: {e}").into() })?;
self.stdin.write_all(b"\n").map_err(|e| BackendError {
status: ErrorStatus::KernelLaunch,
context: format!("tt-runtime write nl: {e}").into(),
})?;
self.stdin
.flush()
.map_err(|e| BackendError { status: ErrorStatus::KernelLaunch, context: format!("tt-runtime flush: {e}").into() })?;
Ok(())
}
fn poll_read(&mut self, timeout_ms: u64) -> Result<bool, BackendError> {
match self.child.try_wait() {
Ok(Some(status)) => {
return Err(BackendError {
status: ErrorStatus::KernelLaunch,
context: format!("tt-runtime exited unexpectedly (status {status})").into(),
});
}
Err(e) => {
return Err(BackendError {
status: ErrorStatus::KernelLaunch,
context: format!("tt-runtime wait error: {e}").into(),
});
}
Ok(None) => {} }
let fd = std::os::unix::io::AsRawFd::as_raw_fd(self.stdout.get_mut());
let mut pollfd = libc::pollfd { fd, events: libc::POLLIN, revents: 0 };
let timeout_ms = i32::try_from(timeout_ms).unwrap_or(i32::MAX);
let ret = unsafe { libc::poll(&mut pollfd, 1, timeout_ms) };
match ret {
-1 => {
let err = std::io::Error::last_os_error();
return Err(BackendError { status: ErrorStatus::KernelLaunch, context: format!("poll error: {err}").into() });
}
0 => Ok(false), _ => Ok(pollfd.revents & libc::POLLIN != 0), }
}
fn recv_with_timeout(&mut self, timeout_ms: u64) -> Result<String, BackendError> {
let mut attempts = 0;
let max_attempts = 3;
let poll_timeout = timeout_ms / max_attempts;
while attempts < max_attempts {
if self.poll_read(poll_timeout)? {
let mut line = String::new();
match self.stdout.read_line(&mut line) {
Ok(0) => {
return Err(BackendError {
status: ErrorStatus::KernelLaunch,
context: "tt-runtime closed stdout".into(),
});
}
Ok(_) => {
return Ok(line.trim().to_string());
}
Err(e) => {
if attempts == max_attempts - 1 {
return Err(BackendError {
status: ErrorStatus::KernelLaunch,
context: format!("tt-runtime read error: {e}").into(),
});
}
attempts += 1;
continue;
}
}
}
match self.child.try_wait() {
Ok(Some(status)) => {
return Err(BackendError {
status: ErrorStatus::KernelLaunch,
context: format!("tt-runtime exited unexpectedly during read (status {status})").into(),
});
}
Err(e) => {
return Err(BackendError {
status: ErrorStatus::KernelLaunch,
context: format!("tt-runtime wait error during read: {e}").into(),
});
}
Ok(None) => {
attempts += 1;
}
}
}
Err(BackendError {
status: ErrorStatus::KernelLaunch,
context: format!("tt-runtime read timeout after {}ms", timeout_ms).into(),
})
}
fn run(&mut self, hash: &str, n_tiles: u32, src_noc: u64, dst_noc: u64) -> Result<(), BackendError> {
let cmd = format!(r#"{{"cmd":"run","hash":"{hash}","n_tiles":{n_tiles},"src_noc":{src_noc},"dst_noc":{dst_noc}}}"#);
self.send(&cmd)?;
let resp = self.recv_with_timeout(self.timeout_ms)?;
if resp.contains("\"error\"") {
let msg = extract_json_str(&resp, "msg").unwrap_or_else(|| "unknown".into());
return Err(BackendError {
status: ErrorStatus::KernelLaunch,
context: format!("tt-runtime run error: {msg}").into(),
});
}
Ok(())
}
fn exit(&mut self) -> Result<(), BackendError> {
self.send(r#"{"cmd":"exit"}"#)?;
let resp = self.recv_with_timeout(self.timeout_ms)?;
if resp.contains("\"error\"") {
let msg = extract_json_str(&resp, "msg").unwrap_or_else(|| "unknown".into());
return Err(BackendError {
status: ErrorStatus::KernelLaunch,
context: format!("tt-runtime exit error: {msg}").into(),
});
}
self.child.wait().ok();
Ok(())
}
fn set_timeout(&mut self, timeout_ms: u64) {
self.timeout_ms = timeout_ms;
}
}
fn extract_json_str(json: &str, key: &str) -> Option<String> {
let k = json.find(&format!("\"{key}\""))?;
let after_colon = &json[k + key.len() + 4..]; let start = after_colon.find('"')? + 1;
let end = after_colon[start..].find('"')?;
Some(after_colon[start..start + end].to_string())
}
#[derive(Debug)]
struct TTProgram {
hash: String,
}
#[derive(Debug)]
pub struct TTDevice {
device_info: DeviceInfo,
memory_pool_id: PoolId,
runtime: Option<RuntimeProcess>,
kernel_dir: PathBuf,
cache_dir: PathBuf,
runtime_path: PathBuf,
programs: Slab<DeviceProgramId, TTProgram>,
}
impl TTDevice {
pub fn deinitialize(&mut self) {
if let Some(mut rt) = self.runtime.take() {
rt.set_timeout(10000);
let _ = rt.exit();
}
}
pub const fn info(&self) -> &DeviceInfo {
&self.device_info
}
pub const fn memory_pool_id(&self) -> PoolId {
self.memory_pool_id
}
pub const fn free_compute(&self) -> u128 {
self.device_info.compute
}
pub fn compile(&mut self, kernel: &Kernel, debug_asm: bool) -> Result<DeviceProgramId, BackendError> {
let hash = format!("{:016x}", kernel.get_hash());
if self.runtime.is_none() {
let kernel_dir = self.kernel_dir.to_string_lossy().to_string();
let cache_dir = self.cache_dir.to_string_lossy().to_string();
let runtime_path = self.runtime_path.to_string_lossy().to_string();
match RuntimeProcess::new(&runtime_path, &kernel_dir, &cache_dir) {
Ok(rt) => self.runtime = Some(rt),
Err(e) => {
if debug_asm {
eprintln!("[tenstorrent] runtime: {e}");
}
return Err(e);
}
}
}
let compute_path = self.cache_dir.join(format!("{hash}.cpp"));
if !compute_path.exists() {
if debug_asm {
eprintln!("[tenstorrent] generating {hash}.cpp");
}
let source = generate_compute_kernel(kernel)?;
fs::create_dir_all(&self.cache_dir).map_err(|e| BackendError {
status: ErrorStatus::KernelCompilation,
context: format!("create cache dir: {e}").into(),
})?;
fs::write(&compute_path, &source).map_err(|e| BackendError {
status: ErrorStatus::KernelCompilation,
context: format!("write {hash}.cpp: {e}").into(),
})?;
} else if debug_asm {
eprintln!("[tenstorrent] using cached {hash}.cpp");
}
let prog_id = self.programs.push(TTProgram { hash });
Ok(prog_id)
}
pub fn release(&mut self, program_id: DeviceProgramId) {
if self.programs.contains_key(program_id) {
unsafe { self.programs.remove_and_return(program_id) };
}
}
pub fn launch(
&mut self,
program_id: DeviceProgramId,
memory_pool: &mut TTMemoryPool,
args: &[PoolBufferId],
event_wait_list: Vec<Event>,
) -> Result<Event, BackendError> {
let _ = event_wait_list;
let prog = if self.programs.contains_key(program_id) {
&self.programs[program_id]
} else {
return Err(BackendError { status: ErrorStatus::KernelLaunch, context: "invalid program id".into() });
};
let rt = self
.runtime
.as_mut()
.ok_or_else(|| BackendError { status: ErrorStatus::KernelLaunch, context: "runtime not initialized".into() })?;
let n_inputs = args.len() / 2;
if n_inputs == 0 || args.len() < 2 {
return Err(BackendError {
status: ErrorStatus::KernelLaunch,
context: format!("expected at least 2 buffers, got {}", args.len()).into(),
});
}
let src_buf = args[0];
let dst_buf = args[n_inputs];
let src_noc = memory_pool
.noc_address(src_buf)
.map_err(|e| BackendError { status: ErrorStatus::KernelLaunch, context: format!("src noc address: {e}").into() })?;
let dst_noc = memory_pool
.noc_address(dst_buf)
.map_err(|e| BackendError { status: ErrorStatus::KernelLaunch, context: format!("dst noc address: {e}").into() })?;
let src_bytes = memory_pool
.buffer_size(src_buf)
.map_err(|e| BackendError { status: ErrorStatus::KernelLaunch, context: format!("src buffer size: {e}").into() })?;
let tile_bytes: u64 = 2048; let n_tiles = ((src_bytes + tile_bytes - 1) / tile_bytes) as u32;
if n_tiles == 0 {
return Err(BackendError { status: ErrorStatus::KernelLaunch, context: "empty buffer".into() });
}
let kernel_timeout_ms = 60000;
rt.set_timeout(kernel_timeout_ms);
rt.run(&prog.hash, n_tiles, src_noc, dst_noc)
.map_err(|e| BackendError { status: ErrorStatus::KernelLaunch, context: format!("runtime run: {e}").into() })?;
rt.set_timeout(30000);
Ok(Event::TT(TTEvent))
}
}
struct SfpuInfo {
header: &'static str,
init_fn: &'static str,
tile_fn: &'static str,
}
fn uop_to_sfpu(uop: UOp) -> Result<SfpuInfo, BackendError> {
match uop {
UOp::Exp => Ok(SfpuInfo { header: "api/compute/eltwise_unary/exp.h", init_fn: "exp_tile_init", tile_fn: "exp_tile" }),
UOp::Reciprocal => {
Ok(SfpuInfo { header: "api/compute/eltwise_unary/recip.h", init_fn: "recip_tile_init", tile_fn: "recip_tile" })
}
UOp::Sqrt => Ok(SfpuInfo { header: "api/compute/eltwise_unary/sqrt.h", init_fn: "sqrt_tile_init", tile_fn: "sqrt_tile" }),
UOp::Sin => {
Ok(SfpuInfo { header: "api/compute/eltwise_unary/trigonometry.h", init_fn: "sin_tile_init", tile_fn: "sin_tile" })
}
UOp::Cos => {
Ok(SfpuInfo { header: "api/compute/eltwise_unary/trigonometry.h", init_fn: "cos_tile_init", tile_fn: "cos_tile" })
}
UOp::Neg => Ok(SfpuInfo {
header: "api/compute/eltwise_unary/negative.h",
init_fn: "negative_tile_init",
tile_fn: "negative_tile",
}),
UOp::Floor => {
Ok(SfpuInfo { header: "api/compute/eltwise_unary/rounding.h", init_fn: "floor_tile_init", tile_fn: "floor_tile" })
}
UOp::Trunc => {
Ok(SfpuInfo { header: "api/compute/eltwise_unary/rounding.h", init_fn: "trunc_tile_init", tile_fn: "trunc_tile" })
}
_ => Err(BackendError {
status: ErrorStatus::KernelCompilation,
context: format!("unsupported unary op {uop:?} for Tenstorrent (add an IR optimization pass)").into(),
}),
}
}
fn generate_compute_kernel(kernel: &Kernel) -> Result<String, BackendError> {
let mut uop = None;
let mut op_id = kernel.head;
while !op_id.is_null() {
match kernel.at(op_id) {
Op::Unary { uop: op, .. } => {
uop = Some(*op);
break;
}
_ => {}
}
op_id = kernel.next_op(op_id);
}
let sfpu = uop_to_sfpu(uop.ok_or_else(|| BackendError {
status: ErrorStatus::KernelCompilation,
context: "no unary op found in kernel".into(),
})?)?;
Ok(format!(
r####"#include <cstdint>
#include "api/compute/cb_api.h"
#include "api/compute/tile_move_copy.h"
#include "api/compute/eltwise_unary/eltwise_unary.h"
#include "{header}"
void kernel_main() {{
uint32_t n_tiles = get_arg_val<uint32_t>(0);
unary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_16);
{init_fn}();
for (uint32_t i = 0; i < n_tiles; i++) {{
tile_regs_acquire();
cb_wait_front(tt::CBIndex::c_0, 1);
copy_tile(tt::CBIndex::c_0, 0, 0);
{tile_fn}(0);
cb_pop_front(tt::CBIndex::c_0, 1);
tile_regs_commit();
tile_regs_wait();
cb_reserve_back(tt::CBIndex::c_16, 1);
pack_tile(0, tt::CBIndex::c_16);
cb_push_back(tt::CBIndex::c_16, 1);
tile_regs_release();
}}
}}
"####,
header = sfpu.header,
init_fn = sfpu.init_fn,
tile_fn = sfpu.tile_fn,
))
}