use std::sync::{Arc, Mutex};
use std::collections::{VecDeque, HashMap};
use serde::{Serialize, de::DeserializeOwned};
use crate::error::{Error, Result};
use crate::communication::{CommunicationChannel, RpcChannel};
use crate::utils::logging;
pub struct MemoryChannel {
name: String,
host_to_guest: Mutex<VecDeque<Vec<u8>>>,
guest_to_host: Mutex<VecDeque<Vec<u8>>>,
capacity: usize,
closed: Mutex<bool>,
}
impl MemoryChannel {
pub fn new(name: &str, capacity: usize) -> Self {
Self {
name: name.to_string(),
host_to_guest: Mutex::new(VecDeque::with_capacity(capacity)),
guest_to_host: Mutex::new(VecDeque::with_capacity(capacity)),
capacity,
closed: Mutex::new(false),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn capacity(&self) -> usize {
self.capacity
}
}
impl CommunicationChannel for MemoryChannel {
fn send_to_guest(&self, message: &[u8]) -> Result<()> {
if *self.closed.lock().unwrap() {
return Err(Error::Communication("Channel is closed".to_string()));
}
let mut queue = self.host_to_guest.lock().unwrap();
if queue.len() >= self.capacity {
return Err(Error::Communication("Channel is full".to_string()));
}
queue.push_back(message.to_vec());
logging::log_communication_event(&self.name, "sent", message.len());
Ok(())
}
fn receive_from_guest(&self) -> Result<Vec<u8>> {
if *self.closed.lock().unwrap() {
return Err(Error::Communication("Channel is closed".to_string()));
}
let mut queue = self.guest_to_host.lock().unwrap();
if let Some(message) = queue.pop_front() {
logging::log_communication_event(&self.name, "received", message.len());
Ok(message)
} else {
Err(Error::Communication("No messages available".to_string()))
}
}
fn has_messages(&self) -> bool {
if *self.closed.lock().unwrap() {
return false;
}
!self.guest_to_host.lock().unwrap().is_empty()
}
fn close(&self) -> Result<()> {
*self.closed.lock().unwrap() = true;
Ok(())
}
}
pub struct RpcChannel {
channel: Arc<dyn CommunicationChannel>,
host_functions: Mutex<HashMap<String, Box<dyn Fn(&[u8]) -> Result<Vec<u8>> + Send + Sync>>>,
}
impl RpcChannel {
pub fn new(channel: Arc<dyn CommunicationChannel>) -> Self {
Self {
channel,
host_functions: Mutex::new(HashMap::new()),
}
}
}
impl RpcChannel for RpcChannel {
fn register_host_function<F, Params, Return>(
&mut self,
name: &str,
function: F,
) -> Result<()>
where
F: Fn(Params) -> Result<Return> + Send + Sync + 'static,
Params: DeserializeOwned + 'static,
Return: Serialize + 'static,
{
let wrapper = move |data: &[u8]| -> Result<Vec<u8>> {
let params: Params = serde_json::from_slice(data)
.map_err(|e| Error::Communication(format!("Failed to deserialize parameters: {}", e)))?;
let result = function(params)?;
let result_data = serde_json::to_vec(&result)
.map_err(|e| Error::Communication(format!("Failed to serialize result: {}", e)))?;
Ok(result_data)
};
let mut host_functions = self.host_functions.lock().unwrap();
host_functions.insert(name.to_string(), Box::new(wrapper));
Ok(())
}
fn call_guest_function<Params, Return>(
&self,
function_name: &str,
params: &Params,
) -> Result<Return>
where
Params: Serialize + ?Sized,
Return: DeserializeOwned + 'static,
{
let mut message = Vec::new();
let name_len = function_name.len() as u16;
message.extend_from_slice(&name_len.to_le_bytes());
message.extend_from_slice(function_name.as_bytes());
let params_data = serde_json::to_vec(params)
.map_err(|e| Error::Communication(format!("Failed to serialize parameters: {}", e)))?;
let params_len = params_data.len() as u32;
message.extend_from_slice(¶ms_len.to_le_bytes());
message.extend_from_slice(¶ms_data);
self.channel.send_to_guest(&message)?;
let response_data = self.channel.receive_from_guest()?;
let response: Return = serde_json::from_slice(&response_data)
.map_err(|e| Error::Communication(format!("Failed to deserialize response: {}", e)))?;
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_channel() {
let channel = MemoryChannel::new("test", 10);
let message = b"Hello, world!";
assert!(channel.send_to_guest(message).is_ok());
assert!(!channel.has_messages());
{
let mut queue = channel.guest_to_host.lock().unwrap();
queue.push_back(b"Response".to_vec());
}
assert!(channel.has_messages());
let response = channel.receive_from_guest().unwrap();
assert_eq!(response, b"Response");
assert!(channel.close().is_ok());
assert!(channel.send_to_guest(message).is_err());
}
#[test]
fn test_rpc_channel() {
let memory_channel = Arc::new(MemoryChannel::new("rpc", 10));
let mut rpc_channel = RpcChannel::new(memory_channel.clone());
let result = rpc_channel.register_host_function("echo", |s: String| -> Result<String> {
Ok(s)
});
assert!(result.is_ok());
{
let mut queue = memory_channel.guest_to_host.lock().unwrap();
queue.push_back(br#""Hello, world!""#.to_vec());
}
let result: String = rpc_channel.call_guest_function("echo", &"Hello").unwrap();
assert_eq!(result, "Hello, world!");
}
}