use std::fmt;
use std::io::{ErrorKind, Read, Write};
#[derive(Debug)]
pub enum ScpError {
Io(std::io::Error),
BadHeader(&'static str),
Remote(String),
Warning(String),
BadName(&'static str),
PathEscape,
Unexpected(&'static str),
}
impl fmt::Display for ScpError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ScpError::Io(e) => write!(f, "scp: io: {}", e),
ScpError::BadHeader(m) => write!(f, "scp: bad header: {}", m),
ScpError::Remote(m) => write!(f, "scp: remote error: {}", m),
ScpError::Warning(m) => write!(f, "scp: remote warning: {}", m),
ScpError::BadName(m) => write!(f, "scp: bad name: {}", m),
ScpError::PathEscape => f.write_str("scp: path escapes base directory"),
ScpError::Unexpected(m) => write!(f, "scp: unexpected: {}", m),
}
}
}
impl std::error::Error for ScpError {}
impl From<std::io::Error> for ScpError {
fn from(e: std::io::Error) -> Self {
ScpError::Io(e)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Header {
File {
mode: u32,
size: u64,
name: String,
},
Dir {
mode: u32,
name: String,
},
EndDir,
Times {
mtime: i64,
atime: i64,
},
}
pub fn validate_name(name: &str) -> Result<(), ScpError> {
if name.is_empty() {
return Err(ScpError::BadName("empty"));
}
if name.contains('\0') {
return Err(ScpError::BadName("contains NUL"));
}
if name.contains('\n') {
return Err(ScpError::BadName("contains newline"));
}
if name.starts_with('-') {
return Err(ScpError::BadName("starts with '-'"));
}
if name == ".." || name.contains('/') {
return Err(ScpError::BadName("path component"));
}
Ok(())
}
pub fn write_header<W: Write>(w: &mut W, h: &Header) -> Result<(), ScpError> {
match h {
Header::File { mode, size, name } => {
validate_name(name)?;
writeln!(w, "C{:04o} {} {}", mode & 0o7777, size, name)?;
}
Header::Dir { mode, name } => {
validate_name(name)?;
writeln!(w, "D{:04o} 0 {}", mode & 0o7777, name)?;
}
Header::EndDir => {
w.write_all(b"E\n")?;
}
Header::Times { mtime, atime } => {
writeln!(w, "T{} 0 {} 0", *mtime, *atime)?;
}
}
w.flush()?;
Ok(())
}
pub fn read_header<R: Read>(r: &mut R) -> Result<Option<Header>, ScpError> {
let mut first = [0u8; 1];
match r.read(&mut first) {
Ok(0) => return Ok(None),
Ok(_) => {}
Err(e) if e.kind() == ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ScpError::Io(e)),
}
if first[0] == 0x01 || first[0] == 0x02 {
let msg = read_line(r)?;
return if first[0] == 0x02 {
Err(ScpError::Remote(msg))
} else {
Err(ScpError::Warning(msg))
};
}
let kind = first[0] as char;
if !matches!(kind, 'C' | 'D' | 'E' | 'T') {
return Err(ScpError::BadHeader("unknown header kind"));
}
let rest = read_line(r)?;
match kind {
'E' => {
if !rest.is_empty() {
return Err(ScpError::BadHeader("E has trailing data"));
}
Ok(Some(Header::EndDir))
}
'C' | 'D' => {
let mut parts = rest.splitn(3, ' ');
let mode_s = parts.next().ok_or(ScpError::BadHeader("missing mode"))?;
let size_s = parts.next().ok_or(ScpError::BadHeader("missing size"))?;
let name = parts.next().ok_or(ScpError::BadHeader("missing name"))?;
let mode =
u32::from_str_radix(mode_s, 8).map_err(|_| ScpError::BadHeader("bad mode"))?;
let size = size_s
.parse::<u64>()
.map_err(|_| ScpError::BadHeader("bad size"))?;
validate_name(name)?;
if kind == 'C' {
Ok(Some(Header::File {
mode,
size,
name: name.to_string(),
}))
} else {
if size != 0 {
return Err(ScpError::BadHeader("D header with non-zero size"));
}
Ok(Some(Header::Dir {
mode,
name: name.to_string(),
}))
}
}
'T' => {
let mut parts = rest.split(' ');
let mtime_s = parts.next().ok_or(ScpError::BadHeader("T missing mtime"))?;
let zero1 = parts.next().ok_or(ScpError::BadHeader("T missing zero"))?;
let atime_s = parts.next().ok_or(ScpError::BadHeader("T missing atime"))?;
let zero2 = parts
.next()
.ok_or(ScpError::BadHeader("T missing trailing zero"))?;
if parts.next().is_some() {
return Err(ScpError::BadHeader("T trailing junk"));
}
if zero1 != "0" || zero2 != "0" {
return Err(ScpError::BadHeader("T non-zero microseconds"));
}
let mtime = mtime_s
.parse::<i64>()
.map_err(|_| ScpError::BadHeader("T bad mtime"))?;
let atime = atime_s
.parse::<i64>()
.map_err(|_| ScpError::BadHeader("T bad atime"))?;
Ok(Some(Header::Times { mtime, atime }))
}
_ => unreachable!(),
}
}
fn read_line<R: Read>(r: &mut R) -> Result<String, ScpError> {
const MAX_LINE: usize = 4096;
let mut out = Vec::with_capacity(64);
let mut byte = [0u8; 1];
loop {
match r.read(&mut byte) {
Ok(0) => return Err(ScpError::BadHeader("EOF mid-line")),
Ok(_) => {}
Err(e) => return Err(ScpError::Io(e)),
}
if byte[0] == b'\n' {
break;
}
if out.len() >= MAX_LINE {
return Err(ScpError::BadHeader("line too long"));
}
out.push(byte[0]);
}
String::from_utf8(out).map_err(|_| ScpError::BadHeader("non-UTF-8 line"))
}
pub fn write_ok<W: Write>(w: &mut W) -> Result<(), ScpError> {
w.write_all(&[0x00])?;
w.flush()?;
Ok(())
}
pub fn write_fatal<W: Write>(w: &mut W, msg: &str) -> Result<(), ScpError> {
let mut buf = Vec::with_capacity(2 + msg.len() + 1);
buf.push(0x02);
for &b in msg.as_bytes() {
if b == b'\n' || b == 0 {
buf.push(b' ');
} else {
buf.push(b);
}
}
buf.push(b'\n');
w.write_all(&buf)?;
w.flush()?;
Ok(())
}
pub fn read_ack<R: Read>(r: &mut R) -> Result<(), ScpError> {
let mut b = [0u8; 1];
match r.read(&mut b) {
Ok(0) => return Err(ScpError::Unexpected("EOF awaiting ack")),
Ok(_) => {}
Err(e) => return Err(ScpError::Io(e)),
}
match b[0] {
0x00 => Ok(()),
0x01 => Err(ScpError::Warning(read_line(r)?)),
0x02 => Err(ScpError::Remote(read_line(r)?)),
_ => Err(ScpError::BadHeader("non-ack byte")),
}
}
pub fn write_payload_term<R: Read, W: Write>(
w: &mut W,
src: &mut R,
size: u64,
) -> Result<(), ScpError> {
let mut remaining = size;
let mut buf = [0u8; 32 * 1024];
while remaining > 0 {
let take = (remaining as usize).min(buf.len());
let n = src.read(&mut buf[..take])?;
if n == 0 {
return Err(ScpError::Unexpected("local payload shorter than header"));
}
w.write_all(&buf[..n])?;
remaining -= n as u64;
}
w.write_all(&[0x00])?;
w.flush()?;
Ok(())
}
pub fn read_payload_term<R: Read, W: Write>(
r: &mut R,
dst: &mut W,
size: u64,
) -> Result<(), ScpError> {
let mut remaining = size;
let mut buf = [0u8; 32 * 1024];
while remaining > 0 {
let take = (remaining as usize).min(buf.len());
let n = r.read(&mut buf[..take])?;
if n == 0 {
return Err(ScpError::Unexpected("EOF mid-payload"));
}
dst.write_all(&buf[..n])?;
remaining -= n as u64;
}
let mut term = [0u8; 1];
match r.read(&mut term) {
Ok(0) => return Err(ScpError::Unexpected("EOF awaiting payload terminator")),
Ok(_) => {}
Err(e) => return Err(ScpError::Io(e)),
}
if term[0] != 0x00 {
return Err(ScpError::BadHeader("payload terminator was not 0x00"));
}
Ok(())
}
#[cfg(test)]
mod unit {
use super::*;
#[test]
fn round_trip_file_header() {
let mut buf = Vec::new();
write_header(
&mut buf,
&Header::File {
mode: 0o644,
size: 1234,
name: "hello.txt".into(),
},
)
.unwrap();
assert_eq!(buf, b"C0644 1234 hello.txt\n");
let mut cur = std::io::Cursor::new(buf);
match read_header(&mut cur).unwrap().unwrap() {
Header::File { mode, size, name } => {
assert_eq!(mode, 0o644);
assert_eq!(size, 1234);
assert_eq!(name, "hello.txt");
}
_ => panic!("expected File"),
}
}
#[test]
fn round_trip_dir_and_end() {
let mut buf = Vec::new();
write_header(
&mut buf,
&Header::Dir {
mode: 0o755,
name: "sub".into(),
},
)
.unwrap();
write_header(&mut buf, &Header::EndDir).unwrap();
assert_eq!(buf, b"D0755 0 sub\nE\n");
let mut cur = std::io::Cursor::new(buf);
assert!(matches!(
read_header(&mut cur).unwrap().unwrap(),
Header::Dir { .. }
));
assert!(matches!(
read_header(&mut cur).unwrap().unwrap(),
Header::EndDir
));
}
#[test]
fn round_trip_times() {
let mut buf = Vec::new();
write_header(
&mut buf,
&Header::Times {
mtime: 1_700_000_000,
atime: 1_700_000_001,
},
)
.unwrap();
assert_eq!(buf, b"T1700000000 0 1700000001 0\n");
let mut cur = std::io::Cursor::new(buf);
match read_header(&mut cur).unwrap().unwrap() {
Header::Times { mtime, atime } => {
assert_eq!(mtime, 1_700_000_000);
assert_eq!(atime, 1_700_000_001);
}
_ => panic!("expected Times"),
}
}
#[test]
fn reject_leading_dash_name() {
assert!(matches!(validate_name("-rf"), Err(ScpError::BadName(_))));
}
#[test]
fn reject_newline_in_name() {
assert!(matches!(validate_name("a\nb"), Err(ScpError::BadName(_))));
}
#[test]
fn reject_slash_in_name() {
assert!(matches!(validate_name("a/b"), Err(ScpError::BadName(_))));
}
#[test]
fn ack_ok() {
let mut cur = std::io::Cursor::new(vec![0x00u8]);
assert!(read_ack(&mut cur).is_ok());
}
#[test]
fn ack_remote_error_carries_message() {
let mut cur = std::io::Cursor::new(b"\x02boom\n".to_vec());
match read_ack(&mut cur) {
Err(ScpError::Remote(m)) => assert_eq!(m, "boom"),
other => panic!("unexpected {:?}", other),
}
}
#[test]
fn read_header_returns_none_on_eof() {
let mut cur = std::io::Cursor::new(Vec::<u8>::new());
assert!(read_header(&mut cur).unwrap().is_none());
}
#[test]
fn write_fatal_strips_newlines() {
let mut buf = Vec::new();
write_fatal(&mut buf, "line1\nline2").unwrap();
assert_eq!(buf, b"\x02line1 line2\n");
}
}