1use thiserror::Error;
7
8pub type SageResult<T> = Result<T, SageError>;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum ErrorKind {
17 Llm,
19 Agent,
21 Runtime,
23 Tool,
25 User,
27 Protocol,
29}
30
31impl std::fmt::Display for ErrorKind {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 ErrorKind::Llm => write!(f, "Llm"),
35 ErrorKind::Agent => write!(f, "Agent"),
36 ErrorKind::Runtime => write!(f, "Runtime"),
37 ErrorKind::Tool => write!(f, "Tool"),
38 ErrorKind::User => write!(f, "User"),
39 ErrorKind::Protocol => write!(f, "Protocol"),
40 }
41 }
42}
43
44#[derive(Debug, Error)]
49pub enum SageError {
50 #[error("LLM error: {0}")]
52 Llm(String),
53
54 #[error("Agent error: {0}")]
56 Agent(String),
57
58 #[error("Type error: expected {expected}, got {got}")]
60 Type { expected: String, got: String },
61
62 #[error("HTTP error: {0}")]
64 Http(#[from] reqwest::Error),
65
66 #[error("JSON error: {0}")]
68 Json(#[from] serde_json::Error),
69
70 #[error("Agent task failed: {0}")]
72 JoinError(String),
73
74 #[error("Tool error: {0}")]
76 Tool(String),
77
78 #[error("I/O error: {0}")]
80 Io(#[from] std::io::Error),
81
82 #[error("{0}")]
84 User(String),
85
86 #[error("Supervisor error: {0}")]
88 Supervisor(String),
89
90 #[error("Protocol error: {0}")]
92 Protocol(String),
93}
94
95impl SageError {
96 #[must_use]
100 pub fn message(&self) -> String {
101 self.to_string()
102 }
103
104 #[must_use]
108 pub fn kind(&self) -> ErrorKind {
109 match self {
110 SageError::Llm(_) | SageError::Json(_) => ErrorKind::Llm,
111 SageError::Agent(_) | SageError::JoinError(_) | SageError::Supervisor(_) => {
112 ErrorKind::Agent
113 }
114 SageError::Type { .. } => ErrorKind::Runtime,
115 SageError::Http(_) | SageError::Tool(_) | SageError::Io(_) => ErrorKind::Tool,
117 SageError::User(_) => ErrorKind::User,
118 SageError::Protocol(_) => ErrorKind::Protocol,
120 }
121 }
122
123 #[must_use]
125 pub fn llm(msg: impl Into<String>) -> Self {
126 SageError::Llm(msg.into())
127 }
128
129 #[must_use]
131 pub fn agent(msg: impl Into<String>) -> Self {
132 SageError::Agent(msg.into())
133 }
134
135 #[must_use]
137 pub fn type_error(expected: impl Into<String>, got: impl Into<String>) -> Self {
138 SageError::Type {
139 expected: expected.into(),
140 got: got.into(),
141 }
142 }
143
144 #[must_use]
146 pub fn tool(msg: impl Into<String>) -> Self {
147 SageError::Tool(msg.into())
148 }
149
150 #[must_use]
152 pub fn user(msg: impl Into<String>) -> Self {
153 SageError::User(msg.into())
154 }
155
156 #[must_use]
158 pub fn protocol(msg: impl Into<String>) -> Self {
159 SageError::Protocol(msg.into())
160 }
161}
162
163#[cfg(not(target_arch = "wasm32"))]
164impl From<tokio::task::JoinError> for SageError {
165 fn from(e: tokio::task::JoinError) -> Self {
166 SageError::JoinError(e.to_string())
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn error_kind_classification() {
176 assert_eq!(SageError::llm("test").kind(), ErrorKind::Llm);
177 assert_eq!(SageError::agent("test").kind(), ErrorKind::Agent);
178 assert_eq!(
179 SageError::type_error("Int", "String").kind(),
180 ErrorKind::Runtime
181 );
182 }
183
184 #[test]
185 fn error_message() {
186 let err = SageError::llm("inference failed");
187 assert_eq!(err.message(), "LLM error: inference failed");
188 }
189
190 #[test]
191 fn error_kind_display() {
192 assert_eq!(format!("{}", ErrorKind::Llm), "Llm");
193 assert_eq!(format!("{}", ErrorKind::Agent), "Agent");
194 assert_eq!(format!("{}", ErrorKind::Runtime), "Runtime");
195 assert_eq!(format!("{}", ErrorKind::Tool), "Tool");
196 assert_eq!(format!("{}", ErrorKind::Protocol), "Protocol");
197 }
198
199 #[test]
200 fn tool_error_classification() {
201 assert_eq!(SageError::tool("http failed").kind(), ErrorKind::Tool);
202 assert_eq!(SageError::tool("timeout").message(), "Tool error: timeout");
203 }
204
205 #[test]
206 fn protocol_error_classification() {
207 assert_eq!(
208 SageError::protocol("unexpected message").kind(),
209 ErrorKind::Protocol
210 );
211 assert_eq!(
212 SageError::protocol("wrong sender").message(),
213 "Protocol error: wrong sender"
214 );
215 }
216}