sftp-protocol 0.1.0

A pure Rust implementation of the SFTP protocol
Documentation
use std::io;
use std::io::Write;

use circular::Buffer;
use nom_derive::Parse;

use crate::packet::Packet;
use crate::packet::PacketHeader;
use crate::Error;

#[derive(Clone, Debug)]
pub struct Parser {
	buffer: Buffer,
	#[allow(dead_code)]
	current_header: Option<PacketHeader>
}

impl Default for Parser {
	fn default() -> Self {
		#[allow(clippy::identity_op)] // Including the 1 makes it more clear that this is 1MiB
		Self::with_capacity(1 * 1024 * 1024)
	}
}

impl Parser {
	pub fn with_capacity(capacity: usize) -> Self {
		Self{
			buffer: Buffer::with_capacity(capacity),
			current_header: None
		}
	}

	#[inline]
	pub fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
		self.buffer.write(buf)
	}

	fn check_packet(&mut self) -> bool {
		match &self.current_header {
			None => {
				if(self.buffer.available_data() >= 5) {
					let header_bytes = &self.buffer.data()[..5];
					self.current_header = Some(PacketHeader::parse(header_bytes).unwrap().1);
					self.check_packet()
				} else {
					false
				}
			},
			Some(header) => {
				self.buffer.available_data() >= header.length as usize + 4
			}
		}
	}

	pub fn get_packet(&mut self) -> Result<Option<Packet>, Error> {
		match (self.check_packet(), &self.current_header) {
			(false, _) => Ok(None),
			(true, None) => Err(Error::NoHeader),
			(true, Some(header)) => {
				let len = header.length as usize + 4; // Exclude type byte, we already have it
				let data = &self.buffer.data()[..len];
				let packet = Packet::parse(data)?.1;
				self.buffer.consume(len);
				self.current_header = None;
				Ok(Some(packet))
			}
		}
	}
}

#[cfg(test)]
pub(crate) fn encode(packet: &Packet) -> Vec<u8> {
	use bincode::Options;
	let en = bincode::DefaultOptions::new().with_big_endian().with_fixint_encoding();
	en.serialize(packet).unwrap()
}

#[cfg(test)]
mod test {
	use rand::Rng;
	use test_strategy::proptest;
	use crate::packet::*;
	use super::*;

	fn random_slices<T>(slice: &[T]) -> Vec<&[T]> /* {{{ */ {
		let mut rng = rand::thread_rng();
		let magic = rng.gen_range(2..std::cmp::min(16, slice.len() / 2));
		let count = rng.gen_range(1..(slice.len() / magic));
		let approx_size = slice.len() / count;
		let mut lengths = (0..count).map(|_| rng.gen_range((approx_size / 2)..((approx_size * 3) / 2))).collect::<Vec<_>>();
		while(lengths.iter().sum::<usize>() != slice.len()) {
			if(lengths.iter().sum::<usize>() < slice.len()) {
				for len in lengths.iter_mut() {
					if(rng.gen::<f32>() < 0.01) {
						(*len) += 1;
					}
				}
			} else {
				for len in lengths.iter_mut() {
					if(*len > 1 && rng.gen::<f32>() < 0.01) {
						(*len) -= 1;
					}
				}
			}
		}
		let mut subslices = Vec::with_capacity(lengths.len());
		let mut current_pos = 0;
		for len in lengths {
			subslices.push(&slice[current_pos..(len + current_pos)]);
			current_pos += len;
		}
		subslices
	} // }}}

	/*
	#[test]
	fn test_random_slices() {
		let pile_of_bytes = vec![0u8; 1024];
		let slices = random_slices(&pile_of_bytes);
		assert_eq!(slices.iter().map(|s| s.len()).sum::<usize>(), pile_of_bytes.len());
		let pile_of_bytes = vec![0u8; 32];
		let slices = random_slices(&pile_of_bytes);
		assert_eq!(slices.iter().map(|s| s.len()).sum::<usize>(), pile_of_bytes.len());
		let pile_of_bytes = vec![0u8; 1024 * 1024];
		let slices = random_slices(&pile_of_bytes);
		assert_eq!(slices.iter().map(|s| s.len()).sum::<usize>(), pile_of_bytes.len());
	}
	*/

