use super::io::{stream_pair, LocalStream};
use super::MockNetRuntime;
use core::fmt;
use tor_rtcompat::tls::TlsConnector;
use tor_rtcompat::{CertifiedConn, Runtime, TcpListener, TcpProvider, TlsProvider};
use tor_rtcompat::{UdpProvider, UdpSocket};
use async_trait::async_trait;
use futures::channel::mpsc;
use futures::io::{AsyncRead, AsyncWrite};
use futures::lock::Mutex as AsyncMutex;
use futures::sink::SinkExt;
use futures::stream::{Stream, StreamExt};
use futures::FutureExt;
use std::collections::HashMap;
use std::fmt::Formatter;
use std::io::{self, Error as IoError, ErrorKind, Result as IoResult};
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use thiserror::Error;
use void::Void;
type ConnSender = mpsc::Sender<(LocalStream, SocketAddr)>;
type ConnReceiver = mpsc::Receiver<(LocalStream, SocketAddr)>;
#[derive(Default)]
pub struct MockNetwork {
listening: Mutex<HashMap<SocketAddr, AddrBehavior>>,
}
#[derive(Clone)]
struct ListenerEntry {
send: ConnSender,
tls_cert: Option<Vec<u8>>,
}
#[derive(Clone)]
enum AddrBehavior {
Listener(ListenerEntry),
Timeout,
}
#[derive(Clone)]
pub struct MockNetProvider {
inner: Arc<MockNetProviderInner>,
}
impl fmt::Debug for MockNetProvider {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockNetProvider").finish_non_exhaustive()
}
}
struct MockNetProviderInner {
addrs: Vec<IpAddr>,
net: Arc<MockNetwork>,
next_port: AtomicU16,
}
pub struct MockNetListener {
addr: SocketAddr,
receiver: AsyncMutex<ConnReceiver>,
}
pub struct ProviderBuilder {
addrs: Vec<IpAddr>,
net: Arc<MockNetwork>,
}
impl Default for MockNetProvider {
fn default() -> Self {
Arc::new(MockNetwork::default()).builder().provider()
}
}
impl MockNetwork {
pub fn new() -> Arc<Self> {
Default::default()
}
pub fn builder(self: &Arc<Self>) -> ProviderBuilder {
ProviderBuilder {
addrs: vec![],
net: Arc::clone(self),
}
}
pub fn add_blackhole(&self, address: SocketAddr) -> IoResult<()> {
let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
if listener_map.contains_key(&address) {
return Err(err(ErrorKind::AddrInUse));
}
listener_map.insert(address, AddrBehavior::Timeout);
Ok(())
}
async fn send_connection(
&self,
source_addr: SocketAddr,
target_addr: SocketAddr,
peer_stream: LocalStream,
) -> IoResult<Option<Vec<u8>>> {
let entry = {
let listener_map = self.listening.lock().expect("Poisoned lock for listener");
listener_map.get(&target_addr).cloned()
};
match entry {
Some(AddrBehavior::Listener(mut entry)) => {
if entry.send.send((peer_stream, source_addr)).await.is_ok() {
return Ok(entry.tls_cert);
}
Err(err(ErrorKind::ConnectionRefused))
}
Some(AddrBehavior::Timeout) => futures::future::pending().await,
None => Err(err(ErrorKind::ConnectionRefused)),
}
}
fn add_listener(&self, addr: SocketAddr, tls_cert: Option<Vec<u8>>) -> IoResult<ConnReceiver> {
let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
if listener_map.contains_key(&addr) {
return Err(err(ErrorKind::AddrInUse));
}
let (send, recv) = mpsc::channel(16);
let entry = ListenerEntry { send, tls_cert };
listener_map.insert(addr, AddrBehavior::Listener(entry));
Ok(recv)
}
}
impl ProviderBuilder {
pub fn add_address(&mut self, addr: IpAddr) -> &mut Self {
self.addrs.push(addr);
self
}
pub fn runtime<R: Runtime>(&self, runtime: R) -> super::MockNetRuntime<R> {
MockNetRuntime::new(runtime, self.provider())
}
pub fn provider(&self) -> MockNetProvider {
let inner = MockNetProviderInner {
addrs: self.addrs.clone(),
net: Arc::clone(&self.net),
next_port: AtomicU16::new(1),
};
MockNetProvider {
inner: Arc::new(inner),
}
}
}
#[async_trait]
impl TcpListener for MockNetListener {
type TcpStream = LocalStream;
type Incoming = Self;
async fn accept(&self) -> IoResult<(Self::TcpStream, SocketAddr)> {
let mut receiver = self.receiver.lock().await;
receiver
.next()
.await
.ok_or_else(|| err(ErrorKind::BrokenPipe))
}
fn local_addr(&self) -> IoResult<SocketAddr> {
Ok(self.addr)
}
fn incoming(self) -> Self {
self
}
}
impl Stream for MockNetListener {
type Item = IoResult<(LocalStream, SocketAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut recv = futures::ready!(self.receiver.lock().poll_unpin(cx));
match recv.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub struct MockUdpSocket {
void: Void,
}
#[async_trait]
impl UdpProvider for MockNetProvider {
type UdpSocket = MockUdpSocket;
async fn bind(&self, addr: &SocketAddr) -> IoResult<MockUdpSocket> {
let _ = addr; Err(io::ErrorKind::Unsupported.into())
}
}
#[allow(clippy::diverging_sub_expression)] #[async_trait]
impl UdpSocket for MockUdpSocket {
async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
void::unreachable((self.void, buf).0)
}
async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
void::unreachable((self.void, buf, target).0)
}
fn local_addr(&self) -> IoResult<SocketAddr> {
void::unreachable(self.void)
}
}
impl MockNetProvider {
fn get_addr_in_family(&self, other: &IpAddr) -> Option<IpAddr> {
self.inner
.addrs
.iter()
.find(|a| a.is_ipv4() == other.is_ipv4())
.copied()
}
fn arbitrary_port(&self) -> u16 {
let next = self.inner.next_port.fetch_add(1, Ordering::Relaxed);
assert!(next != 0);
next
}
fn get_origin_addr_for(&self, addr: &SocketAddr) -> IoResult<SocketAddr> {
let my_addr = self
.get_addr_in_family(&addr.ip())
.ok_or_else(|| err(ErrorKind::AddrNotAvailable))?;
Ok(SocketAddr::new(my_addr, self.arbitrary_port()))
}
fn get_listener_addr(&self, spec: &SocketAddr) -> IoResult<SocketAddr> {
let ipaddr = {
let ip = spec.ip();
if ip.is_unspecified() {
self.get_addr_in_family(&ip)
.ok_or_else(|| err(ErrorKind::AddrNotAvailable))?
} else if self.inner.addrs.iter().any(|a| a == &ip) {
ip
} else {
return Err(err(ErrorKind::AddrNotAvailable));
}
};
let port = {
if spec.port() == 0 {
self.arbitrary_port()
} else {
spec.port()
}
};
Ok(SocketAddr::new(ipaddr, port))
}
pub fn listen_tls(&self, addr: &SocketAddr, tls_cert: Vec<u8>) -> IoResult<MockNetListener> {
let addr = self.get_listener_addr(addr)?;
let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, Some(tls_cert))?);
Ok(MockNetListener { addr, receiver })
}
}
#[async_trait]
impl TcpProvider for MockNetProvider {
type TcpStream = LocalStream;
type TcpListener = MockNetListener;
async fn connect(&self, addr: &SocketAddr) -> IoResult<LocalStream> {
let my_addr = self.get_origin_addr_for(addr)?;
let (mut mine, theirs) = stream_pair();
let cert = self
.inner
.net
.send_connection(my_addr, *addr, theirs)
.await?;
mine.tls_cert = cert;
Ok(mine)
}
async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::TcpListener> {
let addr = self.get_listener_addr(addr)?;
let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, None)?);
Ok(MockNetListener { addr, receiver })
}
}
#[async_trait]
impl TlsProvider<LocalStream> for MockNetProvider {
type Connector = MockTlsConnector;
type TlsStream = MockTlsStream;
fn tls_connector(&self) -> MockTlsConnector {
MockTlsConnector {}
}
fn supports_keying_material_export(&self) -> bool {
false
}
}
#[derive(Clone)]
#[non_exhaustive]
pub struct MockTlsConnector;
pub struct MockTlsStream {
peer_cert: Option<Vec<u8>>,
stream: LocalStream,
}
#[async_trait]
impl TlsConnector<LocalStream> for MockTlsConnector {
type Conn = MockTlsStream;
async fn negotiate_unvalidated(
&self,
mut stream: LocalStream,
_sni_hostname: &str,
) -> IoResult<MockTlsStream> {
let peer_cert = stream.tls_cert.take();
if peer_cert.is_none() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"attempted to wrap non-TLS stream!",
));
}
Ok(MockTlsStream { peer_cert, stream })
}
}
impl CertifiedConn for MockTlsStream {
fn peer_certificate(&self) -> IoResult<Option<Vec<u8>>> {
Ok(self.peer_cert.clone())
}
fn export_keying_material(
&self,
_len: usize,
_label: &[u8],
_context: Option<&[u8]>,
) -> IoResult<Vec<u8>> {
Ok(Vec::new())
}
}
impl AsyncRead for MockTlsStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<IoResult<usize>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
}
}
impl AsyncWrite for MockTlsStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<IoResult<usize>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
Pin::new(&mut self.stream).poll_close(cx)
}
}
#[derive(Clone, Error, Debug)]
#[non_exhaustive]
pub enum MockNetError {
#[error("Invalid operation on mock network")]
BadOp,
}
fn err(k: ErrorKind) -> IoError {
IoError::new(k, MockNetError::BadOp)
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_duration_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use tor_rtcompat::test_with_all_runtimes;
fn client_pair() -> (MockNetProvider, MockNetProvider) {
let net = MockNetwork::new();
let client1 = net
.builder()
.add_address("192.0.2.55".parse().unwrap())
.provider();
let client2 = net
.builder()
.add_address("198.51.100.7".parse().unwrap())
.provider();
(client1, client2)
}
#[test]
fn end_to_end() {
test_with_all_runtimes!(|_rt| async {
let (client1, client2) = client_pair();
let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
let address = lis.local_addr()?;
let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
async {
let mut conn = client1.connect(&address).await?;
conn.write_all(b"This is totally a network.").await?;
conn.close().await?;
let a2 = "192.0.2.200:99".parse().unwrap();
let cant_connect = client1.connect(&a2).await;
assert!(cant_connect.is_err());
Ok(())
},
async {
let (mut conn, a) = lis.accept().await?;
assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
let mut inp = Vec::new();
conn.read_to_end(&mut inp).await?;
assert_eq!(&inp[..], &b"This is totally a network."[..]);
Ok(())
}
);
r1?;
r2?;
IoResult::Ok(())
});
}
#[test]
fn pick_listener_addr() -> IoResult<()> {
let net = MockNetwork::new();
let ip4 = "192.0.2.55".parse().unwrap();
let ip6 = "2001:db8::7".parse().unwrap();
let client = net.builder().add_address(ip4).add_address(ip6).provider();
let a1 = client.get_listener_addr(&"0.0.0.0:99".parse().unwrap())?;
assert_eq!(a1.ip(), ip4);
assert_eq!(a1.port(), 99);
let a2 = client.get_listener_addr(&"192.0.2.55:100".parse().unwrap())?;
assert_eq!(a2.ip(), ip4);
assert_eq!(a2.port(), 100);
let a3 = client.get_listener_addr(&"192.0.2.55:0".parse().unwrap())?;
assert_eq!(a3.ip(), ip4);
assert!(a3.port() != 0);
let a4 = client.get_listener_addr(&"0.0.0.0:0".parse().unwrap())?;
assert_eq!(a4.ip(), ip4);
assert!(a4.port() != 0);
assert!(a4.port() != a3.port());
let a5 = client.get_listener_addr(&"[::]:99".parse().unwrap())?;
assert_eq!(a5.ip(), ip6);
assert_eq!(a5.port(), 99);
let a6 = client.get_listener_addr(&"[2001:db8::7]:100".parse().unwrap())?;
assert_eq!(a6.ip(), ip6);
assert_eq!(a6.port(), 100);
let e1 = client.get_listener_addr(&"192.0.2.56:0".parse().unwrap());
let e2 = client.get_listener_addr(&"[2001:db8::8]:0".parse().unwrap());
assert!(e1.is_err());
assert!(e2.is_err());
IoResult::Ok(())
}
#[test]
fn listener_stream() {
test_with_all_runtimes!(|_rt| async {
let (client1, client2) = client_pair();
let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
let address = lis.local_addr()?;
let mut incoming = lis.incoming();
let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
async {
for _ in 0..3_u8 {
let mut c = client1.connect(&address).await?;
c.close().await?;
}
Ok(())
},
async {
for _ in 0..3_u8 {
let (mut c, a) = incoming.next().await.unwrap()?;
let mut v = Vec::new();
let _ = c.read_to_end(&mut v).await?;
assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
}
Ok(())
}
);
r1?;
r2?;
IoResult::Ok(())
});
}
#[test]
fn tls_basics() {
let (client1, client2) = client_pair();
let cert = b"I am certified for something I assure you.";
let lis = client2
.listen_tls(&"0.0.0.0:0".parse().unwrap(), cert[..].into())
.unwrap();
let address = lis.local_addr().unwrap();
test_with_all_runtimes!(|_rt| async {
let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
async {
let connector = client1.tls_connector();
let conn = client1.connect(&address).await?;
let mut conn = connector
.negotiate_unvalidated(conn, "zombo.example.com")
.await?;
assert_eq!(&conn.peer_certificate()?.unwrap()[..], &cert[..]);
conn.write_all(b"This is totally encrypted.").await?;
let mut v = Vec::new();
conn.read_to_end(&mut v).await?;
conn.close().await?;
assert_eq!(v[..], b"Yup, your secrets is safe"[..]);
Ok(())
},
async {
let (mut conn, a) = lis.accept().await?;
assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
let mut inp = [0_u8; 26];
conn.read_exact(&mut inp[..]).await?;
assert_eq!(&inp[..], &b"This is totally encrypted."[..]);
conn.write_all(b"Yup, your secrets is safe").await?;
Ok(())
}
);
r1?;
r2?;
IoResult::Ok(())
});
}
}