use alloc::{collections::BTreeMap, sync::Arc};
use core::time::Duration;
use ax_api::modules::{ax_hal::time::monotonic_time, ax_sync::Mutex, ax_task::WaitQueue};
use ax_errno::LinuxError;
use ax_posix_api::ctypes::{FUTEX_RELATIVE_TIMEOUT, timespec};
use log::{info, trace};
use crate::err;
static FUTEX_TABLE: Mutex<BTreeMap<usize, Arc<WaitQueue>>> = Mutex::new(BTreeMap::new());
#[unsafe(no_mangle)]
pub fn sys_futex_wait(
address: *mut u32,
expected: u32,
timeout: *const timespec,
flags: u32,
) -> i32 {
if flags & !FUTEX_RELATIVE_TIMEOUT != 0 {
return err(LinuxError::EINVAL);
}
let Some(value) = (unsafe { address.as_ref() }) else {
return err(LinuxError::EINVAL);
};
if *value != expected {
return err(LinuxError::EAGAIN);
}
let wait_queue = {
let mut table = FUTEX_TABLE.lock();
table
.entry(address as usize)
.or_insert_with(|| Arc::new(WaitQueue::new()))
.clone()
};
trace!(
"futex wait on address {:p} with expected value {}",
address, expected
);
if let Some(timeout) = unsafe { timeout.as_ref() } {
trace!("called sys_futex_wait with timeout: {:?}", timeout);
let timeout = Duration::new(
timeout.tv_sec.saturating_cast_unsigned(),
timeout.tv_nsec.saturating_cast_unsigned(),
);
let duration = if flags & FUTEX_RELATIVE_TIMEOUT != 0 {
timeout
} else {
let now = monotonic_time();
let Some(duration) = timeout.checked_sub(now) else {
return err(LinuxError::ETIMEDOUT);
};
duration
};
if wait_queue.wait_timeout(duration) {
return err(LinuxError::ETIMEDOUT);
}
} else {
trace!("called sys_futex_wait without timeout");
wait_queue.wait();
}
0
}
#[unsafe(no_mangle)]
pub fn sys_futex_wake(address: *mut u32, count: i32) -> i32 {
info!(
"called sys_futex_wake with address {:p} and count {}",
address, count
);
if count < 0 {
return err(LinuxError::EINVAL);
}
let wait_queue = {
let table = FUTEX_TABLE.lock();
match table.get(&(address as usize)) {
Some(queue) => queue.clone(),
None => return 0,
}
};
let mut woken_count = 0;
for _ in 0..count {
if wait_queue.notify_one(true) {
woken_count += 1;
} else {
break;
}
}
trace!("futex woke {} threads", woken_count);
woken_count as _
}