use alloc::{
collections::{BTreeMap, btree_map::Entry},
format,
vec::Vec,
};
use ax_errno::AxResult;
use ax_kspin::SpinNoIrq as Mutex;
use ax_page_table_multiarch::PagingHandler;
use crate::{GuestPhysAddr, HostPhysAddr};
type HostIVCChannel = IVCChannel<crate::HostPagingHandler>;
static IVC_CHANNELS: Mutex<BTreeMap<(usize, usize), HostIVCChannel>> = Mutex::new(BTreeMap::new());
pub const MAX_IVC_CHANNEL_SIZE: usize = 4096;
pub fn insert_channel(publisher_vm_id: usize, channel: HostIVCChannel) -> AxResult<()> {
let mut channels = IVC_CHANNELS.lock();
let channel_key = (publisher_vm_id, channel.key);
match channels.entry(channel_key) {
Entry::Vacant(entry) => {
entry.insert(channel);
Ok(())
}
Entry::Occupied(_) => Err(ax_errno::ax_err_type!(
AlreadyExists,
"IVC channel already exists"
)),
}
}
pub fn ensure_channel_absent(publisher_vm_id: usize, key: usize) -> AxResult<()> {
if IVC_CHANNELS.lock().contains_key(&(publisher_vm_id, key)) {
Err(ax_errno::ax_err_type!(
AlreadyExists,
format!(
"IVC channel for publisher VM {} with key {} already exists",
publisher_vm_id, key
)
))
} else {
Ok(())
}
}
pub fn unpublish_channel(publisher_vm_id: usize, key: usize) -> AxResult<(GuestPhysAddr, usize)> {
let mut channels = IVC_CHANNELS.lock();
let channel_key = (publisher_vm_id, key);
let channel = channels.get_mut(&channel_key).ok_or_else(|| {
ax_errno::ax_err_type!(
NotFound,
format!(
"IVC channel for publisher VM {} with key {} not found",
publisher_vm_id, key
)
)
})?;
let base_gpa = channel.base_gpa_in_publisher().ok_or_else(|| {
ax_errno::ax_err_type!(
NotFound,
format!(
"IVC channel for publisher VM {} with key {} has no base GPA, it may have been \
marked as unpublished",
publisher_vm_id, key
)
)
})?;
let size = channel.size();
if channel.has_subscribers() {
channel.mark_unpublished();
} else {
channels.remove(&channel_key);
}
Ok((base_gpa, size))
}
pub fn prepare_subscribe_channel(
publisher_vm_id: usize,
key: usize,
subscriber_vm_id: usize,
) -> AxResult<usize> {
let channels = IVC_CHANNELS.lock();
let channel = channels.get(&(publisher_vm_id, key)).ok_or_else(|| {
ax_errno::ax_err_type!(
NotFound,
format!(
"IVC channel for publisher VM {} with key {} not found",
publisher_vm_id, key
)
)
})?;
if channel.is_unpublished() {
return Err(ax_errno::ax_err_type!(
NotFound,
format!(
"IVC channel for publisher VM {} with key {} has been unpublished",
publisher_vm_id, key
)
));
}
if channel.has_subscriber(subscriber_vm_id) {
return Err(ax_errno::ax_err_type!(
AlreadyExists,
format!(
"VM[{}] has already subscribed to publisher VM[{}] Key {:#x}",
subscriber_vm_id, publisher_vm_id, key
)
));
}
Ok(channel.size())
}
pub fn subscribe_to_channel_of_publisher(
publisher_vm_id: usize,
key: usize,
subscriber_vm_id: usize,
subscriber_gpa: GuestPhysAddr,
) -> AxResult<(HostPhysAddr, usize)> {
let mut channels = IVC_CHANNELS.lock();
let channel = channels.get_mut(&(publisher_vm_id, key)).ok_or_else(|| {
ax_errno::ax_err_type!(
NotFound,
format!(
"IVC channel for publisher VM [{}] key {:#x} not found",
publisher_vm_id, key
)
)
})?;
if channel.is_unpublished() {
return Err(ax_errno::ax_err_type!(
NotFound,
format!(
"IVC channel for publisher VM [{}] key {:#x} has been unpublished",
publisher_vm_id, key
)
));
}
if channel.has_subscriber(subscriber_vm_id) {
return Err(ax_errno::ax_err_type!(
AlreadyExists,
format!(
"VM[{}] has already subscribed to publisher VM[{}] Key {:#x}",
subscriber_vm_id, publisher_vm_id, key
)
));
}
channel.add_subscriber(subscriber_vm_id, subscriber_gpa);
Ok((channel.base_hpa(), channel.size()))
}
pub fn unsubscribe_from_channel_of_publisher(
publisher_vm_id: usize,
key: usize,
subscriber_vm_id: usize,
) -> AxResult<(GuestPhysAddr, usize)> {
let mut channels = IVC_CHANNELS.lock();
let (base_gpa, size) = if let Some(channel) = channels.get_mut(&(publisher_vm_id, key)) {
if let Some(subscriber_gpa) = channel.remove_subscriber(subscriber_vm_id) {
Ok((subscriber_gpa, channel.size()))
} else {
Err(ax_errno::ax_err_type!(
NotFound,
format!(
"VM[{}] tries to unsubscribe non-existed channel publisher VM[{}] Key {:#x}",
subscriber_vm_id, publisher_vm_id, key
)
))
}
} else {
Err(ax_errno::ax_err_type!(
NotFound,
format!("IVC channel for publisher VM {} not found", publisher_vm_id)
))
}?;
if channels
.get(&(publisher_vm_id, key))
.is_some_and(|c| c.subscribers().is_empty() && c.is_unpublished())
{
channels.remove(&(publisher_vm_id, key));
}
Ok((base_gpa, size))
}
pub struct IVCChannel<H: PagingHandler> {
publisher_vm_id: usize,
key: usize,
subscriber_vms: BTreeMap<usize, GuestPhysAddr>,
shared_region_base: HostPhysAddr,
shared_region_size: usize,
base_gpa: Option<GuestPhysAddr>,
_phantom: core::marker::PhantomData<H>,
}
#[repr(C)]
pub struct IVCChannelHeader {
pub publisher_id: u64,
pub key: u64,
}
impl<H: PagingHandler> IVCChannel<H> {
#[allow(unused)]
pub fn header(&self) -> &IVCChannelHeader {
unsafe {
&*H::phys_to_virt(self.shared_region_base).as_mut_ptr_of::<IVCChannelHeader>()
}
}
pub fn header_mut(&mut self) -> &mut IVCChannelHeader {
unsafe {
&mut *H::phys_to_virt(self.shared_region_base).as_mut_ptr_of::<IVCChannelHeader>()
}
}
#[allow(unused)]
pub fn data_region(&self) -> *const u8 {
unsafe {
H::phys_to_virt(self.shared_region_base)
.as_mut_ptr()
.add(core::mem::size_of::<IVCChannelHeader>())
}
}
}
impl<H: PagingHandler> core::fmt::Debug for IVCChannel<H> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"IVCChannel(publisher[{}], subscribers {:?}, base: {:?}, size: {:#x}, gpa: {:?})",
self.publisher_vm_id,
self.subscriber_vms,
self.shared_region_base,
self.shared_region_size,
self.base_gpa
)
}
}
impl<H: PagingHandler> Drop for IVCChannel<H> {
fn drop(&mut self) {
debug!(
"Dropping IVCChannel for VM[{}], shared region base: {:?}",
self.publisher_vm_id, self.shared_region_base
);
H::dealloc_frame(self.shared_region_base);
}
}
impl<H: PagingHandler> IVCChannel<H> {
pub fn alloc(
publisher_vm_id: usize,
key: usize,
shared_region_size: usize,
base_gpa: GuestPhysAddr,
) -> AxResult<Self> {
if shared_region_size > MAX_IVC_CHANNEL_SIZE {
warn!(
"IVC channel requested size {shared_region_size:#x} > {MAX_IVC_CHANNEL_SIZE:#x}; \
truncating to {MAX_IVC_CHANNEL_SIZE:#x} (TODO: support larger sizes)"
);
}
let shared_region_size = shared_region_size.min(MAX_IVC_CHANNEL_SIZE);
let shared_region_base = H::alloc_frame().ok_or_else(|| {
ax_errno::ax_err_type!(NoMemory, "Failed to allocate shared region frame")
})?;
let mut channel = IVCChannel {
publisher_vm_id,
key,
subscriber_vms: BTreeMap::new(),
shared_region_base,
shared_region_size,
base_gpa: Some(base_gpa),
_phantom: core::marker::PhantomData,
};
{
let header = channel.header_mut();
header.publisher_id = publisher_vm_id as u64;
header.key = key as u64;
}
debug!("Allocated IVCChannel: {channel:?}");
Ok(channel)
}
pub fn base_hpa(&self) -> HostPhysAddr {
self.shared_region_base
}
pub fn base_gpa_in_publisher(&self) -> Option<GuestPhysAddr> {
self.base_gpa
}
pub fn size(&self) -> usize {
self.shared_region_size
}
pub fn add_subscriber(&mut self, subscriber_vm_id: usize, subscriber_gpa: GuestPhysAddr) {
self.subscriber_vms.insert(subscriber_vm_id, subscriber_gpa);
}
pub fn remove_subscriber(&mut self, subscriber_vm_id: usize) -> Option<GuestPhysAddr> {
self.subscriber_vms.remove(&subscriber_vm_id)
}
pub fn subscribers(&self) -> Vec<(usize, GuestPhysAddr)> {
self.subscriber_vms
.iter()
.map(|(vm_id, gpa)| (*vm_id, *gpa))
.collect()
}
pub fn has_subscribers(&self) -> bool {
!self.subscriber_vms.is_empty()
}
pub fn has_subscriber(&self, subscriber_vm_id: usize) -> bool {
self.subscriber_vms.contains_key(&subscriber_vm_id)
}
pub fn mark_unpublished(&mut self) {
self.base_gpa = None;
}
pub fn is_unpublished(&self) -> bool {
self.base_gpa.is_none()
}
}