use rcgen::generate_simple_self_signed;
use rustls::client::danger::HandshakeSignatureValid;
use rustls::crypto::ring;
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, UnixTime};
use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
use rustls::version::TLS13;
use rustls::{DigitallySignedStruct, DistinguishedName, Error, ServerConfig, SignatureScheme};
use sfo_cmd_server::errors::{CmdErrorCode, CmdResult, into_cmd_err};
use sfo_cmd_server::server::{CmdServer, CmdTunnelListener, DefaultCmdServer};
use sfo_cmd_server::{CmdBody, CmdHeader, CmdTunnel, CmdTunnelRead, CmdTunnelWrite, PeerId};
use sha2::Digest;
use std::fmt::Debug;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, split};
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
struct TlsStreamRead {
local_id: PeerId,
remote_id: PeerId,
read: Option<tokio::io::ReadHalf<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>>,
}
impl TlsStreamRead {
pub fn new(
local_id: PeerId,
remote_id: PeerId,
read: tokio::io::ReadHalf<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>,
) -> Self {
Self {
local_id,
remote_id,
read: Some(read),
}
}
}
impl AsyncRead for TlsStreamRead {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
if let Some(read) = this.read.as_mut() {
Pin::new(read).poll_read(cx, buf)
} else {
Poll::Ready(Ok(()))
}
}
}
impl CmdTunnelRead<()> for TlsStreamRead {
fn get_local_peer_id(&self) -> PeerId {
self.local_id.clone()
}
fn get_remote_peer_id(&self) -> PeerId {
self.remote_id.clone()
}
}
struct TlsStreamWrite {
local_id: PeerId,
remote_id: PeerId,
write: Option<tokio::io::WriteHalf<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>>,
}
impl TlsStreamWrite {
pub fn new(
local_id: PeerId,
remote_id: PeerId,
write: tokio::io::WriteHalf<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>,
) -> Self {
Self {
local_id,
remote_id,
write: Some(write),
}
}
}
impl AsyncWrite for TlsStreamWrite {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let this = self.get_mut();
if let Some(write) = this.write.as_mut() {
Pin::new(write).poll_write(cx, buf)
} else {
Poll::Ready(Ok(0))
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
if let Some(write) = this.write.as_mut() {
Pin::new(write).poll_flush(cx)
} else {
Poll::Ready(Ok(()))
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
if let Some(write) = this.write.as_mut() {
Pin::new(write).poll_shutdown(cx)
} else {
Poll::Ready(Ok(()))
}
}
}
impl CmdTunnelWrite<()> for TlsStreamWrite {
fn get_local_peer_id(&self) -> PeerId {
self.local_id.clone()
}
fn get_remote_peer_id(&self) -> PeerId {
self.remote_id.clone()
}
}
pub struct TlsClientCertVerifier {
pub subjects: Vec<DistinguishedName>,
}
impl Debug for TlsClientCertVerifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TlsClientCertVerifier").finish()
}
}
impl ClientCertVerifier for TlsClientCertVerifier {
fn root_hint_subjects(&self) -> &[DistinguishedName] {
self.subjects.as_slice()
}
fn verify_client_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_now: UnixTime,
) -> Result<ClientCertVerified, Error> {
Ok(ClientCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![SignatureScheme::ECDSA_NISTP256_SHA256]
}
}
struct TunnelListener {
local_id: PeerId,
tls_acceptor: TlsAcceptor,
tcp_listener: TcpListener,
}
fn generate_cert() -> (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>) {
let subject_alt_names = vec!["127.0.0.1".to_string()];
let cert_key = generate_simple_self_signed(subject_alt_names).unwrap();
(
vec![CertificateDer::from_pem_slice(cert_key.cert.pem().as_bytes()).unwrap()],
PrivateKeyDer::from_pem_slice(cert_key.key_pair.serialize_pem().as_bytes()).unwrap(),
)
}
impl TunnelListener {
pub async fn bind(addr: &str) -> CmdResult<Self> {
let listener = TcpListener::bind(addr)
.await
.map_err(into_cmd_err!(CmdErrorCode::IoError, "bind failed"))?;
let (certs, key) = generate_cert();
let mut sha256 = sha2::Sha256::new();
sha256.update(certs[0].as_ref());
let local_id = PeerId::from(sha256.finalize().to_vec());
let config = ServerConfig::builder_with_provider(ring::default_provider().into())
.with_protocol_versions(&[&TLS13])
.unwrap()
.with_client_cert_verifier(Arc::new(TlsClientCertVerifier { subjects: vec![] }))
.with_single_cert(certs, key)
.map_err(into_cmd_err!(
CmdErrorCode::TlsError,
"create tls config failed"
))?;
Ok(Self {
local_id,
tls_acceptor: TlsAcceptor::from(Arc::new(config)),
tcp_listener: listener,
})
}
}
#[async_trait::async_trait]
impl CmdTunnelListener<(), TlsStreamRead, TlsStreamWrite> for TunnelListener {
async fn accept(
&self,
) -> sfo_cmd_server::errors::CmdResult<CmdTunnel<TlsStreamRead, TlsStreamWrite>> {
let (stream, _) = self
.tcp_listener
.accept()
.await
.map_err(into_cmd_err!(CmdErrorCode::IoError, "accept failed"))?;
let tls_stream = self
.tls_acceptor
.accept(stream)
.await
.map_err(into_cmd_err!(CmdErrorCode::TlsError, "tls accept failed"))?;
let tls_key = tls_stream
.get_ref()
.1
.peer_certificates()
.unwrap()
.get(0)
.unwrap()
.to_vec();
let mut sha256 = sha2::Sha256::new();
sha256.update(tls_key.as_slice());
let peer_id = PeerId::from(sha256.finalize().as_slice().to_vec());
let (r, w) = split(tls_stream);
Ok(CmdTunnel::new(
TlsStreamRead::new(self.local_id.clone(), peer_id.clone(), r),
TlsStreamWrite::new(self.local_id.clone(), peer_id, w),
))
}
}
#[tokio::main]
async fn main() {
let listener = TunnelListener::bind("127.0.0.1:4453").await.unwrap();
let server = DefaultCmdServer::<(), TlsStreamRead, TlsStreamWrite, u16, u8, _>::new(listener);
let sender = server.clone();
server.register_cmd_handler(
0x01,
move |_local_id, peer_id, _tunnel_id, _header: CmdHeader<u16, u8>, _body_read| {
let sender = sender.clone();
async move {
sender.send(&peer_id, 0x02, 0, vec![].as_slice()).await?;
tokio::spawn(async move {
match sender
.send_with_resp(
&peer_id,
0x06,
0,
"server".as_bytes(),
Duration::from_secs(10),
)
.await
{
Ok(resp) => {
println!(
"recv client resp. cmd {} data {}",
0x06,
resp.into_string().await.unwrap()
);
}
Err(e) => {
println!("send err {}", e.msg());
}
}
});
Ok(None)
}
},
);
server.register_cmd_handler(
0x03,
move |_local_id,
_peer_id,
_tunnel_id,
header: CmdHeader<u16, u8>,
mut _body_read: CmdBody| {
async move {
println!(
"recv cmd {} body {}",
header.cmd_code(),
_body_read.into_string().await?
);
Ok(Some(CmdBody::from_string("server resp 0x03".to_string())))
}
},
);
server.start();
tokio::signal::ctrl_c().await.unwrap();
}