arceos-rust-interface 1.0.0

Arceos ABI definitions
use alloc::{collections::BTreeMap, sync::Arc};
use core::time::Duration;

use ax_api::modules::{ax_hal::time::monotonic_time, ax_sync::Mutex, ax_task::WaitQueue};
use ax_errno::LinuxError;
use ax_posix_api::ctypes::{FUTEX_RELATIVE_TIMEOUT, timespec};
use log::{info, trace};

use crate::err;

static FUTEX_TABLE: Mutex<BTreeMap<usize, Arc<WaitQueue>>> = Mutex::new(BTreeMap::new());

/// If the value at address matches the expected value, park the current thread until it is either
/// woken up with [`futex_wake`] (returns 0) or an optional timeout elapses (returns -ETIMEDOUT).
///
/// Setting `timeout` to null means the function will only return if [`futex_wake`] is called.
/// Otherwise, `timeout` is interpreted as an absolute time measured with [`CLOCK_MONOTONIC`].
/// If [`FUTEX_RELATIVE_TIMEOUT`] is set in `flags` the timeout is understood to be relative
/// to the current time.
///
/// Returns -EINVAL if `address` is null, the timeout is negative or `flags` contains unknown values.
#[unsafe(no_mangle)]
pub fn sys_futex_wait(
    address: *mut u32,
    expected: u32,
    timeout: *const timespec,
    flags: u32,
) -> i32 {
    if flags & !FUTEX_RELATIVE_TIMEOUT != 0 {
        // other flags are not supported
        return err(LinuxError::EINVAL);
    }
    let Some(value) = (unsafe { address.as_ref() }) else {
        return err(LinuxError::EINVAL);
    };
    if *value != expected {
        return err(LinuxError::EAGAIN);
    }
    let wait_queue = {
        let mut table = FUTEX_TABLE.lock();
        table
            .entry(address as usize)
            .or_insert_with(|| Arc::new(WaitQueue::new()))
            .clone()
    };
    trace!(
        "futex wait on address {:p} with expected value {}",
        address, expected
    );
    if let Some(timeout) = unsafe { timeout.as_ref() } {
        trace!("called sys_futex_wait with timeout: {:?}", timeout);
        let timeout = Duration::new(
            timeout.tv_sec.saturating_cast_unsigned(),
            timeout.tv_nsec.saturating_cast_unsigned(),
        );
        let duration = if flags & FUTEX_RELATIVE_TIMEOUT != 0 {
            timeout
        } else {
            let now = monotonic_time();
            let Some(duration) = timeout.checked_sub(now) else {
                return err(LinuxError::ETIMEDOUT);
            };
            duration
        };
        // relative timeout
        if wait_queue.wait_timeout(duration) {
            // timeout
            return err(LinuxError::ETIMEDOUT);
        }
    } else {
        trace!("called sys_futex_wait without timeout");
        wait_queue.wait();
    }
    0
}

/// Wake `count` threads waiting on the futex at `address`. Returns the number of threads
/// woken up (saturates to `i32::MAX`). If `count` is `i32::MAX`, wake up all matching
/// waiting threads. If `count` is negative or `address` is null, returns -EINVAL.
#[unsafe(no_mangle)]
pub fn sys_futex_wake(address: *mut u32, count: i32) -> i32 {
    info!(
        "called sys_futex_wake with address {:p} and count {}",
        address, count
    );
    if count < 0 {
        return err(LinuxError::EINVAL);
    }
    let wait_queue = {
        let table = FUTEX_TABLE.lock();
        match table.get(&(address as usize)) {
            Some(queue) => queue.clone(),
            None => return 0,
        }
    };
    let mut woken_count = 0;
    for _ in 0..count {
        if wait_queue.notify_one(true) {
            woken_count += 1;
        } else {
            break;
        }
    }
    trace!("futex woke {} threads", woken_count);
    woken_count as _
}