use std::sync::{Arc, Mutex};
use std::collections::{VecDeque, HashMap};
use crate::error::{Error, Result};
use crate::communication::{CommunicationChannel, RpcChannel};
use crate::communication::memory::SharedMemoryRegion;
use crate::utils::logging;
pub struct MemoryChannelConfig {
pub name: String,
pub data_size: usize,
pub control_size: usize,
pub queue_capacity: usize,
}
impl Default for MemoryChannelConfig {
fn default() -> Self {
Self {
name: "default".to_string(),
data_size: 64 * 1024, control_size: 1024, queue_capacity: 32, }
}
}
#[repr(u8)]
#[derive(PartialEq)]
pub enum ControlFlag {
Ready = 1,
Reading = 2,
WritingReady = 3,
Writing = 4,
Closed = 5,
}
pub struct MemoryChannel {
name: String,
data_region: Arc<SharedMemoryRegion>,
control_region: Arc<SharedMemoryRegion>,
message_queue: Mutex<VecDeque<(usize, usize)>>,
closed: Mutex<bool>,
capacity: usize,
}
impl MemoryChannel {
pub fn new(
config: &MemoryChannelConfig,
data_region: Arc<SharedMemoryRegion>,
control_region: Arc<SharedMemoryRegion>,
) -> Self {
Self {
name: config.name.clone(),
data_region,
control_region,
message_queue: Mutex::new(VecDeque::with_capacity(config.queue_capacity)),
closed: Mutex::new(false),
capacity: config.queue_capacity,
}
}
pub fn name(&self) -> &str {
&self.name
}
fn set_control_flag(&self, flag: ControlFlag) -> Result<()> {
let mut buffer = [0u8; 1];
buffer[0] = flag as u8;
self.control_region.write(0, &buffer)?;
Ok(())
}
fn get_control_flag(&self) -> Result<ControlFlag> {
let mut buffer = [0u8; 1];
self.control_region.read(0, &mut buffer)?;
match buffer[0] {
1 => Ok(ControlFlag::Ready),
2 => Ok(ControlFlag::Reading),
3 => Ok(ControlFlag::WritingReady),
4 => Ok(ControlFlag::Writing),
5 => Ok(ControlFlag::Closed),
_ => Err(Error::Communication("Invalid control flag".to_string())),
}
}
fn write_message_length(&self, length: usize) -> Result<()> {
let length_bytes = (length as u32).to_le_bytes();
self.control_region.write(1, &length_bytes)?;
Ok(())
}
fn read_message_length(&self) -> Result<usize> {
let mut length_bytes = [0u8; 4];
self.control_region.read(1, &mut length_bytes)?;
Ok(u32::from_le_bytes(length_bytes) as usize)
}
}
impl CommunicationChannel for MemoryChannel {
fn send_to_guest(&self, message: &[u8]) -> Result<()> {
if *self.closed.lock().unwrap() {
return Err(Error::Communication("Channel is closed".to_string()));
}
let mut queue = self.message_queue.lock().unwrap();
if queue.len() >= self.capacity {
return Err(Error::Communication("Channel is full".to_string()));
}
let mut retries = 0;
while self.get_control_flag()? != ControlFlag::WritingReady && retries < 100 {
std::thread::sleep(std::time::Duration::from_millis(10));
retries += 1;
}
if retries >= 100 {
return Err(Error::Communication("Timeout waiting for channel to be ready".to_string()));
}
self.set_control_flag(ControlFlag::Writing)?;
self.write_message_length(message.len())?;
let offset = if let Some((last_offset, last_len)) = queue.back() {
last_offset + last_len
} else {
0
};
self.data_region.write(offset, message)?;
queue.push_back((offset, message.len()));
self.set_control_flag(ControlFlag::Ready)?;
logging::log_communication_event(&self.name, "sent", message.len());
Ok(())
}
fn receive_from_guest(&self) -> Result<Vec<u8>> {
if *self.closed.lock().unwrap() {
return Err(Error::Communication("Channel is closed".to_string()));
}
let mut retries = 0;
while self.get_control_flag()? != ControlFlag::Ready && retries < 100 {
std::thread::sleep(std::time::Duration::from_millis(10));
retries += 1;
}
if retries >= 100 {
return Err(Error::Communication("Timeout waiting for message".to_string()));
}
self.set_control_flag(ControlFlag::Reading)?;
let length = self.read_message_length()?;
let mut message = vec![0u8; length];
self.data_region.read(0, &mut message)?;
self.set_control_flag(ControlFlag::WritingReady)?;
logging::log_communication_event(&self.name, "received", message.len());
Ok(message)
}
fn has_messages(&self) -> bool {
if *self.closed.lock().unwrap() {
return false;
}
if let Ok(flag) = self.get_control_flag() {
return flag == ControlFlag::Ready;
}
false
}
fn close(&self) -> Result<()> {
*self.closed.lock().unwrap() = true;
self.set_control_flag(ControlFlag::Closed)?;
Ok(())
}
}
pub struct MemoryRpcChannel {
channel: Arc<MemoryChannel>,
functions: Mutex<HashMap<String, Box<dyn Fn(&[u8]) -> Result<Vec<u8>> + Send + Sync>>>,
}
impl MemoryRpcChannel {
pub fn new(channel: Arc<MemoryChannel>) -> Self {
Self {
channel,
functions: Mutex::new(HashMap::new()),
}
}
}
impl RpcChannel for MemoryRpcChannel {
fn register_host_function_json(
&mut self,
name: &str,
function: Box<dyn Fn(&str) -> Result<String> + Send + Sync + 'static>,
) -> Result<()> {
let name = name.to_string();
let func = Box::new(move |data: &[u8]| -> Result<Vec<u8>> {
let params_json = String::from_utf8_lossy(data);
let result_json = function(¶ms_json)?;
Ok(result_json.into_bytes())
});
self.functions.lock().unwrap().insert(name, func);
Ok(())
}
fn call_guest_function_json(
&self,
function_name: &str,
params_json: &str,
) -> Result<String> {
let mut message = Vec::with_capacity(function_name.len() + params_json.len() + 5);
message.push(function_name.len() as u8);
message.extend_from_slice(function_name.as_bytes());
message.extend_from_slice(params_json.as_bytes());
self.channel.send_to_guest(&message)?;
let response_bytes = self.channel.receive_from_guest()?;
let response = String::from_utf8_lossy(&response_bytes).to_string();
Ok(response)
}
fn register_host_function_msgpack(
&mut self,
name: &str,
function: Box<dyn Fn(&[u8]) -> Result<Vec<u8>> + Send + Sync + 'static>,
) -> Result<()> {
let name = name.to_string();
self.functions.lock().unwrap().insert(name, function);
Ok(())
}
fn call_guest_function_msgpack(
&self,
function_name: &str,
params_msgpack: &[u8],
) -> Result<Vec<u8>> {
let mut message = Vec::with_capacity(function_name.len() + params_msgpack.len() + 5);
message.push(function_name.len() as u8);
message.extend_from_slice(function_name.as_bytes());
message.extend_from_slice(params_msgpack);
self.channel.send_to_guest(&message)?;
let response_bytes = self.channel.receive_from_guest()?;
Ok(response_bytes)
}
}