tpu-sg2002 0.1.0

TPU driver in Rust for SG2002 SoC.
Documentation
//! TPU Platform abstraction layer for SG2002.

use crate::PhysAddr;
use crate::registers::{TDMA_MASK_INIT, TDMA_INT_EOD, TDMA_INT_EOPMU, TDMA_ALL_IDLE};

/// DMA buffer header magic value.
pub const TPU_DMABUF_HEADER_M: u16 = 0xB5B5;

/// Timeout in microseconds for TPU operations.
pub const TIMEOUT_US: u64 = 60_000_000;

/// Platform configuration for TPU operations.
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct TpuPlatformCfg {
    /// TDMA MMIO base address (virtual).
    pub iomem_tdma_base: usize,
    /// TIU MMIO base address (virtual).
    pub iomem_tiu_base: usize,
    /// PMU buffer size in bytes.
    pub pmubuf_size: u32,
    /// PMU buffer physical address.
    pub pmubuf_addr_p: PhysAddr,
}

/// Command ID node for tracking BD and TDMA command IDs.
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct CmdIdNode {
    /// BD (TIU) command ID.
    pub bd_cmd_id: u32,
    /// TDMA command ID.
    pub tdma_cmd_id: u32,
}

/// TPU register backup for suspend/resume.
#[derive(Debug, Clone, Copy, Default)]
pub struct TpuRegBackup {
    pub tdma_int_mask: u32,
    pub tdma_sync_status: u32,
    pub tiu_ctrl: u32,
    /// Array base low (8) + high (2) = 10 registers.
    pub tdma_arraybase: [u32; 10],
    pub tdma_des_base: u32,
    pub tdma_dbg_mode: u32,
    pub tdma_dcm_disable: u32,
    pub tdma_ctrl: u32,
}

/// TDMA register structure for PIO mode.
///
/// This structure mirrors the C driver's `tdma_reg_t` used for
/// programmed I/O (PIO) mode TDMA transfers.
#[derive(Debug, Clone, Copy)]
pub struct TdmaReg {
    pub vld: u32,
    pub compress_en: u32,
    pub eod: u32,
    pub intp_en: u32,
    pub bar_en: u32,
    pub check_bf16_value: u32,
    pub trans_dir: u32,
    pub rsv00: u32,
    pub trans_fmt: u32,
    pub transpose_md: u32,
    pub rsv01: u32,
    pub intra_cmd_paral: u32,
    pub outstanding_en: u32,
    pub cmd_id: u32,
    pub spec_func: u32,
    pub dst_fmt: u32,
    pub src_fmt: u32,
    pub cmprs_fmt: u32,
    pub sys_dtype: u32,
    pub rsv2_1: u32,
    pub int8_sign: u32,
    pub compress_zero_guard: u32,
    pub int8_rnd_mode: u32,
    pub wait_id_tpu: u32,
    pub wait_id_other_tdma: u32,
    pub wait_id_sdma: u32,
    pub const_val: u32,
    pub src_base_reg_sel: u32,
    pub mv_lut_idx: u32,
    pub dst_base_reg_sel: u32,
    pub mv_lut_base: u32,
    pub rsv4_5: u32,
    pub dst_h_stride: u32,
    pub dst_c_stride_low: u32,
    pub dst_n_stride: u32,
    pub src_h_stride: u32,
    pub src_c_stride_low: u32,
    pub src_n_stride: u32,
    pub dst_c: u32,
    pub src_c: u32,
    pub dst_w: u32,
    pub dst_h: u32,
    pub src_w: u32,
    pub src_h: u32,
    pub dst_base_addr_low: u32,
    pub src_base_addr_low: u32,
    pub src_n: u32,
    pub dst_base_addr_high: u32,
    pub src_base_addr_high: u32,
    pub src_c_stride_high: u32,
    pub dst_c_stride_high: u32,
    pub compress_bias0: u32,
    pub compress_bias1: u32,
    pub layer_id: u32,
}

