Skip to main content

steer_tools/
error.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use thiserror::Error;
4
5use crate::tools::{
6    AST_GREP_TOOL_NAME, BASH_TOOL_NAME, DISPATCH_AGENT_TOOL_NAME, EDIT_TOOL_NAME, FETCH_TOOL_NAME,
7    GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME, MULTI_EDIT_TOOL_NAME, REPLACE_TOOL_NAME,
8    TODO_READ_TOOL_NAME, TODO_WRITE_TOOL_NAME, VIEW_TOOL_NAME, astgrep::AstGrepError,
9    bash::BashError, dispatch_agent::DispatchAgentError, edit::EditError,
10    edit::multi_edit::MultiEditError, fetch::FetchError, glob::GlobError, grep::GrepError,
11    ls::LsError, replace::ReplaceError, todo::read::TodoReadError, todo::write::TodoWriteError,
12    view::ViewError,
13};
14
15#[derive(Error, Debug, Clone, Serialize, Deserialize, JsonSchema)]
16pub enum ToolError {
17    #[error("Unknown tool: {0}")]
18    UnknownTool(String),
19
20    #[error("Invalid parameters for {tool_name}: {message}")]
21    InvalidParams { tool_name: String, message: String },
22
23    #[error("{0}")]
24    Execution(ToolExecutionError),
25
26    #[error("{0} was cancelled")]
27    Cancelled(String),
28
29    #[error("{0} timed out")]
30    Timeout(String),
31
32    #[error("{0} requires approval to run")]
33    DeniedByUser(String),
34
35    #[error("{0} denied by approval policy")]
36    DeniedByPolicy(String),
37
38    #[error("Unexpected error: {0}")]
39    InternalError(String),
40}
41
42impl ToolError {
43    pub fn execution<T: Into<String>, M: Into<String>>(tool_name: T, message: M) -> Self {
44        ToolError::Execution(ToolExecutionError::External {
45            tool_name: tool_name.into(),
46            message: message.into(),
47        })
48    }
49
50    pub fn invalid_params<T: Into<String>, M: Into<String>>(tool_name: T, message: M) -> Self {
51        ToolError::InvalidParams {
52            tool_name: tool_name.into(),
53            message: message.into(),
54        }
55    }
56}
57
58#[derive(Error, Debug, Clone, Serialize, Deserialize, JsonSchema)]
59#[serde(tag = "tool", content = "error", rename_all = "snake_case")]
60pub enum ToolExecutionError {
61    #[error("{0}")]
62    AstGrep(AstGrepError),
63    #[error("{0}")]
64    Bash(BashError),
65    #[error("{0}")]
66    Edit(EditError),
67    #[error("{0}")]
68    MultiEdit(MultiEditError),
69    #[error("{0}")]
70    Fetch(FetchError),
71    #[error("{0}")]
72    Glob(GlobError),
73    #[error("{0}")]
74    Grep(GrepError),
75    #[error("{0}")]
76    Ls(LsError),
77    #[error("{0}")]
78    Replace(ReplaceError),
79    #[error("{0}")]
80    TodoRead(TodoReadError),
81    #[error("{0}")]
82    TodoWrite(TodoWriteError),
83    #[error("{0}")]
84    View(ViewError),
85    #[error("{0}")]
86    DispatchAgent(DispatchAgentError),
87
88    #[error("{tool_name} failed: {message}")]
89    External { tool_name: String, message: String },
90}
91
92impl ToolExecutionError {
93    pub fn tool_name(&self) -> &str {
94        match self {
95            ToolExecutionError::AstGrep(_) => AST_GREP_TOOL_NAME,
96            ToolExecutionError::Bash(_) => BASH_TOOL_NAME,
97            ToolExecutionError::Edit(_) => EDIT_TOOL_NAME,
98            ToolExecutionError::MultiEdit(_) => MULTI_EDIT_TOOL_NAME,
99            ToolExecutionError::Fetch(_) => FETCH_TOOL_NAME,
100            ToolExecutionError::Glob(_) => GLOB_TOOL_NAME,
101            ToolExecutionError::Grep(_) => GREP_TOOL_NAME,
102            ToolExecutionError::Ls(_) => LS_TOOL_NAME,
103            ToolExecutionError::Replace(_) => REPLACE_TOOL_NAME,
104            ToolExecutionError::TodoRead(_) => TODO_READ_TOOL_NAME,
105            ToolExecutionError::TodoWrite(_) => TODO_WRITE_TOOL_NAME,
106            ToolExecutionError::View(_) => VIEW_TOOL_NAME,
107            ToolExecutionError::DispatchAgent(_) => DISPATCH_AGENT_TOOL_NAME,
108            ToolExecutionError::External { tool_name, .. } => tool_name.as_str(),
109        }
110    }
111}
112
113#[derive(Error, Debug, Clone, Serialize, Deserialize, JsonSchema)]
114#[serde(tag = "code", rename_all = "snake_case")]
115pub enum WorkspaceOpError {
116    #[error("path is outside workspace")]
117    InvalidPath,
118
119    #[error("file not found")]
120    NotFound,
121
122    #[error("permission denied")]
123    PermissionDenied,
124
125    #[error("operation not supported: {message}")]
126    NotSupported { message: String },
127
128    #[error("io error: {message}")]
129    Io { message: String },
130
131    #[error("{message}")]
132    Other { message: String },
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use serde::Serialize;
139    use serde::de::DeserializeOwned;
140    use serde_json::Value;
141
142    fn assert_workspace_error_roundtrip<T>(error: T)
143    where
144        T: Serialize + DeserializeOwned + std::fmt::Debug,
145    {
146        let serialized = serde_json::to_string(&error).expect("serialize error");
147        let value: Value = serde_json::from_str(&serialized).expect("deserialize json");
148        let obj = match &value {
149            Value::Object(map) => map,
150            other => panic!("expected object, got {other:?}"),
151        };
152        assert_eq!(obj.len(), 2);
153        assert!(obj.contains_key("code"));
154        assert!(obj.contains_key("details"));
155        assert_eq!(
156            value.get("code"),
157            Some(&Value::String("workspace".to_string()))
158        );
159        let details = value.get("details").expect("workspace details missing");
160        assert_eq!(
161            details.get("code"),
162            Some(&Value::String("not_found".to_string()))
163        );
164        let _: T = serde_json::from_str(&serialized).expect("roundtrip error");
165    }
166
167    #[test]
168    fn workspace_error_wrappers_roundtrip() {
169        let workspace_error = WorkspaceOpError::NotFound;
170        assert_workspace_error_roundtrip(AstGrepError::Workspace(workspace_error.clone()));
171        assert_workspace_error_roundtrip(EditError::Workspace(workspace_error.clone()));
172        assert_workspace_error_roundtrip(MultiEditError::Workspace(workspace_error.clone()));
173        assert_workspace_error_roundtrip(GlobError::Workspace(workspace_error.clone()));
174        assert_workspace_error_roundtrip(GrepError::Workspace(workspace_error.clone()));
175        assert_workspace_error_roundtrip(LsError::Workspace(workspace_error.clone()));
176        assert_workspace_error_roundtrip(ReplaceError::Workspace(workspace_error.clone()));
177        assert_workspace_error_roundtrip(ViewError::Workspace(workspace_error.clone()));
178        assert_workspace_error_roundtrip(DispatchAgentError::Workspace(workspace_error));
179    }
180}