use std::net::{SocketAddr, ToSocketAddrs, UdpSocket};
use std::time::Duration;
use crate::error::{Error, Result};
use crate::url::Url;
const OP_RRQ: u16 = 1;
const OP_WRQ: u16 = 2;
const OP_DATA: u16 = 3;
const OP_ACK: u16 = 4;
const OP_ERROR: u16 = 5;
const BLOCK_SIZE: usize = 512;
const READ_TIMEOUT: Duration = Duration::from_secs(5);
const MAX_RETRIES: u32 = 3;
const MAX_TOTAL_BYTES: usize = 256 * 1024 * 1024;
#[derive(Debug, PartialEq, Eq)]
struct DataPacket<'a> {
block: u16,
data: &'a [u8],
}
fn build_request(opcode: u16, filename: &str) -> Vec<u8> {
let mut p = Vec::with_capacity(2 + filename.len() + 1 + 5 + 1);
p.extend_from_slice(&opcode.to_be_bytes());
p.extend_from_slice(filename.as_bytes());
p.push(0);
p.extend_from_slice(b"octet");
p.push(0);
p
}
fn build_rrq(filename: &str) -> Vec<u8> {
build_request(OP_RRQ, filename)
}
fn build_wrq(filename: &str) -> Vec<u8> {
build_request(OP_WRQ, filename)
}
fn build_data(block: u16, payload: &[u8]) -> Vec<u8> {
let mut p = Vec::with_capacity(4 + payload.len());
p.extend_from_slice(&OP_DATA.to_be_bytes());
p.extend_from_slice(&block.to_be_bytes());
p.extend_from_slice(payload);
p
}
fn build_ack(block: u16) -> [u8; 4] {
let mut p = [0u8; 4];
p[0..2].copy_from_slice(&OP_ACK.to_be_bytes());
p[2..4].copy_from_slice(&block.to_be_bytes());
p
}
fn parse_opcode(buf: &[u8]) -> Option<u16> {
if buf.len() < 2 {
return None;
}
Some(u16::from_be_bytes([buf[0], buf[1]]))
}
fn parse_data(buf: &[u8]) -> Result<DataPacket<'_>> {
if buf.len() < 4 {
return Err(Error::BadResponse("tftp: short DATA packet".into()));
}
if parse_opcode(buf) != Some(OP_DATA) {
return Err(Error::BadResponse("tftp: not a DATA packet".into()));
}
let block = u16::from_be_bytes([buf[2], buf[3]]);
Ok(DataPacket {
block,
data: &buf[4..],
})
}
fn parse_ack(buf: &[u8]) -> Result<u16> {
if buf.len() < 4 {
return Err(Error::BadResponse("tftp: short ACK packet".into()));
}
if parse_opcode(buf) != Some(OP_ACK) {
return Err(Error::BadResponse("tftp: not an ACK packet".into()));
}
Ok(u16::from_be_bytes([buf[2], buf[3]]))
}
fn parse_error(buf: &[u8]) -> Result<String> {
if buf.len() < 4 {
return Err(Error::BadResponse("tftp: short ERROR packet".into()));
}
if parse_opcode(buf) != Some(OP_ERROR) {
return Err(Error::BadResponse("tftp: not an ERROR packet".into()));
}
let msg_bytes = &buf[4..];
let end = msg_bytes
.iter()
.position(|&b| b == 0)
.unwrap_or(msg_bytes.len());
Ok(String::from_utf8_lossy(&msg_bytes[..end]).into_owned())
}
fn resolve(host: &str, port: u16) -> Result<SocketAddr> {
(host, port)
.to_socket_addrs()?
.next()
.ok_or_else(|| Error::BadResponse(format!("tftp: cannot resolve {host}:{port}")))
}
fn filename_of(url: &Url) -> Result<&str> {
let filename = url.path.strip_prefix('/').unwrap_or(&url.path);
if filename.is_empty() {
return Err(Error::InvalidUrl(format!(
"tftp: empty filename in {}://{}/{}",
url.scheme, url.host, url.path
)));
}
if filename.as_bytes().contains(&0) {
return Err(Error::InvalidUrl("tftp: filename contains NUL".into()));
}
Ok(filename)
}
pub fn fetch(url: &Url) -> Result<Vec<u8>> {
let filename = filename_of(url)?;
let server = resolve(&url.host, url.port)?;
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.set_read_timeout(Some(READ_TIMEOUT))?;
let rrq = build_rrq(filename);
let mut peer: Option<SocketAddr> = None;
let mut out: Vec<u8> = Vec::new();
let mut expected_block: u16 = 1;
let mut buf = [0u8; 4 + BLOCK_SIZE + 16];
let mut last_packet: Vec<u8> = rrq;
let mut last_dest: SocketAddr = server;
socket.send_to(&last_packet, last_dest)?;
let mut retries: u32 = 0;
loop {
let (n, from) = match socket.recv_from(&mut buf) {
Ok(v) => v,
Err(e) => {
if matches!(
e.kind(),
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
) {
if retries >= MAX_RETRIES {
return Err(Error::UnexpectedEof);
}
retries += 1;
socket.send_to(&last_packet, last_dest)?;
continue;
}
return Err(Error::Io(e));
}
};
if let Some(p) = peer {
if from != p {
continue;
}
} else if from.ip() != server.ip() {
continue;
}
let pkt = &buf[..n];
match parse_opcode(pkt) {
Some(OP_DATA) => {
let data = parse_data(pkt)?;
if data.block != expected_block {
if data.block.wrapping_add(1) == expected_block {
let ack = build_ack(data.block);
socket.send_to(&ack, from)?;
}
continue;
}
if peer.is_none() {
peer = Some(from);
}
if out.len() + data.data.len() > MAX_TOTAL_BYTES {
return Err(Error::BadResponse(format!(
"tftp: transfer exceeds {} bytes",
MAX_TOTAL_BYTES
)));
}
out.extend_from_slice(data.data);
let ack = build_ack(data.block);
socket.send_to(&ack, from)?;
let is_last = data.data.len() < BLOCK_SIZE;
if is_last {
return Ok(out);
}
last_packet = ack.to_vec();
last_dest = from;
retries = 0;
expected_block = match expected_block.checked_add(1) {
Some(b) => b,
None => {
return Err(Error::BadResponse(
"tftp: block number wrapped; refusing oversized transfer".into(),
));
}
};
}
Some(OP_ERROR) => {
let msg = parse_error(pkt)?;
return Err(Error::BadResponse(format!("tftp: {msg}")));
}
Some(op) => {
return Err(Error::BadResponse(format!("tftp: unexpected opcode {op}")));
}
None => {
return Err(Error::BadResponse("tftp: packet too short".into()));
}
}
}
}
pub fn store(url: &Url, data: &[u8]) -> Result<()> {
let filename = filename_of(url)?;
if data.len() > MAX_TOTAL_BYTES {
return Err(Error::BadResponse(format!(
"tftp: upload exceeds {MAX_TOTAL_BYTES} bytes"
)));
}
let server = resolve(&url.host, url.port)?;
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.set_read_timeout(Some(READ_TIMEOUT))?;
let wrq = build_wrq(filename);
let mut peer: Option<SocketAddr> = None;
let mut buf = [0u8; 4 + BLOCK_SIZE + 16];
let mut last_packet: Vec<u8> = wrq;
let mut last_dest: SocketAddr = server;
socket.send_to(&last_packet, last_dest)?;
let mut retries: u32 = 0;
let mut block: u16 = 0;
let mut sent_offset: usize = 0;
let mut sent_final = false;
loop {
let (n, from) = match socket.recv_from(&mut buf) {
Ok(v) => v,
Err(e) => {
if matches!(
e.kind(),
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
) {
if retries >= MAX_RETRIES {
return Err(Error::UnexpectedEof);
}
retries += 1;
socket.send_to(&last_packet, last_dest)?;
continue;
}
return Err(Error::Io(e));
}
};
if let Some(p) = peer {
if from != p {
continue;
}
} else if from.ip() != server.ip() {
continue;
}
let pkt = &buf[..n];
match parse_opcode(pkt) {
Some(OP_ACK) => {
let acked = parse_ack(pkt)?;
if acked != block {
continue;
}
if peer.is_none() {
peer = Some(from);
}
if sent_final {
return Ok(());
}
let next_offset = if block == 0 {
0
} else {
sent_offset + BLOCK_SIZE
};
let next_block = match block.checked_add(1) {
Some(b) => b,
None => {
return Err(Error::BadResponse(
"tftp: block number wrapped; refusing oversized transfer".into(),
));
}
};
let end = (next_offset + BLOCK_SIZE).min(data.len());
let payload = &data[next_offset..end];
let dgram = build_data(next_block, payload);
socket.send_to(&dgram, from)?;
block = next_block;
sent_offset = next_offset;
sent_final = payload.len() < BLOCK_SIZE;
last_packet = dgram;
last_dest = from;
retries = 0;
}
Some(OP_ERROR) => {
let msg = parse_error(pkt)?;
return Err(Error::BadResponse(format!("tftp: {msg}")));
}
Some(op) => {
return Err(Error::BadResponse(format!("tftp: unexpected opcode {op}")));
}
None => {
return Err(Error::BadResponse("tftp: packet too short".into()));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rrq_builds_in_octet_mode() {
let p = build_rrq("hello.txt");
assert_eq!(p[0..2], [0x00, 0x01]);
assert_eq!(&p[2..2 + b"hello.txt".len()], b"hello.txt");
let after_name = 2 + b"hello.txt".len();
assert_eq!(p[after_name], 0);
assert_eq!(&p[after_name + 1..after_name + 1 + 5], b"octet");
assert_eq!(*p.last().unwrap(), 0);
assert_eq!(p.len(), 2 + 9 + 1 + 5 + 1);
}
#[test]
fn rrq_handles_empty_filename_shape() {
let p = build_rrq("");
assert_eq!(
p,
vec![0x00, 0x01, 0x00, b'o', b'c', b't', b'e', b't', 0x00]
);
}
#[test]
fn ack_encodes_block_number_big_endian() {
assert_eq!(build_ack(0), [0x00, 0x04, 0x00, 0x00]);
assert_eq!(build_ack(1), [0x00, 0x04, 0x00, 0x01]);
assert_eq!(build_ack(0x0102), [0x00, 0x04, 0x01, 0x02]);
assert_eq!(build_ack(0xFFFF), [0x00, 0x04, 0xFF, 0xFF]);
}
#[test]
fn parse_opcode_handles_short_input() {
assert_eq!(parse_opcode(&[]), None);
assert_eq!(parse_opcode(&[0x00]), None);
assert_eq!(parse_opcode(&[0x00, 0x03]), Some(3));
assert_eq!(parse_opcode(&[0x00, 0x05, 0xAA]), Some(5));
}
#[test]
fn parse_data_extracts_block_and_payload() {
let pkt = [0x00, 0x03, 0x00, 0x07, b'a', b'b', b'c'];
let d = parse_data(&pkt).unwrap();
assert_eq!(d.block, 7);
assert_eq!(d.data, b"abc");
}
#[test]
fn parse_data_allows_empty_payload() {
let pkt = [0x00, 0x03, 0x00, 0x42];
let d = parse_data(&pkt).unwrap();
assert_eq!(d.block, 0x42);
assert_eq!(d.data, b"");
}
#[test]
fn parse_data_rejects_short_header() {
assert!(parse_data(&[]).is_err());
assert!(parse_data(&[0x00, 0x03]).is_err());
assert!(parse_data(&[0x00, 0x03, 0x00]).is_err());
}
#[test]
fn parse_data_rejects_wrong_opcode() {
let pkt = [0x00, 0x04, 0x00, 0x01];
assert!(parse_data(&pkt).is_err());
}
#[test]
fn parse_error_strips_trailing_nul() {
let pkt = [0x00, 0x05, 0x00, 0x01, b'N', b'o', b'p', b'e', 0x00];
let m = parse_error(&pkt).unwrap();
assert_eq!(m, "Nope");
}
#[test]
fn parse_error_tolerates_missing_nul() {
let pkt = [0x00, 0x05, 0x00, 0x02, b'h', b'i'];
let m = parse_error(&pkt).unwrap();
assert_eq!(m, "hi");
}
#[test]
fn parse_error_handles_empty_message() {
let pkt = [0x00, 0x05, 0x00, 0x03, 0x00];
let m = parse_error(&pkt).unwrap();
assert_eq!(m, "");
}
#[test]
fn parse_error_rejects_short_header() {
assert!(parse_error(&[0x00, 0x05]).is_err());
assert!(parse_error(&[0x00, 0x05, 0x00]).is_err());
}
#[test]
fn parse_error_rejects_wrong_opcode() {
let pkt = [0x00, 0x03, 0x00, 0x01, b'x', 0x00];
assert!(parse_error(&pkt).is_err());
}
#[test]
fn parse_error_invalid_utf8_lossy() {
let pkt = [0x00, 0x05, 0x00, 0x01, 0xFF, 0x00];
let m = parse_error(&pkt).unwrap();
assert!(m.contains('\u{FFFD}'));
}
#[test]
fn wrq_builds_in_octet_mode() {
let p = build_wrq("hello.txt");
assert_eq!(p[0..2], [0x00, 0x02]);
assert_eq!(&p[2..2 + b"hello.txt".len()], b"hello.txt");
let after_name = 2 + b"hello.txt".len();
assert_eq!(p[after_name], 0);
assert_eq!(&p[after_name + 1..after_name + 1 + 5], b"octet");
assert_eq!(*p.last().unwrap(), 0);
assert_eq!(p.len(), 2 + 9 + 1 + 5 + 1);
}
#[test]
fn wrq_and_rrq_share_framing_only_opcode_differs() {
let r = build_rrq("f");
let w = build_wrq("f");
assert_eq!(r[0..2], [0x00, 0x01]);
assert_eq!(w[0..2], [0x00, 0x02]);
assert_eq!(&r[2..], &w[2..]);
}
#[test]
fn data_packet_builds_header_and_payload() {
let p = build_data(1, b"abc");
assert_eq!(p, vec![0x00, 0x03, 0x00, 0x01, b'a', b'b', b'c']);
}
#[test]
fn data_packet_block_number_big_endian() {
let p = build_data(0x0102, b"");
assert_eq!(p, vec![0x00, 0x03, 0x01, 0x02]);
}
#[test]
fn data_packet_full_block_is_516_bytes() {
let payload = vec![0x5Au8; BLOCK_SIZE];
let p = build_data(7, &payload);
assert_eq!(p.len(), 4 + BLOCK_SIZE);
assert_eq!(p[0..4], [0x00, 0x03, 0x00, 0x07]);
assert_eq!(&p[4..], &payload[..]);
}
#[test]
fn parse_ack_extracts_block_number() {
assert_eq!(parse_ack(&[0x00, 0x04, 0x00, 0x00]).unwrap(), 0);
assert_eq!(parse_ack(&[0x00, 0x04, 0x00, 0x01]).unwrap(), 1);
assert_eq!(parse_ack(&[0x00, 0x04, 0xFF, 0xFF]).unwrap(), 0xFFFF);
assert_eq!(parse_ack(&[0x00, 0x04, 0x12, 0x34, 0xAA]).unwrap(), 0x1234);
}
#[test]
fn parse_ack_rejects_short_header() {
assert!(parse_ack(&[]).is_err());
assert!(parse_ack(&[0x00, 0x04]).is_err());
assert!(parse_ack(&[0x00, 0x04, 0x00]).is_err());
}
#[test]
fn parse_ack_rejects_wrong_opcode() {
assert!(parse_ack(&[0x00, 0x03, 0x00, 0x01]).is_err());
}
fn split_blocks(data: &[u8]) -> Vec<(u16, &[u8])> {
let mut out = Vec::new();
let mut block: u16 = 0;
let mut offset = 0usize;
loop {
let next_offset = if block == 0 { 0 } else { offset + BLOCK_SIZE };
let next_block = block.checked_add(1).expect("no wrap in test sizes");
let end = (next_offset + BLOCK_SIZE).min(data.len());
let payload = &data[next_offset..end];
out.push((next_block, payload));
block = next_block;
offset = next_offset;
if payload.len() < BLOCK_SIZE {
break;
}
}
out
}
#[test]
fn split_blocks_short_single_block() {
let data = b"hello";
let blocks = split_blocks(data);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].0, 1);
assert_eq!(blocks[0].1, b"hello");
}
#[test]
fn split_blocks_empty_input_sends_one_empty_block() {
let blocks = split_blocks(b"");
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0], (1u16, &b""[..]));
}
#[test]
fn split_blocks_exact_multiple_appends_trailing_empty_block() {
let data = vec![0xABu8; BLOCK_SIZE];
let blocks = split_blocks(&data);
assert_eq!(blocks.len(), 2);
assert_eq!(blocks[0].0, 1);
assert_eq!(blocks[0].1.len(), BLOCK_SIZE);
assert_eq!(blocks[1].0, 2);
assert_eq!(blocks[1].1.len(), 0);
}
#[test]
fn split_blocks_two_full_blocks_plus_empty() {
let data = vec![1u8; 2 * BLOCK_SIZE];
let blocks = split_blocks(&data);
assert_eq!(blocks.len(), 3);
assert_eq!(blocks[0].0, 1);
assert_eq!(blocks[1].0, 2);
assert_eq!(blocks[2].0, 3);
assert_eq!(blocks[2].1.len(), 0);
}
#[test]
fn split_blocks_partial_final_block_no_trailing_empty() {
let mut data = vec![0u8; BLOCK_SIZE];
data.extend_from_slice(b"tail");
let blocks = split_blocks(&data);
assert_eq!(blocks.len(), 2);
assert_eq!(blocks[1].0, 2);
assert_eq!(blocks[1].1, b"tail");
}
#[test]
fn split_blocks_reassemble_matches_input() {
for &len in &[0usize, 1, 511, 512, 513, 1024, 1500, 4096] {
let data: Vec<u8> = (0..len).map(|i| (i % 251) as u8).collect();
let blocks = split_blocks(&data);
let joined: Vec<u8> = blocks.iter().flat_map(|(_, p)| p.iter().copied()).collect();
assert_eq!(joined, data, "len {len}");
}
}
use std::net::UdpSocket;
use std::thread;
fn run_mock_wrq_server(server: UdpSocket, inject_foreign_tid: bool) -> Vec<u8> {
let mut buf = [0u8; 4 + BLOCK_SIZE + 16];
let (n, from) = server.recv_from(&mut buf).unwrap();
assert_eq!(parse_opcode(&buf[..n]), Some(OP_WRQ));
let tid_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
tid_sock
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
tid_sock.send_to(&build_ack(0), from).unwrap();
let foreign = UdpSocket::bind("127.0.0.1:0").unwrap();
let mut injected = false;
let mut collected = Vec::new();
let mut expected: u16 = 1;
loop {
let (n, peer) = tid_sock.recv_from(&mut buf).unwrap();
let d = parse_data(&buf[..n]).unwrap();
assert_eq!(d.block, expected);
collected.extend_from_slice(d.data);
if inject_foreign_tid && !injected {
foreign.send_to(&build_ack(d.block), peer).unwrap();
injected = true;
}
tid_sock.send_to(&build_ack(d.block), peer).unwrap();
let last = d.data.len() < BLOCK_SIZE;
expected = expected.wrapping_add(1);
if last {
break;
}
}
collected
}
fn test_url(port: u16, path: &str) -> Url {
Url {
scheme: "tftp".into(),
userinfo: None,
host: "127.0.0.1".into(),
port,
path: path.into(),
}
}
fn upload_roundtrip_inner(payload: Vec<u8>, inject_foreign_tid: bool) {
let server = UdpSocket::bind("127.0.0.1:0").unwrap();
let server_port = server.local_addr().unwrap().port();
let handle = thread::spawn(move || run_mock_wrq_server(server, inject_foreign_tid));
let url = test_url(server_port, "/upload.bin");
store(&url, &payload).unwrap();
let got = handle.join().unwrap();
assert_eq!(got, payload);
}
fn upload_roundtrip(payload: Vec<u8>) {
upload_roundtrip_inner(payload, false);
}
#[test]
fn store_uploads_short_file() {
upload_roundtrip(b"hello, tftp world".to_vec());
}
#[test]
fn store_uploads_empty_file() {
upload_roundtrip(Vec::new());
}
#[test]
fn store_uploads_exact_multiple_of_block() {
upload_roundtrip(vec![0x7Eu8; BLOCK_SIZE]);
}
#[test]
fn store_uploads_multi_block_file() {
let payload: Vec<u8> = (0..(2 * BLOCK_SIZE + 100))
.map(|i| (i % 256) as u8)
.collect();
upload_roundtrip(payload);
}
#[test]
fn store_ignores_foreign_tid_acks() {
let payload: Vec<u8> = (0..(2 * BLOCK_SIZE + 7)).map(|i| (i % 256) as u8).collect();
upload_roundtrip_inner(payload, true);
}
#[test]
fn store_rejects_empty_filename() {
let url = test_url(69, "/");
assert!(matches!(store(&url, b"x"), Err(Error::InvalidUrl(_))));
}
#[test]
fn store_surfaces_server_error_packet() {
let server = UdpSocket::bind("127.0.0.1:0").unwrap();
let server_port = server.local_addr().unwrap().port();
let handle = thread::spawn(move || {
let mut buf = [0u8; 64];
let (n, from) = server.recv_from(&mut buf).unwrap();
assert_eq!(parse_opcode(&buf[..n]), Some(OP_WRQ));
let err = [0x00, 0x05, 0x00, 0x02, b'n', b'o', 0x00];
server.send_to(&err, from).unwrap();
});
let url = test_url(server_port, "/denied");
let err = store(&url, b"data").unwrap_err();
match err {
Error::BadResponse(m) => assert!(m.contains("no"), "got {m}"),
other => panic!("expected BadResponse, got {other:?}"),
}
handle.join().unwrap();
}
}