use crate::{
common::{
channel::{PluginChannel, UpstreamChannel},
error::{inv_op, ErrorKind, Result},
protocol::{GatestreamDown, GatestreamUp, PluginToSimulator, SimulatorToPlugin},
},
trace,
};
use ipc_channel::ipc::{IpcOneShotServer, IpcReceiverSet, IpcSelectionResult, IpcSender};
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Copy, Clone)]
enum Incoming {
Simulator,
Upstream,
Downstream,
}
#[derive(Debug, PartialEq)]
pub enum IncomingMessage {
Simulator(SimulatorToPlugin),
Upstream(GatestreamDown),
Downstream(GatestreamUp),
}
#[derive(Debug, PartialEq)]
pub enum OutgoingMessage {
Simulator(PluginToSimulator),
Upstream(GatestreamUp),
Downstream(GatestreamDown),
}
pub struct Connection {
incoming: IpcReceiverSet,
incoming_map: HashMap<u64, Incoming>,
incoming_buffer: VecDeque<IncomingMessage>,
pending_upstream: Option<IpcOneShotServer<UpstreamChannel>>,
response: IpcSender<PluginToSimulator>,
upstream: Option<IpcSender<GatestreamUp>>,
downstream: Option<IpcSender<GatestreamDown>>,
}
impl Connection {
fn connect(simulator: impl Into<String>) -> Result<PluginChannel> {
let connect = IpcSender::connect(simulator.into())?;
let (request_tx, request) = ipc_channel::ipc::channel()?;
let (response, response_rx) = ipc_channel::ipc::channel()?;
connect.send((request_tx, response_rx))?;
Ok((response, request))
}
pub fn new(simulator: impl Into<String>) -> Result<Connection> {
let channel = Connection::connect(simulator)?;
let mut incoming = IpcReceiverSet::new()?;
let mut incoming_map = HashMap::with_capacity(2);
incoming_map.insert(incoming.add(channel.1)?, Incoming::Simulator);
Ok(Connection {
incoming,
incoming_map,
incoming_buffer: VecDeque::new(),
response: channel.0,
downstream: None,
pending_upstream: None,
upstream: None,
})
}
pub fn connect_downstream(&mut self, downstream: impl Into<String>) -> Result<()> {
if self.downstream.is_some() {
inv_op("already connected to a downstream plugin")?;
}
let downstream = IpcSender::connect(downstream.into())?;
let (down_tx, down_rx) = ipc_channel::ipc::channel()?;
let (up_tx, up_rx) = ipc_channel::ipc::channel()?;
downstream.send((up_tx, down_rx) as UpstreamChannel)?;
self.incoming_map
.insert(self.incoming.add(up_rx)?, Incoming::Downstream);
self.downstream.replace(down_tx);
Ok(())
}
pub fn serve_upstream(&mut self) -> Result<String> {
if self.pending_upstream.is_some() {
inv_op("already connecting to an upstream plugin")?;
} else if self.upstream.is_some() {
inv_op("already connected to an upstream plugin")?;
}
let (pending, address) = IpcOneShotServer::new()?;
self.pending_upstream.replace(pending);
Ok(address)
}
pub fn accept_upstream(&mut self) -> Result<()> {
if self.pending_upstream.is_none() {
inv_op("not yet connecting to an upstream plugin, call serve_upstream() first")?;
} else if self.upstream.is_some() {
inv_op("already connected to an upstream plugin")?;
}
let (_, upstream): (_, UpstreamChannel) = self.pending_upstream.take().unwrap().accept()?;
self.incoming_map
.insert(self.incoming.add(upstream.1)?, Incoming::Upstream);
self.upstream.replace(upstream.0);
Ok(())
}
fn downstream_ref(&self) -> Result<&IpcSender<GatestreamDown>> {
Ok(self
.downstream
.as_ref()
.ok_or_else(|| ErrorKind::IPCError("Downstream sender does not exist".to_string()))?)
}
fn upstream_ref(&self) -> Result<&IpcSender<GatestreamUp>> {
Ok(self
.upstream
.as_ref()
.ok_or_else(|| ErrorKind::IPCError("Upstream sender does not exist".to_string()))?)
}
pub fn send(&self, message: OutgoingMessage) -> Result<()> {
match message {
OutgoingMessage::Simulator(response) => self.response.send(response)?,
OutgoingMessage::Downstream(request) => self.downstream_ref()?.send(request)?,
OutgoingMessage::Upstream(response) => self.upstream_ref()?.send(response)?,
}
Ok(())
}
fn buffer_incoming(&mut self) -> Result<()> {
let mut received_any = false;
while !received_any && !self.incoming_map.is_empty() {
for event in self.incoming.select()? {
match event {
IpcSelectionResult::MessageReceived(id, msg) => {
if let Some(incoming) = self.incoming_map.get(&id) {
self.incoming_buffer.push_back(match incoming {
Incoming::Simulator => IncomingMessage::Simulator(msg.to()?),
Incoming::Upstream => IncomingMessage::Upstream(msg.to()?),
Incoming::Downstream => IncomingMessage::Downstream(msg.to()?),
});
received_any = true;
}
}
IpcSelectionResult::ChannelClosed(id) => {
trace!("Channel closed: {:?}", self.incoming_map.get(&id));
self.incoming_map.remove(&id);
}
}
}
}
Ok(())
}
pub fn next_request(&mut self) -> Result<Option<IncomingMessage>> {
if self.incoming_buffer.is_empty() {
self.buffer_incoming()?;
}
Ok(self.incoming_buffer.pop_front())
}
pub fn next_downstream_request(&mut self) -> Result<Option<IncomingMessage>> {
if let Some(idx) = self.incoming_buffer.iter().position(|msg| match msg {
IncomingMessage::Downstream(_) => true,
_ => false,
}) {
Ok(Some(self.incoming_buffer.remove(idx).unwrap()))
} else {
self.buffer_incoming()?;
if self.incoming_map.is_empty() {
Ok(None)
} else {
self.next_downstream_request()
}
}
}
}
impl Drop for Connection {
fn drop(&mut self) {
trace!("Dropping Connection");
}
}
#[cfg(test)]
mod tests {
use super::{Connection, IncomingMessage, OutgoingMessage};
use crate::common::{
channel::SimulatorChannel,
protocol::{PluginToSimulator, SimulatorToPlugin},
};
use ipc_channel::ipc::IpcOneShotServer;
#[test]
fn connect() {
let (server, server_name) = IpcOneShotServer::new().unwrap();
let plugin = std::thread::spawn(move || {
let channel = Connection::connect(server_name).unwrap();
let req = channel.1.recv();
assert!(req.is_ok());
assert_eq!(req.unwrap(), SimulatorToPlugin::Abort);
let res = channel.0.send(PluginToSimulator::Success);
assert!(res.is_ok());
});
let (_, channel): (_, SimulatorChannel) = server.accept().unwrap();
let req = channel.0.send(SimulatorToPlugin::Abort);
assert!(req.is_ok());
let res = channel.1.recv();
assert!(res.is_ok());
assert_eq!(res.unwrap(), PluginToSimulator::Success);
assert!(plugin.join().is_ok());
}
#[test]
fn simulator_connection() {
let (server, server_name) = IpcOneShotServer::new().unwrap();
let plugin = std::thread::spawn(move || {
let mut connection = Connection::new(server_name).unwrap();
let req = connection.next_request();
assert!(req.is_ok());
assert_eq!(
req.unwrap().unwrap(),
IncomingMessage::Simulator(SimulatorToPlugin::Abort)
);
let res = connection.send(OutgoingMessage::Simulator(PluginToSimulator::Success));
assert!(res.is_ok());
});
let (_, channel): (_, SimulatorChannel) = server.accept().unwrap();
let req = channel.0.send(SimulatorToPlugin::Abort);
assert!(req.is_ok());
let res = channel.1.recv();
assert!(res.is_ok());
assert_eq!(res.unwrap(), PluginToSimulator::Success);
assert!(plugin.join().is_ok());
}
#[test]
fn bad_address() {
let connection = Connection::new("asdf");
assert!(connection.is_err());
#[cfg(target_os = "macos")]
assert_eq!(
connection.err().unwrap().to_string(),
String::from("I/O error: Unknown Mach error: 44e")
);
#[cfg(target_os = "linux")]
assert_eq!(
connection.err().unwrap().to_string(),
String::from("I/O error: No such file or directory (os error 2)")
);
}
}