use std::{fmt, sync::Arc};
use serde::Deserialize;
use thiserror::Error;
use crate::{bson::Document, options::StreamAddress};
const RECOVERING_CODES: [i32; 5] = [11600, 11602, 13436, 189, 91];
const NOTMASTER_CODES: [i32; 3] = [10107, 13435, 10058];
const SHUTTING_DOWN_CODES: [i32; 2] = [11600, 91];
const RETRYABLE_READ_CODES: [i32; 11] =
[11600, 11602, 10107, 13435, 13436, 189, 91, 7, 6, 89, 9001];
const RETRYABLE_WRITE_CODES: [i32; 12] = [
11600, 11602, 10107, 13435, 13436, 189, 91, 7, 6, 89, 9001, 262,
];
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Clone, Debug, Error)]
#[error("{kind}")]
#[non_exhaustive]
pub struct Error {
pub kind: ErrorKind,
labels: Vec<String>,
}
impl Error {
pub(crate) fn pool_cleared_error(address: &StreamAddress) -> Self {
ErrorKind::ConnectionPoolClearedError {
message: format!(
"Connection pool for {} cleared during operation execution",
address
),
}
.into()
}
pub(crate) fn authentication_error(mechanism_name: &str, reason: &str) -> Self {
ErrorKind::AuthenticationError {
message: format!("{} failure: {}", mechanism_name, reason),
}
.into()
}
pub(crate) fn unknown_authentication_error(mechanism_name: &str) -> Error {
Error::authentication_error(mechanism_name, "internal error")
}
pub(crate) fn invalid_authentication_response(mechanism_name: &str) -> Error {
Error::authentication_error(mechanism_name, "invalid server response")
}
pub(crate) fn is_state_change_error(&self) -> bool {
self.is_recovering() || self.is_not_master()
}
pub(crate) fn is_auth_error(&self) -> bool {
matches!(self.kind, ErrorKind::AuthenticationError { .. })
}
pub(crate) fn is_command_error(&self) -> bool {
matches!(self.kind, ErrorKind::CommandError(_))
}
pub(crate) fn is_network_timeout(&self) -> bool {
matches!(self.kind, ErrorKind::Io(ref io_err) if io_err.kind() == std::io::ErrorKind::TimedOut)
}
pub(crate) fn is_ns_not_found(&self) -> bool {
matches!(self.kind, ErrorKind::CommandError(ref err) if err.code == 26)
}
pub(crate) fn is_read_retryable(&self) -> bool {
if self.is_network_error() {
return true;
}
match &self.kind.code() {
Some(code) => {
RETRYABLE_READ_CODES.contains(&code)
}
None => false,
}
}
pub(crate) fn is_write_retryable(&self) -> bool {
self.contains_label("RetryableWriteError")
}
pub(crate) fn should_add_retryable_write_label(&self, max_wire_version: i32) -> bool {
if max_wire_version > 8 {
return self.is_network_error();
}
if self.is_network_error() {
return true;
}
match &self.kind.code() {
Some(code) => RETRYABLE_WRITE_CODES.contains(&code),
None => false,
}
}
pub(crate) fn is_server_error(&self) -> bool {
matches!(
self.kind,
ErrorKind::AuthenticationError { .. }
| ErrorKind::BulkWriteError(_)
| ErrorKind::CommandError(_)
| ErrorKind::WriteError(_)
)
}
pub fn labels(&self) -> &[String] {
match self.kind {
ErrorKind::CommandError(ref err) => &err.labels,
ErrorKind::WriteError(ref err) => match err {
WriteFailure::WriteError(_) => &self.labels,
WriteFailure::WriteConcernError(ref err) => &err.labels,
},
ErrorKind::BulkWriteError(ref err) => match err.write_concern_error {
Some(ref err) => &err.labels,
None => &self.labels,
},
_ => &self.labels,
}
}
pub fn contains_label<T: AsRef<str>>(&self, label: T) -> bool {
self.labels()
.iter()
.any(|actual_label| actual_label.as_str() == label.as_ref())
}
pub(crate) fn with_label<T: AsRef<str>>(mut self, label: T) -> Self {
let label = label.as_ref().to_string();
match self.kind {
ErrorKind::CommandError(ref err) => {
let mut err = err.clone();
err.labels.push(label);
ErrorKind::CommandError(err).into()
}
ErrorKind::WriteError(ref err) => match err {
WriteFailure::WriteError(_) => {
self.labels.push(label);
self
}
WriteFailure::WriteConcernError(ref err) => {
let mut err = err.clone();
err.labels.push(label);
ErrorKind::WriteError(WriteFailure::WriteConcernError(err)).into()
}
},
ErrorKind::BulkWriteError(ref err) => match err.write_concern_error {
Some(ref write_concern_error) => {
let mut err = err.clone();
let mut write_concern_error = write_concern_error.clone();
write_concern_error.labels.push(label);
err.write_concern_error = Some(write_concern_error);
ErrorKind::BulkWriteError(err).into()
}
None => {
self.labels.push(label);
self
}
},
_ => {
self.labels.push(label);
self
}
}
}
}
impl<E> From<E> for Error
where
ErrorKind: From<E>,
{
fn from(err: E) -> Self {
Self {
kind: err.into(),
labels: Vec::new(),
}
}
}
impl From<bson::de::Error> for ErrorKind {
fn from(err: bson::de::Error) -> Self {
Self::BsonDecode(Arc::new(err))
}
}
impl From<bson::ser::Error> for ErrorKind {
fn from(err: bson::ser::Error) -> Self {
Self::BsonEncode(Arc::new(err))
}
}
impl From<std::io::Error> for ErrorKind {
fn from(err: std::io::Error) -> Self {
Self::Io(Arc::new(err))
}
}
impl From<std::io::ErrorKind> for ErrorKind {
fn from(err: std::io::ErrorKind) -> Self {
Self::Io(Arc::new(err.into()))
}
}
impl std::ops::Deref for Error {
type Target = ErrorKind;
fn deref(&self) -> &Self::Target {
&self.kind
}
}
#[allow(missing_docs)]
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum ErrorKind {
#[error("{0}")]
AddrParse(#[from] std::net::AddrParseError),
#[error("An invalid argument was provided to a database operation: {message}")]
#[non_exhaustive]
ArgumentError { message: String },
#[error("{message}")]
#[non_exhaustive]
AuthenticationError { message: String },
#[error("{0}")]
BsonDecode(Arc<crate::bson::de::Error>),
#[error("{0}")]
BsonEncode(Arc<crate::bson::ser::Error>),
#[error("An error occurred when trying to execute a write operation: {0:?}")]
BulkWriteError(BulkWriteFailure),
#[error("Command failed {0}")]
CommandError(CommandError),
#[error("{0}")]
DnsResolve(trust_dns_resolver::error::ResolveError),
#[error("Internal error: {message}")]
#[non_exhaustive]
InternalError { message: String },
#[error("{0}")]
InvalidDnsName(#[from] webpki::InvalidDNSNameError),
#[error("Unable to parse hostname: {hostname}")]
#[non_exhaustive]
InvalidHostname { hostname: String },
#[error("{0}")]
Io(Arc<std::io::Error>),
#[error("No DNS results for domain {0}")]
NoDnsResults(StreamAddress),
#[error("A database operation failed to send or receive a reply: {message}")]
#[non_exhaustive]
OperationError { message: String },
#[error("Unable to parse {data_type} data from {file_path}")]
#[non_exhaustive]
ParseError {
data_type: String,
file_path: String,
},
#[error("{message}")]
#[non_exhaustive]
ConnectionPoolClearedError { message: String },
#[error("The server returned an invalid reply to a database operation: {message}")]
#[non_exhaustive]
ResponseError { message: String },
#[error("{message}")]
#[non_exhaustive]
ServerSelectionError { message: String },
#[error("An error occurred during SRV record lookup: {message}")]
#[non_exhaustive]
SrvLookupError { message: String },
#[error("Attempted to start a session on a deployment that does not support sessions")]
SessionsNotSupported,
#[error("{0}")]
RustlsConfig(#[from] rustls::TLSError),
#[error("An error occurred during TXT record lookup: {message}")]
#[non_exhaustive]
TxtLookupError { message: String },
#[error(
"Timed out while checking out a connection from connection pool with address {address}"
)]
#[non_exhaustive]
WaitQueueTimeoutError { address: StreamAddress },
#[error("An error occurred when trying to execute a write operation: {0:?}")]
WriteError(WriteFailure),
}
impl From<trust_dns_resolver::error::ResolveError> for ErrorKind {
fn from(error: trust_dns_resolver::error::ResolveError) -> Self {
Self::DnsResolve(error)
}
}
impl ErrorKind {
pub(crate) fn is_non_timeout_network_error(&self) -> bool {
matches!(self, ErrorKind::Io(ref io_err) if io_err.kind() != std::io::ErrorKind::TimedOut)
}
pub(crate) fn is_network_error(&self) -> bool {
matches!(
self,
ErrorKind::Io(..) | ErrorKind::ConnectionPoolClearedError { .. }
)
}
pub(crate) fn code(&self) -> Option<i32> {
match self {
ErrorKind::CommandError(command_error) => {
Some(command_error.code)
},
ErrorKind::BulkWriteError(BulkWriteFailure { write_concern_error: Some(wc_error), .. }) => {
Some(wc_error.code)
}
ErrorKind::WriteError(WriteFailure::WriteConcernError(wc_error)) => Some(wc_error.code),
_ => None
}
}
#[cfg(test)]
pub(crate) fn server_message(&self) -> Option<String> {
match self {
ErrorKind::CommandError(command_error) => {
Some(command_error.message.clone())
},
ErrorKind::BulkWriteError(BulkWriteFailure { write_concern_error, write_errors }) => {
let mut msg = "".to_string();
if let Some(wc_error) = write_concern_error {
msg.push_str(wc_error.message.as_str());
}
if let Some(write_errors) = write_errors {
for we in write_errors {
msg.push_str(we.message.as_str());
}
}
Some(msg)
}
ErrorKind::WriteError(WriteFailure::WriteConcernError(wc_error)) => Some(wc_error.message.clone()),
ErrorKind::WriteError(WriteFailure::WriteError(write_error)) => Some(write_error.message.clone()),
_ => None
}
}
#[cfg(test)]
pub(crate) fn code_name(&self) -> Option<&str> {
match self {
ErrorKind::CommandError(ref cmd_err) => Some(cmd_err.code_name.as_str()),
ErrorKind::WriteError(ref failure) => match failure {
WriteFailure::WriteConcernError(ref wce) => Some(wce.code_name.as_str()),
WriteFailure::WriteError(ref we) => we.code_name.as_deref(),
},
ErrorKind::BulkWriteError(ref bwe) => bwe
.write_concern_error
.as_ref()
.map(|wce| wce.code_name.as_str()),
_ => None,
}
}
pub(crate) fn is_not_master(&self) -> bool {
self.code().map(|code| NOTMASTER_CODES.contains(&code)).unwrap_or(false)
}
pub(crate) fn is_recovering(&self) -> bool {
self.code()
.map(|code| RECOVERING_CODES.contains(&code))
.unwrap_or(false)
}
pub(crate) fn is_shutting_down(&self) -> bool {
self.code()
.map(|code| SHUTTING_DOWN_CODES.contains(&code))
.unwrap_or(false)
}
}
#[derive(Clone, Debug, Deserialize)]
#[non_exhaustive]
pub struct CommandError {
pub code: i32,
#[serde(rename = "codeName", default)]
pub code_name: String,
#[serde(rename = "errmsg")]
pub message: String,
#[serde(rename = "errorLabels", default)]
pub labels: Vec<String>,
}
impl fmt::Display for CommandError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "({}): {})", self.code_name, self.message)
}
}
#[derive(Clone, Debug, Deserialize, PartialEq)]
#[non_exhaustive]
pub struct WriteConcernError {
pub code: i32,
#[serde(rename = "codeName", default)]
pub code_name: String,
#[serde(rename = "errmsg")]
pub message: String,
#[serde(rename = "errInfo")]
pub details: Option<Document>,
#[serde(rename = "errorLabels", default)]
pub labels: Vec<String>,
}
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub struct WriteError {
pub code: i32,
pub code_name: Option<String>,
pub message: String,
}
#[derive(Debug, PartialEq, Clone, Deserialize)]
#[non_exhaustive]
pub struct BulkWriteError {
#[serde(default)]
pub index: usize,
pub code: i32,
#[serde(rename = "codeName", default)]
pub code_name: Option<String>,
#[serde(rename = "errmsg")]
pub message: String,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct BulkWriteFailure {
pub write_errors: Option<Vec<BulkWriteError>>,
pub write_concern_error: Option<WriteConcernError>,
}
impl BulkWriteFailure {
pub(crate) fn new() -> Self {
BulkWriteFailure {
write_errors: None,
write_concern_error: None,
}
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum WriteFailure {
WriteConcernError(WriteConcernError),
WriteError(WriteError),
}
impl WriteFailure {
fn from_bulk_failure(bulk: BulkWriteFailure) -> Result<Self> {
if let Some(bulk_write_error) = bulk.write_errors.and_then(|es| es.into_iter().next()) {
let write_error = WriteError {
code: bulk_write_error.code,
code_name: bulk_write_error.code_name,
message: bulk_write_error.message,
};
Ok(WriteFailure::WriteError(write_error))
} else if let Some(wc_error) = bulk.write_concern_error {
Ok(WriteFailure::WriteConcernError(wc_error))
} else {
Err(ErrorKind::ResponseError {
message: "error missing write errors and write concern errors".to_string(),
}
.into())
}
}
}
pub(crate) fn convert_bulk_errors(error: Error) -> Error {
match error.kind {
ErrorKind::BulkWriteError(ref bulk_failure) => {
match WriteFailure::from_bulk_failure(bulk_failure.clone()) {
Ok(failure) => ErrorKind::WriteError(failure).into(),
Err(e) => e,
}
}
_ => error,
}
}