use std;
use std::env;
use std::fs::File;
use std::io::{Read, Write};
#[cfg(not(target_os = "linux"))]
use std::net::Ipv6Addr;
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::atomic;
use std::sync::Arc;
use std::{thread, time};
use futures::Stream;
use native_tls;
use native_tls::{Certificate, TlsAcceptor};
use tokio::runtime::current_thread::Runtime;
use trust_dns_proto::error::ProtoError;
#[allow(unused)]
use {TlsStream, TlsStreamBuilder};
#[test]
fn test_tls_client_stream_ipv4() {
tls_client_stream_test(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), false)
}
#[cfg(feature = "mtls")]
#[test]
#[cfg(not(target_os = "macos"))] fn test_tls_client_stream_ipv4_mtls() {
tls_client_stream_test(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), true)
}
#[test]
#[cfg(not(target_os = "linux"))] fn test_tls_client_stream_ipv6() {
tls_client_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), false)
}
const TEST_BYTES: &'static [u8; 8] = b"DEADBEEF";
const TEST_BYTES_LEN: usize = 8;
fn read_file(path: &str) -> Vec<u8> {
let mut bytes = vec![];
let mut file = File::open(path).expect(&format!("failed to open file: {}", path));
file.read_to_end(&mut bytes)
.expect(&format!("failed to read file: {}", path));
bytes
}
#[allow(unused, unused_mut)]
fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
let succeeded = Arc::new(atomic::AtomicBool::new(false));
let succeeded_clone = succeeded.clone();
thread::Builder::new()
.name("thread_killer".to_string())
.spawn(move || {
let succeeded = succeeded_clone.clone();
for _ in 0..15 {
thread::sleep(time::Duration::from_secs(1));
if succeeded.load(atomic::Ordering::Relaxed) {
return;
}
}
panic!("timeout");
})
.unwrap();
let server_path = env::var("TDNS_SERVER_SRC_ROOT").unwrap_or("../server".to_owned());
println!("using server src path: {}", server_path);
let root_cert_der = read_file(&format!("{}/../tests/ca.der", server_path));
let dns_name = "ns.example.com";
let server_pkcs12_der = read_file(&format!("{}/../tests/cert.p12", server_path));
let server = std::net::TcpListener::bind(SocketAddr::new(server_addr, 0)).unwrap();
let server_addr = server.local_addr().unwrap();
let send_recv_times = 4;
let server_handle = thread::Builder::new()
.name("test_tls_client_stream:server".to_string())
.spawn(move || {
let pkcs12 = native_tls::Pkcs12::from_der(&server_pkcs12_der, "mypass")
.expect("Pkcs12::from_der");
let mut tls = TlsAcceptor::builder(pkcs12).expect("build with pkcs12 failed");
let tls = tls.build().expect("tls build failed");
let (socket, _) = server.accept().expect("tcp accept failed");
socket
.set_read_timeout(Some(std::time::Duration::from_secs(5)))
.unwrap(); socket
.set_write_timeout(Some(std::time::Duration::from_secs(5)))
.unwrap();
let mut socket = tls.accept(socket).expect("tls accept failed");
for _ in 0..send_recv_times {
let mut len_bytes = [0_u8; 2];
socket
.read_exact(&mut len_bytes)
.expect("SERVER: receive failed");
let length = (len_bytes[0] as u16) << 8 & 0xFF00 | len_bytes[1] as u16 & 0x00FF;
assert_eq!(length as usize, TEST_BYTES_LEN);
let mut buffer = [0_u8; TEST_BYTES_LEN];
socket.read_exact(&mut buffer).unwrap();
assert_eq!(&buffer, TEST_BYTES);
socket
.write_all(&len_bytes)
.expect("SERVER: send length failed");
socket
.write_all(&buffer)
.expect("SERVER: send buffer failed");
std::thread::yield_now();
}
})
.unwrap();
std::thread::yield_now();
let mut io_loop = Runtime::new().unwrap();
let trust_chain = Certificate::from_der(&root_cert_der).unwrap();
let mut builder = TlsStreamBuilder::new();
builder.add_ca(trust_chain);
let (stream, sender) =
builder.build::<ProtoError>(server_addr, dns_name.to_string());
let mut stream = io_loop.block_on(stream).ok().expect("run failed to get stream");
for _ in 0..send_recv_times {
sender
.unbounded_send((TEST_BYTES.to_vec(), server_addr))
.expect("send failed");
let (buffer, stream_tmp) = io_loop
.block_on(stream.into_future())
.ok()
.expect("future iteration run failed");
stream = stream_tmp;
let (buffer, _) = buffer.expect("no buffer received");
assert_eq!(&buffer, TEST_BYTES);
}
succeeded.store(true, std::sync::atomic::Ordering::Relaxed);
server_handle.join().expect("server thread failed");
}