use std::{borrow::Cow, error::Error as StdError, fmt};
use crate::CommandExit;
pub type BoxError = Box<dyn StdError + Send + Sync + 'static>;
pub type Result<T, E = Error> = std::result::Result<T, E>;
pub struct CategoryError<K> {
kind: K,
message: Cow<'static, str>,
source: Option<BoxError>,
}
impl<K> CategoryError<K> {
pub fn new(kind: K, message: impl Into<Cow<'static, str>>) -> Self {
Self {
kind,
message: message.into(),
source: None,
}
}
pub fn with_source<E>(kind: K, message: impl Into<Cow<'static, str>>, source: E) -> Self
where
E: StdError + Send + Sync + 'static,
{
Self::with_boxed_source(kind, message, Box::new(source))
}
pub fn with_boxed_source(
kind: K,
message: impl Into<Cow<'static, str>>,
source: BoxError,
) -> Self {
Self {
kind,
message: message.into(),
source: Some(source),
}
}
pub fn kind(&self) -> K
where
K: Copy,
{
self.kind
}
pub fn message(&self) -> &str {
&self.message
}
pub fn has_source(&self) -> bool {
self.source.is_some()
}
}
impl<K> fmt::Debug for CategoryError<K>
where
K: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CategoryError")
.field("kind", &self.kind)
.field("message", &self.message)
.field("has_source", &self.source.is_some())
.finish()
}
}
impl<K> fmt::Display for CategoryError<K> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.message)
}
}
impl<K> StdError for CategoryError<K>
where
K: fmt::Debug + 'static,
{
fn source(&self) -> Option<&(dyn StdError + 'static)> {
self.source
.as_ref()
.map(|source| source.as_ref() as &(dyn StdError + 'static))
}
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum TransportErrorKind {
Dns,
TcpConnect,
Negotiation,
Keepalive,
Encryption,
Io,
Other,
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum HostKeyErrorKind {
Unknown,
Changed,
Rejected,
Unsupported,
Unavailable,
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum AuthenticationErrorKind {
Rejected,
Exhausted,
Partial,
UnsupportedMethod,
Unavailable,
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum ChannelErrorKind {
Open,
Request,
Read,
Write,
Eof,
Close,
Protocol,
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum SftpErrorKind {
RemoteStatus,
Protocol,
ChannelIo,
UnexpectedResponse,
UnsupportedVersion,
Unsupported,
NoSuchFile,
PermissionDenied,
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum ForwardingErrorKind {
Bind,
Listen,
Accept,
Connect,
GlobalRequest,
ChannelOpen,
StreamCopy,
Cancel,
Shutdown,
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum Operation {
Connect,
Authentication,
ChannelOpen,
Command,
Shell,
Sftp,
Forwarding,
Server,
Shutdown,
Other,
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum SshErrorKind {
Russh,
Other,
}
pub type TransportError = CategoryError<TransportErrorKind>;
pub type HostKeyError = CategoryError<HostKeyErrorKind>;
pub type AuthenticationError = CategoryError<AuthenticationErrorKind>;
pub type ChannelError = CategoryError<ChannelErrorKind>;
pub type SftpError = CategoryError<SftpErrorKind>;
pub type ForwardingError = CategoryError<ForwardingErrorKind>;
pub type TimeoutError = CategoryError<Operation>;
pub type CancelledError = CategoryError<Operation>;
pub type DisconnectedError = CategoryError<Operation>;
pub type SshError = CategoryError<SshErrorKind>;
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("invalid configuration: {0}")]
InvalidConfig(Cow<'static, str>),
#[error("transport error: {0}")]
Transport(#[source] TransportError),
#[error("host key verification failed: {0}")]
HostKey(#[source] HostKeyError),
#[error("authentication failed: {0}")]
Authentication(#[source] AuthenticationError),
#[error("channel error: {0}")]
Channel(#[source] ChannelError),
#[error("remote command exited unsuccessfully: {exit:?}")]
CommandExit {
exit: CommandExit,
},
#[error("sftp error: {0}")]
Sftp(#[source] SftpError),
#[error("forwarding error: {0}")]
Forwarding(#[source] ForwardingError),
#[error("operation timed out: {0}")]
Timeout(#[source] TimeoutError),
#[error("operation cancelled: {0}")]
Cancelled(#[source] CancelledError),
#[error("remote disconnected: {0}")]
Disconnected(#[source] DisconnectedError),
#[error("unsupported operation: {0}")]
Unsupported(Cow<'static, str>),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error("ssh error: {0}")]
Ssh(#[source] SshError),
}
impl Error {
pub fn invalid_config(message: impl Into<Cow<'static, str>>) -> Self {
Self::InvalidConfig(message.into())
}
pub fn transport(kind: TransportErrorKind, message: impl Into<Cow<'static, str>>) -> Self {
Self::Transport(TransportError::new(kind, message))
}
pub fn transport_with_source<E>(
kind: TransportErrorKind,
message: impl Into<Cow<'static, str>>,
source: E,
) -> Self
where
E: StdError + Send + Sync + 'static,
{
Self::Transport(TransportError::with_source(kind, message, source))
}
pub fn host_key(kind: HostKeyErrorKind, message: impl Into<Cow<'static, str>>) -> Self {
Self::HostKey(HostKeyError::new(kind, message))
}
pub fn host_key_with_source<E>(
kind: HostKeyErrorKind,
message: impl Into<Cow<'static, str>>,
source: E,
) -> Self
where
E: StdError + Send + Sync + 'static,
{
Self::HostKey(HostKeyError::with_source(kind, message, source))
}
pub fn authentication(message: impl Into<Cow<'static, str>>) -> Self {
Self::Authentication(AuthenticationError::new(
AuthenticationErrorKind::Rejected,
message,
))
}
pub fn authentication_kind(
kind: AuthenticationErrorKind,
message: impl Into<Cow<'static, str>>,
) -> Self {
Self::Authentication(AuthenticationError::new(kind, message))
}
pub fn authentication_with_source<E>(
kind: AuthenticationErrorKind,
message: impl Into<Cow<'static, str>>,
source: E,
) -> Self
where
E: StdError + Send + Sync + 'static,
{
Self::Authentication(AuthenticationError::with_source(kind, message, source))
}
pub fn channel(message: impl Into<Cow<'static, str>>) -> Self {
Self::Channel(ChannelError::new(ChannelErrorKind::Protocol, message))
}
pub fn channel_kind(kind: ChannelErrorKind, message: impl Into<Cow<'static, str>>) -> Self {
Self::Channel(ChannelError::new(kind, message))
}
pub fn channel_with_source<E>(
kind: ChannelErrorKind,
message: impl Into<Cow<'static, str>>,
source: E,
) -> Self
where
E: StdError + Send + Sync + 'static,
{
Self::Channel(ChannelError::with_source(kind, message, source))
}
pub fn command_exit(exit: CommandExit) -> Self {
Self::CommandExit { exit }
}
pub fn sftp(kind: SftpErrorKind, message: impl Into<Cow<'static, str>>) -> Self {
Self::Sftp(SftpError::new(kind, message))
}
pub fn sftp_with_source<E>(
kind: SftpErrorKind,
message: impl Into<Cow<'static, str>>,
source: E,
) -> Self
where
E: StdError + Send + Sync + 'static,
{
Self::Sftp(SftpError::with_source(kind, message, source))
}
pub fn forwarding(kind: ForwardingErrorKind, message: impl Into<Cow<'static, str>>) -> Self {
Self::Forwarding(ForwardingError::new(kind, message))
}
pub fn forwarding_with_source<E>(
kind: ForwardingErrorKind,
source: E,
message: impl Into<Cow<'static, str>>,
) -> Self
where
E: StdError + Send + Sync + 'static,
{
Self::Forwarding(ForwardingError::with_source(kind, message, source))
}
pub fn timeout(operation: Operation, message: impl Into<Cow<'static, str>>) -> Self {
Self::Timeout(TimeoutError::new(operation, message))
}
pub fn cancelled(operation: Operation, message: impl Into<Cow<'static, str>>) -> Self {
Self::Cancelled(CancelledError::new(operation, message))
}
pub fn disconnected(operation: Operation, message: impl Into<Cow<'static, str>>) -> Self {
Self::Disconnected(DisconnectedError::new(operation, message))
}
pub fn unsupported(message: impl Into<Cow<'static, str>>) -> Self {
Self::Unsupported(message.into())
}
pub fn ssh_with_source<E>(message: impl Into<Cow<'static, str>>, source: E) -> Self
where
E: StdError + Send + Sync + 'static,
{
Self::Ssh(SshError::with_source(SshErrorKind::Other, message, source))
}
pub fn is_timeout(&self) -> bool {
matches!(self, Self::Timeout(_))
|| matches!(self, Self::Io(error) if error.kind() == std::io::ErrorKind::TimedOut)
}
pub fn is_cancelled(&self) -> bool {
matches!(self, Self::Cancelled(_))
}
pub fn is_disconnected(&self) -> bool {
matches!(self, Self::Disconnected(_))
}
}
impl From<BoxError> for Error {
fn from(source: BoxError) -> Self {
Self::Ssh(SshError::with_boxed_source(
SshErrorKind::Other,
"lower-level SSH error",
source,
))
}
}
#[cfg(test)]
mod tests {
use std::{error::Error as StdError, fmt};
use crate::{
AuthenticationErrorKind, Error, HostKeyError, HostKeyErrorKind, Operation, SshErrorKind,
TransportError, TransportErrorKind,
};
#[derive(Debug)]
struct SourceError;
impl fmt::Display for SourceError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("source display")
}
}
impl StdError for SourceError {}
#[derive(Debug)]
struct SecretSource;
impl fmt::Display for SecretSource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("secret-source-display")
}
}
impl StdError for SecretSource {}
#[test]
fn category_errors_preserve_source_without_debugging_it() {
let error = TransportError::with_source(
TransportErrorKind::Dns,
"failed to resolve host",
SecretSource,
);
assert_eq!(error.kind(), TransportErrorKind::Dns);
assert_eq!(error.message(), "failed to resolve host");
assert!(error.has_source());
assert!(StdError::source(&error).is_some());
let debug = format!("{error:?}");
assert!(debug.contains("Dns"));
assert!(debug.contains("has_source: true"));
assert!(!debug.contains("secret-source-display"));
}
#[test]
fn top_level_errors_preserve_category_sources() {
let error = Error::transport_with_source(
TransportErrorKind::TcpConnect,
"tcp connect failed",
SourceError,
);
let category = StdError::source(&error).expect("category source");
assert_eq!(category.to_string(), "tcp connect failed");
assert!(category.source().is_some());
}
#[test]
fn helper_predicates_classify_common_control_flow() {
assert!(Error::timeout(Operation::Connect, "connect timed out").is_timeout());
assert!(Error::from(std::io::Error::new(std::io::ErrorKind::TimedOut, "io")).is_timeout());
assert!(Error::cancelled(Operation::Command, "cancelled").is_cancelled());
assert!(Error::disconnected(Operation::Sftp, "disconnect").is_disconnected());
assert!(!Error::unsupported("not yet").is_timeout());
}
#[test]
fn typed_variants_expose_stable_kinds() {
let auth = Error::authentication_kind(AuthenticationErrorKind::Exhausted, "no credentials");
let Error::Authentication(auth) = auth else {
panic!("expected authentication error");
};
assert_eq!(auth.kind(), AuthenticationErrorKind::Exhausted);
let host_key = Error::HostKey(HostKeyError::new(
HostKeyErrorKind::Changed,
"host key changed",
));
let Error::HostKey(host_key) = host_key else {
panic!("expected host key error");
};
assert_eq!(host_key.kind(), HostKeyErrorKind::Changed);
}
#[test]
fn boxed_sources_convert_to_unclassified_ssh_errors() {
let error = Error::from(Box::new(SourceError) as crate::BoxError);
let Error::Ssh(ssh) = error else {
panic!("expected ssh error");
};
assert_eq!(ssh.kind(), SshErrorKind::Other);
assert!(ssh.has_source());
}
}