impl Default for TdmaReg {
    fn default() -> Self {
        Self {
            vld: 0,
            compress_en: 0,
            eod: 0,
            intp_en: 0,
            bar_en: 0,
            check_bf16_value: 0,
            trans_dir: 0,
            rsv00: 0,
            trans_fmt: 0,
            transpose_md: 0,
            rsv01: 0,
            intra_cmd_paral: 0,
            outstanding_en: 0,
            cmd_id: 0,
            spec_func: 0,
            dst_fmt: 1,
            src_fmt: 1,
            cmprs_fmt: 0,
            sys_dtype: 0,
            rsv2_1: 0,
            int8_sign: 0,
            compress_zero_guard: 0,
            int8_rnd_mode: 0,
            wait_id_tpu: 0,
            wait_id_other_tdma: 0,
            wait_id_sdma: 0,
            const_val: 0,
            src_base_reg_sel: 0,
            mv_lut_idx: 0,
            dst_base_reg_sel: 0,
            mv_lut_base: 0,
            rsv4_5: 0,
            dst_h_stride: 1,
            dst_c_stride_low: 1,
            dst_n_stride: 1,
            src_h_stride: 1,
            src_c_stride_low: 1,
            src_n_stride: 1,
            dst_c: 1,
            src_c: 1,
            dst_w: 1,
            dst_h: 1,
            src_w: 1,
            src_h: 1,
            dst_base_addr_low: 0,
            src_base_addr_low: 0,
            src_n: 1,
            dst_base_addr_high: 0,
            src_base_addr_high: 0,
            src_c_stride_high: 0,
            dst_c_stride_high: 0,
            compress_bias0: 0,
            compress_bias1: 0,
            layer_id: 0,
        }
    }
}

impl TdmaReg {
    /// Reset all fields to default values.
    pub fn reset(&mut self) {
        *self = Self::default();
    }

    /// Emit the TDMA register structure as a 16-word array for PIO mode.
    ///
    /// This function mirrors the C driver's `emit_tdma_reg()` function,
    /// packing the register fields into the hardware format.
    pub fn emit(&self) -> [u32; 16] {
        let mut p = [0u32; 16];

        p[15] = (self.compress_bias0 & ((1u32 << 8) - 1))
            | ((self.compress_bias1 & ((1u32 << 8) - 1)) << 8)
            | ((self.layer_id & ((1u32 << 16) - 1)) << 16);

        p[14] = (self.src_c_stride_high & ((1u32 << 16) - 1))
            | ((self.dst_c_stride_high & ((1u32 << 16) - 1)) << 16);

        p[13] = (self.src_n & ((1u32 << 16) - 1))
            | ((self.dst_base_addr_high & ((1u32 << 8) - 1)) << 16)
            | ((self.src_base_addr_high & ((1u32 << 8) - 1)) << 24);

        p[12] = self.src_base_addr_low;
        p[11] = self.dst_base_addr_low;

        p[10] = (self.src_w & ((1u32 << 16) - 1))
            | ((self.src_h & ((1u32 << 16) - 1)) << 16);

        p[9] = (self.dst_w & ((1u32 << 16) - 1))
            | ((self.dst_h & ((1u32 << 16) - 1)) << 16);

        p[8] = (self.dst_c & ((1u32 << 16) - 1))
            | ((self.src_c & ((1u32 << 16) - 1)) << 16);

        p[7] = self.src_n_stride;

        p[6] = (self.src_h_stride & ((1u32 << 16) - 1))
            | ((self.src_c_stride_low & ((1u32 << 16) - 1)) << 16);

        p[5] = self.dst_n_stride;

        p[4] = (self.dst_h_stride & ((1u32 << 16) - 1))
            | ((self.dst_c_stride_low & ((1u32 << 16) - 1)) << 16);

        p[3] = (self.const_val & ((1u32 << 16) - 1))
            | ((self.src_base_reg_sel & ((1u32 << 3) - 1)) << 16)
            | ((self.mv_lut_idx & 1) << 19)
            | ((self.dst_base_reg_sel & ((1u32 << 3) - 1)) << 20)
            | ((self.mv_lut_base & 1) << 23)
            | ((self.rsv4_5 & ((1u32 << 8) - 1)) << 24);

        p[2] = (self.wait_id_other_tdma & ((1u32 << 16) - 1))
            | ((self.wait_id_sdma & ((1u32 << 16) - 1)) << 16);

        p[1] = (self.spec_func & ((1u32 << 3) - 1))
            | ((self.dst_fmt & ((1u32 << 2) - 1)) << 3)
            | ((self.src_fmt & ((1u32 << 2) - 1)) << 5)
            | ((self.cmprs_fmt & 1) << 7)
            | ((self.sys_dtype & 1) << 8)
            | ((self.rsv2_1 & ((1u32 << 4) - 1)) << 9)
            | ((self.int8_sign & 1) << 13)
            | ((self.compress_zero_guard & 1) << 14)
            | ((self.int8_rnd_mode & 1) << 15)
            | ((self.wait_id_tpu & ((1u32 << 16) - 1)) << 16);

        p[0] = (self.vld & 1)
            | ((self.compress_en & 1) << 1)
            | ((self.eod & 1) << 2)
            | ((self.intp_en & 1) << 3)
            | ((self.bar_en & 1) << 4)
            | ((self.check_bf16_value & 1) << 5)
            | ((self.trans_dir & ((1u32 << 2) - 1)) << 6)
            | ((self.rsv00 & ((1u32 << 2) - 1)) << 8)
            | ((self.trans_fmt & 1) << 10)
            | ((self.transpose_md & ((1u32 << 2) - 1)) << 11)
            | ((self.rsv01 & 1) << 13)
            | ((self.intra_cmd_paral & 1) << 14)
            | ((self.outstanding_en & 1) << 15)
            | ((self.cmd_id & ((1u32 << 16) - 1)) << 16);

        p
    }
}

