use std::borrow::Cow;
use bytes::{Buf, BufMut, Bytes};
use serde::{ser::SerializeTuple, Deserialize, Serialize};
use crate::wire::{Error, SftpDecoder, SftpEncoder};
mod attrs;
mod close;
mod data;
mod extended;
mod extended_reply;
mod fsetstat;
mod fstat;
mod handle;
mod init;
mod lstat;
mod mkdir;
mod name;
mod open;
mod opendir;
mod path;
mod read;
mod readdir;
mod readlink;
mod realpath;
mod remove;
mod rename;
mod rmdir;
mod setstat;
mod stat;
mod status;
mod symlink;
mod version;
mod write;
#[cfg(test)]
mod test_utils;
pub use attrs::{Attrs, Owner, Permisions, Time};
pub use close::Close;
pub use data::Data;
pub use extended::Extended;
pub use extended_reply::ExtendedReply;
pub use fsetstat::FSetStat;
pub use fstat::FStat;
pub use handle::Handle;
pub use init::Init;
pub use lstat::LStat;
pub use mkdir::MkDir;
pub use name::{Name, NameEntry};
pub use open::{Open, PFlags};
pub use opendir::OpenDir;
pub use path::Path;
pub use read::Read;
pub use readdir::ReadDir;
pub use readlink::ReadLink;
pub use realpath::RealPath;
pub use remove::Remove;
pub use rename::Rename;
pub use rmdir::RmDir;
pub use setstat::SetStat;
pub use stat::Stat;
pub use status::{Status, StatusCode};
pub use symlink::Symlink;
pub use version::Version;
pub use write::Write;
macro_rules! messages {
($($name:ident: $discriminant:expr)*) => {
#[derive(Debug, PartialEq, Eq, Clone)]
#[repr(u8)]
#[non_exhaustive]
pub enum Message {
$($name($name) = $discriminant,)*
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[repr(u8)]
#[non_exhaustive]
pub enum MessageKind {
$($name = $discriminant,)*
}
impl Message {
pub fn kind(&self) -> MessageKind {
match self {
$(Self::$name(_) => MessageKind::$name,)*
}
}
}
impl MessageKind {
pub fn code(&self) -> u8 {
match self {
$(Self::$name => $discriminant,)*
}
}
}
impl From<Message> for MessageKind {
fn from(value: Message) -> Self {
value.kind()
}
}
impl Serialize for Message {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut state = serializer.serialize_tuple(2)?;
state.serialize_element(&self.code())?;
match self {
$(Message::$name(value) => state.serialize_element(value)?,)*
}
state.end()
}
}
impl<'de> Deserialize<'de> for Message {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = Message;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a type code and a message content")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let no_value = || ::serde::de::Error::custom("no value found");
let code = seq.next_element::<u8>()?.ok_or_else(no_value)?;
let content = match code {
$($discriminant => seq.next_element::<$name>()?.ok_or_else(no_value)?.into(),)*
_ => return Err(::serde::de::Error::custom("invalid message type")),
};
Ok(content)
}
}
deserializer.deserialize_tuple(3, Visitor)
}
}
impl<'a> Serialize for MessageWithId<'a> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let id = match self.message.as_ref() {
Message::Init(_) | Message::Version(_) => None,
_ => Some(self.id),
};
let mut state = serializer.serialize_tuple(3)?;
state.serialize_element(&self.message.code())?;
state.serialize_element(&id)?;
match self.message.as_ref() {
$(Message::$name(value) => state.serialize_element(&value)?,)*
}
state.end()
}
}
impl<'de> Deserialize<'de> for MessageWithId<'de> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = MessageWithId<'de>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a type code, an id, and a message content")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let no_value = || ::serde::de::Error::custom("no value found");
let code = seq.next_element::<u8>()?.ok_or_else(no_value)?;
let (id, message) = if code == Init::DISCRIMINANT || code == Version::DISCRIMINANT {
seq.next_element()?.ok_or_else(no_value)?;
if code == Init::DISCRIMINANT {
(0, Message::Init(seq.next_element()?.ok_or_else(no_value)?))
} else {
(0, Message::Version(seq.next_element()?.ok_or_else(no_value)?))
}
} else {
let id = seq.next_element()?.ok_or_else(no_value)?;
let message = match code {
$($discriminant => seq.next_element::<$name>()?.ok_or_else(no_value)?.into(),)*
_ => return Err(::serde::de::Error::custom("invalid message type")),
};
(id, message)
};
Ok(MessageWithId { id, message: Cow::Owned(message) })
}
}
deserializer.deserialize_tuple(3, Visitor)
}
}
$(
impl $name {
#[allow(dead_code)]
const DISCRIMINANT: u8 = $discriminant;
}
impl From<$name> for Message {
fn from(value: $name) -> Self {
Self::$name(value)
}
}
impl TryFrom<Message> for $name {
type Error = Message;
fn try_from(value: Message) -> Result<Self, Self::Error> {
if let Message::$name(value) = value {
Ok(value)
} else {
Err(value)
}
}
}
)*
};
}
messages! {
Init: 1
Version: 2
Open: 3
Close: 4
Read: 5
Write: 6
LStat: 7
FStat: 8
SetStat: 9
FSetStat: 10
OpenDir: 11
ReadDir: 12
Remove: 13
MkDir: 14
RmDir: 15
RealPath: 16
Stat: 17
Rename: 18
ReadLink: 19
Symlink: 20
Status: 101
Handle: 102
Data: 103
Name: 104
Attrs: 105
Extended: 200
ExtendedReply: 201
}
impl From<Init> for Version {
fn from(value: Init) -> Self {
Self {
version: value.version,
extensions: value.extensions,
}
}
}
impl From<Version> for Init {
fn from(value: Version) -> Self {
Self {
version: value.version,
extensions: value.extensions,
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct MessageWithId<'a> {
id: u32,
message: Cow<'a, Message>,
}
impl Message {
pub fn code(&self) -> u8 {
self.kind().code()
}
pub fn encode(&self, id: u32) -> Result<Bytes, Error> {
let mut encoder = SftpEncoder::with_vec(Vec::with_capacity(16));
encoder.buf.put_u32(0);
MessageWithId {
id,
message: Cow::Borrowed(self),
}
.serialize(&mut encoder)?;
let frame_length = (encoder.buf.len() - std::mem::size_of::<u32>()) as u32;
let mut buf = encoder.buf.as_mut_slice();
buf.put_u32(frame_length);
Ok(encoder.buf.into())
}
pub fn decode(mut buf: &[u8]) -> Result<(u32, Self), Error> {
let frame_length = buf.get_u32() as usize;
let mut decoder = SftpDecoder::new(&buf[0..frame_length]);
let message_with_id = MessageWithId::deserialize(&mut decoder).map_err(Into::into)?;
Ok((message_with_id.id, message_with_id.message.into_owned()))
}
}