use async_trait::async_trait;
#[cfg(feature = "logging")]
use log::debug;
#[cfg(feature = "logging")]
use log::info;
use std::{
error,
io::{self},
time::Duration,
};
use tokio::{
fs::{remove_file, File, OpenOptions},
net::ToSocketAddrs,
time,
};
use crate::{read_position, recv, send_unil_recv, u8s_to_u64, write_position, Source};
#[async_trait]
trait ProgressTracker {
async fn recv_msg(&mut self, msg_num: u64) -> Result<(), Box<dyn error::Error + Send + Sync>>;
async fn get_unrecv(&self) -> Result<Vec<u64>, Box<dyn error::Error + Send + Sync>>;
async fn destruct(&self);
}
struct FileProgTrack {
filename: String,
file: File,
size: u64,
}
impl FileProgTrack {
async fn new(filename: String, size: u64) -> Result<Self, Box<dyn error::Error>> {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(&filename)
.await?;
file.set_len(get_msg_amt(size)).await?;
Ok(Self {
filename,
file,
size,
})
}
}
#[async_trait]
impl ProgressTracker for FileProgTrack {
async fn recv_msg(&mut self, msg_num: u64) -> Result<(), Box<dyn error::Error + Send + Sync>> {
let (offset, pos_in_offset) = get_pos_of_num(msg_num);
let (offset_buf, _) = read_position(&self.file, [0u8; 1], offset).await?;
let mut offset_binary = to_binary(offset_buf[0]);
offset_binary[pos_in_offset as usize] = true;
let offset_buf = from_binary(offset_binary);
write_position(&self.file, [offset_buf], offset).await?;
Ok(())
}
async fn get_unrecv(&self) -> Result<Vec<u64>, Box<dyn error::Error + Send + Sync>> {
let mut dropped: Vec<u64> = Vec::new();
let total = if self.size % 500 == 0 {
self.size / 500
} else {
self.size / 500 + 1
};
for byte in 0..self.file.metadata().await?.len() {
let ([bin], _) = read_position(&self.file, [0u8], byte).await?;
let bin = to_binary(bin);
let mut bit_pos = 0;
for bit in bin {
let num = get_num_of_pos(byte, bit_pos);
if num == total {
return Ok(dropped);
}
if !bit {
dropped.push(num);
if dropped.len() == 63 {
return Ok(dropped);
}
}
bit_pos += 1;
}
}
Ok(dropped)
}
async fn destruct(&self) {
remove_file(&self.filename).await.unwrap()
}
}
fn get_msg_amt(file_len: u64) -> u64 {
if file_len % 500 == 0 {
file_len / 500
} else {
file_len / 500 + 1
}
}
struct MemProgTracker {
tracker: Vec<u8>,
}
impl MemProgTracker {
fn new(size: u64) -> Self {
let tracker_size = get_msg_amt(size) as usize;
let tracker = vec![0u8; tracker_size];
Self { tracker }
}
}
#[async_trait]
impl ProgressTracker for MemProgTracker {
async fn recv_msg(&mut self, msg_num: u64) -> Result<(), Box<dyn error::Error + Send + Sync>> {
let (offset, pos_in_offset) = get_pos_of_num(msg_num);
let mut offset_binary = to_binary(self.tracker[offset as usize]);
offset_binary[pos_in_offset as usize] = true;
let offset_buf = from_binary(offset_binary);
self.tracker[offset as usize] = offset_buf;
Ok(())
}
async fn get_unrecv(&self) -> Result<Vec<u64>, Box<dyn error::Error + Send + Sync>> {
let mut dropped: Vec<u64> = Vec::new();
let total = get_msg_amt(self.tracker.len() as u64);
let mut i = 0;
for byte in &self.tracker {
let bin = to_binary(*byte);
let mut bit_pos = 0;
for bit in bin {
let num = get_num_of_pos(i, bit_pos);
if num == total {
return Ok(dropped);
}
if !bit {
dropped.push(num);
if dropped.len() == 63 {
return Ok(dropped);
}
}
bit_pos += 1;
}
i += 1;
}
Ok(dropped)
}
async fn destruct(&self) {}
}
pub enum ProgressTracking {
File(String),
Memory,
}
fn get_offset(msg_num: u64) -> u64 {
msg_num * 500
}
fn get_pos_of_num(num: u64) -> (u64, u8) {
let cell = num / 8;
let cellpos = num % 8;
(cell, cellpos as u8)
}
fn get_num_of_pos(byte: u64, pos: u8) -> u64 {
byte * 8 + pos as u64
}
fn to_binary(mut num: u8) -> [bool; 8] {
let mut arr = [false; 8];
if num >= 128 {
arr[0] = true;
num -= 128;
}
if num >= 64 {
arr[1] = true;
num -= 64;
}
if num >= 32 {
arr[2] = true;
num -= 32;
}
if num >= 16 {
arr[3] = true;
num -= 16;
}
if num >= 8 {
arr[4] = true;
num -= 8;
}
if num >= 4 {
arr[5] = true;
num -= 4;
}
if num >= 2 {
arr[6] = true;
num -= 2;
}
if num >= 1 {
arr[7] = true;
}
arr
}
fn from_binary(bin: [bool; 8]) -> u8 {
let mut num = 0;
if bin[0] {
num += 128;
}
if bin[1] {
num += 64;
}
if bin[2] {
num += 32;
}
if bin[3] {
num += 16;
}
if bin[4] {
num += 8;
}
if bin[5] {
num += 4;
}
if bin[6] {
num += 2;
}
if bin[7] {
num += 1;
}
num
}
async fn write_msg(
buf: &[u8],
out_file: &File,
prog_tracker: &mut Box<dyn ProgressTracker>,
) -> Result<u64, Box<dyn error::Error + Send + Sync>> {
let msg_num = u8s_to_u64(&buf[0..8])?;
let msg_offset = get_offset(msg_num);
let rest = buf[8..].to_owned();
write_position(out_file, rest, msg_offset).await.unwrap();
prog_tracker.recv_msg(msg_num).await?;
Ok(msg_num)
}
pub async fn recv_file<T>(
source: Source,
file: &mut File,
sender: T,
progress_tracking: ProgressTracking,
) -> Result<(), Box<dyn error::Error + Send + Sync>>
where
T: 'static + Clone + ToSocketAddrs + std::marker::Send + Copy, {
let sock = source.into_socket().await;
let sock_ = sock.clone();
let sender_ = sender.clone();
let holepuncher = tokio::task::spawn(async move {
let sock = sock_;
let sender = sender_;
let mut holepunch_interval = time::interval(Duration::from_secs(5));
loop {
sock.send_to(&[255u8], sender).await.unwrap();
holepunch_interval.tick().await;
}
});
#[cfg(feature = "logging")]
debug!("getting file size");
let buf: [u8; 508];
let amt = loop {
let mut new_buf = [0u8; 508];
let amt = send_unil_recv(&*sock, &[9], &sender, &mut new_buf, 500).await?;
#[cfg(feature = "logging")]
debug!("got size msg: {:?}", &new_buf[0..amt]);
if amt == 9 && new_buf[0] == 8 {
buf = new_buf;
break amt;
}
};
let buf = &buf[0..amt];
let size_be_bytes = &buf[1..];
let size = u8s_to_u64(size_be_bytes)?;
#[cfg(feature = "logging")]
debug!("size: {}", size);
let mut prog_tracker: Box<dyn ProgressTracker> = match progress_tracking {
ProgressTracking::File(filename) => {
Box::new(FileProgTrack::new(filename, size).await.unwrap())
}
ProgressTracking::Memory => Box::new(MemProgTracker::new(size)),
};
let mut first = true;
'pass: loop {
let mut first_data: Option<([u8; 508], usize)> = None;
if !first {
let dropped = prog_tracker.get_unrecv().await?;
if dropped.len() == 0 {
#[cfg(feature = "logging")]
debug!("everything recieved correctly");
loop {
let sleep = time::sleep(Duration::from_millis(1500));
let mut buf = [0u8; 508];
tokio::select! {
_ = sleep => {
break;
}
amt = recv(&sock, &sender, &mut buf) => {
let amt = amt?;
let buf = &buf[0..amt];
if buf[0] == 5 {
sock.send_to(&[7], sender).await?;
}
}
}
}
break;
}
#[cfg(feature = "logging")]
debug!("everything was not recieved correctly");
let dropped_msg = gen_dropped_msg(dropped)?;
loop {
let mut buf = [0u8; 508];
let amt = send_unil_recv(&sock, &dropped_msg, &sender, &mut buf, 100).await?;
let msg_buf = &buf[0..amt];
if msg_buf.len() > 1 && msg_buf[0] != 5 {
first_data = Some((buf, amt));
break;
}
}
}
loop {
let wait_time = time::sleep(Duration::from_millis(2000));
let mut buf = [0; 508];
let amt = if let Some((new_buf, amt)) = first_data {
buf = new_buf;
first_data = None;
amt
} else {
let amt = tokio::select! {
_ = wait_time => {
break;
}
amt = recv(&sock, &sender, &mut buf) => {
let amt = amt?;
amt
}
};
amt
};
let buf = &buf[0..amt];
if buf.len() == 1 && buf[0] == 255 {
continue;
} else if buf.len() == 1 && buf[0] == 5 {
continue 'pass;
}
if first && buf[0] == 8 {
continue;
}
#[cfg(feature = "logging")]
let msg_num = write_msg(buf, file, &mut prog_tracker).await?;
#[cfg(not(feature = "logging"))]
write_msg(buf, file, &mut prog_tracker).await?;
#[cfg(feature = "logging")]
info!("msg {} / {}, {}%", msg_num, size / 500, msg_num * 100 / (size / 500));
first = false;
}
}
holepuncher.abort();
Ok(())
}
fn gen_dropped_msg(dropped: Vec<u64>) -> Result<Vec<u8>, Box<dyn error::Error + Send + Sync>> {
if dropped.len() > 63 {
return Err(Box::new(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"maximum amount of dropped messages is 63, got {}",
dropped.len()
)
.as_str(),
)));
}
let mut msg: Vec<u8> = vec![6];
for drop in dropped {
msg.append(&mut drop.to_be_bytes().as_slice().to_owned())
}
Ok(msg)
}