use backtrace::Backtrace;
use std::fmt::Debug;
use thiserror::Error;
#[derive(Error, Debug, Clone)]
pub enum Error {
#[error("Incompatible dimensions for {name}: {dimensions}\n{trace}")]
Dimensions {
name: String,
dimensions: String,
trace: String,
},
#[error(
"Incorrect reduction of dimensions for {name}: {dimensions} {reduced_dimensions}\n{trace}"
)]
ReducedDimensions {
name: String,
dimensions: String,
reduced_dimensions: String,
trace: String,
},
#[error("Incompatible lengths for {name}: {lengths}\n{trace}")]
Lengths {
name: String,
lengths: String,
trace: String,
},
#[error("Unexpected empty input for {name}\n{trace}")]
Empty { name: String, trace: String },
#[error("Trying to obtain `id` of a constant value.")]
MissingId { name: String, trace: String },
#[error("No gradient of the expected type could be found in gradient store.")]
MissingGradient { name: String, trace: String },
#[error("Trying to obtain a node from an incorrect `id`.")]
MissingNode { name: String, trace: String },
}
pub type Result<T> = std::result::Result<T, Error>;
#[macro_export]
macro_rules! func_name {
() => {{
fn f() {}
fn type_name_of<T>(_: T) -> &'static str {
std::any::type_name::<T>()
}
let name = type_name_of(f);
&name[..name.len() - 3]
}};
}
impl Error {
fn backtrace() -> String {
if std::env::var("RUST_BACKTRACE").is_ok() {
format!("{:?}", Backtrace::new())
} else {
String::new()
}
}
pub fn dimensions<D>(name: &str, dims: D) -> Self
where
D: Debug,
{
Error::Dimensions {
name: name.to_string(),
dimensions: format!("{:?}", dims),
trace: Self::backtrace(),
}
}
pub fn reduced_dimensions<D>(name: &str, dims: D, rdims: D) -> Self
where
D: Debug,
{
Error::ReducedDimensions {
name: name.to_string(),
dimensions: format!("{:?}", dims),
reduced_dimensions: format!("{:?}", rdims),
trace: Self::backtrace(),
}
}
pub fn lengths<L>(name: &str, lengths: L) -> Self
where
L: Debug,
{
Error::Lengths {
name: name.to_string(),
lengths: format!("{:?}", lengths),
trace: Self::backtrace(),
}
}
pub fn empty(name: &str) -> Self {
Error::Empty {
name: name.to_string(),
trace: Self::backtrace(),
}
}
pub fn missing_id(name: &str) -> Self {
Error::MissingId {
name: name.to_string(),
trace: Self::backtrace(),
}
}
pub fn missing_gradient(name: &str) -> Self {
Error::MissingGradient {
name: name.to_string(),
trace: Self::backtrace(),
}
}
pub fn missing_node(name: &str) -> Self {
Error::MissingNode {
name: name.to_string(),
trace: Self::backtrace(),
}
}
}
pub fn check_equal_dimensions<D>(name: &str, dims: &[&D]) -> Result<D>
where
D: PartialEq + Debug + Clone,
{
let mut it = dims.iter();
if let Some(first) = it.next() {
if it.all(|x| x == first) {
Ok((**first).clone())
} else {
Err(Error::dimensions(name, dims))
}
} else {
Err(Error::empty(name))
}
}
pub fn check_equal_lengths(name: &str, lengths: &[usize]) -> Result<usize> {
let mut it = lengths.iter();
if let Some(first) = it.next() {
if it.all(|x| x == first) {
Ok(*first)
} else {
Err(Error::lengths(name, lengths))
}
} else {
Err(Error::empty(name))
}
}
#[cfg(feature = "arrayfire")]
pub mod af {
use super::*;
use arrayfire as af;
pub fn check_reduced_dimensions(
name: &str,
dims: af::Dim4,
rdims: af::Dim4,
) -> Result<af::Dim4> {
for i in 0..4 {
if rdims[i] == dims[i] {
continue;
}
if rdims[i] != 1 {
return Err(Error::reduced_dimensions(name, dims, rdims));
}
}
Ok(rdims)
}
}