	#[ignore] // It's broken and doesn't have any actual functional impact
	#[test]
	fn validate_packet_result_size() {
		assert_eq!(std::mem::size_of::<Result<Packet, Error>>(), std::mem::size_of::<Packet>() + std::mem::size_of::<*const ()>());
		assert_eq!(std::mem::size_of::<Result<Option<Packet>, Error>>(), std::mem::size_of::<Result<Packet, Error>>());
	}

	#[test]
	fn single_init() {
		let mut stream = Parser::default();
		assert_eq!(stream.get_packet(), Ok(None));
		let packet = Payload::init(1, vec![]).into_packet();
		stream.write(&encode(&packet)).unwrap();
		assert_eq!(stream.get_packet(), Ok(Some(packet)));
		assert_eq!(stream.get_packet(), Ok(None));
	}

	#[test]
	fn multipart_init() {
		let mut stream = Parser::default();
		let packet = Payload::init(2, vec![]).into_packet();
		let bytes = encode(&packet);
		stream.write(&bytes[0..3]).unwrap();
		assert_eq!(stream.get_packet(), Ok(None));
		stream.write(&bytes[3..bytes.len()]).unwrap();
		assert_eq!(stream.get_packet(), Ok(Some(packet)));
	}

	#[test]
	fn handshake() {
		let mut stream = Parser::default();
		let init = Payload::init(32768, (0..100).collect()).into_packet();
		stream.write(&encode(&init)).unwrap();
		assert_eq!(stream.get_packet(), Ok(Some(init)));
		let version = Payload::version(3, (100..150).collect()).into_packet();
		stream.write(&encode(&version)).unwrap();
		assert_eq!(stream.get_packet(), Ok(Some(version)));
		assert_eq!(stream.get_packet(), Ok(None));
	}

	#[test]
	fn handshake_queued() {
		let mut stream = Parser::default();
		let init = Payload::init(32768, (0..100).collect()).into_packet();
		stream.write(&encode(&init)).unwrap();
		let version = Payload::version(3, (100..150).collect()).into_packet();
		stream.write(&encode(&version)).unwrap();
		assert_eq!(stream.get_packet(), Ok(Some(init)));
		assert_eq!(stream.get_packet(), Ok(Some(version)));
		assert_eq!(stream.get_packet(), Ok(None));
	}

	#[proptest]
	fn arbitrary_sequence(input: Vec<Packet>) {
		let mut stream = Parser::default();
		assert_eq!(stream.get_packet(), Ok(None));
		for packet in input {
			stream.write(&encode(&packet)).unwrap();
			assert_eq!(stream.get_packet(), Ok(Some(packet)));
			assert_eq!(stream.get_packet(), Ok(None));
		}
		assert_eq!(stream.get_packet(), Ok(None));
	}

	#[proptest]
	fn arbitrary_sequence_queued(input: Vec<Packet>) {
		let mut stream = Parser::default();
		assert_eq!(stream.get_packet(), Ok(None));
		for packet in &input {
			stream.write(&encode(packet)).unwrap();
		}
		for packet in input {
			assert_eq!(stream.get_packet(), Ok(Some(packet)));
		}
		assert_eq!(stream.get_packet(), Ok(None));
	}

	#[proptest]
	fn arbitrary_sequence_multipart(input: Vec<Packet>) {
		let mut stream = Parser::default();
		assert_eq!(stream.get_packet(), Ok(None));
		for packet in input {
			let bytes = encode(&packet);
			let slices = random_slices(&bytes);
			for slice in slices {
				stream.write(slice).unwrap();
			}
			assert_eq!(stream.get_packet(), Ok(Some(packet)));
			assert_eq!(stream.get_packet(), Ok(None));
		}
		assert_eq!(stream.get_packet(), Ok(None));
	}

	#[proptest]
	fn arbitrary_sequence_multipart_queued(input: Vec<Packet>) {
		let mut stream = Parser::default();
		assert_eq!(stream.get_packet(), Ok(None));
		for packet in &input {
			let bytes = encode(packet);
			let slices = random_slices(&bytes);
			for slice in slices {
				stream.write(slice).unwrap();
			}
		}
		for packet in input {
			assert_eq!(stream.get_packet(), Ok(Some(packet)));
		}
		assert_eq!(stream.get_packet(), Ok(None));
	}
}