use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use crate::buffer::{InputBuffer, OutputBuffer};
use crate::error::RnsError;
use crate::iface::RxMessage;
use crate::packet::Packet;
use crate::serde::Serialize;
use tokio::io::AsyncReadExt;
use alloc::string::String;
use super::hdlc::Hdlc;
use super::{Interface, InterfaceContext};
const PACKET_TRACE: bool = false;
pub struct TcpClient {
addr: String,
stream: Option<TcpStream>,
}
impl TcpClient {
pub fn new<T: Into<String>>(addr: T) -> Self {
Self {
addr: addr.into(),
stream: None,
}
}
pub fn new_from_stream<T: Into<String>>(addr: T, stream: TcpStream) -> Self {
Self {
addr: addr.into(),
stream: Some(stream),
}
}
pub async fn spawn(context: InterfaceContext<TcpClient>) {
let iface_stop = context.channel.stop.clone();
let addr = { context.inner.lock().unwrap().addr.clone() };
let iface_address = context.channel.address;
let mut stream = { context.inner.lock().unwrap().stream.take() };
let (rx_channel, tx_channel) = context.channel.split();
let tx_channel = Arc::new(tokio::sync::Mutex::new(tx_channel));
let mut running = true;
loop {
if !running || context.cancel.is_cancelled() {
break;
}
let stream = {
match stream.take() {
Some(stream) => {
running = false;
Ok(stream)
}
None => TcpStream::connect(addr.clone())
.await
.map_err(|_| RnsError::ConnectionError),
}
};
if stream.is_err() {
log::info!("tcp_client: couldn't connect to <{}>", addr);
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
continue;
}
let cancel = context.cancel.clone();
let stop = CancellationToken::new();
let stream = stream.unwrap();
let (read_stream, write_stream) = stream.into_split();
log::info!("tcp_client connected to <{}>", addr);
const BUFFER_SIZE: usize = core::mem::size_of::<Packet>() * 2;
let rx_task = {
let cancel = cancel.clone();
let stop = stop.clone();
let mut stream = read_stream;
let rx_channel = rx_channel.clone();
tokio::spawn(async move {
let mut hdlc_rx_buffer = [0u8; BUFFER_SIZE];
let mut rx_buffer = [0u8; BUFFER_SIZE + (BUFFER_SIZE / 2)];
let mut tcp_buffer = [0u8; (BUFFER_SIZE * 16)];
loop {
tokio::select! {
_ = cancel.cancelled() => {
break;
}
_ = stop.cancelled() => {
break;
}
result = stream.read(&mut tcp_buffer[..]) => {
match result {
Ok(0) => {
log::warn!("tcp_client: connection closed");
stop.cancel();
break;
}
Ok(n) => {
for &byte in tcp_buffer.iter().take(n) {
rx_buffer[BUFFER_SIZE - 1] = byte;
let frame = Hdlc::find(&rx_buffer[..]);
if let Some(frame) = frame {
let frame_buffer = &mut rx_buffer[frame.0..frame.1+1];
let mut output = OutputBuffer::new(&mut hdlc_rx_buffer[..]);
if Hdlc::decode(frame_buffer, &mut output).is_ok() {
if let Ok(packet) = Packet::deserialize(&mut InputBuffer::new(output.as_slice())) {
if PACKET_TRACE {
log::trace!("tcp_client: rx << ({}) {}", iface_address, packet);
}
let _ = rx_channel.send(RxMessage { address: iface_address, packet }).await;
} else {
log::warn!("tcp_client: couldn't decode packet");
}
} else {
log::warn!("tcp_client: couldn't decode hdlc frame");
}
frame_buffer.fill(0);
} else {
rx_buffer.copy_within(1.., 0);
}
}
}
Err(e) => {
log::warn!("tcp_client: connection error {}", e);
break;
}
}
},
};
}
})
};
let tx_task = {
let cancel = cancel.clone();
let tx_channel = tx_channel.clone();
let mut stream = write_stream;
tokio::spawn(async move {
loop {
if stop.is_cancelled() {
break;
}
let mut hdlc_tx_buffer = [0u8; BUFFER_SIZE];
let mut tx_buffer = [0u8; BUFFER_SIZE];
let mut tx_channel = tx_channel.lock().await;
tokio::select! {
_ = cancel.cancelled() => {
break;
}
_ = stop.cancelled() => {
break;
}
Some(message) = tx_channel.recv() => {
let packet = message.packet;
if PACKET_TRACE {
log::trace!("tcp_client: tx >> ({}) {}", iface_address, packet);
}
let mut output = OutputBuffer::new(&mut tx_buffer);
if packet.serialize(&mut output).is_ok() {
let mut hdlc_output = OutputBuffer::new(&mut hdlc_tx_buffer[..]);
if Hdlc::encode(output.as_slice(), &mut hdlc_output).is_ok() {
let _ = stream.write_all(hdlc_output.as_slice()).await;
let _ = stream.flush().await;
}
}
}
};
}
})
};
tx_task.await.unwrap();
rx_task.await.unwrap();
log::info!("tcp_client: disconnected from <{}>", addr);
}
iface_stop.cancel();
}
}
impl Interface for TcpClient {
fn mtu() -> usize {
2048
}
}