use crate::PhysAddr;
use crate::registers::{TDMA_MASK_INIT, TDMA_INT_EOD, TDMA_INT_EOPMU, TDMA_ALL_IDLE};
pub const TPU_DMABUF_HEADER_M: u16 = 0xB5B5;
pub const TIMEOUT_US: u64 = 60_000_000;
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct TpuPlatformCfg {
pub iomem_tdma_base: usize,
pub iomem_tiu_base: usize,
pub pmubuf_size: u32,
pub pmubuf_addr_p: PhysAddr,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct CmdIdNode {
pub bd_cmd_id: u32,
pub tdma_cmd_id: u32,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TpuRegBackup {
pub tdma_int_mask: u32,
pub tdma_sync_status: u32,
pub tiu_ctrl: u32,
pub tdma_arraybase: [u32; 10],
pub tdma_des_base: u32,
pub tdma_dbg_mode: u32,
pub tdma_dcm_disable: u32,
pub tdma_ctrl: u32,
}
#[derive(Debug, Clone, Copy)]
pub struct TdmaReg {
pub vld: u32,
pub compress_en: u32,
pub eod: u32,
pub intp_en: u32,
pub bar_en: u32,
pub check_bf16_value: u32,
pub trans_dir: u32,
pub rsv00: u32,
pub trans_fmt: u32,
pub transpose_md: u32,
pub rsv01: u32,
pub intra_cmd_paral: u32,
pub outstanding_en: u32,
pub cmd_id: u32,
pub spec_func: u32,
pub dst_fmt: u32,
pub src_fmt: u32,
pub cmprs_fmt: u32,
pub sys_dtype: u32,
pub rsv2_1: u32,
pub int8_sign: u32,
pub compress_zero_guard: u32,
pub int8_rnd_mode: u32,
pub wait_id_tpu: u32,
pub wait_id_other_tdma: u32,
pub wait_id_sdma: u32,
pub const_val: u32,
pub src_base_reg_sel: u32,
pub mv_lut_idx: u32,
pub dst_base_reg_sel: u32,
pub mv_lut_base: u32,
pub rsv4_5: u32,
pub dst_h_stride: u32,
pub dst_c_stride_low: u32,
pub dst_n_stride: u32,
pub src_h_stride: u32,
pub src_c_stride_low: u32,
pub src_n_stride: u32,
pub dst_c: u32,
pub src_c: u32,
pub dst_w: u32,
pub dst_h: u32,
pub src_w: u32,
pub src_h: u32,
pub dst_base_addr_low: u32,
pub src_base_addr_low: u32,
pub src_n: u32,
pub dst_base_addr_high: u32,
pub src_base_addr_high: u32,
pub src_c_stride_high: u32,
pub dst_c_stride_high: u32,
pub compress_bias0: u32,
pub compress_bias1: u32,
pub layer_id: u32,
}
impl Default for TdmaReg {
fn default() -> Self {
Self {
vld: 0,
compress_en: 0,
eod: 0,
intp_en: 0,
bar_en: 0,
check_bf16_value: 0,
trans_dir: 0,
rsv00: 0,
trans_fmt: 0,
transpose_md: 0,
rsv01: 0,
intra_cmd_paral: 0,
outstanding_en: 0,
cmd_id: 0,
spec_func: 0,
dst_fmt: 1,
src_fmt: 1,
cmprs_fmt: 0,
sys_dtype: 0,
rsv2_1: 0,
int8_sign: 0,
compress_zero_guard: 0,
int8_rnd_mode: 0,
wait_id_tpu: 0,
wait_id_other_tdma: 0,
wait_id_sdma: 0,
const_val: 0,
src_base_reg_sel: 0,
mv_lut_idx: 0,
dst_base_reg_sel: 0,
mv_lut_base: 0,
rsv4_5: 0,
dst_h_stride: 1,
dst_c_stride_low: 1,
dst_n_stride: 1,
src_h_stride: 1,
src_c_stride_low: 1,
src_n_stride: 1,
dst_c: 1,
src_c: 1,
dst_w: 1,
dst_h: 1,
src_w: 1,
src_h: 1,
dst_base_addr_low: 0,
src_base_addr_low: 0,
src_n: 1,
dst_base_addr_high: 0,
src_base_addr_high: 0,
src_c_stride_high: 0,
dst_c_stride_high: 0,
compress_bias0: 0,
compress_bias1: 0,
layer_id: 0,
}
}
}
impl TdmaReg {
pub fn reset(&mut self) {
*self = Self::default();
}
pub fn emit(&self) -> [u32; 16] {
let mut p = [0u32; 16];
p[15] = (self.compress_bias0 & ((1u32 << 8) - 1))
| ((self.compress_bias1 & ((1u32 << 8) - 1)) << 8)
| ((self.layer_id & ((1u32 << 16) - 1)) << 16);
p[14] = (self.src_c_stride_high & ((1u32 << 16) - 1))
| ((self.dst_c_stride_high & ((1u32 << 16) - 1)) << 16);
p[13] = (self.src_n & ((1u32 << 16) - 1))
| ((self.dst_base_addr_high & ((1u32 << 8) - 1)) << 16)
| ((self.src_base_addr_high & ((1u32 << 8) - 1)) << 24);
p[12] = self.src_base_addr_low;
p[11] = self.dst_base_addr_low;
p[10] = (self.src_w & ((1u32 << 16) - 1))
| ((self.src_h & ((1u32 << 16) - 1)) << 16);
p[9] = (self.dst_w & ((1u32 << 16) - 1))
| ((self.dst_h & ((1u32 << 16) - 1)) << 16);
p[8] = (self.dst_c & ((1u32 << 16) - 1))
| ((self.src_c & ((1u32 << 16) - 1)) << 16);
p[7] = self.src_n_stride;
p[6] = (self.src_h_stride & ((1u32 << 16) - 1))
| ((self.src_c_stride_low & ((1u32 << 16) - 1)) << 16);
p[5] = self.dst_n_stride;
p[4] = (self.dst_h_stride & ((1u32 << 16) - 1))
| ((self.dst_c_stride_low & ((1u32 << 16) - 1)) << 16);
p[3] = (self.const_val & ((1u32 << 16) - 1))
| ((self.src_base_reg_sel & ((1u32 << 3) - 1)) << 16)
| ((self.mv_lut_idx & 1) << 19)
| ((self.dst_base_reg_sel & ((1u32 << 3) - 1)) << 20)
| ((self.mv_lut_base & 1) << 23)
| ((self.rsv4_5 & ((1u32 << 8) - 1)) << 24);
p[2] = (self.wait_id_other_tdma & ((1u32 << 16) - 1))
| ((self.wait_id_sdma & ((1u32 << 16) - 1)) << 16);
p[1] = (self.spec_func & ((1u32 << 3) - 1))
| ((self.dst_fmt & ((1u32 << 2) - 1)) << 3)
| ((self.src_fmt & ((1u32 << 2) - 1)) << 5)
| ((self.cmprs_fmt & 1) << 7)
| ((self.sys_dtype & 1) << 8)
| ((self.rsv2_1 & ((1u32 << 4) - 1)) << 9)
| ((self.int8_sign & 1) << 13)
| ((self.compress_zero_guard & 1) << 14)
| ((self.int8_rnd_mode & 1) << 15)
| ((self.wait_id_tpu & ((1u32 << 16) - 1)) << 16);
p[0] = (self.vld & 1)
| ((self.compress_en & 1) << 1)
| ((self.eod & 1) << 2)
| ((self.intp_en & 1) << 3)
| ((self.bar_en & 1) << 4)
| ((self.check_bf16_value & 1) << 5)
| ((self.trans_dir & ((1u32 << 2) - 1)) << 6)
| ((self.rsv00 & ((1u32 << 2) - 1)) << 8)
| ((self.trans_fmt & 1) << 10)
| ((self.transpose_md & ((1u32 << 2) - 1)) << 11)
| ((self.rsv01 & 1) << 13)
| ((self.intra_cmd_paral & 1) << 14)
| ((self.outstanding_en & 1) << 15)
| ((self.cmd_id & ((1u32 << 16) - 1)) << 16);
p
}
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TdmaTransDir {
Tg2L = 0,
L2Tg = 1,
G2G = 2,
L2L = 3,
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TdmaTransFmt {
Tensor = 0,
Common = 1,
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GdmaDataType {
F32 = 0,
F16 = 1,
I32 = 2,
I16 = 3,
I8 = 4,
I4 = 5,
I2 = 6,
I1 = 7,
}
impl GdmaDataType {
pub const fn bits(&self) -> u32 {
match self {
Self::F32 | Self::I32 => 32,
Self::F16 | Self::I16 => 16,
Self::I8 => 8,
Self::I4 => 4,
Self::I2 => 2,
Self::I1 => 1,
}
}
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TdmaSpecFunc {
#[default]
None = 0,
Transpose = 1,
FillConst = 2,
CwTranspose = 3,
MatMul = 4,
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TdmaTransposeMode {
#[default]
None = 0,
TransposeWC = 1,
TransposeHC = 2,
TransposeNC = 3,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TdmaIntStatus {
pub raw: u32,
pub status: u32,
}
impl TdmaIntStatus {
pub fn from_raw(raw: u32) -> Self {
Self { raw, status: (raw >> 16) & !TDMA_MASK_INIT }
}
pub fn is_eod(&self) -> bool { self.status == TDMA_INT_EOD }
pub fn is_eopmu(&self) -> bool { self.status == TDMA_INT_EOPMU }
pub fn has_error(&self) -> bool { self.status != 0 && !self.is_eod() && !self.is_eopmu() }
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TdmaSyncStatus {
pub raw: u32,
}
impl TdmaSyncStatus {
pub fn from_raw(raw: u32) -> Self { Self { raw } }
pub fn sync_id(&self) -> u16 { (self.raw >> 16) as u16 }
pub fn is_all_idle(&self) -> bool { (self.raw as u16 & TDMA_ALL_IDLE as u16) == TDMA_ALL_IDLE as u16 }
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TiuCtrlStatus {
pub raw: u32,
}
impl TiuCtrlStatus {
pub fn from_raw(raw: u32) -> Self { Self { raw } }
pub fn cmd_id(&self) -> u16 { ((self.raw >> 6) & 0xFFFF) as u16 }
pub fn enabled(&self) -> bool { (self.raw & 1) != 0 }
pub fn interrupt(&self) -> bool { (self.raw & 2) != 0 }
}