use alloc::collections::VecDeque;
use core::{
cell::Cell,
sync::atomic::{AtomicBool, AtomicU32, AtomicU64, AtomicUsize, Ordering},
};
use ax_kspin::SpinNoPreempt as Mutex;
use super::{
TDMA_PHYS_BASE, TIU_PHYS_BASE,
error::TpuError,
platform::{DelayFn, TiuIrqCallback, TpuRuntimeState},
tdma::TdmaRegs,
tiu::TiuRegs,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TpuState {
Uninitialized,
Idle,
Running,
Suspended,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TpuSubmitPath {
DesNormal = 0,
}
#[derive(Debug)]
pub struct TpuTaskNode {
pub pid: u32,
pub seq_no: u32,
pub dmabuf_fd: i32,
pub dmabuf_vaddr: usize,
pub dmabuf_paddr: u64,
pub tpu_path: TpuSubmitPath,
pub ret: i32,
}
#[derive(Default)]
pub struct TpuKernelWork {
pub task_list: VecDeque<TpuTaskNode>,
pub done_list: VecDeque<TpuTaskNode>,
}
struct TpuDeviceInner {
tdma: TdmaRegs,
tiu: TiuRegs,
state: TpuState,
runtime: TpuRuntimeState,
kernel_work: TpuKernelWork,
tiu_irq_callback: Option<TiuIrqCallback>,
}
pub struct Sg2002Tpu {
tdma_vaddr: *mut u8,
tiu_vaddr: *mut u8,
inner: Mutex<TpuDeviceInner>,
seq_counter: AtomicU32,
irq_pending: AtomicBool,
irq_handler_hits: AtomicU64,
poll_fallback_hits: AtomicU64,
fallback_warned: AtomicBool,
delay_fn: AtomicUsize,
}
const WAIT_POLL_INTERVAL_US: u64 = 100;
impl Sg2002Tpu {
pub unsafe fn new() -> Self {
let virt_offset = 0xffff_ffc0_0000_0000u64 as isize;
let tdma_vaddr = (TDMA_PHYS_BASE as isize + virt_offset) as *mut u8;
let tiu_vaddr = (TIU_PHYS_BASE as isize + virt_offset) as *mut u8;
unsafe { Self::from_vaddr(tdma_vaddr, tiu_vaddr) }
}
pub unsafe fn from_vaddr(tdma_vaddr: *mut u8, tiu_vaddr: *mut u8) -> Self {
Self {
tdma_vaddr,
tiu_vaddr,
inner: Mutex::new(TpuDeviceInner {
tdma: unsafe { TdmaRegs::new(tdma_vaddr) },
tiu: unsafe { TiuRegs::new(tiu_vaddr) },
state: TpuState::Uninitialized,
runtime: TpuRuntimeState::default(),
kernel_work: TpuKernelWork::default(),
tiu_irq_callback: None,
}),
seq_counter: AtomicU32::new(0),
irq_pending: AtomicBool::new(false),
irq_handler_hits: AtomicU64::new(0),
poll_fallback_hits: AtomicU64::new(0),
fallback_warned: AtomicBool::new(false),
delay_fn: AtomicUsize::new(0),
}
}
pub fn register_tiu_irq_callback(&self, callback: TiuIrqCallback) {
let mut inner = self.inner.lock();
inner.tiu_irq_callback = Some(callback);
}
pub fn clear_tiu_irq_callback(&self) {
let mut inner = self.inner.lock();
inner.tiu_irq_callback = None;
}
pub fn set_delay_fn(&self, delay_fn: DelayFn) {
self.delay_fn.store(delay_fn as usize, Ordering::Release);
}
fn wait_poll_interval(&self, usecs: u64) {
let raw = self.delay_fn.load(Ordering::Acquire);
if raw != 0 {
let delay: DelayFn = unsafe { core::mem::transmute::<usize, DelayFn>(raw) };
delay(usecs);
} else {
core::hint::spin_loop();
}
}
pub fn init(&self) -> Result<(), TpuError> {
let mut inner = self.inner.lock();
super::platform::resync_cmd_id(&inner.tdma, &inner.tiu);
inner.state = TpuState::Idle;
inner.runtime = TpuRuntimeState::default();
info!("TPU device initialized");
Ok(())
}
pub fn state(&self) -> TpuState {
self.inner.lock().state
}
pub fn is_ready(&self) -> bool {
self.inner.lock().state == TpuState::Idle
}
pub fn handle_irq(&self) -> bool {
let tdma = unsafe { TdmaRegs::new(self.tdma_vaddr) };
let reg_value = tdma.read(super::tdma::TDMA_INT_MASK);
let int_status = (reg_value >> 16) & !super::tdma::TDMA_MASK_INIT;
if int_status == 0 {
return false;
}
let has_error =
int_status != super::tdma::TDMA_INT_EOD && int_status != super::tdma::TDMA_INT_EOPMU;
tdma.clear_interrupt();
self.irq_handler_hits.fetch_add(1, Ordering::AcqRel);
self.irq_pending.store(true, Ordering::Release);
has_error
}
pub fn irq_stats(&self) -> (u64, u64) {
(
self.irq_handler_hits.load(Ordering::Acquire),
self.poll_fallback_hits.load(Ordering::Acquire),
)
}
pub fn next_seq_no(&self) -> u32 {
self.seq_counter.fetch_add(1, Ordering::SeqCst)
}
pub fn submit_dmabuf(
&self,
fd: i32,
seq_no: u32,
dmabuf_vaddr: usize,
dmabuf_paddr: u64,
) -> Result<(), TpuError> {
debug!("[TPU] submit dmabuf: fd={}, seq_no={}", fd, seq_no);
debug!(
"[TPU] Buffer: vaddr=0x{:x}, paddr=0x{:x}",
dmabuf_vaddr, dmabuf_paddr
);
let task = TpuTaskNode {
pid: 0, seq_no,
dmabuf_fd: fd,
dmabuf_vaddr,
dmabuf_paddr,
tpu_path: TpuSubmitPath::DesNormal,
ret: 0,
};
let mut inner = self.inner.lock();
inner.kernel_work.task_list.push_back(task);
self.process_task_locked(&mut inner)?;
Ok(())
}
fn process_task_locked(&self, inner: &mut TpuDeviceInner) -> Result<(), TpuError> {
while let Some(mut task) = inner.kernel_work.task_list.pop_front() {
super::platform::resync_cmd_id(&inner.tdma, &inner.tiu);
inner.runtime.irq_received = false;
let result = self.run_dmabuf_internal(
inner,
task.seq_no,
task.dmabuf_vaddr as *const u8,
task.dmabuf_paddr,
);
task.ret = match result {
Ok(_) => 0,
Err(e) => {
error!("TPU run dmabuf failed: {:?}", e);
-1
}
};
inner.kernel_work.done_list.push_back(task);
}
Ok(())
}
fn run_dmabuf_internal(
&self,
inner: &mut TpuDeviceInner,
seq_no: u32,
dmabuf_vaddr: *const u8,
dmabuf_paddr: u64,
) -> Result<(), TpuError> {
if inner.state != TpuState::Idle && inner.state != TpuState::Uninitialized {
return Err(TpuError::NotInitialized);
}
inner.state = TpuState::Running;
let timeout_counter = Cell::new(0u64);
let timeout_limit = 10_000_000_000u64; const WAIT_TIMEOUT_US: u64 = 10_000_000;
let wait_poll_steps = WAIT_TIMEOUT_US / WAIT_POLL_INTERVAL_US;
self.irq_pending.store(false, Ordering::Release);
let tdma_irq_poll = unsafe { TdmaRegs::new(self.tdma_vaddr) };
let wait_irq = || -> Result<(), TpuError> {
let mut steps = 0u64;
while steps < wait_poll_steps {
if self.irq_pending.swap(false, Ordering::AcqRel) {
return Ok(());
}
let int_status = tdma_irq_poll.get_int_status();
if int_status == super::tdma::TDMA_INT_EOD
|| int_status == super::tdma::TDMA_INT_EOPMU
{
tdma_irq_poll.clear_interrupt();
self.poll_fallback_hits.fetch_add(1, Ordering::AcqRel);
if self
.fallback_warned
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
warn!("[TPU] external IRQ path not observed yet, using MMIO poll fallback");
}
return Ok(());
}
self.wait_poll_interval(WAIT_POLL_INTERVAL_US);
steps += 1;
}
Err(TpuError::Timeout)
};
let timeout_checker = || -> bool {
let next = timeout_counter.get().saturating_add(1);
timeout_counter.set(next);
next > timeout_limit
};
let tdma = &inner.tdma as *const TdmaRegs;
let tiu = &inner.tiu as *const TiuRegs;
let tiu_irq_callback = inner.tiu_irq_callback;
let runtime = &mut inner.runtime;
runtime.current_seq_no = seq_no;
runtime.tiu_irq_callback = tiu_irq_callback;
let result = unsafe {
super::platform::run_dmabuf(
&*tdma,
&*tiu,
dmabuf_vaddr,
dmabuf_paddr,
runtime,
wait_irq,
timeout_checker,
)
};
inner.state = TpuState::Idle;
result
}
pub fn wait_dmabuf(&self, seq_no: u32) -> Result<i32, TpuError> {
debug!("TPU wait dmabuf: seq_no={}", seq_no);
let mut inner = self.inner.lock();
let mut found_idx = None;
for (idx, task) in inner.kernel_work.done_list.iter().enumerate() {
if task.seq_no == seq_no {
found_idx = Some(idx);
break;
}
}
if let Some(idx) = found_idx {
let task = inner.kernel_work.done_list.remove(idx).unwrap();
debug!(
"TPU wait dmabuf completed: seq_no={}, ret={}",
seq_no, task.ret
);
Ok(task.ret)
} else {
warn!("TPU wait dmabuf: seq_no {} not found", seq_no);
Err(TpuError::NotInitialized)
}
}
pub fn cache_flush_paddr(&self, paddr: u64, size: u64) -> Result<(), TpuError> {
debug!("TPU cache flush: paddr=0x{:x}, size={}", paddr, size);
#[cfg(target_arch = "riscv64")]
{
unsafe {
core::arch::asm!("fence iorw, iorw");
}
}
let _ = (paddr, size);
Ok(())
}
pub fn cache_invalidate_paddr(&self, paddr: u64, size: u64) -> Result<(), TpuError> {
debug!("TPU cache invalidate: paddr=0x{:x}, size={}", paddr, size);
#[cfg(target_arch = "riscv64")]
{
unsafe {
core::arch::asm!("fence iorw, iorw");
}
}
let _ = (paddr, size);
Ok(())
}
pub fn suspend(&self) -> Result<(), TpuError> {
let mut inner = self.inner.lock();
if inner.state == TpuState::Suspended {
return Ok(());
}
let tdma = &inner.tdma as *const TdmaRegs;
let tiu = &inner.tiu as *const TiuRegs;
let reg_backup = &mut inner.runtime.reg_backup;
unsafe {
super::platform::backup_registers(&*tdma, &*tiu, reg_backup);
}
inner.state = TpuState::Suspended;
info!("TPU suspended");
Ok(())
}
pub fn resume(&self) -> Result<(), TpuError> {
let mut inner = self.inner.lock();
if inner.state != TpuState::Suspended {
return Err(TpuError::NotInitialized);
}
let tdma = &inner.tdma as *const TdmaRegs;
let tiu = &inner.tiu as *const TiuRegs;
let reg_backup = &inner.runtime.reg_backup;
unsafe {
super::platform::restore_registers(&*tdma, &*tiu, reg_backup);
}
inner.state = TpuState::Idle;
info!("TPU resumed");
Ok(())
}
pub fn reset(&self) {
let mut inner = self.inner.lock();
super::platform::resync_cmd_id(&inner.tdma, &inner.tiu);
inner.runtime = TpuRuntimeState::default();
inner.state = TpuState::Idle;
info!("TPU reset");
}
}
unsafe impl Send for Sg2002Tpu {}
unsafe impl Sync for Sg2002Tpu {}