/// TDMA transfer direction.
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TdmaTransDir {
    /// Transfer from global memory to local memory.
    Tg2L = 0,
    /// Transfer from local memory to global memory.
    L2Tg = 1,
    /// Transfer within global memory (global to global).
    G2G = 2,
    /// Transfer within local memory (local to local).
    L2L = 3,
}

/// TDMA transfer format.
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TdmaTransFmt {
    /// Tensor format with shape (N, C, H, W).
    Tensor = 0,
    /// Common/linear format.
    Common = 1,
}

/// GDMA data types (backward compatible).
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GdmaDataType {
    F32 = 0,
    F16 = 1,
    I32 = 2,
    I16 = 3,
    I8 = 4,
    I4 = 5,
    I2 = 6,
    I1 = 7,
}

impl GdmaDataType {
    /// Get the number of bits per element.
    pub const fn bits(&self) -> u32 {
        match self {
            Self::F32 | Self::I32 => 32,
            Self::F16 | Self::I16 => 16,
            Self::I8 => 8,
            Self::I4 => 4,
            Self::I2 => 2,
            Self::I1 => 1,
        }
    }
}

/// TDMA special function codes.
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TdmaSpecFunc {
    /// No special function.
    #[default]
    None = 0,
    /// Transpose operation.
    Transpose = 1,
    /// Fill constant operation.
    FillConst = 2,
    /// CW transpose operation.
    CwTranspose = 3,
    /// Matrix multiplication operation.
    MatMul = 4,
}

/// TDMA transpose mode.
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TdmaTransposeMode {
    /// No transpose.
    #[default]
    None = 0,
    /// Transpose W and C dimensions.
    TransposeWC = 1,
    /// Transpose H and C dimensions.
    TransposeHC = 2,
    /// Transpose N and C dimensions.
    TransposeNC = 3,
}

/// TDMA interrupt status.
#[derive(Debug, Clone, Copy, Default)]
pub struct TdmaIntStatus {
    pub raw: u32,
    pub status: u32,
}

impl TdmaIntStatus {
    pub fn from_raw(raw: u32) -> Self {
        Self { raw, status: (raw >> 16) & !TDMA_MASK_INIT }
    }

    pub fn is_eod(&self) -> bool { self.status == TDMA_INT_EOD }
    pub fn is_eopmu(&self) -> bool { self.status == TDMA_INT_EOPMU }
    pub fn has_error(&self) -> bool { self.status != 0 && !self.is_eod() && !self.is_eopmu() }
}

/// TDMA sync status.
#[derive(Debug, Clone, Copy, Default)]
pub struct TdmaSyncStatus {
    pub raw: u32,
}

impl TdmaSyncStatus {
    pub fn from_raw(raw: u32) -> Self { Self { raw } }
    pub fn sync_id(&self) -> u16 { (self.raw >> 16) as u16 }
    pub fn is_all_idle(&self) -> bool { (self.raw as u16 & TDMA_ALL_IDLE as u16) == TDMA_ALL_IDLE as u16 }
}

/// TIU control status.
#[derive(Debug, Clone, Copy, Default)]
pub struct TiuCtrlStatus {
    pub raw: u32,
}

impl TiuCtrlStatus {
    pub fn from_raw(raw: u32) -> Self { Self { raw } }
    pub fn cmd_id(&self) -> u16 { ((self.raw >> 6) & 0xFFFF) as u16 }
    pub fn enabled(&self) -> bool { (self.raw & 1) != 0 }
    pub fn interrupt(&self) -> bool { (self.raw & 2) != 0 }
}