use thiserror::Error;
#[derive(Error, Debug)]
pub enum RragError {
#[error("Document processing failed: {message}")]
DocumentProcessing {
message: String,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
},
#[error("Embedding generation failed for {content_type}: {message}")]
Embedding {
content_type: String,
message: String,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
},
#[error("Vector storage operation failed: {operation}")]
Storage {
operation: String,
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("rsllm client error: {operation}")]
RsllmClient {
operation: String,
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("Retrieval failed: {query}")]
Retrieval {
query: String,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
},
#[error("Tool '{tool}' execution failed: {message}")]
ToolExecution {
tool: String,
message: String,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
},
#[error("Configuration error: {field}")]
Configuration {
field: String,
expected: String,
actual: String,
},
#[error("Network operation failed: {operation}")]
Network {
operation: String,
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("Serialization error: {data_type}")]
Serialization {
data_type: String,
#[source]
source: serde_json::Error,
},
#[error("Operation timed out after {duration_ms}ms: {operation}")]
Timeout {
operation: String,
duration_ms: u64
},
#[error("Memory operation failed: {operation}")]
Memory {
operation: String,
message: String
},
#[error("Stream error in {context}: {message}")]
Stream {
context: String,
message: String
},
#[error("Agent execution failed: {agent_id}")]
Agent {
agent_id: String,
message: String,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
},
#[error("Validation failed: {field}")]
Validation {
field: String,
constraint: String,
value: String,
},
}
impl RragError {
pub fn document_processing(message: impl Into<String>) -> Self {
Self::DocumentProcessing {
message: message.into(),
source: None,
}
}
pub fn document_processing_with_source(
message: impl Into<String>,
source: impl std::error::Error + Send + Sync + 'static,
) -> Self {
Self::DocumentProcessing {
message: message.into(),
source: Some(Box::new(source)),
}
}
pub fn embedding(content_type: impl Into<String>, message: impl Into<String>) -> Self {
Self::Embedding {
content_type: content_type.into(),
message: message.into(),
source: None,
}
}
pub fn storage(
operation: impl Into<String>,
source: impl std::error::Error + Send + Sync + 'static,
) -> Self {
Self::Storage {
operation: operation.into(),
source: Box::new(source),
}
}
pub fn rsllm_client(
operation: impl Into<String>,
source: impl std::error::Error + Send + Sync + 'static,
) -> Self {
Self::RsllmClient {
operation: operation.into(),
source: Box::new(source),
}
}
pub fn retrieval(query: impl Into<String>) -> Self {
Self::Retrieval {
query: query.into(),
source: None,
}
}
pub fn evaluation(message: impl Into<String>) -> Self {
Self::Agent {
agent_id: "evaluation".to_string(),
message: message.into(),
source: None,
}
}
pub fn tool_execution(tool: impl Into<String>, message: impl Into<String>) -> Self {
Self::ToolExecution {
tool: tool.into(),
message: message.into(),
source: None,
}
}
pub fn config(
field: impl Into<String>,
expected: impl Into<String>,
actual: impl Into<String>,
) -> Self {
Self::Configuration {
field: field.into(),
expected: expected.into(),
actual: actual.into(),
}
}
pub fn timeout(operation: impl Into<String>, duration_ms: u64) -> Self {
Self::Timeout {
operation: operation.into(),
duration_ms,
}
}
pub fn memory(operation: impl Into<String>, message: impl Into<String>) -> Self {
Self::Memory {
operation: operation.into(),
message: message.into(),
}
}
pub fn stream(context: impl Into<String>, message: impl Into<String>) -> Self {
Self::Stream {
context: context.into(),
message: message.into(),
}
}
pub fn agent(agent_id: impl Into<String>, message: impl Into<String>) -> Self {
Self::Agent {
agent_id: agent_id.into(),
message: message.into(),
source: None,
}
}
pub fn validation(
field: impl Into<String>,
constraint: impl Into<String>,
value: impl Into<String>,
) -> Self {
Self::Validation {
field: field.into(),
constraint: constraint.into(),
value: value.into(),
}
}
pub fn network(
operation: impl Into<String>,
source: impl std::error::Error + Send + Sync + 'static,
) -> Self {
Self::Network {
operation: operation.into(),
source: Box::new(source),
}
}
pub fn configuration(message: impl Into<String>) -> Self {
Self::Configuration {
field: "configuration".to_string(),
expected: "valid configuration".to_string(),
actual: message.into(),
}
}
pub fn serialization_with_message(
data_type: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self::Agent {
agent_id: "serialization".to_string(),
message: format!("{}: {}", data_type.into(), message.into()),
source: None,
}
}
pub fn io_error(message: impl Into<String>) -> Self {
Self::Network {
operation: "io_operation".to_string(),
source: Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
message.into(),
)),
}
}
pub fn is_retryable(&self) -> bool {
matches!(
self,
Self::Network { .. }
| Self::Timeout { .. }
| Self::RsllmClient { .. }
| Self::Stream { .. }
)
}
pub fn category(&self) -> &'static str {
match self {
Self::DocumentProcessing { .. } => "document_processing",
Self::Embedding { .. } => "embedding",
Self::Storage { .. } => "storage",
Self::RsllmClient { .. } => "rsllm_client",
Self::Retrieval { .. } => "retrieval",
Self::ToolExecution { .. } => "tool_execution",
Self::Configuration { .. } => "configuration",
Self::Network { .. } => "network",
Self::Serialization { .. } => "serialization",
Self::Timeout { .. } => "timeout",
Self::Memory { .. } => "memory",
Self::Stream { .. } => "stream",
Self::Agent { agent_id, .. } => {
if agent_id == "evaluation" {
"evaluation"
} else {
"agent"
}
}
Self::Validation { .. } => "validation",
}
}
pub fn severity(&self) -> ErrorSeverity {
match self {
Self::Configuration { .. } | Self::Validation { .. } => ErrorSeverity::Critical,
Self::Storage { .. } | Self::RsllmClient { .. } => ErrorSeverity::High,
Self::DocumentProcessing { .. } | Self::Embedding { .. } | Self::Retrieval { .. } => {
ErrorSeverity::Medium
}
Self::ToolExecution { .. } | Self::Agent { .. } => ErrorSeverity::Medium,
Self::Network { .. } | Self::Timeout { .. } | Self::Stream { .. } => ErrorSeverity::Low,
Self::Serialization { .. } | Self::Memory { .. } => ErrorSeverity::Low,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ErrorSeverity {
Low = 1,
Medium = 2,
High = 3,
Critical = 4,
}
impl std::fmt::Display for ErrorSeverity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Low => write!(f, "LOW"),
Self::Medium => write!(f, "MEDIUM"),
Self::High => write!(f, "HIGH"),
Self::Critical => write!(f, "CRITICAL"),
}
}
}
pub type RragResult<T> = std::result::Result<T, RragError>;
pub trait RragResultExt<T> {
fn with_rrag_context(self, context: &str) -> RragResult<T>;
fn map_to_rrag_error<F>(self, f: F) -> RragResult<T>
where
F: FnOnce() -> RragError;
}
impl<T, E> RragResultExt<T> for std::result::Result<T, E>
where
E: std::error::Error + Send + Sync + 'static,
{
fn with_rrag_context(self, context: &str) -> RragResult<T> {
self.map_err(|e| RragError::Agent {
agent_id: context.to_string(),
message: e.to_string(),
source: Some(Box::new(e)),
})
}
fn map_to_rrag_error<F>(self, f: F) -> RragResult<T>
where
F: FnOnce() -> RragError,
{
self.map_err(|_| f())
}
}
impl From<serde_json::Error> for RragError {
fn from(err: serde_json::Error) -> Self {
Self::Serialization {
data_type: "json".to_string(),
source: err,
}
}
}
#[cfg(feature = "http")]
impl From<reqwest::Error> for RragError {
fn from(err: reqwest::Error) -> Self {
Self::Network {
operation: "http_request".to_string(),
source: Box::new(err),
}
}
}
impl From<tokio::time::error::Elapsed> for RragError {
fn from(_err: tokio::time::error::Elapsed) -> Self {
Self::Timeout {
operation: "async_operation".to_string(),
duration_ms: 0, }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_categories() {
assert_eq!(
RragError::document_processing("test").category(),
"document_processing"
);
assert_eq!(RragError::timeout("op", 1000).category(), "timeout");
assert_eq!(
RragError::config("field", "expected", "actual").category(),
"configuration"
);
}
#[test]
fn test_error_severity() {
assert_eq!(
RragError::config("field", "expected", "actual").severity(),
ErrorSeverity::Critical
);
assert_eq!(
RragError::timeout("op", 1000).severity(),
ErrorSeverity::Low
);
assert_eq!(
RragError::storage("op", std::io::Error::new(std::io::ErrorKind::Other, "test"))
.severity(),
ErrorSeverity::High
);
}
#[test]
fn test_retryable() {
assert!(RragError::timeout("op", 1000).is_retryable());
assert!(!RragError::config("field", "expected", "actual").is_retryable());
}
#[test]
fn test_error_construction() {
let err = RragError::tool_execution("calculator", "invalid input");
if let RragError::ToolExecution { tool, message, .. } = err {
assert_eq!(tool, "calculator");
assert_eq!(message, "invalid input");
} else {
panic!("Wrong error type");
}
}
}