use crate::Codec;
use std::fmt;
pub const RPC_ERR_PREFIX: &str = "rpc_";
#[derive(thiserror::Error)]
pub enum RpcError<E: RpcErrCodec> {
User(#[from] E),
Rpc(#[from] RpcIntErr),
}
impl<E: RpcErrCodec> fmt::Display for RpcError<E> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::User(e) => RpcErrCodec::fmt(e, f),
Self::Rpc(e) => fmt::Display::fmt(e, f),
}
}
}
impl<E: RpcErrCodec> fmt::Debug for RpcError<E> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl<E: RpcErrCodec> std::cmp::PartialEq<RpcIntErr> for RpcError<E> {
#[inline]
fn eq(&self, other: &RpcIntErr) -> bool {
if let Self::Rpc(r) = self
&& r == other
{
return true;
}
false
}
}
impl<E: RpcErrCodec + PartialEq> std::cmp::PartialEq<E> for RpcError<E> {
#[inline]
fn eq(&self, other: &E) -> bool {
if let Self::User(r) = self {
return r == other;
}
false
}
}
impl<E: RpcErrCodec + PartialEq> std::cmp::PartialEq<RpcError<E>> for RpcError<E> {
#[inline]
fn eq(&self, other: &Self) -> bool {
match self {
Self::Rpc(r) => {
if let Self::Rpc(o) = other {
return r == o;
}
}
Self::User(r) => {
if let Self::User(o) = other {
return r == o;
}
}
}
false
}
}
impl From<&str> for RpcError<String> {
#[inline]
fn from(e: &str) -> Self {
Self::User(e.to_string())
}
}
pub trait RpcErrCodec: Send + Sized + 'static + Unpin {
fn encode<C: Codec>(&self, codec: &C) -> EncodedErr;
fn decode<C: Codec>(codec: &C, buf: Result<u32, &[u8]>) -> Result<Self, ()>;
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result;
#[inline(always)]
fn should_failover(&self) -> Result<Option<&str>, ()> {
Err(())
}
}
macro_rules! impl_rpc_error_for_num {
($t: tt) => {
impl RpcErrCodec for $t {
#[inline(always)]
fn encode<C: Codec>(&self, _codec: &C) -> EncodedErr {
EncodedErr::Num(*self as u32)
}
#[inline(always)]
fn decode<C: Codec>(_codec: &C, buf: Result<u32, &[u8]>) -> Result<Self, ()> {
if let Ok(i) = buf {
if i <= $t::MAX as u32 {
return Ok(i as Self);
}
}
Err(())
}
#[inline(always)]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "errno {}", self)
}
}
};
}
impl_rpc_error_for_num!(i8);
impl_rpc_error_for_num!(u8);
impl_rpc_error_for_num!(i16);
impl_rpc_error_for_num!(u16);
impl_rpc_error_for_num!(i32);
impl_rpc_error_for_num!(u32);
impl RpcErrCodec for nix::errno::Errno {
#[inline(always)]
fn encode<C: Codec>(&self, _codec: &C) -> EncodedErr {
EncodedErr::Num(*self as u32)
}
#[inline(always)]
fn decode<C: Codec>(_codec: &C, buf: Result<u32, &[u8]>) -> Result<Self, ()> {
if let Ok(i) = buf
&& i <= i32::MAX as u32
{
return Ok(Self::from_raw(i as i32));
}
Err(())
}
#[inline(always)]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl RpcErrCodec for () {
#[inline(always)]
fn encode<C: Codec>(&self, _codec: &C) -> EncodedErr {
EncodedErr::Num(0u32)
}
#[inline(always)]
fn decode<C: Codec>(_codec: &C, buf: Result<u32, &[u8]>) -> Result<Self, ()> {
if let Ok(i) = buf
&& i == 0
{
return Ok(());
}
Err(())
}
#[inline(always)]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "err")
}
}
impl RpcErrCodec for String {
#[inline(always)]
fn encode<C: Codec>(&self, _codec: &C) -> EncodedErr {
EncodedErr::Buf(Vec::from(self.as_bytes()))
}
#[inline(always)]
fn decode<C: Codec>(_codec: &C, buf: Result<u32, &[u8]>) -> Result<Self, ()> {
if let Err(s) = buf
&& let Ok(s) = str::from_utf8(s)
{
return Ok(s.to_string());
}
Err(())
}
#[inline(always)]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
}
}
#[derive(
strum::Display,
strum::EnumString,
strum::AsRefStr,
PartialEq,
PartialOrd,
Clone,
thiserror::Error,
)]
#[repr(u8)]
pub enum RpcIntErr {
#[strum(serialize = "rpc_unreachable")]
Unreachable = 0,
#[strum(serialize = "rpc_io_err")]
IO = 1,
#[strum(serialize = "rpc_timeout")]
Timeout = 2,
#[strum(serialize = "rpc_method_notfound")]
Method = 3,
#[strum(serialize = "rpc_service_notfound")]
Service = 4,
#[strum(serialize = "rpc_encode")]
Encode = 5,
#[strum(serialize = "rpc_decode")]
Decode = 6,
#[strum(serialize = "rpc_internal_err")]
Internal = 7,
#[strum(serialize = "rpc_invalid_ver")]
Version = 8,
}
impl fmt::Debug for RpcIntErr {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl RpcIntErr {
#[inline]
pub fn as_bytes(&self) -> &[u8] {
self.as_ref().as_bytes()
}
}
impl From<std::io::Error> for RpcIntErr {
#[inline(always)]
fn from(_e: std::io::Error) -> Self {
Self::IO
}
}
#[derive(Debug)]
pub enum EncodedErr {
Rpc(RpcIntErr),
Num(u32),
Static(&'static str),
Buf(Vec<u8>),
}
impl EncodedErr {
#[inline]
pub fn try_as_str(&self) -> Result<&str, ()> {
match self {
Self::Static(s) => return Ok(s),
Self::Buf(b) => {
if let Ok(s) = str::from_utf8(b) {
return Ok(s);
}
}
_ => {}
}
Err(())
}
}
impl std::cmp::PartialEq<EncodedErr> for EncodedErr {
fn eq(&self, other: &EncodedErr) -> bool {
match self {
Self::Rpc(e) => {
if let Self::Rpc(o) = other {
return e == o;
}
}
Self::Num(e) => {
if let Self::Num(o) = other {
return e == o;
}
}
Self::Static(s) => {
if let Ok(o) = other.try_as_str() {
return *s == o;
}
}
Self::Buf(s) => {
if let Self::Buf(o) = other {
return s == o;
} else if let Ok(o) = other.try_as_str() {
if let Ok(_s) = str::from_utf8(s) {
return _s == o;
}
}
}
}
false
}
}
impl fmt::Display for EncodedErr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Rpc(e) => e.fmt(f),
Self::Num(no) => write!(f, "errno {}", no),
Self::Static(s) => write!(f, "{}", s),
Self::Buf(b) => match str::from_utf8(b) {
Ok(s) => {
write!(f, "{}", s)
}
Err(_) => {
write!(f, "err blob {} length", b.len())
}
},
}
}
}
impl From<RpcIntErr> for EncodedErr {
#[inline(always)]
fn from(e: RpcIntErr) -> Self {
Self::Rpc(e)
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::errno::Errno;
use std::str::FromStr;
#[test]
fn test_internal_error() {
println!("{}", RpcIntErr::Internal);
println!("{:?}", RpcIntErr::Internal);
let s = RpcIntErr::Timeout.as_ref();
println!("RpcIntErr::Timeout as {}", s);
let e = RpcIntErr::from_str(s).expect("parse");
assert_eq!(e, RpcIntErr::Timeout);
assert!(RpcIntErr::from_str("timeoutss").is_err());
assert!(RpcIntErr::Timeout < RpcIntErr::Method);
assert!(RpcIntErr::IO < RpcIntErr::Method);
assert!(RpcIntErr::Unreachable < RpcIntErr::Method);
}
#[test]
fn test_rpc_error_default() {
let e = RpcError::<i32>::from(1i32);
println!("err {:?} {}", e, e);
let e = RpcError::<Errno>::from(Errno::EIO);
println!("err {:?} {}", e, e);
let e = RpcError::<String>::from("err_str");
println!("err {:?} {}", e, e);
let e2 = RpcError::<String>::from("err_str".to_string());
assert_eq!(e, e2);
let e = RpcError::<String>::from(RpcIntErr::IO);
println!("err {:?} {}", e, e);
let _e: Result<(), RpcIntErr> = Err(RpcIntErr::IO);
let e: Result<(), RpcError<String>> = _e.map_err(|e| e.into());
println!("err {:?}", e);
}
}