use std::io::{BufRead, BufReader, Read, Write};
use std::sync::{Arc, Mutex};
use go_lib::{
chan::chan,
go,
net::{TcpListener, TcpStream},
select,
sync::WaitGroup,
};
#[test]
#[go_lib::main]
fn net_listener_local_addr() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind failed");
let addr = listener.local_addr().expect("local_addr failed");
assert_eq!(addr.ip().to_string(), "127.0.0.1");
assert_ne!(addr.port(), 0, "OS must assign a non-zero port");
}
#[test]
#[go_lib::main]
fn net_read_write_mut_ref() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let (done_tx, done_rx) = chan::<()>(1);
go!(move || {
let mut conn = listener.accept().unwrap();
let mut buf = [0u8; 64];
let n = conn.read(&mut buf).unwrap();
conn.write_all(&buf[..n]).unwrap();
done_tx.send(());
});
let mut client = TcpStream::connect(addr).unwrap();
client.write_all(b"hello").unwrap();
let mut resp = [0u8; 5];
client.read_exact(&mut resp).unwrap();
assert_eq!(&resp, b"hello");
done_rx.recv();
}
#[test]
#[go_lib::main]
fn net_read_write_shared_ref() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let (done_tx, done_rx) = chan::<()>(1);
go!(move || {
let conn = listener.accept().unwrap();
let mut buf = [0u8; 64];
let n = (&conn).read(&mut buf).unwrap();
(&conn).write_all(&buf[..n]).unwrap();
done_tx.send(());
});
let client = TcpStream::connect(addr).unwrap();
(&client).write_all(b"shared").unwrap();
let mut resp = [0u8; 6];
(&client).read_exact(&mut resp).unwrap();
assert_eq!(&resp, b"shared");
done_rx.recv();
}
#[test]
#[go_lib::main]
fn net_try_clone_split_halves() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let (done_tx, done_rx) = chan::<()>(1);
go!(move || {
let stream = listener.accept().unwrap();
let mut writer = stream.try_clone().expect("try_clone failed");
let mut buf = [0u8; 64];
let n = (&stream).read(&mut buf).unwrap();
writer.write_all(&buf[..n]).unwrap();
done_tx.send(());
});
let mut client = TcpStream::connect(addr).unwrap();
client.write_all(b"cloned").unwrap();
let mut resp = [0u8; 6];
client.read_exact(&mut resp).unwrap();
assert_eq!(&resp, b"cloned");
done_rx.recv();
}
#[test]
#[go_lib::main]
fn net_try_clone_separate_goroutines() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let (done_tx, done_rx) = chan::<()>(1);
go!(move || {
let stream = listener.accept().unwrap();
let writer = stream.try_clone().expect("try_clone failed");
let (relay_tx, relay_rx) = chan::<Vec<u8>>(1);
go!(move || {
let mut buf = [0u8; 64];
let n = (&stream).read(&mut buf).unwrap();
relay_tx.send(buf[..n].to_vec());
});
go!(move || {
let data = relay_rx.recv().unwrap();
(&writer).write_all(&data).unwrap();
done_tx.send(());
});
});
let mut client = TcpStream::connect(addr).unwrap();
client.write_all(b"split").unwrap();
let mut resp = [0u8; 5];
client.read_exact(&mut resp).unwrap();
assert_eq!(&resp, b"split");
done_rx.recv();
}
#[test]
#[go_lib::main]
fn net_peer_and_local_addr() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let listen_addr = listener.local_addr().unwrap();
let (addr_tx, addr_rx) = chan::<std::net::SocketAddr>(1);
go!(move || {
let conn = listener.accept().unwrap();
let local = conn.local_addr().expect("local_addr failed");
assert_eq!(local.port(), listen_addr.port());
let peer = conn.peer_addr().expect("peer_addr failed");
assert_ne!(peer.port(), 0);
addr_tx.send(peer);
});
let client = TcpStream::connect(listen_addr).unwrap();
let client_local = client.local_addr().expect("client local_addr failed");
let reported = addr_rx.recv().unwrap();
assert_eq!(reported.port(), client_local.port());
}
#[test]
#[go_lib::main]
fn net_bufreader_adapter() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let (done_tx, done_rx) = chan::<()>(1);
go!(move || {
let conn = listener.accept().unwrap();
let mut br = BufReader::new(conn);
let mut line = String::new();
br.read_line(&mut line).unwrap();
assert_eq!(line.trim_end(), "ping");
br.get_mut().write_all(b"pong\n").unwrap();
done_tx.send(());
});
let mut client = TcpStream::connect(addr).unwrap();
client.write_all(b"ping\n").unwrap();
let mut resp = String::new();
BufReader::new(client).read_line(&mut resp).unwrap();
assert_eq!(resp.trim_end(), "pong");
done_rx.recv();
}
#[test]
#[go_lib::main]
fn net_concurrent_connections() {
const N: usize = 8;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let server_wg = Arc::new(WaitGroup::new());
let wg2 = Arc::clone(&server_wg);
go!(move || {
for _ in 0..N {
let conn = listener.accept().unwrap();
let wg3 = Arc::clone(&wg2);
wg3.add(1);
go!(move || {
let mut buf = [0u8; 4];
(&conn).read_exact(&mut buf).unwrap();
(&conn).write_all(&buf).unwrap();
wg3.done();
});
}
});
let results = Arc::new(Mutex::new(Vec::<bool>::new()));
let client_wg = Arc::new(WaitGroup::new());
for i in 0..N {
client_wg.add(1);
let results2 = Arc::clone(&results);
let client_wg2 = Arc::clone(&client_wg);
go!(move || {
let mut conn = TcpStream::connect(addr).unwrap();
let tag = [i as u8; 4];
conn.write_all(&tag).unwrap();
let mut resp = [0u8; 4];
conn.read_exact(&mut resp).unwrap();
results2.lock().unwrap().push(resp == tag);
client_wg2.done();
});
}
client_wg.wait();
server_wg.wait();
let ok = results.lock().unwrap();
assert_eq!(ok.len(), N, "wrong number of results");
assert!(ok.iter().all(|&b| b), "some echo checks failed");
}
#[test]
#[go_lib::main]
fn net_large_payload() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
const SIZE: usize = 128 * 1024;
let payload: Vec<u8> = (0..SIZE).map(|i| (i % 251) as u8).collect();
let payload = Arc::new(payload);
let (done_tx, done_rx) = chan::<()>(1);
let payload2 = Arc::clone(&payload);
go!(move || {
let mut conn = listener.accept().unwrap();
let mut buf = vec![0u8; SIZE];
conn.read_exact(&mut buf).unwrap();
conn.write_all(&buf).unwrap();
done_tx.send(());
});
let mut client = TcpStream::connect(addr).unwrap();
client.write_all(&payload).unwrap();
let mut received = vec![0u8; SIZE];
client.read_exact(&mut received).unwrap();
assert_eq!(received, *payload2, "large payload echo mismatch");
done_rx.recv();
}
#[test]
#[go_lib::main]
fn net_connect_hostname_does_not_overflow_stack() {
let result = TcpStream::connect("go-lib-nonexistent-host.invalid:80");
assert!(
result.is_err(),
"connecting to an unresolvable hostname should return Err, not succeed",
);
}
#[test]
#[go_lib::main]
fn net_leaked_select_dispatch_servers() {
use std::time::Duration;
const N: usize = 8;
fn spawn_forever_server() -> std::net::SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let (conn_tx, conn_rx) = chan::<TcpStream>(8);
go!(move || {
while let Ok(s) = listener.accept() {
conn_tx.send(s);
}
});
let (_shutdown_tx, shutdown_rx) = chan::<()>(1);
let mut stop = false;
go!(move || loop {
select! {
recv(shutdown_rx) -> _sig => { stop = true; }
recv(conn_rx) -> conn => {
match conn {
None => { stop = true; }
Some(s) => {
go!(move || {
let mut w = s.try_clone().unwrap();
let mut r = s.try_clone().unwrap();
let mut buf = [0u8; 4];
for k in 0..4 {
r.read_exact(&mut buf[k..k + 1]).unwrap();
}
w.write_all(&buf).unwrap();
});
}
}
}
}
if stop { break; }
});
addr
}
let addrs: Vec<_> = (0..N).map(|_| spawn_forever_server()).collect();
go_lib::sleep(Duration::from_millis(50));
let results = Arc::new(Mutex::new(0usize));
let client_wg = Arc::new(WaitGroup::new());
for (i, addr) in addrs.into_iter().enumerate() {
client_wg.add(1);
let results2 = Arc::clone(&results);
let wg2 = Arc::clone(&client_wg);
go!(move || {
let client = TcpStream::connect(addr).unwrap();
let mut wc = client.try_clone().unwrap();
let mut rc = client.try_clone().unwrap();
let tag = [i as u8; 4];
wc.write_all(&tag).unwrap();
let mut resp = [0u8; 4];
rc.read_exact(&mut resp).unwrap();
if resp == tag { *results2.lock().unwrap() += 1; }
wg2.done();
});
}
client_wg.wait();
assert_eq!(*results.lock().unwrap(), N);
}
#[test]
#[go_lib::main]
fn net_leaked_forever_accept() {
use std::time::Duration;
const N: usize = 8;
let mut addrs = Vec::new();
for _ in 0..N {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
addrs.push(listener.local_addr().unwrap());
go!(move || {
while let Ok(conn) = listener.accept() {
go!(move || {
let mut c = conn;
let mut buf = [0u8; 4];
if c.read(&mut buf).unwrap_or(0) > 0 {
let _ = c.write(&buf);
}
});
}
});
}
go_lib::sleep(Duration::from_millis(50));
let wg = Arc::new(WaitGroup::new());
for addr in addrs {
wg.add(1);
let wg2 = Arc::clone(&wg);
go!(move || {
let mut c = TcpStream::connect(addr).unwrap();
c.write_all(b"ping").unwrap();
let mut resp = [0u8; 4];
c.read_exact(&mut resp).unwrap();
wg2.done();
});
}
wg.wait();
}