use holochain_tracing_macros::newrelic_autotrace;
use lazy_static::lazy_static;
use lib3h_zombie_actor::GhostMutex;
use std::{
collections::{hash_map::Entry, HashMap},
io::{Read, Write},
};
use url2::prelude::*;
#[derive(Debug)]
pub struct MemListener {
url: Url2,
recv: crossbeam_channel::Receiver<MemStream>,
accept_queue: Vec<MemStream>,
}
#[newrelic_autotrace(SIM2H)]
impl MemListener {
fn priv_new(url: Url2, recv: crossbeam_channel::Receiver<MemStream>) -> Self {
Self {
url,
recv,
accept_queue: Vec::new(),
}
}
pub fn bind(url: &Url2) -> std::io::Result<MemListener> {
MEM_MANAGER.lock().bind(url)
}
pub fn get_url(&self) -> &Url2 {
&self.url
}
pub fn accept(&mut self) -> std::io::Result<MemStream> {
loop {
match self.recv.try_recv() {
Ok(stream) => {
self.accept_queue.push(stream);
}
Err(crossbeam_channel::TryRecvError::Empty) => break,
Err(crossbeam_channel::TryRecvError::Disconnected) => {
if self.accept_queue.is_empty() {
return Err(std::io::ErrorKind::BrokenPipe.into());
}
}
}
}
if self.accept_queue.is_empty() {
return Err(std::io::ErrorKind::WouldBlock.into());
}
Ok(self.accept_queue.remove(0))
}
}
impl Drop for MemListener {
fn drop(&mut self) {
MEM_MANAGER.lock().unbind(&self.url);
}
}
#[derive(Debug)]
pub struct MemStream {
url: Url2,
send: crossbeam_channel::Sender<Vec<u8>>,
recv: crossbeam_channel::Receiver<Vec<u8>>,
recv_buf: Vec<u8>,
}
#[newrelic_autotrace(SIM2H)]
impl MemStream {
fn priv_new(
url: Url2,
send: crossbeam_channel::Sender<Vec<u8>>,
recv: crossbeam_channel::Receiver<Vec<u8>>,
) -> MemStream {
MemStream {
url,
send,
recv,
recv_buf: Vec::new(),
}
}
pub fn connect(url: &Url2) -> std::io::Result<MemStream> {
MEM_MANAGER.lock().connect(url)
}
pub fn get_url(&self) -> &Url2 {
&self.url
}
}
#[newrelic_autotrace(SIM2H)]
impl Read for MemStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut disconnected = false;
loop {
match self.recv.try_recv() {
Ok(mut data) => {
self.recv_buf.append(&mut data);
}
Err(crossbeam_channel::TryRecvError::Empty) => break,
Err(crossbeam_channel::TryRecvError::Disconnected) => {
disconnected = true;
break;
}
}
}
if self.recv_buf.is_empty() {
if disconnected {
return Ok(0);
} else {
return Err(std::io::ErrorKind::WouldBlock.into());
}
}
let v: Vec<u8> = self
.recv_buf
.drain(0..std::cmp::min(buf.len(), self.recv_buf.len()))
.collect();
buf[0..v.len()].copy_from_slice(&v);
Ok(v.len())
}
}
#[newrelic_autotrace(SIM2H)]
impl Write for MemStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self.send.send(buf.to_vec()) {
Ok(_) => Ok(buf.len()),
Err(_) => Err(std::io::ErrorKind::NotConnected.into()),
}
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
fn random_url(prefix: &str) -> Url2 {
Url2::parse(&format!(
"mem://{}-{}",
prefix,
nanoid::simple().replace("_", "-").replace("~", "+"),
))
}
fn create_mem_stream_pair(url_a: Url2, url_b: Url2) -> (MemStream, MemStream) {
let (send1, recv1) = crossbeam_channel::unbounded();
let (send2, recv2) = crossbeam_channel::unbounded();
(
MemStream::priv_new(url_a, send1, recv2),
MemStream::priv_new(url_b, send2, recv1),
)
}
struct MemManager {
listeners: HashMap<Url2, crossbeam_channel::Sender<MemStream>>,
}
#[newrelic_autotrace(SIM2H)]
impl MemManager {
fn new() -> Self {
Self {
listeners: HashMap::new(),
}
}
fn bind(&mut self, url: &Url2) -> std::io::Result<MemListener> {
if "mem" != url.scheme() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"mem bind: url scheme must be mem",
));
}
match url.port() {
Some(4242) | None => (),
_ => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"mem bind: url port must be None or 4242",
));
}
}
if url.host_str().is_none() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"mem bind: host_str must be set",
));
}
let new_url = Url2::parse(&format!("mem://{}:4242", url.host_str().unwrap(),));
match self.listeners.entry(new_url.clone()) {
Entry::Occupied(_) => Err(std::io::ErrorKind::AddrInUse.into()),
Entry::Vacant(e) => {
let (send, recv) = crossbeam_channel::unbounded();
e.insert(send);
Ok(MemListener::priv_new(new_url, recv))
}
}
}
fn unbind(&mut self, url: &Url2) {
self.listeners.remove(url);
}
fn connect(&mut self, url: &Url2) -> std::io::Result<MemStream> {
let url = if url.scheme() != "mem" || url.host_str().is_none() {
Url2::parse(&format!("mem://{}", url,))
} else {
url.clone()
};
let mut disconnected = false;
if let Entry::Occupied(mut e) = self.listeners.entry(url.clone()) {
let (one, two) = create_mem_stream_pair(random_url("assigned"), url.clone());
match e.get_mut().send(one) {
Ok(_) => return Ok(two),
Err(_) => disconnected = true,
}
}
if disconnected {
self.listeners.remove(&url);
}
Err(std::io::ErrorKind::ConnectionRefused.into())
}
}
lazy_static! {
static ref MEM_MANAGER: GhostMutex<MemManager> = { GhostMutex::new(MemManager::new()) };
}
#[cfg(test)]
mod tests {
use super::*;
fn setup() -> (MemListener, MemStream, MemStream) {
let url = random_url("test");
println!("SETUP USING URL: {}", url);
let mut listener = MemListener::bind(&url).unwrap();
println!("LISTENER GOT BOUND URL: {}", listener.get_url());
let client = MemStream::connect(listener.get_url()).unwrap();
let server = listener.accept().unwrap();
(listener, client, server)
}
#[test]
fn it_should_connection_refused() {
match MemStream::connect(&Url2::parse("badconnection:")) {
Err(ref e) if e.kind() == std::io::ErrorKind::ConnectionRefused => (),
e => panic!("unexpected {:?}", e),
}
}
#[test]
fn it_should_addr_in_use() {
let (listener, _c, _s) = setup();
match MemListener::bind(listener.get_url()) {
Err(ref e) if e.kind() == std::io::ErrorKind::AddrInUse => (),
e => panic!("unexpected {:?}", e),
}
}
#[test]
fn it_can_read_write() {
let mut buf = [0_u8; 1024];
let (_listener, mut client, mut server) = setup();
client.write_all(b"test1").unwrap();
assert_eq!(5, server.read(&mut buf).unwrap());
assert_eq!(b"test1", &buf[..5]);
server.write_all(b"test2").unwrap();
assert_eq!(5, client.read(&mut buf).unwrap());
assert_eq!(b"test2", &buf[..5]);
}
#[test]
fn it_should_would_block() {
let mut buf = [0_u8; 1024];
let (mut listener, mut client, mut server) = setup();
match listener.accept() {
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => (),
e => panic!("unexpected {:?}", e),
}
match client.read(&mut buf) {
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => (),
e => panic!("unexpected {:?}", e),
}
match server.read(&mut buf) {
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => (),
e => panic!("unexpected {:?}", e),
}
}
#[test]
fn it_can_keep_talking_after_listener_drop() {
let mut buf = [0_u8; 1024];
let (mut client, mut server) = {
let (_listener, client, server) = setup();
(client, server)
};
client.write_all(b"test1").unwrap();
assert_eq!(5, server.read(&mut buf).unwrap());
assert_eq!(b"test1", &buf[..5]);
server.write_all(b"test2").unwrap();
assert_eq!(5, client.read(&mut buf).unwrap());
assert_eq!(b"test2", &buf[..5]);
}
#[test]
fn it_should_end_of_stream() {
let mut buf = [0_u8; 1024];
let mut server = {
let (_listener, mut client, server) = setup();
client.write_all(b"test1").unwrap();
server
};
assert_eq!(5, server.read(&mut buf).unwrap());
assert_eq!(b"test1", &buf[..5]);
match server.read(&mut buf) {
Ok(0) => (),
_ => panic!("unexpected"),
}
}
}