use alloc::{collections::VecDeque, string::ToString};
use core::{
sync::atomic::{AtomicBool, AtomicU64, Ordering},
time::Duration,
};
use ax_driver::prelude::*;
use ax_errno::{AxError, AxResult, ax_bail};
use ax_sync::Mutex;
use ax_task::future::{block_on, interruptible};
use crate::vsock::connection_manager::VSOCK_CONN_MANAGER;
static VSOCK_DEVICE: Mutex<Option<AxVsockDevice>> = Mutex::new(None);
static PENDING_EVENTS: Mutex<VecDeque<VsockDriverEvent>> = Mutex::new(VecDeque::new());
const VSOCK_RX_TMPBUF_SIZE: usize = 0x1000;
pub fn register_vsock_device(dev: AxVsockDevice) -> AxResult {
let mut guard = VSOCK_DEVICE.lock();
if guard.is_some() {
ax_bail!(AlreadyExists, "vsock device already registered");
}
*guard = Some(dev);
drop(guard);
Ok(())
}
static POLL_REF_COUNT: Mutex<usize> = Mutex::new(0);
static POLL_TASK_RUNNING: AtomicBool = AtomicBool::new(false);
static POLL_FREQUENCY: PollFrequencyController = PollFrequencyController::new();
struct PollFrequencyController {
consecutive_idle: AtomicU64,
}
impl PollFrequencyController {
const fn new() -> Self {
Self {
consecutive_idle: AtomicU64::new(0),
}
}
fn current_interval(&self) -> Duration {
let idle = self.consecutive_idle.load(Ordering::Relaxed);
let interval_us = match idle {
0..=3 => 100, 4..=10 => 500, 11..=20 => 2_000, _ => 10_000, };
Duration::from_micros(interval_us)
}
fn on_event(&self) {
self.consecutive_idle.store(0, Ordering::Release);
}
fn on_idle(&self) {
self.consecutive_idle.fetch_add(1, Ordering::Relaxed);
}
fn stats(&self) -> (u64, u64) {
let idle = self.consecutive_idle.load(Ordering::Relaxed);
let interval = self.current_interval().as_micros() as u64;
(idle, interval)
}
}
pub fn start_vsock_poll() {
let mut count = POLL_REF_COUNT.lock();
*count += 1;
let new_count = *count;
debug!("start_vsock_poll: ref_count -> {}", new_count);
if new_count == 1 {
if !POLL_TASK_RUNNING.swap(true, Ordering::SeqCst) {
drop(count);
debug!("Starting vsock poll task");
ax_task::spawn_with_name(vsock_poll_loop, "vsock-poll".to_string());
} else {
warn!("Poll task already running!");
}
}
}
pub fn stop_vsock_poll() {
let mut count = POLL_REF_COUNT.lock();
if *count == 0 {
warn!("stop_vsock_poll called but ref_count already 0");
return;
}
*count -= 1;
let new_count = *count;
debug!("stop_vsock_poll: ref_count -> {new_count}");
}
fn vsock_poll_loop() {
loop {
let ref_count = *POLL_REF_COUNT.lock();
if ref_count == 0 {
POLL_TASK_RUNNING.store(false, Ordering::SeqCst);
debug!("Vsock poll task exiting (no active connections)");
break;
}
let _ = block_on(interruptible(poll_interfaces_adaptive()));
}
}
async fn poll_interfaces_adaptive() -> AxResult<()> {
let has_events = poll_vsock_interfaces()?;
if has_events {
POLL_FREQUENCY.on_event();
} else {
POLL_FREQUENCY.on_idle();
}
let interval = POLL_FREQUENCY.current_interval();
let (idle_count, interval_us) = POLL_FREQUENCY.stats();
if idle_count > 0 && idle_count % 10 == 0 {
trace!("Poll frequency: idle_count={idle_count}, interval={interval_us}μs",);
}
ax_task::future::sleep(interval).await;
Ok(())
}
fn poll_vsock_interfaces() -> AxResult<bool> {
let mut guard = VSOCK_DEVICE.lock();
let dev = guard.as_mut().ok_or(AxError::NotFound)?;
let mut event_count = 0;
let mut buf = alloc::vec![0; VSOCK_RX_TMPBUF_SIZE];
let pending_events = core::mem::take(&mut *PENDING_EVENTS.lock());
for event in pending_events {
handle_vsock_event(event, dev, &mut buf);
}
loop {
match dev.poll_event() {
Ok(None) => break, Ok(Some(event)) => {
event_count += 1;
handle_vsock_event(event, dev, &mut buf);
}
Err(e) => {
info!("Failed to poll vsock event: {e:?}");
break;
}
}
}
Ok(event_count > 0)
}
fn handle_vsock_event(event: VsockDriverEvent, dev: &mut AxVsockDevice, buf: &mut [u8]) {
let mut manager = VSOCK_CONN_MANAGER.lock();
debug!("Handling vsock event: {event:?}");
match event {
VsockDriverEvent::ConnectionRequest(conn_id) => {
if let Err(e) = manager.on_connection_request(conn_id) {
info!("Connection request failed: {conn_id:?}, error={e:?}");
}
}
VsockDriverEvent::Received(conn_id, len) => {
let free_space = if let Some(conn) = manager.get_connection(conn_id) {
conn.lock().rx_buffer_free()
} else {
info!("Received data for unknown connection: {conn_id:?}");
return;
};
if free_space == 0 {
PENDING_EVENTS
.lock()
.push_back(VsockDriverEvent::Received(conn_id, len));
return;
}
let max_read = core::cmp::min(free_space, buf.len());
match dev.recv(conn_id, &mut buf[..max_read]) {
Ok(read_len) => {
if let Err(e) = manager.on_data_received(conn_id, &buf[..read_len]) {
info!("Failed to handle received data: conn_id={conn_id:?}, error={e:?}",);
}
}
Err(e) => {
info!("Failed to receive vsock data: conn_id={conn_id:?}, error={e:?}",);
}
}
}
VsockDriverEvent::Disconnected(conn_id) => {
if let Err(e) = manager.on_disconnected(conn_id) {
info!("Failed to handle disconnection: {conn_id:?}, error={e:?}",);
}
}
VsockDriverEvent::Connected(conn_id) => {
if let Err(e) = manager.on_connected(conn_id) {
info!("Failed to handle connection established: {conn_id:?}, error={e:?}",);
}
}
VsockDriverEvent::CreditUpdate(conn_id) => {
if let Err(e) = manager.on_credit_update(conn_id) {
warn!(
"Failed to handle credit update: {:?}, error={:?}",
conn_id, e
);
}
}
VsockDriverEvent::Unknown => warn!("Received unknown vsock event"),
}
}
pub fn vsock_listen(addr: VsockAddr) -> AxResult<()> {
let mut guard = VSOCK_DEVICE.lock();
let dev = guard.as_mut().ok_or(AxError::NotFound)?;
dev.listen(addr.port);
Ok(())
}
fn map_dev_err(e: DevError) -> AxError {
match e {
DevError::AlreadyExists => AxError::AlreadyExists,
DevError::Again => AxError::WouldBlock,
DevError::InvalidParam => AxError::InvalidInput,
DevError::Io => AxError::Io,
_ => AxError::BadState,
}
}
pub fn vsock_connect(conn_id: VsockConnId) -> AxResult<()> {
let mut guard = VSOCK_DEVICE.lock();
let dev = guard.as_mut().ok_or(AxError::NotFound)?;
dev.connect(conn_id).map_err(map_dev_err)
}
pub fn vsock_send(conn_id: VsockConnId, buf: &[u8]) -> AxResult<usize> {
let max_retries = 10; for _ in 0..max_retries {
let result = {
let mut guard = VSOCK_DEVICE.lock();
let dev = guard.as_mut().ok_or(AxError::NotFound)?;
dev.send(conn_id, buf)
};
match result {
Ok(len) => return Ok(len),
Err(DevError::Again) => {
let manager = VSOCK_CONN_MANAGER.lock();
if let Some(conn) = manager.get_connection(conn_id) {
drop(manager);
conn.lock().wait_for_tx();
};
}
Err(e) => return Err(map_dev_err(e)),
}
}
Err(map_dev_err(DevError::Again))
}
pub fn vsock_disconnect(conn_id: VsockConnId) -> AxResult<()> {
let mut guard = VSOCK_DEVICE.lock();
let dev = guard.as_mut().ok_or(AxError::NotFound)?;
dev.disconnect(conn_id).map_err(map_dev_err)
}
pub fn vsock_guest_cid() -> AxResult<u64> {
let mut guard = VSOCK_DEVICE.lock();
let dev = guard.as_mut().ok_or(AxError::NotFound)?;
Ok(dev.guest_cid())
}