use crate::types::status::NtStatus;
use crate::types::Command;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Error {
#[error("Invalid data: {message}")]
InvalidData {
message: String,
},
#[error("Protocol error: {status} during {command:?}")]
Protocol {
status: NtStatus,
command: Command,
},
#[error("Authentication failed: {message}")]
Auth {
message: String,
},
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Operation timed out")]
Timeout,
#[error("Disconnected from server")]
Disconnected,
#[error("DFS referral required for path: {path}")]
DfsReferralRequired {
path: String,
},
#[error("Operation cancelled")]
Cancelled,
#[error("Session expired and reauthentication failed")]
SessionExpired,
}
impl Error {
pub fn invalid_data(msg: impl Into<String>) -> Self {
Error::InvalidData {
message: msg.into(),
}
}
pub fn is_retryable(&self) -> bool {
matches!(
self,
Error::Timeout
| Error::Disconnected
| Error::Protocol {
status: NtStatus::INSUFFICIENT_RESOURCES,
..
}
| Error::Protocol {
status: NtStatus::INSUFF_SERVER_RESOURCES,
..
}
)
}
pub fn status(&self) -> Option<NtStatus> {
match self {
Error::Protocol { status, .. } => Some(*status),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ErrorKind {
AuthRequired,
SigningRequired,
AccessDenied,
NotFound,
AlreadyExists,
SharingViolation,
IsADirectory,
NotADirectory,
DiskFull,
ConnectionLost,
TimedOut,
Cancelled,
SessionExpired,
DfsReferral,
InvalidData,
Io,
Other,
}
impl Error {
pub fn kind(&self) -> ErrorKind {
match self {
Error::InvalidData { .. } => ErrorKind::InvalidData,
Error::Auth { .. } => ErrorKind::AuthRequired,
Error::Io(_) => ErrorKind::Io,
Error::Disconnected => ErrorKind::ConnectionLost,
Error::Timeout => ErrorKind::TimedOut,
Error::Cancelled => ErrorKind::Cancelled,
Error::SessionExpired => ErrorKind::SessionExpired,
Error::DfsReferralRequired { .. } => ErrorKind::DfsReferral,
Error::Protocol { status, .. } => classify_status(*status),
}
}
}
fn classify_status(status: NtStatus) -> ErrorKind {
match status {
NtStatus::LOGON_FAILURE | NtStatus::ACCOUNT_DISABLED => ErrorKind::AuthRequired,
NtStatus::ACCESS_DENIED => {
ErrorKind::AccessDenied
}
NtStatus::NO_SUCH_FILE
| NtStatus::OBJECT_NAME_NOT_FOUND
| NtStatus::OBJECT_PATH_NOT_FOUND
| NtStatus::BAD_NETWORK_NAME => ErrorKind::NotFound,
NtStatus::OBJECT_NAME_COLLISION => ErrorKind::AlreadyExists,
NtStatus::FILE_IS_A_DIRECTORY => ErrorKind::IsADirectory,
NtStatus::NOT_A_DIRECTORY => ErrorKind::NotADirectory,
NtStatus::SHARING_VIOLATION | NtStatus::FILE_LOCK_CONFLICT => ErrorKind::SharingViolation,
NtStatus::DISK_FULL => ErrorKind::DiskFull,
NtStatus::NETWORK_SESSION_EXPIRED => ErrorKind::SessionExpired,
NtStatus::NETWORK_NAME_DELETED | NtStatus::USER_SESSION_DELETED => {
ErrorKind::ConnectionLost
}
NtStatus::PATH_NOT_COVERED => ErrorKind::DfsReferral,
_ => ErrorKind::Other,
}
}
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(test)]
mod tests {
use super::*;
const STATUS_CLASSIFICATION_CONTRACT: &[(NtStatus, ErrorKind)] = &[
(NtStatus::LOGON_FAILURE, ErrorKind::AuthRequired),
(NtStatus::ACCOUNT_DISABLED, ErrorKind::AuthRequired),
(NtStatus::ACCESS_DENIED, ErrorKind::AccessDenied),
(NtStatus::NO_SUCH_FILE, ErrorKind::NotFound),
(NtStatus::OBJECT_NAME_NOT_FOUND, ErrorKind::NotFound),
(NtStatus::OBJECT_PATH_NOT_FOUND, ErrorKind::NotFound),
(NtStatus::BAD_NETWORK_NAME, ErrorKind::NotFound),
(NtStatus::OBJECT_NAME_COLLISION, ErrorKind::AlreadyExists),
(NtStatus::FILE_IS_A_DIRECTORY, ErrorKind::IsADirectory),
(NtStatus::NOT_A_DIRECTORY, ErrorKind::NotADirectory),
(NtStatus::SHARING_VIOLATION, ErrorKind::SharingViolation),
(NtStatus::FILE_LOCK_CONFLICT, ErrorKind::SharingViolation),
(NtStatus::DISK_FULL, ErrorKind::DiskFull),
(NtStatus::NETWORK_NAME_DELETED, ErrorKind::ConnectionLost),
(NtStatus::USER_SESSION_DELETED, ErrorKind::ConnectionLost),
(NtStatus::NETWORK_SESSION_EXPIRED, ErrorKind::SessionExpired),
(NtStatus::PATH_NOT_COVERED, ErrorKind::DfsReferral),
(NtStatus::NOT_IMPLEMENTED, ErrorKind::Other),
(NtStatus::INVALID_PARAMETER, ErrorKind::Other),
(NtStatus::DELETE_PENDING, ErrorKind::Other),
(NtStatus::INSUFFICIENT_RESOURCES, ErrorKind::Other),
(NtStatus::INSUFF_SERVER_RESOURCES, ErrorKind::Other),
];
#[test]
fn classify_status_contract() {
for (status, expected) in STATUS_CLASSIFICATION_CONTRACT {
let err = Error::Protocol {
status: *status,
command: Command::Create,
};
assert_eq!(
err.kind(),
*expected,
"{status} should classify as {expected:?}"
);
}
}
#[test]
fn kind_maps_non_protocol_errors() {
assert_eq!(Error::Timeout.kind(), ErrorKind::TimedOut);
assert_eq!(Error::Disconnected.kind(), ErrorKind::ConnectionLost);
assert_eq!(Error::Cancelled.kind(), ErrorKind::Cancelled);
assert_eq!(Error::SessionExpired.kind(), ErrorKind::SessionExpired);
assert_eq!(Error::invalid_data("test").kind(), ErrorKind::InvalidData);
assert_eq!(
Error::DfsReferralRequired {
path: "test".into()
}
.kind(),
ErrorKind::DfsReferral
);
assert_eq!(
Error::Auth {
message: "test".into()
}
.kind(),
ErrorKind::AuthRequired
);
}
#[test]
fn kind_maps_io_error_to_io_not_connection_lost() {
let err = Error::Io(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled",
));
assert_eq!(err.kind(), ErrorKind::Io);
assert_ne!(err.kind(), ErrorKind::ConnectionLost);
}
#[test]
fn kind_disconnected_is_connection_lost() {
assert_eq!(Error::Disconnected.kind(), ErrorKind::ConnectionLost);
}
#[test]
fn kind_maps_dfs_referral_required_to_dfs_referral() {
let err = Error::DfsReferralRequired {
path: r"\\server\share\path".into(),
};
assert_eq!(err.kind(), ErrorKind::DfsReferral);
}
#[test]
fn dfs_referral_is_not_retryable() {
let err = Error::Protocol {
status: NtStatus::PATH_NOT_COVERED,
command: Command::Create,
};
assert!(!err.is_retryable());
}
}