pub use tokio::fs::File;
use std::{
error,
error::Error,
io::{self},
mem::ManuallyDrop,
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
time::Duration,
};
use tokio::{
net::{lookup_host, ToSocketAddrs, UdpSocket},
time,
};
#[cfg(target_os = "linux")]
use std::os::unix::{
fs::FileExt,
io::{AsRawFd, FromRawFd},
};
#[cfg(target_os = "windows")]
use std::os::windows::{
fs::FileExt,
io::{AsRawHandle, FromRawHandle},
};
pub mod reciever;
pub mod sender;
pub use reciever::recv_file;
pub use sender::send_file;
enum _Message {
DoneSending,
MissedMessages(Vec<u64>),
FileRecieved,
FileSize(u64),
SendFileSize,
HolePunch,
}
fn u8s_to_u64(nums: &[u8]) -> io::Result<u64> {
if nums.len() != 8 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"nums must be 8 bytes long",
));
}
let msg_u8: [u8; 8] = [
nums[0], nums[1], nums[2], nums[3], nums[4], nums[5], nums[6], nums[7],
];
let big_number = u64::from_be_bytes(msg_u8);
Ok(big_number)
}
async fn send_unil_recv<T: ToSocketAddrs>(
sock: &UdpSocket,
msg: &[u8],
addr: &T,
buf: &mut [u8],
interval: u64,
) -> Result<usize, Box<dyn error::Error + Send + Sync>> {
let mut send_interval = time::interval(Duration::from_millis(interval));
let amt = loop {
tokio::select! {
_ = send_interval.tick() => {
sock.send_to(msg, addr).await?;
}
result = sock.recv_from(buf) => {
let (amt, src) = result?;
if &src != &lookup_host(addr).await?.next().unwrap() {
continue;
}
break amt;
}
}
};
Ok(amt)
}
async fn punch_hole<T: ToSocketAddrs>(
sock: &UdpSocket,
addr: T,
) -> Result<(), Box<dyn Error + Send + Sync>> {
sock.send_to(&[255u8], addr).await?;
Ok(())
}
fn get_buf(msg_num: &u64, file_buf: &[u8]) -> Vec<u8> {
let msg_num_u8 = msg_num.to_be_bytes();
let full = [&msg_num_u8, file_buf].concat();
full
}
async fn read_position<Buf>(
file: &File,
mut buf: Buf,
offset: u64,
) -> Result<(Buf, usize), Box<dyn error::Error + Send + Sync>>
where
Buf: AsMut<[u8]> + Send + 'static,
{
with_std_file(file, move |file| {
#[cfg(target_os = "linux")]
let amt = file.read_at(buf.as_mut(), offset)?;
#[cfg(target_os = "windows")]
let amt = file.seek_read(buf.as_mut(), offset)?;
Ok((buf, amt))
})
.await
}
async fn write_position<Buf>(
file: &File,
buf: Buf,
offset: u64,
) -> Result<usize, Box<dyn error::Error + Send + Sync>>
where
Buf: AsRef<[u8]> + Send + 'static,
{
with_std_file(file, move |file| {
#[cfg(target_os = "linux")]
let amt = file.write_at(buf.as_ref(), offset)?;
#[cfg(target_os = "windows")]
let amt = file.seek_write(buf.as_ref(), offset)?;
Ok(amt)
})
.await
}
async fn with_std_file<F, O>(file: &File, f: F) -> O
where
F: FnOnce(&std::fs::File) -> O + Send + 'static,
O: Send + 'static,
{
#[cfg(unix)]
let file = unsafe { std::fs::File::from_raw_fd(file.as_raw_fd()) };
#[cfg(windows)]
let file = unsafe { std::fs::File::from_raw_handle(file.as_raw_handle()) };
let file = ManuallyDrop::new(file);
tokio::task::spawn_blocking(move || f(&*file))
.await
.unwrap()
}
async fn recv<T>(
sock: &UdpSocket,
from: &T,
buf: &mut [u8],
) -> Result<usize, Box<dyn Error + Send + Sync>>
where
T: ToSocketAddrs,
{
loop {
let (amt, src) = sock.recv_from(buf).await?;
if &src == &lookup_host(from).await?.next().unwrap() {
return Ok(amt);
}
}
}
pub enum Source {
SocketArc(Arc<UdpSocket>),
Socket(UdpSocket),
Port(u16),
}
impl Source {
async fn into_socket(self) -> Arc<UdpSocket> {
match self {
Source::SocketArc(s) => s,
Source::Socket(s) => Arc::new(s),
Source::Port(port) => Arc::new(
UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port))
.await
.unwrap(),
),
}
}
}