use crate::error::RusTorchError;
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone)]
pub struct ErrorContext {
pub operation: String,
pub location: Option<ErrorLocation>,
pub metadata: HashMap<String, String>,
pub stack_trace: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct ErrorLocation {
pub file: String,
pub line: u32,
pub column: u32,
}
impl ErrorContext {
pub fn new(operation: impl Into<String>) -> Self {
Self {
operation: operation.into(),
location: None,
metadata: HashMap::new(),
stack_trace: Vec::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn with_location(mut self, file: impl Into<String>, line: u32, column: u32) -> Self {
self.location = Some(ErrorLocation {
file: file.into(),
line,
column,
});
self
}
pub fn push_operation(mut self, operation: impl Into<String>) -> Self {
self.stack_trace.push(operation.into());
self
}
pub fn format_context(&self) -> String {
let mut context = format!("Operation: {}", self.operation);
if let Some(ref location) = self.location {
context.push_str(&format!(
"\nLocation: {}:{}:{}",
location.file, location.line, location.column
));
}
if !self.metadata.is_empty() {
context.push_str("\nMetadata:");
for (key, value) in &self.metadata {
context.push_str(&format!("\n {}: {}", key, value));
}
}
if !self.stack_trace.is_empty() {
context.push_str("\nStack trace:");
for (i, operation) in self.stack_trace.iter().rev().enumerate() {
context.push_str(&format!("\n {}: {}", i, operation));
}
}
context
}
}
impl fmt::Display for ErrorContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.format_context())
}
}
impl fmt::Display for ErrorLocation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}:{}", self.file, self.line, self.column)
}
}
pub trait WithContext<T> {
fn with_context<F>(self, f: F) -> Result<T, RusTorchError>
where
F: FnOnce() -> ErrorContext;
fn with_operation(self, operation: &str) -> Result<T, RusTorchError>;
}
impl<T, E> WithContext<T> for Result<T, E>
where
E: Into<RusTorchError>,
{
fn with_context<F>(self, f: F) -> Result<T, RusTorchError>
where
F: FnOnce() -> ErrorContext,
{
self.map_err(|e| {
let mut error: RusTorchError = e.into();
let context = f();
match &mut error {
RusTorchError::TensorOp { message, .. } => {
*message = format!("{}\n{}", message, context.format_context());
}
RusTorchError::Device { message, .. } => {
*message = format!("{}\n{}", message, context.format_context());
}
RusTorchError::Gpu { message, .. } => {
*message = format!("{}\n{}", message, context.format_context());
}
_ => {
}
}
error
})
}
fn with_operation(self, operation: &str) -> Result<T, RusTorchError> {
self.with_context(|| ErrorContext::new(operation))
}
}
#[macro_export]
macro_rules! error_context {
($operation:expr) => {
$crate::error::context::ErrorContext::new($operation)
.with_location(file!(), line!(), column!())
};
($operation:expr, $($key:expr => $value:expr),+) => {
{
let mut context = $crate::error::context::ErrorContext::new($operation)
.with_location(file!(), line!(), column!());
$(
context = context.with_metadata($key, $value);
)+
context
}
};
}
#[macro_export]
macro_rules! with_context {
($result:expr, $operation:expr) => {
$crate::error::context::WithContext::with_context($result, || {
$crate::error_context!($operation)
})
};
($result:expr, $operation:expr, $($key:expr => $value:expr),+) => {
$crate::error::context::WithContext::with_context($result, || {
$crate::error_context!($operation, $($key => $value),+)
})
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::RusTorchError;
#[test]
fn test_error_context_creation() {
let context = ErrorContext::new("matrix multiplication")
.with_metadata("input_shape", "[2, 3]")
.with_metadata("weight_shape", "[3, 4]")
.with_location("tensor.rs", 42, 10);
let formatted = context.format_context();
assert!(formatted.contains("Operation: matrix multiplication"));
assert!(formatted.contains("input_shape: [2, 3]"));
assert!(formatted.contains("Location: tensor.rs:42:10"));
}
#[test]
fn test_with_context_trait() {
let tensor_error = RusTorchError::empty_tensor();
let result: Result<(), _> = Err(tensor_error);
let enhanced = result.with_operation("test operation");
assert!(enhanced.is_err());
let error_message = enhanced.unwrap_err().to_string();
assert!(error_message.contains("test operation"));
}
#[test]
fn test_error_context_macro() {
let context = error_context!("tensor add", "shape1" => "[2, 3]", "shape2" => "[2, 3]");
assert_eq!(context.operation, "tensor add");
assert!(context.metadata.contains_key("shape1"));
assert!(context.metadata.contains_key("shape2"));
assert!(context.location.is_some());
}
#[test]
fn test_stack_trace() {
let context = ErrorContext::new("outer operation")
.push_operation("middle operation")
.push_operation("inner operation");
let formatted = context.format_context();
assert!(formatted.contains("Stack trace:"));
assert!(formatted.contains("0: inner operation"));
assert!(formatted.contains("1: middle operation"));
}
}