use crate::types::status::NtStatus;
use crate::types::Command;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Error {
#[error("Invalid data: {message}")]
InvalidData {
message: String,
},
#[error("Protocol error: {status} during {command:?}")]
Protocol {
status: NtStatus,
command: Command,
},
#[error("Authentication failed: {message}")]
Auth {
message: String,
},
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Operation timed out")]
Timeout,
#[error("Disconnected from server")]
Disconnected,
#[error("DFS referral required for path: {path}")]
DfsReferralRequired {
path: String,
},
#[error("Operation cancelled")]
Cancelled,
#[error("Session expired and reauthentication failed")]
SessionExpired,
}
impl Error {
pub fn invalid_data(msg: impl Into<String>) -> Self {
Error::InvalidData {
message: msg.into(),
}
}
pub fn is_retryable(&self) -> bool {
matches!(
self,
Error::Timeout
| Error::Disconnected
| Error::Protocol {
status: NtStatus::INSUFFICIENT_RESOURCES,
..
}
| Error::Protocol {
status: NtStatus::INSUFF_SERVER_RESOURCES,
..
}
)
}
pub fn status(&self) -> Option<NtStatus> {
match self {
Error::Protocol { status, .. } => Some(*status),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorKind {
AuthRequired,
SigningRequired,
AccessDenied,
NotFound,
SharingViolation,
DiskFull,
ConnectionLost,
TimedOut,
Cancelled,
SessionExpired,
DfsReferral,
InvalidData,
Io,
Other,
}
impl Error {
pub fn kind(&self) -> ErrorKind {
match self {
Error::InvalidData { .. } => ErrorKind::InvalidData,
Error::Auth { .. } => ErrorKind::AuthRequired,
Error::Io(_) => ErrorKind::Io,
Error::Disconnected => ErrorKind::ConnectionLost,
Error::Timeout => ErrorKind::TimedOut,
Error::Cancelled => ErrorKind::Cancelled,
Error::SessionExpired => ErrorKind::SessionExpired,
Error::DfsReferralRequired { .. } => ErrorKind::DfsReferral,
Error::Protocol { status, .. } => classify_status(*status),
}
}
}
fn classify_status(status: NtStatus) -> ErrorKind {
match status {
NtStatus::LOGON_FAILURE | NtStatus::ACCOUNT_DISABLED => ErrorKind::AuthRequired,
NtStatus::ACCESS_DENIED => {
ErrorKind::AccessDenied
}
NtStatus::NO_SUCH_FILE
| NtStatus::OBJECT_NAME_NOT_FOUND
| NtStatus::OBJECT_PATH_NOT_FOUND
| NtStatus::BAD_NETWORK_NAME => ErrorKind::NotFound,
NtStatus::SHARING_VIOLATION | NtStatus::FILE_LOCK_CONFLICT => ErrorKind::SharingViolation,
NtStatus::DISK_FULL => ErrorKind::DiskFull,
NtStatus::NETWORK_SESSION_EXPIRED => ErrorKind::SessionExpired,
NtStatus::NETWORK_NAME_DELETED | NtStatus::USER_SESSION_DELETED => {
ErrorKind::ConnectionLost
}
NtStatus::PATH_NOT_COVERED => ErrorKind::DfsReferral,
_ => ErrorKind::Other,
}
}
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn kind_maps_protocol_not_found() {
let err = Error::Protocol {
status: NtStatus::OBJECT_NAME_NOT_FOUND,
command: Command::Create,
};
assert_eq!(err.kind(), ErrorKind::NotFound);
}
#[test]
fn kind_maps_protocol_access_denied() {
let err = Error::Protocol {
status: NtStatus::ACCESS_DENIED,
command: Command::Create,
};
assert_eq!(err.kind(), ErrorKind::AccessDenied);
}
#[test]
fn kind_maps_protocol_sharing_violation() {
let err = Error::Protocol {
status: NtStatus::SHARING_VIOLATION,
command: Command::Create,
};
assert_eq!(err.kind(), ErrorKind::SharingViolation);
}
#[test]
fn kind_maps_protocol_logon_failure() {
let err = Error::Protocol {
status: NtStatus::LOGON_FAILURE,
command: Command::SessionSetup,
};
assert_eq!(err.kind(), ErrorKind::AuthRequired);
}
#[test]
fn kind_maps_protocol_bad_network_name() {
let err = Error::Protocol {
status: NtStatus::BAD_NETWORK_NAME,
command: Command::TreeConnect,
};
assert_eq!(err.kind(), ErrorKind::NotFound);
}
#[test]
fn kind_maps_protocol_disk_full() {
let err = Error::Protocol {
status: NtStatus::DISK_FULL,
command: Command::Write,
};
assert_eq!(err.kind(), ErrorKind::DiskFull);
}
#[test]
fn kind_maps_non_protocol_errors() {
assert_eq!(Error::Timeout.kind(), ErrorKind::TimedOut);
assert_eq!(Error::Disconnected.kind(), ErrorKind::ConnectionLost);
assert_eq!(Error::Cancelled.kind(), ErrorKind::Cancelled);
assert_eq!(Error::SessionExpired.kind(), ErrorKind::SessionExpired);
assert_eq!(Error::invalid_data("test").kind(), ErrorKind::InvalidData);
assert_eq!(
Error::DfsReferralRequired {
path: "test".into()
}
.kind(),
ErrorKind::DfsReferral
);
assert_eq!(
Error::Auth {
message: "test".into()
}
.kind(),
ErrorKind::AuthRequired
);
}
#[test]
fn kind_maps_io_error_to_io_not_connection_lost() {
let err = Error::Io(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled",
));
assert_eq!(err.kind(), ErrorKind::Io);
assert_ne!(err.kind(), ErrorKind::ConnectionLost);
}
#[test]
fn kind_disconnected_is_connection_lost() {
assert_eq!(Error::Disconnected.kind(), ErrorKind::ConnectionLost);
}
#[test]
fn kind_unknown_status_maps_to_other() {
let err = Error::Protocol {
status: NtStatus::NOT_IMPLEMENTED,
command: Command::Ioctl,
};
assert_eq!(err.kind(), ErrorKind::Other);
}
#[test]
fn kind_maps_path_not_covered_to_dfs_referral() {
let err = Error::Protocol {
status: NtStatus::PATH_NOT_COVERED,
command: Command::Create,
};
assert_eq!(err.kind(), ErrorKind::DfsReferral);
}
#[test]
fn kind_maps_dfs_referral_required_to_dfs_referral() {
let err = Error::DfsReferralRequired {
path: r"\\server\share\path".into(),
};
assert_eq!(err.kind(), ErrorKind::DfsReferral);
}
#[test]
fn dfs_referral_is_not_retryable() {
let err = Error::Protocol {
status: NtStatus::PATH_NOT_COVERED,
command: Command::Create,
};
assert!(!err.is_retryable());
}
}