use rcgen::generate_simple_self_signed;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::crypto::ring;
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
use rustls::version::TLS13;
use rustls::{ClientConfig, DigitallySignedStruct, Error, SignatureScheme};
use sfo_cmd_server::client::{CmdClient, CmdTunnelFactory, DefaultCmdClient};
use sfo_cmd_server::errors::{CmdErrorCode, CmdResult, into_cmd_err};
use sfo_cmd_server::{CmdBody, CmdHeader, CmdTunnel, CmdTunnelRead, CmdTunnelWrite, PeerId};
use sha2::Digest;
use std::fmt::{Debug, Formatter};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, split};
use tokio_rustls::TlsConnector;
struct TlsStreamRead {
local_id: PeerId,
remote_id: PeerId,
read: Option<tokio::io::ReadHalf<tokio_rustls::client::TlsStream<tokio::net::TcpStream>>>,
}
impl TlsStreamRead {
pub fn new(
local_id: PeerId,
remote_id: PeerId,
read: tokio::io::ReadHalf<tokio_rustls::client::TlsStream<tokio::net::TcpStream>>,
) -> Self {
Self {
local_id,
remote_id,
read: Some(read),
}
}
}
impl AsyncRead for TlsStreamRead {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
if let Some(read) = this.read.as_mut() {
Pin::new(read).poll_read(cx, buf)
} else {
Poll::Ready(Ok(()))
}
}
}
impl CmdTunnelRead<()> for TlsStreamRead {
fn get_local_peer_id(&self) -> PeerId {
self.local_id.clone()
}
fn get_remote_peer_id(&self) -> PeerId {
self.remote_id.clone()
}
}
struct TlsStreamWrite {
local_id: PeerId,
remote_id: PeerId,
write: Option<tokio::io::WriteHalf<tokio_rustls::client::TlsStream<tokio::net::TcpStream>>>,
}
impl TlsStreamWrite {
pub fn new(
local_id: PeerId,
remote_id: PeerId,
write: tokio::io::WriteHalf<tokio_rustls::client::TlsStream<tokio::net::TcpStream>>,
) -> Self {
Self {
local_id,
remote_id,
write: Some(write),
}
}
}
impl AsyncWrite for TlsStreamWrite {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let this = self.get_mut();
if let Some(write) = this.write.as_mut() {
Pin::new(write).poll_write(cx, buf)
} else {
Poll::Ready(Ok(0))
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
if let Some(write) = this.write.as_mut() {
Pin::new(write).poll_flush(cx)
} else {
Poll::Ready(Ok(()))
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
if let Some(write) = this.write.as_mut() {
Pin::new(write).poll_shutdown(cx)
} else {
Poll::Ready(Ok(()))
}
}
}
impl CmdTunnelWrite<()> for TlsStreamWrite {
fn get_local_peer_id(&self) -> PeerId {
self.local_id.clone()
}
fn get_remote_peer_id(&self) -> PeerId {
self.remote_id.clone()
}
}
struct TlsConnection {
tls_key: Vec<u8>,
stream: Mutex<Option<tokio_rustls::client::TlsStream<tokio::net::TcpStream>>>,
}
impl TlsConnection {
pub fn new(stream: tokio_rustls::client::TlsStream<tokio::net::TcpStream>) -> Self {
let tls_key = stream
.get_ref()
.1
.peer_certificates()
.unwrap()
.get(0)
.unwrap()
.to_vec();
Self {
tls_key,
stream: Mutex::new(Some(stream)),
}
}
}
fn generate_cert() -> (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>) {
let subject_alt_names = vec!["127.0.0.1".to_string()];
let cert_key = generate_simple_self_signed(subject_alt_names).unwrap();
(
vec![CertificateDer::from_pem_slice(cert_key.cert.pem().as_bytes()).unwrap()],
PrivateKeyDer::from_pem_slice(cert_key.key_pair.serialize_pem().as_bytes()).unwrap(),
)
}
pub struct TlsServerCertVerifier {}
impl Debug for TlsServerCertVerifier {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str("TlsServerCertVerifier")
}
}
impl ServerCertVerifier for TlsServerCertVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
server_name: &ServerName<'_>,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![SignatureScheme::ECDSA_NISTP256_SHA256]
}
}
pub struct TlsConnectionFactory {
local_id: PeerId,
tls_connector: TlsConnector,
}
impl TlsConnectionFactory {
pub fn new() -> Self {
let (certs, key) = generate_cert();
let mut sha256 = sha2::Sha256::new();
sha256.update(certs[0].as_ref());
let local_id = PeerId::from(sha256.finalize().to_vec());
let config = ClientConfig::builder_with_provider(ring::default_provider().into())
.with_protocol_versions(&[&TLS13])
.unwrap()
.dangerous()
.with_custom_certificate_verifier(Arc::new(TlsServerCertVerifier {}))
.with_client_auth_cert(certs, key)
.unwrap();
Self {
local_id,
tls_connector: TlsConnector::from(Arc::new(config)),
}
}
}
#[async_trait::async_trait]
impl CmdTunnelFactory<(), TlsStreamRead, TlsStreamWrite> for TlsConnectionFactory {
async fn create_tunnel(&self) -> CmdResult<CmdTunnel<TlsStreamRead, TlsStreamWrite>> {
let socket = tokio::net::TcpStream::connect("127.0.0.1:4453")
.await
.map_err(into_cmd_err!(
CmdErrorCode::IoError,
"connect to server failed"
))?;
let tls_stream = self
.tls_connector
.connect("127.0.0.1".to_string().try_into().unwrap(), socket)
.await
.map_err(into_cmd_err!(CmdErrorCode::IoError, "tls handshake failed"))?;
let tls_key = tls_stream
.get_ref()
.1
.peer_certificates()
.unwrap()
.get(0)
.unwrap()
.to_vec();
let mut sha256 = sha2::Sha256::new();
sha256.update(tls_key.as_slice());
let peer_id = PeerId::from(sha256.finalize().as_slice().to_vec());
let (r, w) = split(tls_stream);
Ok(CmdTunnel::new(
TlsStreamRead::new(self.local_id.clone(), peer_id.clone(), r),
TlsStreamWrite::new(self.local_id.clone(), peer_id, w),
))
}
}
#[tokio::main]
async fn main() {
let client = DefaultCmdClient::<(), TlsStreamRead, TlsStreamWrite, _, u16, u8>::new(
TlsConnectionFactory::new(),
5,
);
client.register_cmd_handler(
0x02,
move |_local_id, _peer_id, _tunnel_id, header: CmdHeader<u16, u8>, _body| async move {
println!("recv cmd {}", header.cmd_code());
Ok(None)
},
);
client.register_cmd_handler(
0x06,
move |_local_id, _peer_id, _tunnel_id, header: CmdHeader<u16, u8>, _body| async move {
println!("recv cmd {}", header.cmd_code());
Ok(Some(CmdBody::from_string("client resp 6".to_string())))
},
);
client.send(0x01, 0, "client".as_bytes()).await.unwrap();
let resp = client
.send_with_resp(
0x03,
0,
"client send 0x03".as_bytes(),
Duration::from_secs(1000),
)
.await
.unwrap();
println!(
"recv server resp. cmd {} data {}",
0x03,
resp.into_string().await.unwrap()
);
tokio::time::sleep(Duration::from_secs(10)).await;
}