use std::pin::Pin;
use std::task::{Context, Poll};
use rustls::pki_types::ServerName;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use crate::Error;
use crate::transport::tls_config::make_rustls_client_config;
pub(crate) enum MaybeTls {
Plain(TcpStream),
Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
}
impl AsyncRead for MaybeTls {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_read(cx, buf),
MaybeTls::Tls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for MaybeTls {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_write(cx, buf),
MaybeTls::Tls(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_flush(cx),
MaybeTls::Tls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_shutdown(cx),
MaybeTls::Tls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
async fn connect(ip: &str, port: u16, secured: bool) -> Result<MaybeTls, Error> {
let tcp = TcpStream::connect((ip, port)).await?;
if secured {
let config = make_rustls_client_config();
let connector = TlsConnector::from(config);
let server_name = ServerName::try_from(ip.to_string()).unwrap_or_else(|_| {
ServerName::IpAddress(ip.parse::<std::net::IpAddr>().unwrap().into())
});
let stream = connector
.connect(server_name, tcp)
.await
.map_err(|e| Error::ConnectionFailed(format!("TLS connect error: {e}")))?;
Ok(MaybeTls::Tls(Box::new(stream)))
} else {
Ok(MaybeTls::Plain(tcp))
}
}
pub(crate) async fn upload_image(
ip: &str,
port: u16,
sec_key: &str,
file_type: &str,
image_data: &[u8],
secured: bool,
) -> Result<MaybeTls, Error> {
let mut stream = connect(ip, port, secured).await?;
let header = serde_json::json!({
"num": 0,
"total": 1,
"fileLength": image_data.len(),
"fileName": "framesmith_upload",
"fileType": file_type,
"secKey": sec_key,
"version": "0.0.1"
});
let header_bytes = header.to_string().into_bytes();
let header_len = (header_bytes.len() as u32).to_be_bytes();
stream.write_all(&header_len).await?;
stream.write_all(&header_bytes).await?;
const CHUNK_SIZE: usize = 64 * 1024;
for chunk in image_data.chunks(CHUNK_SIZE) {
stream.write_all(chunk).await?;
}
stream.flush().await?;
Ok(stream)
}
pub(crate) async fn download_thumbnail(
ip: &str,
port: u16,
secured: bool,
) -> Result<Vec<u8>, Error> {
let mut stream = connect(ip, port, secured).await?;
let mut header_len_buf = [0u8; 4];
stream.read_exact(&mut header_len_buf).await?;
let header_len = u32::from_be_bytes(header_len_buf) as usize;
let mut header_buf = vec![0u8; header_len];
stream.read_exact(&mut header_buf).await?;
let header: serde_json::Value = serde_json::from_slice(&header_buf)?;
let file_length = header["fileLength"]
.as_u64()
.ok_or_else(|| Error::ConnectionFailed("missing fileLength in thumbnail header".into()))?
as usize;
let mut data = vec![0u8; file_length];
stream.read_exact(&mut data).await?;
Ok(data)
}