use std::{
borrow::Cow,
io::{self, Read, Write},
net::SocketAddr,
path::Path,
};
use anyhow::Context;
use crate::{
pb::ProgressBar,
tf::{receiver_receive_file, receiver_send_file},
};
pub trait App {
type Stream: Read + Write;
type UpgradeStream: Read + Write;
fn prefix(&self) -> &str;
fn broadcast_addr(&self) -> SocketAddr;
fn download_dir<'a>(&'a self) -> Cow<'a, Path>;
fn disable_broadcaster(&self) -> bool {
false
}
fn preprocess_connection(&self, stream: &mut Self::Stream) -> anyhow::Result<bool> {
let _ = stream;
Ok(true)
}
fn auth(&self, stream: &mut Self::Stream) -> anyhow::Result<bool> {
let _ = stream;
Ok(true)
}
fn upgrade_stream(&self, stream: Self::Stream) -> anyhow::Result<Self::UpgradeStream>;
fn postprocess_connection(&self, stream: &mut Self::UpgradeStream) -> anyhow::Result<()> {
let _ = stream;
Ok(())
}
fn create_progress_bar(&self, total: u64) -> Box<dyn ProgressBar>;
fn start_broadcaster(
&self,
listener_addr: SocketAddr,
) -> (impl FnOnce(), std::thread::JoinHandle<()>);
}
pub fn run_v1_0<A, P, I, F>(
app: A,
files_to_send: impl Iterator<Item = P>,
create_listener: F,
) -> anyhow::Result<()>
where
A: App,
P: AsRef<Path>,
I: Iterator<Item = io::Result<A::Stream>> + Send + 'static,
F: Fn(&A) -> anyhow::Result<(SocketAddr, I)>,
{
let (listen_addr, incoming_streams) = create_listener(&app)?;
let broadcaster = if !app.disable_broadcaster() {
Some(app.start_broadcaster(listen_addr))
} else {
None
};
let stream = accept_authenticated_stream(&app, incoming_streams).with_context(|| {
format!(
"Failed to accept authenticated connection on {}",
listen_addr
)
})?;
if let Some((stop, handle)) = broadcaster {
stop();
handle
.join()
.map_err(|_| anyhow::anyhow!("Broadcaster thread panicked"))?;
}
let mut stream = app.upgrade_stream(stream)?;
app.postprocess_connection(&mut stream)
.context("postprocess faild")?;
loop {
let mut marker = [0u8; 5];
stream.read_exact(&mut marker)?;
match &marker {
b":fff:" => {
receiver_receive_file(&app, &mut stream)?;
}
b":eof:" => break,
_ => unreachable!("Invalid protocol marker"),
}
}
for path in files_to_send {
receiver_send_file(&app, path, &mut stream)?;
}
stream.write_all(b":eof:")?;
stream.flush()?;
Ok(())
}
fn accept_authenticated_stream<A: App, L>(app: &A, incoming: L) -> anyhow::Result<A::Stream>
where
L: Iterator<Item = io::Result<A::Stream>>,
{
for stream in incoming {
let mut stream = match stream {
Ok(s) => s,
Err(_) => continue,
};
match app.preprocess_connection(&mut stream) {
Ok(false) | Err(_) => continue,
_ => {}
}
match match_bytes("fs-share:v1.0\n", &mut stream) {
Ok(true) => {
stream.write_all(b":accept:")?;
stream.flush()?;
}
Ok(false) => {
let _ = stream.write_all(b":reject:");
let _ = stream.flush();
continue;
}
Err(_) => continue,
}
match app.auth(&mut stream) {
Ok(true) => return Ok(stream),
_ => continue,
}
}
anyhow::bail!("No authenticated connection found")
}
fn match_bytes<B: AsRef<[u8]>, R: Read>(bytes: B, mut reader: R) -> anyhow::Result<bool> {
let expected = bytes.as_ref();
let mut buf = vec![0u8; expected.len()].into_boxed_slice();
reader
.read_exact(&mut buf)
.context("Failed to read bytes from reader")?;
Ok(buf.as_ref() == expected)
}