use alloc::{collections::VecDeque, string::String, sync::Arc};
use core::{
ptr::NonNull,
sync::atomic::{AtomicBool, AtomicPtr, Ordering},
time::Duration,
};
use ax_kspin::SpinNoIrq;
use ax_task::WaitQueue;
use sg2002_tpu::{
ion::IonBuffer,
tpu::{
Sg2002Tpu,
error::TpuError,
types::{
CVITPU_DMABUF_FLUSH, CVITPU_DMABUF_FLUSH_FD, CVITPU_DMABUF_INVLD,
CVITPU_DMABUF_INVLD_FD, CVITPU_LOAD_TEE, CVITPU_PIO_MODE, CVITPU_SUBMIT_DMABUF,
CVITPU_SUBMIT_TEE, CVITPU_UNLOAD_TEE, CVITPU_WAIT_DMABUF, CviCacheOpArg,
CviSubmitDmaArg, CviWaitDmaArg,
},
},
};
use crate::{
file::{get_file_like, ion::IonBufferFile},
pseudofs::DeviceOps,
};
struct TpuTask {
tid: u64,
seq_no: u32,
vaddr: usize,
paddr: u64,
_buffer: Arc<IonBuffer>,
ret: i32,
}
static TASK_LIST: SpinNoIrq<VecDeque<TpuTask>> = SpinNoIrq::new(VecDeque::new());
static DONE_LIST: SpinNoIrq<VecDeque<TpuTask>> = SpinNoIrq::new(VecDeque::new());
const DONE_LIST_MAX: usize = 64;
static TASK_WQ: WaitQueue = WaitQueue::new();
static DONE_WQ: WaitQueue = WaitQueue::new();
static IRQ_WQ: WaitQueue = WaitQueue::new();
static WORKER_SPAWNED: AtomicBool = AtomicBool::new(false);
static HW_PTR: AtomicPtr<Sg2002Tpu> = AtomicPtr::new(core::ptr::null_mut());
pub struct TpuDevice {
hw: Arc<Sg2002Tpu>,
irq_registered: bool,
}
const TPU_TDMA_IRQ: usize = 76;
const TPU_WAIT_TIMEOUT: Duration = Duration::from_secs(10);
fn register_tpu_irq(hw: &Arc<Sg2002Tpu>) -> bool {
let data = unsafe { NonNull::new_unchecked(Arc::as_ptr(hw) as *mut ()) };
if ax_runtime::hal::irq::request_shared_irq(TPU_TDMA_IRQ, tpu_tdma_irq_handler, data).is_err() {
warn!("[TPU] failed to register tdma irq {}", TPU_TDMA_IRQ);
return false;
}
ax_runtime::hal::irq::set_enable(TPU_TDMA_IRQ, true);
true
}
unsafe fn tpu_tdma_irq_handler(
_ctx: ax_runtime::hal::irq::IrqContext,
data: NonNull<()>,
) -> ax_runtime::hal::irq::IrqReturn {
let hw = unsafe { &*(data.as_ptr() as *const Sg2002Tpu) };
if hw.handle_irq() {
warn!("[TPU] tdma irq {} reports error status", TPU_TDMA_IRQ);
}
IRQ_WQ.notify_all(false);
ax_runtime::hal::irq::IrqReturn::Handled
}
fn tpu_wait_irq(timeout_us: u64) -> bool {
let hw = HW_PTR.load(Ordering::Acquire);
if hw.is_null() {
return false;
}
let hw = unsafe { &*hw };
!IRQ_WQ.wait_timeout_until(Duration::from_micros(timeout_us), || hw.irq_pending())
}
fn tpu_worker(hw: Arc<Sg2002Tpu>) {
info!("[TPU] worker thread started");
loop {
let mut task = loop {
if let Some(task) = TASK_LIST.lock().pop_front() {
break task;
}
TASK_WQ.wait_until(|| !TASK_LIST.lock().is_empty());
};
task.ret = hw
.run_one(task.seq_no, task.vaddr, task.paddr)
.map_or(-1, |_| 0);
{
let mut done = DONE_LIST.lock();
done.push_back(task);
while done.len() > DONE_LIST_MAX {
let dropped = done.pop_front();
if let Some(t) = dropped {
warn!(
"[TPU] done list full, dropping orphaned result (tid={}, seq_no={})",
t.tid, t.seq_no
);
}
}
}
DONE_WQ.notify_all(false);
}
}
impl TpuDevice {
pub unsafe fn new() -> Self {
let hw = Arc::new(unsafe { Sg2002Tpu::new() });
Self::setup(hw)
}
#[allow(dead_code)]
pub unsafe fn from_vaddr(tdma_vaddr: *mut u8, tiu_vaddr: *mut u8) -> Self {
let hw = Arc::new(unsafe { Sg2002Tpu::from_vaddr(tdma_vaddr, tiu_vaddr) });
Self::setup(hw)
}
fn setup(hw: Arc<Sg2002Tpu>) -> Self {
hw.set_wait_irq_fn(tpu_wait_irq);
if let Err(err) = hw.init() {
warn!("[TPU] init failed: {:?}", err);
}
let irq_registered = register_tpu_irq(&hw);
if WORKER_SPAWNED
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
HW_PTR.store(Arc::as_ptr(&hw) as *mut Sg2002Tpu, Ordering::Release);
let worker_hw = hw.clone();
ax_task::spawn_with_name(move || tpu_worker(worker_hw), String::from("tpu-worker"));
}
Self { hw, irq_registered }
}
fn submit_dmabuf(&self, arg: usize) -> Result<usize, TpuError> {
let submit_arg = unsafe { &*(arg as *const CviSubmitDmaArg) };
debug!(
"[TPU] submit dmabuf: fd={}, seq_no={}",
submit_arg.fd, submit_arg.seq_no
);
if !self.irq_registered {
warn!(
"[TPU] tdma irq {} not registered, execution may timeout",
TPU_TDMA_IRQ
);
}
let fd = submit_arg.fd;
let file = get_file_like(fd).map_err(|_| {
error!("[TPU] Failed to get file for fd={}", fd);
TpuError::InvalidDmabuf
})?;
let ion_file: Arc<IonBufferFile> = file.downcast_arc::<IonBufferFile>().map_err(|_| {
error!("[TPU] fd={} is not an IonBufferFile", fd);
TpuError::InvalidDmabuf
})?;
let buffer = ion_file.buffer().clone();
debug!(
"[TPU] dmabuf info: handle={}, size={}, paddr=0x{:x}",
buffer.handle.as_u32(),
buffer.size,
buffer.dma_info.bus_addr.as_u64()
);
let task = TpuTask {
tid: ax_task::current().id().as_u64(),
seq_no: submit_arg.seq_no,
vaddr: buffer.dma_info.cpu_addr.as_ptr() as usize,
paddr: buffer.dma_info.bus_addr.as_u64(),
_buffer: buffer,
ret: 0,
};
TASK_LIST.lock().push_back(task);
TASK_WQ.notify_one(true);
Ok(0)
}
fn wait_dmabuf(&self, arg: usize) -> Result<usize, TpuError> {
let wait_arg = unsafe { &mut *(arg as *mut CviWaitDmaArg) };
let seq_no = wait_arg.seq_no;
let tid = ax_task::current().id().as_u64();
let timed_out = DONE_WQ.wait_timeout_until(TPU_WAIT_TIMEOUT, || {
DONE_LIST
.lock()
.iter()
.any(|t| t.tid == tid && t.seq_no == seq_no)
});
let found = {
let mut done = DONE_LIST.lock();
done.iter()
.position(|t| t.tid == tid && t.seq_no == seq_no)
.map(|idx| done.remove(idx).unwrap())
};
match found {
Some(task) => {
wait_arg.ret = task.ret;
if task.ret != 0 {
return Err(TpuError::Timeout);
}
Ok(0)
}
None => {
wait_arg.ret = -1;
warn!(
"[TPU] wait dmabuf: (tid={}, seq_no={}) not found (timed_out={})",
tid, seq_no, timed_out
);
Err(TpuError::Timeout)
}
}
}
fn cache_flush(&self, arg: usize) -> Result<usize, TpuError> {
let flush_arg = unsafe { &*(arg as *const CviCacheOpArg) };
self.hw.cache_flush_paddr(flush_arg.paddr, flush_arg.size)?;
Ok(0)
}
fn cache_invalidate(&self, arg: usize) -> Result<usize, TpuError> {
let invalidate_arg = unsafe { &*(arg as *const CviCacheOpArg) };
self.hw
.cache_invalidate_paddr(invalidate_arg.paddr, invalidate_arg.size)?;
Ok(0)
}
fn dmabuf_flush_fd(&self, arg: usize) -> Result<usize, TpuError> {
let fd = arg as i32;
debug!("TPU dmabuf flush fd: {}", fd);
let buffer = self.lookup_ion_buffer(fd)?;
let paddr = buffer.dma_info.bus_addr.as_u64();
let size = buffer.size as u64;
self.hw.cache_flush_paddr(paddr, size)?;
debug!("Flushed buffer: paddr=0x{:x}, size={}", paddr, size);
Ok(0)
}
fn dmabuf_invld_fd(&self, arg: usize) -> Result<usize, TpuError> {
let fd = arg as i32;
debug!("TPU dmabuf invalidate fd: {}", fd);
let buffer = self.lookup_ion_buffer(fd)?;
let paddr = buffer.dma_info.bus_addr.as_u64();
let size = buffer.size as u64;
self.hw.cache_invalidate_paddr(paddr, size)?;
Ok(0)
}
fn lookup_ion_buffer(&self, fd: i32) -> Result<Arc<IonBuffer>, TpuError> {
let file = get_file_like(fd).map_err(|err| {
error!("[TPU] failed to get file for fd={}: {:?}", fd, err);
TpuError::InvalidDmabuf
})?;
let ion_file: Arc<IonBufferFile> = file.downcast_arc::<IonBufferFile>().map_err(|_| {
error!("[TPU] fd={} is not an IonBufferFile", fd);
TpuError::InvalidDmabuf
})?;
Ok(ion_file.buffer().clone())
}
}
impl DeviceOps for TpuDevice {
fn read_at(&self, _buf: &mut [u8], _offset: u64) -> axfs_ng_vfs::VfsResult<usize> {
Ok(0)
}
fn write_at(&self, _buf: &[u8], _offset: u64) -> axfs_ng_vfs::VfsResult<usize> {
Ok(0)
}
fn ioctl(&self, cmd: u32, arg: usize) -> axfs_ng_vfs::VfsResult<usize> {
debug!("TPU ioctl: cmd=0x{:x}, arg=0x{:x}", cmd, arg);
let result = match cmd {
CVITPU_SUBMIT_DMABUF => self.submit_dmabuf(arg),
CVITPU_DMABUF_FLUSH_FD => self.dmabuf_flush_fd(arg),
CVITPU_DMABUF_INVLD_FD => self.dmabuf_invld_fd(arg),
CVITPU_DMABUF_FLUSH => self.cache_flush(arg),
CVITPU_DMABUF_INVLD => self.cache_invalidate(arg),
CVITPU_WAIT_DMABUF => self.wait_dmabuf(arg),
CVITPU_PIO_MODE => {
warn!("TPU PIO mode not implemented");
Ok(0)
}
CVITPU_LOAD_TEE | CVITPU_SUBMIT_TEE | CVITPU_UNLOAD_TEE => {
warn!("TPU TEE operations not supported");
Err(TpuError::NotInitialized)
}
_ => {
warn!("Unknown TPU ioctl command: 0x{:x}", cmd);
Err(TpuError::NotInitialized)
}
};
match result {
Ok(v) => Ok(v),
Err(e) => {
error!("TPU ioctl error: {:?}", e);
Err(ax_errno::AxError::Unsupported)
}
}
}
fn as_any(&self) -> &dyn core::any::Any {
self
}
}