use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum Error {
#[error("invalid handle: {0}")]
InvalidHandle(u32),
#[error("body not found: handle={0}")]
BodyNotFound(u32),
#[error("collider not found: handle={0}")]
ColliderNotFound(u32),
#[error("invalid parameter '{name}': {message}")]
InvalidParameter {
name: String,
message: String,
},
#[error("mass must be positive, got {0}")]
InvalidMass(f64),
#[error("time step must be positive, got {0}")]
InvalidTimeStep(f64),
#[error("dimension must be positive, got {0}")]
InvalidDimension(f64),
#[error("world at capacity: max {max} bodies")]
CapacityExceeded {
max: usize,
},
#[error("simulation diverged at step {step}")]
SimulationDiverged {
step: u64,
},
#[error("solver did not converge after {iterations} iterations")]
SolverConvergenceFailed {
iterations: u32,
},
#[error("body {0} is sleeping")]
BodySleeping(u32),
#[error("serialization error: {0}")]
Serialization(String),
#[error("snapshot validation failed: {0}")]
SnapshotValidation(String),
#[error("type error: expected {expected}, got {got}")]
TypeError {
expected: String,
got: String,
},
#[error("missing argument: '{0}'")]
MissingArgument(String),
#[error("wrong array length: expected {expected}, got {got}")]
WrongArrayLength {
expected: usize,
got: usize,
},
#[error("{0}")]
General(String),
}
pub type Result<T> = std::result::Result<T, Error>;
impl Error {
pub fn invalid_param(name: impl Into<String>, message: impl Into<String>) -> Self {
Self::InvalidParameter {
name: name.into(),
message: message.into(),
}
}
pub fn general(msg: impl Into<String>) -> Self {
Self::General(msg.into())
}
pub fn type_error(expected: impl Into<String>, got: impl Into<String>) -> Self {
Self::TypeError {
expected: expected.into(),
got: got.into(),
}
}
pub fn wrong_len(expected: usize, got: usize) -> Self {
Self::WrongArrayLength { expected, got }
}
pub fn is_handle_error(&self) -> bool {
matches!(
self,
Error::InvalidHandle(_) | Error::BodyNotFound(_) | Error::ColliderNotFound(_)
)
}
pub fn is_parameter_error(&self) -> bool {
matches!(
self,
Error::InvalidParameter { .. }
| Error::InvalidMass(_)
| Error::InvalidTimeStep(_)
| Error::InvalidDimension(_)
)
}
pub fn is_capacity_error(&self) -> bool {
matches!(self, Error::CapacityExceeded { .. })
}
pub fn is_stability_error(&self) -> bool {
matches!(
self,
Error::SimulationDiverged { .. } | Error::SolverConvergenceFailed { .. }
)
}
pub fn is_serialization_error(&self) -> bool {
matches!(self, Error::Serialization(_) | Error::SnapshotValidation(_))
}
pub fn is_type_error(&self) -> bool {
matches!(
self,
Error::TypeError { .. } | Error::MissingArgument(_) | Error::WrongArrayLength { .. }
)
}
pub fn to_json(&self) -> String {
let variant_json =
serde_json::to_string(self).unwrap_or_else(|_| "\"<serialization failed>\"".into());
let message = self.to_string();
format!(
r#"{{"error":{variant_json},"message":{message_json}}}"#,
variant_json = variant_json,
message_json = serde_json::to_string(&message).unwrap_or_default(),
)
}
pub fn from_json(json: &str) -> std::result::Result<Self, String> {
let direct: std::result::Result<Error, _> = serde_json::from_str(json);
if let Ok(e) = direct {
return Ok(e);
}
let v: serde_json::Value = serde_json::from_str(json).map_err(|e| e.to_string())?;
let inner = v
.get("error")
.ok_or_else(|| "missing 'error' field".to_string())?;
serde_json::from_value(inner.clone()).map_err(|e| e.to_string())
}
pub fn python_exception_class(&self) -> &'static str {
match self {
Error::TypeError { .. } | Error::WrongArrayLength { .. } => "TypeError",
Error::MissingArgument(_) => "ValueError",
Error::InvalidParameter { .. }
| Error::InvalidMass(_)
| Error::InvalidTimeStep(_)
| Error::InvalidDimension(_) => "ValueError",
Error::InvalidHandle(_) | Error::BodyNotFound(_) | Error::ColliderNotFound(_) => {
"KeyError"
}
Error::CapacityExceeded { .. } => "MemoryError",
Error::SimulationDiverged { .. } | Error::SolverConvergenceFailed { .. } => {
"RuntimeError"
}
Error::Serialization(_) | Error::SnapshotValidation(_) => "ValueError",
_ => "RuntimeError",
}
}
pub fn recovery_hint(&self) -> String {
match self {
Error::InvalidTimeStep(_) => {
"Use a positive dt such as 1/60 for a 60 Hz simulation.".into()
}
Error::InvalidMass(_) => "Mass must be strictly positive (e.g. 1.0).".into(),
Error::BodyNotFound(h) => format!(
"Body handle {} not found. Re-add the body or check it has not been removed.",
h
),
Error::CapacityExceeded { max } => format!(
"World capacity of {} bodies reached. Remove unused bodies first.",
max
),
Error::SimulationDiverged { step } => format!(
"Divergence detected at step {}. Reduce dt or applied forces.",
step
),
Error::TypeError { expected, got } => {
format!("Expected a Python '{}' but received '{}'.", expected, got)
}
Error::WrongArrayLength { expected, got } => {
format!("Array length should be {} but is {}.", expected, got)
}
_ => String::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_invalid_handle() {
let e = Error::InvalidHandle(5);
assert!(e.to_string().contains("5"));
assert!(e.is_handle_error());
}
#[test]
fn test_body_not_found() {
let e = Error::BodyNotFound(10);
assert!(e.to_string().contains("10"));
assert!(e.is_handle_error());
}
#[test]
fn test_collider_not_found() {
let e = Error::ColliderNotFound(3);
assert!(e.is_handle_error());
}
#[test]
fn test_invalid_param() {
let e = Error::invalid_param("mass", "must be positive");
assert!(e.to_string().contains("mass"));
assert!(e.to_string().contains("positive"));
assert!(e.is_parameter_error());
}
#[test]
fn test_invalid_mass() {
let e = Error::InvalidMass(-1.0);
assert!(e.is_parameter_error());
assert!(e.to_string().contains("-1"));
}
#[test]
fn test_invalid_time_step() {
let e = Error::InvalidTimeStep(0.0);
assert!(e.is_parameter_error());
}
#[test]
fn test_invalid_dimension() {
let e = Error::InvalidDimension(-0.5);
assert!(e.is_parameter_error());
}
#[test]
fn test_capacity_exceeded() {
let e = Error::CapacityExceeded { max: 1000 };
assert!(e.is_capacity_error());
assert!(e.to_string().contains("1000"));
}
#[test]
fn test_simulation_diverged() {
let e = Error::SimulationDiverged { step: 99 };
assert!(e.is_stability_error());
assert!(e.to_string().contains("99"));
}
#[test]
fn test_solver_convergence_failed() {
let e = Error::SolverConvergenceFailed { iterations: 50 };
assert!(e.is_stability_error());
assert!(e.to_string().contains("50"));
}
#[test]
fn test_body_sleeping() {
let e = Error::BodySleeping(7);
assert!(e.to_string().contains("7"));
}
#[test]
fn test_serialization_error() {
let e = Error::Serialization("unexpected eof".into());
assert!(e.is_serialization_error());
}
#[test]
fn test_snapshot_validation_error() {
let e = Error::SnapshotValidation("missing version".into());
assert!(e.is_serialization_error());
}
#[test]
fn test_type_error() {
let e = Error::type_error("list", "int");
assert!(e.is_type_error());
assert!(e.to_string().contains("list"));
assert!(e.to_string().contains("int"));
}
#[test]
fn test_missing_argument() {
let e = Error::MissingArgument("mass".into());
assert!(e.is_type_error());
assert!(e.to_string().contains("mass"));
}
#[test]
fn test_wrong_array_length() {
let e = Error::wrong_len(3, 2);
assert!(e.is_type_error());
assert!(e.to_string().contains("3"));
assert!(e.to_string().contains("2"));
}
#[test]
fn test_general_error() {
let e = Error::general("oops");
assert!(e.to_string().contains("oops"));
}
#[test]
fn test_clone_eq() {
let e1 = Error::InvalidHandle(42);
let e2 = e1.clone();
assert_eq!(e1, e2);
}
#[test]
fn test_to_json_contains_type() {
let e = Error::InvalidTimeStep(-0.01);
let json = e.to_json();
assert!(json.contains("InvalidTimeStep"), "json={}", json);
assert!(json.contains("message"), "json={}", json);
}
#[test]
fn test_from_json_direct() {
let original = Error::BodyNotFound(7);
let json = serde_json::to_string(&original).unwrap();
let recovered = Error::from_json(&json).unwrap();
assert_eq!(original, recovered);
}
#[test]
fn test_from_json_envelope() {
let e = Error::CapacityExceeded { max: 256 };
let envelope = e.to_json();
let recovered = Error::from_json(&envelope).unwrap();
assert_eq!(recovered, e);
}
#[test]
fn test_from_json_invalid() {
assert!(Error::from_json("{bad json").is_err());
}
#[test]
fn test_python_exception_class_value_error() {
assert_eq!(
Error::InvalidTimeStep(0.0).python_exception_class(),
"ValueError"
);
assert_eq!(
Error::InvalidMass(-1.0).python_exception_class(),
"ValueError"
);
assert_eq!(
Error::invalid_param("x", "y").python_exception_class(),
"ValueError"
);
}
#[test]
fn test_python_exception_class_key_error() {
assert_eq!(Error::BodyNotFound(1).python_exception_class(), "KeyError");
assert_eq!(Error::InvalidHandle(0).python_exception_class(), "KeyError");
}
#[test]
fn test_python_exception_class_type_error() {
assert_eq!(
Error::type_error("list", "int").python_exception_class(),
"TypeError"
);
assert_eq!(Error::wrong_len(3, 1).python_exception_class(), "TypeError");
}
#[test]
fn test_python_exception_class_runtime_error() {
assert_eq!(
Error::SimulationDiverged { step: 1 }.python_exception_class(),
"RuntimeError"
);
assert_eq!(
Error::SolverConvergenceFailed { iterations: 10 }.python_exception_class(),
"RuntimeError"
);
}
#[test]
fn test_recovery_hint_time_step() {
let hint = Error::InvalidTimeStep(0.0).recovery_hint();
assert!(!hint.is_empty());
}
#[test]
fn test_recovery_hint_body_not_found() {
let hint = Error::BodyNotFound(42).recovery_hint();
assert!(hint.contains("42"));
}
#[test]
fn test_recovery_hint_capacity() {
let hint = Error::CapacityExceeded { max: 500 }.recovery_hint();
assert!(hint.contains("500"));
}
#[test]
fn test_recovery_hint_diverged() {
let hint = Error::SimulationDiverged { step: 7 }.recovery_hint();
assert!(hint.contains("7"));
}
#[test]
fn test_recovery_hint_type_error() {
let hint = Error::type_error("ndarray", "str").recovery_hint();
assert!(hint.contains("ndarray"));
}
#[test]
fn test_recovery_hint_wrong_len() {
let hint = Error::wrong_len(3, 1).recovery_hint();
assert!(hint.contains("3"));
assert!(hint.contains("1"));
}
#[test]
fn test_recovery_hint_general_empty() {
let hint = Error::general("x").recovery_hint();
let _ = hint;
}
}