use std::{fmt, sync::Arc};
use err_derive::Error;
use lazy_static::lazy_static;
use serde::Deserialize;
use time::OutOfRangeError;
use crate::{bson::Document, options::StreamAddress};
lazy_static! {
static ref RECOVERING_CODES: Vec<i32> = vec![11600, 11602, 13436, 189, 91];
static ref NOTMASTER_CODES: Vec<i32> = vec![10107, 13435];
static ref SHUTTING_DOWN_CODES: Vec<i32> = vec![11600, 91];
static ref RETRYABLE_READ_CODES: Vec<i32> =
vec![11600, 11602, 10107, 13435, 13436, 189, 91, 7, 6, 89, 9001];
static ref RETRYABLE_WRITE_CODES: Vec<i32> =
vec![11600, 11602, 10107, 13435, 13436, 189, 91, 7, 6, 89, 9001, 262];
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Clone, Debug, Error)]
#[error(display = "{}", kind)]
#[non_exhaustive]
pub struct Error {
pub kind: Arc<ErrorKind>,
labels: Vec<String>,
}
impl Error {
pub(crate) fn pool_cleared_error(address: &StreamAddress) -> Self {
ErrorKind::ConnectionPoolClearedError {
message: format!(
"Connection pool for {} cleared during operation execution",
address
),
}
.into()
}
pub(crate) fn authentication_error(mechanism_name: &str, reason: &str) -> Self {
ErrorKind::AuthenticationError {
message: format!("{} failure: {}", mechanism_name, reason),
}
.into()
}
pub(crate) fn unknown_authentication_error(mechanism_name: &str) -> Error {
Error::authentication_error(mechanism_name, "internal error")
}
pub(crate) fn invalid_authentication_response(mechanism_name: &str) -> Error {
Error::authentication_error(mechanism_name, "invalid server response")
}
pub(crate) fn is_ns_not_found(&self) -> bool {
matches!(self.kind.as_ref(), ErrorKind::CommandError(err) if err.code == 26)
}
pub(crate) fn is_read_retryable(&self) -> bool {
if self.is_network_error() {
return true;
}
match &self.kind.code_and_message() {
Some((code, message)) => {
if RETRYABLE_READ_CODES.contains(&code) {
return true;
}
if is_not_master(*code, message) || is_recovering(*code, message) {
return true;
}
false
}
None => false,
}
}
pub(crate) fn is_write_retryable(&self) -> bool {
self.contains_label("RetryableWriteError")
}
pub(crate) fn should_add_retryable_write_label(&self, max_wire_version: i32) -> bool {
if max_wire_version > 8 {
return self.is_network_error();
}
if self.is_network_error() {
return true;
}
match &self.kind.code_and_message() {
Some((code, _)) => RETRYABLE_WRITE_CODES.contains(&code),
None => false,
}
}
pub(crate) fn is_server_error(&self) -> bool {
matches!(
self.kind.as_ref(),
ErrorKind::AuthenticationError { .. }
| ErrorKind::BulkWriteError(_)
| ErrorKind::CommandError(_)
| ErrorKind::WriteError(_)
)
}
pub fn labels(&self) -> &[String] {
match self.kind.as_ref() {
ErrorKind::CommandError(err) => &err.labels,
ErrorKind::WriteError(err) => match err {
WriteFailure::WriteError(_) => &self.labels,
WriteFailure::WriteConcernError(err) => &err.labels,
},
ErrorKind::BulkWriteError(err) => match err.write_concern_error {
Some(ref err) => &err.labels,
None => &self.labels,
},
_ => &self.labels,
}
}
pub fn contains_label<T: AsRef<str>>(&self, label: T) -> bool {
self.labels()
.iter()
.any(|actual_label| actual_label.as_str() == label.as_ref())
}
pub(crate) fn with_label<T: AsRef<str>>(mut self, label: T) -> Self {
let label = label.as_ref().to_string();
match self.kind.as_ref() {
ErrorKind::CommandError(err) => {
let mut err = err.clone();
err.labels.push(label);
ErrorKind::CommandError(err).into()
}
ErrorKind::WriteError(err) => match err {
WriteFailure::WriteError(_) => {
self.labels.push(label);
self
}
WriteFailure::WriteConcernError(err) => {
let mut err = err.clone();
err.labels.push(label);
ErrorKind::WriteError(WriteFailure::WriteConcernError(err)).into()
}
},
ErrorKind::BulkWriteError(err) => match err.write_concern_error {
Some(ref write_concern_error) => {
let mut err = err.clone();
let mut write_concern_error = write_concern_error.clone();
write_concern_error.labels.push(label);
err.write_concern_error = Some(write_concern_error);
ErrorKind::BulkWriteError(err).into()
}
None => {
self.labels.push(label);
self
}
},
_ => {
self.labels.push(label);
self
}
}
}
}
impl<E> From<E> for Error
where
ErrorKind: From<E>,
{
fn from(err: E) -> Self {
Self {
kind: Arc::new(err.into()),
labels: Vec::new(),
}
}
}
impl std::ops::Deref for Error {
type Target = Arc<ErrorKind>;
fn deref(&self) -> &Self::Target {
&self.kind
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum ErrorKind {
#[error(display = "{}", _0)]
AddrParse(#[error(source)] std::net::AddrParseError),
#[error(
display = "An invalid argument was provided to a database operation: {}",
message
)]
#[non_exhaustive]
ArgumentError { message: String },
#[cfg(feature = "async-std-runtime")]
#[error(display = "{}", _0)]
AsyncStdTimeout(#[error(source)] async_std::future::TimeoutError),
#[error(display = "{}", message)]
#[non_exhaustive]
AuthenticationError { message: String },
#[error(display = "{}", _0)]
BsonDecode(#[error(source)] crate::bson::de::Error),
#[error(display = "{}", _0)]
BsonEncode(#[error(source)] crate::bson::ser::Error),
#[error(
display = "An error occurred when trying to execute a write operation: {:?}",
_0
)]
BulkWriteError(BulkWriteFailure),
#[error(display = "Command failed {}", _0)]
CommandError(CommandError),
#[error(display = "{}", _0)]
DnsResolve(trust_dns_resolver::error::ResolveError),
#[error(display = "Internal error: {}", message)]
#[non_exhaustive]
InternalError { message: String },
#[error(display = "{}", _0)]
InvalidDnsName(#[error(source)] webpki::InvalidDNSNameError),
#[error(display = "Unable to parse hostname: {}", hostname)]
#[non_exhaustive]
InvalidHostname { hostname: String },
#[error(display = "{}", _0)]
Io(#[error(source)] std::io::Error),
#[error(display = "No DNS results for domain {}", _0)]
NoDnsResults(StreamAddress),
#[error(
display = "A database operation failed to send or receive a reply: {}",
message
)]
#[non_exhaustive]
OperationError { message: String },
#[error(display = "{}", _0)]
OutOfRangeError(#[error(source)] OutOfRangeError),
#[error(display = "Unable to parse {} data from {}", data_type, file_path)]
#[non_exhaustive]
ParseError {
data_type: String,
file_path: String,
},
#[error(display = "{}", message)]
#[non_exhaustive]
ConnectionPoolClearedError { message: String },
#[error(
display = "The server returned an invalid reply to a database operation: {}",
message
)]
#[non_exhaustive]
ResponseError { message: String },
#[error(display = "{}", message)]
#[non_exhaustive]
ServerSelectionError { message: String },
#[error(display = "An error occurred during SRV record lookup: {}", message)]
#[non_exhaustive]
SrvLookupError { message: String },
#[cfg(feature = "tokio-runtime")]
#[error(display = "{}", _0)]
TokioTimeoutElapsed(#[error(source)] tokio::time::error::Elapsed),
#[error(display = "{}", _0)]
RustlsConfig(#[error(source)] rustls::TLSError),
#[error(display = "An error occurred during TXT record lookup: {}", message)]
#[non_exhaustive]
TxtLookupError { message: String },
#[error(
display = "Timed out while checking out a connection from connection pool with address {}",
address
)]
#[non_exhaustive]
WaitQueueTimeoutError { address: StreamAddress },
#[error(
display = "An error occurred when trying to execute a write operation: {:?}",
_0
)]
WriteError(WriteFailure),
}
impl From<trust_dns_resolver::error::ResolveError> for ErrorKind {
fn from(error: trust_dns_resolver::error::ResolveError) -> Self {
Self::DnsResolve(error)
}
}
impl ErrorKind {
pub(crate) fn is_non_timeout_network_error(&self) -> bool {
matches!(self, ErrorKind::Io(ref io_err) if io_err.kind() != std::io::ErrorKind::TimedOut)
}
pub(crate) fn is_network_error(&self) -> bool {
matches!(
self,
ErrorKind::Io(..) | ErrorKind::ConnectionPoolClearedError { .. }
)
}
pub(crate) fn code_and_message(&self) -> Option<(i32, &str)> {
match self {
ErrorKind::CommandError(ref cmd_err) => Some((cmd_err.code, cmd_err.message.as_str())),
ErrorKind::WriteError(WriteFailure::WriteConcernError(ref wc_err)) => {
Some((wc_err.code, wc_err.message.as_str()))
}
ErrorKind::BulkWriteError(ref bwe) => bwe
.write_concern_error
.as_ref()
.map(|wc_err| (wc_err.code, wc_err.message.as_str())),
_ => None,
}
}
#[cfg(test)]
pub(crate) fn code_name(&self) -> Option<&str> {
match self {
ErrorKind::CommandError(ref cmd_err) => Some(cmd_err.code_name.as_str()),
ErrorKind::WriteError(ref failure) => match failure {
WriteFailure::WriteConcernError(ref wce) => Some(wce.code_name.as_str()),
WriteFailure::WriteError(ref we) => we.code_name.as_deref(),
},
ErrorKind::BulkWriteError(ref bwe) => bwe
.write_concern_error
.as_ref()
.map(|wce| wce.code_name.as_str()),
_ => None,
}
}
pub(crate) fn is_not_master(&self) -> bool {
self.code_and_message()
.map(|(code, msg)| is_not_master(code, msg))
.unwrap_or(false)
}
pub(crate) fn is_recovering(&self) -> bool {
self.code_and_message()
.map(|(code, msg)| is_recovering(code, msg))
.unwrap_or(false)
}
pub(crate) fn is_shutting_down(&self) -> bool {
self.code_and_message()
.map(|(code, _)| SHUTTING_DOWN_CODES.contains(&code))
.unwrap_or(false)
}
}
fn is_not_master(code: i32, message: &str) -> bool {
if NOTMASTER_CODES.contains(&code) {
return true;
} else if is_recovering(code, message) {
return false;
}
message.contains("not master")
}
fn is_recovering(code: i32, message: &str) -> bool {
if RECOVERING_CODES.contains(&code) {
return true;
}
message.contains("not master or secondary") || message.contains("node is recovering")
}
#[derive(Clone, Debug, Deserialize)]
#[non_exhaustive]
pub struct CommandError {
pub code: i32,
#[serde(rename = "codeName", default)]
pub code_name: String,
#[serde(rename = "errmsg")]
pub message: String,
#[serde(rename = "errorLabels", default)]
pub labels: Vec<String>,
}
impl fmt::Display for CommandError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "({}): {})", self.code_name, self.message)
}
}
#[derive(Clone, Debug, Deserialize, PartialEq)]
#[non_exhaustive]
pub struct WriteConcernError {
pub code: i32,
#[serde(rename = "codeName", default)]
pub code_name: String,
#[serde(rename = "errmsg")]
pub message: String,
#[serde(rename = "errInfo")]
pub details: Option<Document>,
#[serde(rename = "errorLabels", default)]
pub labels: Vec<String>,
}
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub struct WriteError {
pub code: i32,
pub code_name: Option<String>,
pub message: String,
}
#[derive(Debug, PartialEq, Clone, Deserialize)]
#[non_exhaustive]
pub struct BulkWriteError {
pub index: usize,
pub code: i32,
#[serde(rename = "codeName", default)]
pub code_name: Option<String>,
#[serde(rename = "errmsg")]
pub message: String,
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct BulkWriteFailure {
pub write_errors: Option<Vec<BulkWriteError>>,
pub write_concern_error: Option<WriteConcernError>,
}
impl BulkWriteFailure {
pub(crate) fn new() -> Self {
BulkWriteFailure {
write_errors: None,
write_concern_error: None,
}
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum WriteFailure {
WriteConcernError(WriteConcernError),
WriteError(WriteError),
}
impl WriteFailure {
fn from_bulk_failure(bulk: BulkWriteFailure) -> Result<Self> {
if let Some(bulk_write_error) = bulk.write_errors.and_then(|es| es.into_iter().next()) {
let write_error = WriteError {
code: bulk_write_error.code,
code_name: bulk_write_error.code_name,
message: bulk_write_error.message,
};
Ok(WriteFailure::WriteError(write_error))
} else if let Some(wc_error) = bulk.write_concern_error {
Ok(WriteFailure::WriteConcernError(wc_error))
} else {
Err(ErrorKind::ResponseError {
message: "error missing write errors and write concern errors".to_string(),
}
.into())
}
}
}
pub(crate) fn convert_bulk_errors(error: Error) -> Error {
match *error.kind {
ErrorKind::BulkWriteError(ref bulk_failure) => {
match WriteFailure::from_bulk_failure(bulk_failure.clone()) {
Ok(failure) => ErrorKind::WriteError(failure).into(),
Err(e) => e,
}
}
_ => error,
}
}