use std::{error::Error as StdError, fmt, num::ParseIntError, sync::Arc};
use as_variant::as_variant;
use bytes::{BufMut, Bytes};
use serde::{Deserialize, Serialize};
use serde_json::{Value as JsonValue, from_slice as from_json_slice};
use thiserror::Error;
mod kind;
mod kind_serde;
#[cfg(test)]
mod tests;
pub use self::kind::*;
use super::{EndpointError, MatrixVersion, OutgoingResponse};
#[derive(Clone, Debug)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
pub struct Error {
pub status_code: http::StatusCode,
pub body: ErrorBody,
}
impl Error {
pub fn new(status_code: http::StatusCode, body: ErrorBody) -> Self {
Self { status_code, body }
}
pub fn error_kind(&self) -> Option<&ErrorKind> {
as_variant!(&self.body, ErrorBody::Standard(StandardErrorBody { kind, .. }) => kind)
}
pub fn is_endpoint_not_implemented(&self) -> bool {
self.status_code == http::StatusCode::NOT_FOUND
&& self
.error_kind()
.is_some_and(|error_kind| matches!(error_kind, ErrorKind::Unrecognized))
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let status_code = self.status_code.as_u16();
match &self.body {
ErrorBody::Standard(StandardErrorBody { kind, message }) => {
let errcode = kind.errcode();
write!(f, "[{status_code} / {errcode}] {message}")
}
ErrorBody::Json(json) => write!(f, "[{status_code}] {json}"),
ErrorBody::NotJson { .. } => write!(f, "[{status_code}] <non-json bytes>"),
}
}
}
impl StdError for Error {}
impl OutgoingResponse for Error {
fn try_into_http_response<T: Default + BufMut>(
self,
) -> Result<http::Response<T>, IntoHttpError> {
let mut builder = http::Response::builder()
.header(http::header::CONTENT_TYPE, ruma_common::http_headers::APPLICATION_JSON)
.status(self.status_code);
if let Some(ErrorKind::LimitExceeded(LimitExceededErrorData {
retry_after: Some(retry_after),
})) = self.error_kind()
{
let header_value = http::HeaderValue::try_from(retry_after)?;
builder = builder.header(http::header::RETRY_AFTER, header_value);
}
builder
.body(match self.body {
ErrorBody::Standard(standard_body) => {
ruma_common::serde::json_to_buf(&standard_body)?
}
ErrorBody::Json(json) => ruma_common::serde::json_to_buf(&json)?,
ErrorBody::NotJson { .. } => {
return Err(IntoHttpError::Json(serde::ser::Error::custom(
"attempted to serialize ErrorBody::NotJson",
)));
}
})
.map_err(Into::into)
}
}
impl EndpointError for Error {
fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
let status = response.status();
let body_bytes = &response.body().as_ref();
let error_body: ErrorBody = match from_json_slice::<StandardErrorBody>(body_bytes) {
Ok(mut standard_body) => {
let headers = response.headers();
if let ErrorKind::LimitExceeded(LimitExceededErrorData { retry_after }) =
&mut standard_body.kind
{
if let Some(Ok(retry_after_header)) =
headers.get(http::header::RETRY_AFTER).map(RetryAfter::try_from)
{
*retry_after = Some(retry_after_header);
}
}
ErrorBody::Standard(standard_body)
}
Err(_) => match from_json_slice(body_bytes) {
Ok(json) => ErrorBody::Json(json),
Err(error) => ErrorBody::NotJson {
bytes: Bytes::copy_from_slice(body_bytes),
deserialization_error: Arc::new(error),
},
},
};
error_body.into_error(status)
}
}
#[derive(Debug, Clone)]
#[allow(clippy::exhaustive_enums)]
pub enum ErrorBody {
Standard(StandardErrorBody),
Json(JsonValue),
NotJson {
bytes: Bytes,
deserialization_error: Arc<serde_json::Error>,
},
}
impl ErrorBody {
pub fn into_error(self, status_code: http::StatusCode) -> Error {
Error { status_code, body: self }
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
pub struct StandardErrorBody {
#[serde(flatten)]
pub kind: ErrorKind,
#[serde(rename = "error")]
pub message: String,
}
impl StandardErrorBody {
pub fn new(kind: ErrorKind, message: String) -> Self {
Self { kind, message }
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum IntoHttpError {
#[error("failed to add authentication scheme: {0}")]
Authentication(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error(
"endpoint was not supported by server-reported versions, \
but no unstable path to fall back to was defined"
)]
NoUnstablePath,
#[error(
"could not create any path variant for endpoint, as it was removed in version {}",
.0.as_str().expect("no endpoint was removed in Matrix 1.0")
)]
EndpointRemoved(MatrixVersion),
#[error("JSON serialization failed: {0}")]
Json(#[from] serde_json::Error),
#[error("query parameter serialization failed: {0}")]
Query(#[from] serde_html_form::ser::Error),
#[error("header serialization failed: {0}")]
Header(#[from] HeaderSerializationError),
#[error("HTTP request construction failed: {0}")]
Http(#[from] http::Error),
}
impl From<http::header::InvalidHeaderValue> for IntoHttpError {
fn from(value: http::header::InvalidHeaderValue) -> Self {
Self::Header(value.into())
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum FromHttpRequestError {
#[error("deserialization failed: {0}")]
Deserialization(DeserializationError),
#[error("http method mismatch: expected {expected}, received: {received}")]
MethodMismatch {
expected: http::method::Method,
received: http::method::Method,
},
}
impl<T> From<T> for FromHttpRequestError
where
T: Into<DeserializationError>,
{
fn from(err: T) -> Self {
Self::Deserialization(err.into())
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum FromHttpResponseError<E> {
Deserialization(DeserializationError),
Server(E),
}
impl<E> FromHttpResponseError<E> {
pub fn map<F>(self, f: impl FnOnce(E) -> F) -> FromHttpResponseError<F> {
match self {
Self::Deserialization(d) => FromHttpResponseError::Deserialization(d),
Self::Server(s) => FromHttpResponseError::Server(f(s)),
}
}
}
impl<E, F> FromHttpResponseError<Result<E, F>> {
pub fn transpose(self) -> Result<FromHttpResponseError<E>, F> {
match self {
Self::Deserialization(d) => Ok(FromHttpResponseError::Deserialization(d)),
Self::Server(s) => s.map(FromHttpResponseError::Server),
}
}
}
impl<E: fmt::Display> fmt::Display for FromHttpResponseError<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Deserialization(err) => write!(f, "deserialization failed: {err}"),
Self::Server(err) => write!(f, "the server returned an error: {err}"),
}
}
}
impl<E, T> From<T> for FromHttpResponseError<E>
where
T: Into<DeserializationError>,
{
fn from(err: T) -> Self {
Self::Deserialization(err.into())
}
}
impl<E: StdError> StdError for FromHttpResponseError<E> {}
pub trait FromHttpResponseErrorExt {
fn error_kind(&self) -> Option<&ErrorKind>;
}
impl FromHttpResponseErrorExt for FromHttpResponseError<Error> {
fn error_kind(&self) -> Option<&ErrorKind> {
as_variant!(self, Self::Server)?.error_kind()
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum DeserializationError {
#[error(transparent)]
Utf8(#[from] std::str::Utf8Error),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error(transparent)]
Query(#[from] serde_html_form::de::Error),
#[error(transparent)]
Ident(#[from] crate::IdParseError),
#[error(transparent)]
Header(#[from] HeaderDeserializationError),
#[error(transparent)]
MultipartMixed(#[from] MultipartMixedDeserializationError),
}
impl From<std::convert::Infallible> for DeserializationError {
fn from(err: std::convert::Infallible) -> Self {
match err {}
}
}
impl From<http::header::ToStrError> for DeserializationError {
fn from(err: http::header::ToStrError) -> Self {
Self::Header(HeaderDeserializationError::ToStrError(err))
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum HeaderDeserializationError {
#[error("{0}")]
ToStrError(#[from] http::header::ToStrError),
#[error("{0}")]
ParseIntError(#[from] ParseIntError),
#[error("failed to parse HTTP date")]
InvalidHttpDate,
#[error("missing header `{0}`")]
MissingHeader(String),
#[error("invalid header: {0}")]
InvalidHeader(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error(
"The {header} header was received with an unexpected value, \
expected {expected}, received {unexpected}"
)]
InvalidHeaderValue {
header: String,
expected: String,
unexpected: String,
},
#[error(
"The `Content-Type` header for a `multipart/mixed` response is missing the `boundary` attribute"
)]
MissingMultipartBoundary,
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum MultipartMixedDeserializationError {
#[error(
"multipart/mixed response does not have enough body parts, \
expected {expected}, found {found}"
)]
MissingBodyParts {
expected: usize,
found: usize,
},
#[error("multipart/mixed body part is missing separator between headers and content")]
MissingBodyPartInnerSeparator,
#[error("multipart/mixed body part header is missing separator between name and value")]
MissingHeaderSeparator,
#[error("invalid multipart/mixed header: {0}")]
InvalidHeader(Box<dyn std::error::Error + Send + Sync + 'static>),
}
#[derive(Debug)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
pub struct UnknownVersionError;
impl fmt::Display for UnknownVersionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "version string was unknown")
}
}
impl StdError for UnknownVersionError {}
#[derive(Debug)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
pub struct IncorrectArgumentCount {
pub expected: usize,
pub got: usize,
}
impl fmt::Display for IncorrectArgumentCount {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "incorrect path argument count, expected {}, got {}", self.expected, self.got)
}
}
impl StdError for IncorrectArgumentCount {}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum HeaderSerializationError {
#[error(transparent)]
ToHeaderValue(#[from] http::header::InvalidHeaderValue),
#[error("invalid HTTP date")]
InvalidHttpDate,
}