use core::{
cell::Cell,
sync::atomic::{AtomicBool, AtomicU32, AtomicU64, AtomicUsize, Ordering},
};
use ax_kspin::SpinNoPreempt as Mutex;
use super::{
TDMA_PHYS_BASE, TIU_PHYS_BASE,
error::TpuError,
platform::{TiuIrqCallback, TpuRuntimeState, WaitIrqFn},
tdma::TdmaRegs,
tiu::TiuRegs,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TpuState {
Uninitialized,
Idle,
Running,
Suspended,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TpuSubmitPath {
DesNormal = 0,
}
struct TpuDeviceInner {
tdma: TdmaRegs,
tiu: TiuRegs,
state: TpuState,
runtime: TpuRuntimeState,
tiu_irq_callback: Option<TiuIrqCallback>,
}
pub struct Sg2002Tpu {
tdma_vaddr: *mut u8,
tiu_vaddr: *mut u8,
inner: Mutex<TpuDeviceInner>,
seq_counter: AtomicU32,
irq_pending: AtomicBool,
irq_handler_hits: AtomicU64,
poll_fallback_hits: AtomicU64,
fallback_warned: AtomicBool,
wait_fn: AtomicUsize,
}
const WAIT_POLL_INTERVAL_US: u64 = 100;
const WAIT_TOTAL_STEPS: u64 = 10_000_000 / WAIT_POLL_INTERVAL_US;
impl Sg2002Tpu {
pub unsafe fn new() -> Self {
let virt_offset = 0xffff_ffc0_0000_0000u64 as isize;
let tdma_vaddr = (TDMA_PHYS_BASE as isize + virt_offset) as *mut u8;
let tiu_vaddr = (TIU_PHYS_BASE as isize + virt_offset) as *mut u8;
unsafe { Self::from_vaddr(tdma_vaddr, tiu_vaddr) }
}
pub unsafe fn from_vaddr(tdma_vaddr: *mut u8, tiu_vaddr: *mut u8) -> Self {
Self {
tdma_vaddr,
tiu_vaddr,
inner: Mutex::new(TpuDeviceInner {
tdma: unsafe { TdmaRegs::new(tdma_vaddr) },
tiu: unsafe { TiuRegs::new(tiu_vaddr) },
state: TpuState::Uninitialized,
runtime: TpuRuntimeState::default(),
tiu_irq_callback: None,
}),
seq_counter: AtomicU32::new(0),
irq_pending: AtomicBool::new(false),
irq_handler_hits: AtomicU64::new(0),
poll_fallback_hits: AtomicU64::new(0),
fallback_warned: AtomicBool::new(false),
wait_fn: AtomicUsize::new(0),
}
}
pub fn register_tiu_irq_callback(&self, callback: TiuIrqCallback) {
let mut inner = self.inner.lock();
inner.tiu_irq_callback = Some(callback);
}
pub fn clear_tiu_irq_callback(&self) {
let mut inner = self.inner.lock();
inner.tiu_irq_callback = None;
}
pub fn set_wait_irq_fn(&self, wait_fn: WaitIrqFn) {
self.wait_fn.store(wait_fn as usize, Ordering::Release);
}
fn wait_irq_blocking(&self, timeout_us: u64) -> bool {
let raw = self.wait_fn.load(Ordering::Acquire);
if raw != 0 {
let wait: WaitIrqFn = unsafe { core::mem::transmute::<usize, WaitIrqFn>(raw) };
wait(timeout_us)
} else {
core::hint::spin_loop();
self.irq_pending.load(Ordering::Acquire)
}
}
pub fn irq_pending(&self) -> bool {
self.irq_pending.load(Ordering::Acquire)
}
pub fn init(&self) -> Result<(), TpuError> {
let mut inner = self.inner.lock();
super::platform::resync_cmd_id(&inner.tdma, &inner.tiu);
inner.state = TpuState::Idle;
inner.runtime = TpuRuntimeState::default();
info!("TPU device initialized");
Ok(())
}
pub fn state(&self) -> TpuState {
self.inner.lock().state
}
pub fn is_ready(&self) -> bool {
self.inner.lock().state == TpuState::Idle
}
pub fn handle_irq(&self) -> bool {
let tdma = unsafe { TdmaRegs::new(self.tdma_vaddr) };
let reg_value = tdma.read(super::tdma::TDMA_INT_MASK);
let int_status = (reg_value >> 16) & !super::tdma::TDMA_MASK_INIT;
if int_status == 0 {
return false;
}
let has_error =
int_status != super::tdma::TDMA_INT_EOD && int_status != super::tdma::TDMA_INT_EOPMU;
tdma.clear_interrupt();
self.irq_handler_hits.fetch_add(1, Ordering::AcqRel);
self.irq_pending.store(true, Ordering::Release);
has_error
}
pub fn irq_stats(&self) -> (u64, u64) {
(
self.irq_handler_hits.load(Ordering::Acquire),
self.poll_fallback_hits.load(Ordering::Acquire),
)
}
pub fn next_seq_no(&self) -> u32 {
self.seq_counter.fetch_add(1, Ordering::SeqCst)
}
pub fn run_one(
&self,
seq_no: u32,
dmabuf_vaddr: usize,
dmabuf_paddr: u64,
) -> Result<(), TpuError> {
debug!(
"[TPU] run_one: seq_no={}, vaddr=0x{:x}, paddr=0x{:x}",
seq_no, dmabuf_vaddr, dmabuf_paddr
);
let tiu_irq_callback = {
let mut inner = self.inner.lock();
if inner.state != TpuState::Idle && inner.state != TpuState::Uninitialized {
return Err(TpuError::NotInitialized);
}
inner.state = TpuState::Running;
inner.tiu_irq_callback
};
let tdma = unsafe { TdmaRegs::new(self.tdma_vaddr) };
let tiu = unsafe { TiuRegs::new(self.tiu_vaddr) };
let mut runtime = TpuRuntimeState {
current_seq_no: seq_no,
tiu_irq_callback,
..TpuRuntimeState::default()
};
let timeout_counter = Cell::new(0u64);
let timeout_limit = 10_000_000_000u64; self.irq_pending.store(false, Ordering::Release);
let tdma_irq_poll = unsafe { TdmaRegs::new(self.tdma_vaddr) };
let wait_irq = || -> Result<(), TpuError> {
let mut steps = 0u64;
while steps < WAIT_TOTAL_STEPS {
if self.irq_pending.swap(false, Ordering::AcqRel) {
return Ok(());
}
let int_status = tdma_irq_poll.get_int_status();
if int_status == super::tdma::TDMA_INT_EOD
|| int_status == super::tdma::TDMA_INT_EOPMU
{
tdma_irq_poll.clear_interrupt();
self.poll_fallback_hits.fetch_add(1, Ordering::AcqRel);
if self
.fallback_warned
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
warn!("[TPU] external IRQ path not observed yet, using MMIO poll fallback");
}
return Ok(());
}
self.wait_irq_blocking(WAIT_POLL_INTERVAL_US);
steps += 1;
}
Err(TpuError::Timeout)
};
let timeout_checker = || -> bool {
let next = timeout_counter.get().saturating_add(1);
timeout_counter.set(next);
next > timeout_limit
};
let result = unsafe {
super::platform::run_dmabuf(
&tdma,
&tiu,
dmabuf_vaddr as *const u8,
dmabuf_paddr,
&mut runtime,
wait_irq,
timeout_checker,
)
};
{
let mut inner = self.inner.lock();
inner.runtime = runtime;
inner.state = TpuState::Idle;
}
if let Err(e) = &result {
error!("TPU run_one failed: seq_no={}, err={:?}", seq_no, e);
}
result
}
pub fn cache_flush_paddr(&self, paddr: u64, size: u64) -> Result<(), TpuError> {
debug!("TPU cache flush: paddr=0x{:x}, size={}", paddr, size);
#[cfg(target_arch = "riscv64")]
{
unsafe {
core::arch::asm!("fence iorw, iorw");
}
}
let _ = (paddr, size);
Ok(())
}
pub fn cache_invalidate_paddr(&self, paddr: u64, size: u64) -> Result<(), TpuError> {
debug!("TPU cache invalidate: paddr=0x{:x}, size={}", paddr, size);
#[cfg(target_arch = "riscv64")]
{
unsafe {
core::arch::asm!("fence iorw, iorw");
}
}
let _ = (paddr, size);
Ok(())
}
pub fn suspend(&self) -> Result<(), TpuError> {
let mut inner = self.inner.lock();
if inner.state == TpuState::Suspended {
return Ok(());
}
let tdma = &inner.tdma as *const TdmaRegs;
let tiu = &inner.tiu as *const TiuRegs;
let reg_backup = &mut inner.runtime.reg_backup;
unsafe {
super::platform::backup_registers(&*tdma, &*tiu, reg_backup);
}
inner.state = TpuState::Suspended;
info!("TPU suspended");
Ok(())
}
pub fn resume(&self) -> Result<(), TpuError> {
let mut inner = self.inner.lock();
if inner.state != TpuState::Suspended {
return Err(TpuError::NotInitialized);
}
let tdma = &inner.tdma as *const TdmaRegs;
let tiu = &inner.tiu as *const TiuRegs;
let reg_backup = &inner.runtime.reg_backup;
unsafe {
super::platform::restore_registers(&*tdma, &*tiu, reg_backup);
}
inner.state = TpuState::Idle;
info!("TPU resumed");
Ok(())
}
pub fn reset(&self) {
let mut inner = self.inner.lock();
super::platform::resync_cmd_id(&inner.tdma, &inner.tiu);
inner.runtime = TpuRuntimeState::default();
inner.state = TpuState::Idle;
info!("TPU reset");
}
}
unsafe impl Send for Sg2002Tpu {}
unsafe impl Sync for Sg2002Tpu {}