use std::{
net::{Ipv4Addr, SocketAddr},
time::Duration,
};
use netstack::{CreateSocket, netcore::Channel};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWriteExt},
time::timeout,
};
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const RESPONSE_TIMEOUT: Duration = Duration::from_secs(30);
const WRITE_IDLE_TIMEOUT: Duration = Duration::from_secs(60);
const MAX_RESP_HEADERS: usize = 8 * 1024;
#[derive(Debug)]
pub enum TaildropSendError {
Connect,
Io,
InvalidName,
Forbidden,
Conflict,
UnexpectedStatus(u16),
Timeout,
}
impl core::fmt::Display for TaildropSendError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
TaildropSendError::Connect => write!(f, "failed to dial peer over the overlay"),
TaildropSendError::Io => write!(f, "taildrop send I/O error"),
TaildropSendError::InvalidName => write!(f, "invalid taildrop file name"),
TaildropSendError::Forbidden => {
write!(
f,
"peer rejected transfer: file-send capability denied (403)"
)
}
TaildropSendError::Conflict => {
write!(f, "a transfer for this file is already in progress (409)")
}
TaildropSendError::UnexpectedStatus(code) => {
write!(f, "peer returned unexpected status {code}")
}
TaildropSendError::Timeout => write!(f, "taildrop send timed out"),
}
}
}
impl std::error::Error for TaildropSendError {}
pub(crate) fn path_escape(name: &str) -> String {
let mut out = String::with_capacity(name.len());
for &b in name.as_bytes() {
let unreserved = b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_' | b'.' | b'~');
if unreserved {
out.push(b as char);
} else {
out.push('%');
out.push(hex_upper(b >> 4));
out.push(hex_upper(b & 0x0F));
}
}
out
}
fn hex_upper(nibble: u8) -> char {
match nibble {
0..=9 => (b'0' + nibble) as char,
_ => (b'A' + (nibble - 10)) as char,
}
}
fn classify_status(code: u16) -> Result<(), TaildropSendError> {
match code {
200..=299 => Ok(()),
403 => Err(TaildropSendError::Forbidden),
409 => Err(TaildropSendError::Conflict),
_ => Err(TaildropSendError::UnexpectedStatus(code)),
}
}
fn parse_status_line(head: &[u8]) -> Option<u16> {
if !head.starts_with(b"HTTP/") {
return None;
}
let space = head.iter().position(|&b| b == b' ')?;
let digits = head.get(space + 1..space + 4)?;
if !digits.iter().all(|b| b.is_ascii_digit()) {
return None;
}
let code = (digits[0] - b'0') as u16 * 100
+ (digits[1] - b'0') as u16 * 10
+ (digits[2] - b'0') as u16;
Some(code)
}
pub async fn send_file<R>(
channel: &Channel,
self_ipv4: Ipv4Addr,
dst: SocketAddr,
name: &str,
content_length: u64,
mut reader: R,
) -> Result<(), TaildropSendError>
where
R: AsyncRead + Unpin,
{
crate::taildrop::validate_base_name(name).ok_or(TaildropSendError::InvalidName)?;
let local = SocketAddr::new(self_ipv4.into(), 0);
tracing::debug!(%dst, name, content_length, "taildrop send: dialing peer over overlay");
let mut stream = timeout(CONNECT_TIMEOUT, channel.tcp_connect(local, dst))
.await
.map_err(|_| TaildropSendError::Timeout)?
.map_err(|_| TaildropSendError::Connect)?;
let head = format!(
"PUT /v0/put/{} HTTP/1.1\r\nHost: {dst}\r\nContent-Length: {content_length}\r\nConnection: close\r\n\r\n",
path_escape(name),
);
write_all_bounded(&mut stream, head.as_bytes()).await?;
let mut buf = [0u8; 64 * 1024];
loop {
let n = reader
.read(&mut buf)
.await
.map_err(|_| TaildropSendError::Io)?;
if n == 0 {
break;
}
write_all_bounded(&mut stream, &buf[..n]).await?;
}
timeout(WRITE_IDLE_TIMEOUT, stream.flush())
.await
.map_err(|_| TaildropSendError::Timeout)?
.map_err(|_| TaildropSendError::Io)?;
let code = timeout(RESPONSE_TIMEOUT, read_response_status(&mut stream))
.await
.map_err(|_| TaildropSendError::Timeout)??;
match classify_status(code) {
Ok(()) => Ok(()),
Err(e) => {
tracing::warn!(%dst, name, status = code, "taildrop send: peer rejected transfer");
Err(e)
}
}
}
async fn write_all_bounded<S>(stream: &mut S, data: &[u8]) -> Result<(), TaildropSendError>
where
S: AsyncWriteExt + Unpin,
{
timeout(WRITE_IDLE_TIMEOUT, stream.write_all(data))
.await
.map_err(|_| TaildropSendError::Timeout)?
.map_err(|_| TaildropSendError::Io)
}
async fn read_response_status<S>(stream: &mut S) -> Result<u16, TaildropSendError>
where
S: AsyncRead + Unpin,
{
let mut buf = Vec::with_capacity(1024);
let mut tmp = [0u8; 1024];
loop {
if crate::peerapi_doh::find_header_end(&buf).is_some() {
break;
}
if buf.len() > MAX_RESP_HEADERS {
return Err(TaildropSendError::Io);
}
let n = stream
.read(&mut tmp)
.await
.map_err(|_| TaildropSendError::Io)?;
if n == 0 {
return Err(TaildropSendError::Io);
}
buf.extend_from_slice(&tmp[..n]);
}
parse_status_line(&buf).ok_or(TaildropSendError::Io)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn path_escape_leaves_unreserved_verbatim() {
assert_eq!(path_escape("photo.jpg"), "photo.jpg");
assert_eq!(
path_escape("AZaz09-_.~"),
"AZaz09-_.~",
"all unreserved bytes pass through"
);
}
#[test]
fn path_escape_encodes_reserved() {
assert_eq!(path_escape("my file.txt"), "my%20file.txt");
assert_eq!(path_escape("a/b"), "a%2Fb");
assert_eq!(path_escape("é"), "%C3%A9");
}
#[test]
fn classify_status_maps_codes() {
assert!(classify_status(200).is_ok());
assert!(classify_status(204).is_ok());
assert!(matches!(
classify_status(403),
Err(TaildropSendError::Forbidden)
));
assert!(matches!(
classify_status(409),
Err(TaildropSendError::Conflict)
));
assert!(matches!(
classify_status(500),
Err(TaildropSendError::UnexpectedStatus(500))
));
}
#[test]
fn parse_status_line_extracts_code() {
assert_eq!(
parse_status_line(b"HTTP/1.1 200 OK\r\nX: 1\r\n\r\n"),
Some(200)
);
assert_eq!(parse_status_line(b"HTTP/1.1 409 Conflict\r\n"), Some(409));
assert_eq!(parse_status_line(b"not http at all"), None);
assert_eq!(parse_status_line(b"HTTP/1.1 XX OK\r\n"), None);
assert_eq!(parse_status_line(b""), None);
}
#[test]
fn send_error_display_is_non_empty() {
for e in [
TaildropSendError::Connect,
TaildropSendError::Io,
TaildropSendError::InvalidName,
TaildropSendError::Forbidden,
TaildropSendError::Conflict,
TaildropSendError::UnexpectedStatus(418),
TaildropSendError::Timeout,
] {
assert!(!e.to_string().is_empty());
}
}
}