use openssl::error::ErrorStack;
use openssl::ssl::{Ssl, SslContext, SslMethod, SslVersion};
use std::fmt::Debug;
use std::io::{Read, Write};
use std::net::{ToSocketAddrs, TcpStream};
mod error;
mod psk_providers;
mod stream;
#[cfg(test)]
mod tests;
pub use error::TunnelError;
pub use psk_providers::{PskProvider, SimplePskProvider};
pub use stream::TunnelStream;
pub fn connect_simple(addr: impl ToSocketAddrs, identity: &str, psk: &[u8]) -> Result<TunnelStream<TcpStream>, TunnelError> {
let sock_addr = addr.to_socket_addrs()
.map_err(|e| TunnelError::from(e, "Error when resolving the socket address!"))?.next();
let sock_addr = sock_addr.ok_or_else(|| TunnelError::new("No parseable addresses."))?;
let stream = match std::net::TcpStream::connect(sock_addr) {
Ok(s) => s,
Err(e) => {
eprintln!("Couldn't connect to server! {}", e);
std::process::exit(4);
}
};
let stream = match client(stream, identity, psk) {
Ok(s) => s,
Err(e) => {
eprintln!("Error while initialising the connection! {}", e);
std::process::exit(5);
}
};
Ok(stream)
}
pub fn server<S>(
stream: S,
psk_provider: impl PskProvider + Send + Sync + 'static,
) -> Result<TunnelStream<S>, TunnelError>
where
S: Read + Write + Debug + 'static,
{
let mut ctx = SslContext::builder(SslMethod::tls())
.map_err(|e| TunnelError::from(e, "Error when building the SSL context!"))?;
ctx.set_psk_server_callback(move |_ssl, identity, psk_buf| {
let identity = identity.ok_or_else(|| ErrorStack::get())?;
let identity = std::str::from_utf8(identity).map_err(|_| ErrorStack::get())?;
let psk = psk_provider
.get_psk(identity)
.map_err(|_| ErrorStack::get())?;
&mut psk_buf[..psk.len()].copy_from_slice(psk);
Ok(psk.len())
});
ctx.set_min_proto_version(Some(SslVersion::TLS1_3))
.map_err(|e| {
TunnelError::from(e, "Error setting the minimum protocol version to TLS 1.3")
})?;
ctx.set_cipher_list("PSK-CHACHA20-POLY1305").map_err(|e| {
TunnelError::from(
e,
"Error setting the cipher suite list to PSK-CHACHA20-POLY1305",
)
})?;
let ssl = Ssl::new(&ctx.build())
.map_err(|e| TunnelError::from(e, "Error on starting a new TLS session"))?;
let tls_stream = ssl
.accept(stream)
.map_err(|e| TunnelError::from(e, "Error on accepting a new TLS connection"))?;
Ok(TunnelStream { tls_stream })
}
pub fn client<S>(
stream: S,
identity: impl Into<String>,
psk: impl Into<Vec<u8>>,
) -> Result<TunnelStream<S>, TunnelError>
where
S: Read + Write + Debug + 'static,
{
let mut ctx = SslContext::builder(SslMethod::tls())
.map_err(|e| TunnelError::from(e, "Error when building the SSL context!"))?;
let identity = identity.into().into_bytes();
let psk = psk.into();
ctx.set_psk_client_callback(move |_ssl, _hint, identity_buf, psk_buf| {
&mut identity_buf[..identity.len()].copy_from_slice(&identity);
identity_buf[identity.len()] = b'\0';
&mut psk_buf[..psk.len()].copy_from_slice(&psk);
Ok(psk.len())
});
ctx.set_min_proto_version(Some(SslVersion::TLS1_3))
.map_err(|e| {
TunnelError::from(e, "Error setting the minimum protocol version to TLS 1.3")
})?;
ctx.set_cipher_list("PSK-CHACHA20-POLY1305").map_err(|e| {
TunnelError::from(
e,
"Error setting the cipher suite list to PSK-CHACHA20-POLY1305",
)
})?;
let ssl = Ssl::new(&ctx.build())
.map_err(|e| TunnelError::from(e, "Error when starting a new TLS session"))?;
let tls_stream = ssl
.connect(stream)
.map_err(|e| TunnelError::from(e, "Error when connecting to a TLS socket"))?;
Ok(TunnelStream { tls_stream })
}