use std::{fmt, sync::Arc};
use err_derive::Error;
use lazy_static::lazy_static;
use serde::Deserialize;
use time::OutOfRangeError;
use crate::options::StreamAddress;
lazy_static! {
static ref RECOVERING_CODES: Vec<i32> = vec![11600, 11602, 13436, 189, 91];
static ref NOTMASTER_CODES: Vec<i32> = vec![10107, 13435];
static ref SHUTTING_DOWN_CODES: Vec<i32> = vec![11600, 91];
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Clone, Debug, Error)]
#[error(display = "{}", kind)]
pub struct Error {
pub kind: Arc<ErrorKind>,
}
impl Error {
pub(crate) fn authentication_error(mechanism_name: &str, reason: &str) -> Self {
ErrorKind::AuthenticationError {
message: format!("{} failure: {}", mechanism_name, reason),
}
.into()
}
pub(crate) fn unknown_authentication_error(mechanism_name: &str) -> Error {
Error::authentication_error(mechanism_name, "internal error")
}
pub(crate) fn invalid_authentication_response(mechanism_name: &str) -> Error {
Error::authentication_error(mechanism_name, "invalid server response")
}
pub(crate) fn into_io_error(self) -> std::io::Error {
match Arc::try_unwrap(self.kind) {
Ok(ErrorKind::Io(io_error)) => io_error,
Ok(other_error_kind) => {
let error: Error = other_error_kind.into();
std::io::Error::new(std::io::ErrorKind::Other, Box::new(error))
}
Err(e) => std::io::Error::new(std::io::ErrorKind::Other, Box::new(Error { kind: e })),
}
}
pub(crate) fn is_ns_not_found(&self) -> bool {
match self.kind.as_ref() {
ErrorKind::CommandError(err) if err.code == 26 => true,
_ => false,
}
}
}
impl<E> From<E> for Error
where
ErrorKind: From<E>,
{
fn from(err: E) -> Self {
Self {
kind: Arc::new(err.into()),
}
}
}
impl std::ops::Deref for Error {
type Target = Arc<ErrorKind>;
fn deref(&self) -> &Self::Target {
&self.kind
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum ErrorKind {
#[error(display = "{}", _0)]
AddrParse(#[error(source)] std::net::AddrParseError),
#[error(
display = "An invalid argument was provided to a database operation: {}",
message
)]
#[non_exhaustive]
ArgumentError { message: String },
#[cfg(feature = "async-std-runtime")]
#[error(display = "{}", _0)]
AsyncStdTimeout(#[error(source)] async_std::future::TimeoutError),
#[error(display = "{}", message)]
#[non_exhaustive]
AuthenticationError { message: String },
#[error(display = "{}", _0)]
BsonDecode(#[error(source)] crate::bson::de::Error),
#[error(display = "{}", _0)]
BsonEncode(#[error(source)] crate::bson::ser::Error),
#[error(
display = "An error occurred when trying to execute a write operation: {:?}",
_0
)]
BulkWriteError(BulkWriteFailure),
#[error(display = "Command failed {}", _0)]
CommandError(CommandError),
#[error(display = "{}", _0)]
DnsResolve(trust_dns_resolver::error::ResolveError),
#[error(display = "Internal error: {}", message)]
#[non_exhaustive]
InternalError { message: String },
#[error(display = "{}", _0)]
InvalidDnsName(#[error(source)] webpki::InvalidDNSNameError),
#[error(display = "Unable to parse hostname: {}", hostname)]
#[non_exhaustive]
InvalidHostname { hostname: String },
#[error(display = "{}", _0)]
Io(#[error(source)] std::io::Error),
#[error(display = "No DNS results for domain {}", _0)]
NoDnsResults(StreamAddress),
#[error(
display = "A database operation failed to send or receive a reply: {}",
message
)]
#[non_exhaustive]
OperationError { message: String },
#[error(display = "{}", _0)]
OutOfRangeError(#[error(source)] OutOfRangeError),
#[error(display = "Unable to parse {} data from {}", data_type, file_path)]
#[non_exhaustive]
ParseError {
data_type: String,
file_path: String,
},
#[error(
display = "The server returned an invalid reply to a database operation: {}",
message
)]
#[non_exhaustive]
ResponseError { message: String },
#[error(display = "{}", message)]
#[non_exhaustive]
ServerSelectionError { message: String },
#[error(display = "An error occurred during SRV record lookup: {}", message)]
#[non_exhaustive]
SrvLookupError { message: String },
#[cfg(feature = "tokio-runtime")]
#[error(display = "{}", _0)]
TokioTimeoutElapsed(#[error(source)] tokio::time::Elapsed),
#[error(display = "{}", _0)]
RustlsConfig(#[error(source)] rustls::TLSError),
#[error(display = "An error occurred during TXT record lookup: {}", message)]
#[non_exhaustive]
TxtLookupError { message: String },
#[error(
display = "Timed out while checking out a connection from connection pool with address {}",
address
)]
#[non_exhaustive]
WaitQueueTimeoutError { address: StreamAddress },
#[error(
display = "An error occurred when trying to execute a write operation: {:?}",
_0
)]
WriteError(WriteFailure),
}
impl From<trust_dns_resolver::error::ResolveError> for ErrorKind {
fn from(error: trust_dns_resolver::error::ResolveError) -> Self {
Self::DnsResolve(error)
}
}
impl ErrorKind {
pub(crate) fn is_non_timeout_network_error(&self) -> bool {
match self {
ErrorKind::Io(ref io_err) if io_err.kind() != std::io::ErrorKind::TimedOut => true,
_ => false,
}
}
pub(crate) fn is_network_error(&self) -> bool {
match self {
ErrorKind::Io(..) => true,
_ => false,
}
}
pub(crate) fn is_authentication_error(&self) -> bool {
match self {
ErrorKind::AuthenticationError { .. } => true,
_ => false,
}
}
fn code_and_message(&self) -> Option<(i32, &str)> {
match self {
ErrorKind::CommandError(ref cmd_err) => Some((cmd_err.code, cmd_err.message.as_str())),
ErrorKind::WriteError(WriteFailure::WriteConcernError(ref wc_err)) => {
Some((wc_err.code, wc_err.message.as_str()))
}
ErrorKind::BulkWriteError(ref bwe) => bwe
.write_concern_error
.as_ref()
.map(|wc_err| (wc_err.code, wc_err.message.as_str())),
_ => None,
}
}
pub(crate) fn is_not_master(&self) -> bool {
self.code_and_message()
.map(|(code, msg)| is_not_master(code, msg))
.unwrap_or(false)
}
pub(crate) fn is_recovering(&self) -> bool {
self.code_and_message()
.map(|(code, msg)| is_recovering(code, msg))
.unwrap_or(false)
}
pub(crate) fn is_shutting_down(&self) -> bool {
self.code_and_message()
.map(|(code, _)| SHUTTING_DOWN_CODES.contains(&code))
.unwrap_or(false)
}
}
fn is_not_master(code: i32, message: &str) -> bool {
if NOTMASTER_CODES.contains(&code) {
return true;
} else if is_recovering(code, message) {
return false;
}
message.contains("not master")
}
fn is_recovering(code: i32, message: &str) -> bool {
if RECOVERING_CODES.contains(&code) {
return true;
}
message.contains("not master or secondary") || message.contains("node is recovering")
}
#[derive(Clone, Debug, Deserialize)]
#[non_exhaustive]
pub struct CommandError {
pub code: i32,
#[serde(rename = "codeName", default)]
pub code_name: String,
#[serde(rename = "errmsg")]
pub message: String,
#[serde(rename = "errorLabels", default)]
pub labels: Vec<String>,
}
impl fmt::Display for CommandError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "({}): {})", self.code_name, self.message)
}
}
#[derive(Clone, Debug, Deserialize, PartialEq)]
#[non_exhaustive]
pub struct WriteConcernError {
pub code: i32,
#[serde(rename = "codeName")]
pub code_name: String,
#[serde(rename = "errmsg")]
pub message: String,
}
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub struct WriteError {
pub code: i32,
pub code_name: Option<String>,
pub message: String,
}
#[derive(Debug, PartialEq, Clone, Deserialize)]
#[non_exhaustive]
pub struct BulkWriteError {
pub index: usize,
pub code: i32,
#[serde(rename = "codeName", default)]
pub code_name: Option<String>,
#[serde(rename = "errmsg")]
pub message: String,
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct BulkWriteFailure {
pub write_errors: Option<Vec<BulkWriteError>>,
pub write_concern_error: Option<WriteConcernError>,
}
impl BulkWriteFailure {
pub(crate) fn new() -> Self {
BulkWriteFailure {
write_errors: None,
write_concern_error: None,
}
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum WriteFailure {
WriteConcernError(WriteConcernError),
WriteError(WriteError),
}
impl WriteFailure {
fn from_bulk_failure(bulk: BulkWriteFailure) -> Result<Self> {
if let Some(bulk_write_error) = bulk.write_errors.and_then(|es| es.into_iter().next()) {
let write_error = WriteError {
code: bulk_write_error.code,
code_name: bulk_write_error.code_name,
message: bulk_write_error.message,
};
Ok(WriteFailure::WriteError(write_error))
} else if let Some(wc_error) = bulk.write_concern_error {
Ok(WriteFailure::WriteConcernError(wc_error))
} else {
Err(ErrorKind::ResponseError {
message: "error missing write errors and write concern errors".to_string(),
}
.into())
}
}
}
pub(crate) fn convert_bulk_errors(error: Error) -> Error {
match *error.kind {
ErrorKind::BulkWriteError(ref bulk_failure) => {
match WriteFailure::from_bulk_failure(bulk_failure.clone()) {
Ok(failure) => ErrorKind::WriteError(failure).into(),
Err(e) => e,
}
}
_ => error,
}
}