use std::sync::Arc;
use bytes::{BufMut, BytesMut};
use native_tls::TlsConnector;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, Mutex};
use tokio_native_tls::TlsStream;
use crate::client_helper::read_loop;
use crate::error::Error;
pub struct Transport {
writer: Arc<Mutex<tokio::io::WriteHalf<TlsStream<TcpStream>>>>,
}
impl Transport {
pub async fn connect(
host: &str,
port: u16,
frame_tx: mpsc::UnboundedSender<Vec<u8>>,
) -> Result<Self, Error> {
let addr = format!("{host}:{port}");
tracing::debug!("connecting to {addr}");
let tcp = TcpStream::connect(&addr).await?;
let connector = TlsConnector::builder()
.danger_accept_invalid_certs(false)
.build()
.map_err(Error::Tls)?;
let connector = tokio_native_tls::TlsConnector::from(connector);
let tls = connector
.connect(host, tcp)
.await
.map_err(|e| Error::Tls(e.into()))?;
let (reader, writer) = tokio::io::split(tls);
let writer = Arc::new(Mutex::new(writer));
tokio::spawn(async move {
if let Err(e) = read_loop(reader, frame_tx).await {
tracing::error!("transport read loop error: {e}");
}
});
Ok(Self { writer })
}
pub async fn send(&self, payload: &[u8]) -> Result<(), Error> {
let mut buf = BytesMut::with_capacity(4 + payload.len());
buf.put_u32(payload.len() as u32);
buf.put_slice(payload);
let mut w = self.writer.lock().await;
w.write_all(&buf).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
#[tokio::test]
async fn test() {}
}