Skip to main content

hdbconnect_mcp/
error.rs

1use std::time::Duration;
2
3use rmcp::ErrorData;
4use thiserror::Error;
5
6use crate::config::DmlOperation;
7
8#[derive(Error, Debug)]
9pub enum Error {
10    #[error("Connection error: {0}")]
11    Connection(#[from] hdbconnect::HdbError),
12
13    #[error("Query error: {0}")]
14    Query(String),
15
16    #[error("Configuration error: {0}")]
17    Config(String),
18
19    #[error("Connection pool exhausted")]
20    PoolExhausted,
21
22    #[error("Read-only mode: {0}")]
23    ReadOnlyViolation(String),
24
25    #[error("Query timeout after {0:?}")]
26    QueryTimeout(Duration),
27
28    #[error("Schema access denied: {0}")]
29    SchemaAccessDenied(String),
30
31    #[error("Transport error: {0}")]
32    Transport(String),
33
34    // DML-specific errors
35    #[error("DML operations are disabled. Set allow_dml=true in configuration")]
36    DmlDisabled,
37
38    #[error("DML operation not allowed: {0}")]
39    DmlOperationNotAllowed(DmlOperation),
40
41    #[error("WHERE clause required for {0} statements")]
42    DmlWhereClauseRequired(DmlOperation),
43
44    #[error("Affected rows ({actual}) exceeds limit ({limit})")]
45    DmlRowLimitExceeded { actual: u64, limit: u32 },
46
47    #[error("DML operation cancelled by user")]
48    DmlCancelled,
49
50    #[error("Not a valid DML statement. Use INSERT, UPDATE, or DELETE")]
51    DmlNotAStatement,
52
53    // Procedure-specific errors
54    #[error("Stored procedure execution is disabled. Set allow_procedures=true in configuration")]
55    ProcedureDisabled,
56
57    #[error("Procedure not found: {schema}.{name}")]
58    ProcedureNotFound { schema: String, name: String },
59
60    #[error("Invalid procedure name: {0}")]
61    InvalidProcedureName(String),
62
63    #[error("Missing required parameter: {0}")]
64    ProcedureMissingParameter(String),
65
66    #[error("Procedure execution cancelled by user")]
67    ProcedureCancelled,
68
69    #[error("Procedure returned too many result sets ({actual}), limit is {limit}")]
70    ProcedureResultSetLimitExceeded { actual: usize, limit: u32 },
71
72    #[error("Procedure execution failed: {0}")]
73    ProcedureExecutionFailed(String),
74}
75
76impl Error {
77    pub const fn read_only_violation(msg: String) -> Self {
78        Self::ReadOnlyViolation(msg)
79    }
80
81    #[must_use]
82    pub const fn is_read_only_violation(&self) -> bool {
83        matches!(self, Self::ReadOnlyViolation(_))
84    }
85
86    #[must_use]
87    pub const fn is_timeout(&self) -> bool {
88        matches!(self, Self::QueryTimeout(_))
89    }
90
91    #[must_use]
92    pub const fn is_schema_denied(&self) -> bool {
93        matches!(self, Self::SchemaAccessDenied(_))
94    }
95
96    #[must_use]
97    pub const fn is_config(&self) -> bool {
98        matches!(self, Self::Config(_))
99    }
100
101    #[must_use]
102    pub const fn is_pool_exhausted(&self) -> bool {
103        matches!(self, Self::PoolExhausted)
104    }
105
106    #[must_use]
107    pub const fn is_transport(&self) -> bool {
108        matches!(self, Self::Transport(_))
109    }
110
111    #[must_use]
112    pub const fn is_query(&self) -> bool {
113        matches!(self, Self::Query(_))
114    }
115
116    #[must_use]
117    pub const fn is_dml_error(&self) -> bool {
118        matches!(
119            self,
120            Self::DmlDisabled
121                | Self::DmlOperationNotAllowed(_)
122                | Self::DmlWhereClauseRequired(_)
123                | Self::DmlRowLimitExceeded { .. }
124                | Self::DmlCancelled
125                | Self::DmlNotAStatement
126        )
127    }
128
129    #[must_use]
130    pub const fn is_procedure_error(&self) -> bool {
131        matches!(
132            self,
133            Self::ProcedureDisabled
134                | Self::ProcedureNotFound { .. }
135                | Self::InvalidProcedureName(_)
136                | Self::ProcedureMissingParameter(_)
137                | Self::ProcedureCancelled
138                | Self::ProcedureResultSetLimitExceeded { .. }
139                | Self::ProcedureExecutionFailed(_)
140        )
141    }
142}
143
144/// Convert our Error type to rmcp `ErrorData`
145impl From<Error> for ErrorData {
146    fn from(err: Error) -> Self {
147        match err {
148            Error::Connection(e) => {
149                Self::internal_error(format!("Database connection error: {e}"), None)
150            }
151            Error::Query(msg) => Self::internal_error(format!("Query error: {msg}"), None),
152            Error::Config(msg) => Self::invalid_params(format!("Configuration error: {msg}"), None),
153            Error::PoolExhausted => Self::internal_error("Connection pool exhausted", None),
154            Error::ReadOnlyViolation(msg) => {
155                Self::invalid_params(format!("Read-only mode violation: {msg}"), None)
156            }
157            Error::QueryTimeout(duration) => {
158                Self::internal_error(format!("Query timeout after {duration:?}"), None)
159            }
160            Error::SchemaAccessDenied(schema) => {
161                Self::invalid_params(format!("Schema access denied: {schema}"), None)
162            }
163            Error::Transport(msg) => Self::internal_error(format!("Transport error: {msg}"), None),
164            // DML errors
165            Error::DmlDisabled => Self::invalid_params(
166                "DML operations are disabled. Set allow_dml=true in configuration",
167                None,
168            ),
169            Error::DmlOperationNotAllowed(op) => {
170                Self::invalid_params(format!("DML operation not allowed: {op}"), None)
171            }
172            Error::DmlWhereClauseRequired(op) => {
173                Self::invalid_params(format!("WHERE clause required for {op} statements"), None)
174            }
175            Error::DmlRowLimitExceeded { actual, limit } => Self::invalid_params(
176                format!("Affected rows ({actual}) exceeds limit ({limit})"),
177                None,
178            ),
179            Error::DmlCancelled => Self::invalid_params("DML operation cancelled by user", None),
180            Error::DmlNotAStatement => Self::invalid_params(
181                "Not a valid DML statement. Use INSERT, UPDATE, or DELETE",
182                None,
183            ),
184            // Procedure errors
185            Error::ProcedureDisabled => Self::invalid_params(
186                "Stored procedure execution is disabled. Set allow_procedures=true in configuration",
187                None,
188            ),
189            Error::ProcedureNotFound { schema, name } => {
190                Self::invalid_params(format!("Procedure not found: {schema}.{name}"), None)
191            }
192            Error::InvalidProcedureName(name) => {
193                Self::invalid_params(format!("Invalid procedure name: {name}"), None)
194            }
195            Error::ProcedureMissingParameter(name) => {
196                Self::invalid_params(format!("Missing required parameter: {name}"), None)
197            }
198            Error::ProcedureCancelled => {
199                Self::invalid_params("Procedure execution cancelled by user", None)
200            }
201            Error::ProcedureResultSetLimitExceeded { actual, limit } => Self::invalid_params(
202                format!("Procedure returned too many result sets ({actual}), limit is {limit}"),
203                None,
204            ),
205            Error::ProcedureExecutionFailed(msg) => {
206                Self::internal_error(format!("Procedure execution failed: {msg}"), None)
207            }
208        }
209    }
210}
211
212pub type Result<T> = std::result::Result<T, Error>;
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn test_read_only_violation_predicate() {
220        let err = Error::ReadOnlyViolation("test".to_string());
221        assert!(err.is_read_only_violation());
222        assert!(!err.is_timeout());
223        assert!(!err.is_schema_denied());
224        assert!(!err.is_config());
225    }
226
227    #[test]
228    fn test_timeout_predicate() {
229        let err = Error::QueryTimeout(Duration::from_secs(30));
230        assert!(err.is_timeout());
231        assert!(!err.is_read_only_violation());
232    }
233
234    #[test]
235    fn test_schema_denied_predicate() {
236        let err = Error::SchemaAccessDenied("SYS".to_string());
237        assert!(err.is_schema_denied());
238        assert!(!err.is_timeout());
239    }
240
241    #[test]
242    fn test_config_predicate() {
243        let err = Error::Config("invalid config".to_string());
244        assert!(err.is_config());
245        assert!(!err.is_schema_denied());
246    }
247
248    #[test]
249    fn test_pool_exhausted_predicate() {
250        let err = Error::PoolExhausted;
251        assert!(err.is_pool_exhausted());
252        assert!(!err.is_config());
253    }
254
255    #[test]
256    fn test_transport_predicate() {
257        let err = Error::Transport("connection refused".to_string());
258        assert!(err.is_transport());
259        assert!(!err.is_pool_exhausted());
260    }
261
262    #[test]
263    fn test_query_predicate() {
264        let err = Error::Query("syntax error".to_string());
265        assert!(err.is_query());
266        assert!(!err.is_transport());
267    }
268
269    #[test]
270    fn test_error_display() {
271        let err = Error::ReadOnlyViolation("INSERT not allowed".to_string());
272        assert_eq!(err.to_string(), "Read-only mode: INSERT not allowed");
273
274        let err = Error::QueryTimeout(Duration::from_secs(30));
275        assert!(err.to_string().contains("30"));
276
277        let err = Error::SchemaAccessDenied("SECRET".to_string());
278        assert!(err.to_string().contains("SECRET"));
279    }
280
281    #[test]
282    fn test_read_only_violation_constructor() {
283        let err = Error::read_only_violation("DML blocked".to_string());
284        assert!(err.is_read_only_violation());
285        assert!(err.to_string().contains("DML blocked"));
286    }
287
288    #[test]
289    fn test_error_to_error_data_read_only() {
290        let err = Error::ReadOnlyViolation("test".to_string());
291        let data: ErrorData = err.into();
292        assert!(data.message.contains("Read-only mode violation"));
293    }
294
295    #[test]
296    fn test_error_to_error_data_timeout() {
297        let err = Error::QueryTimeout(Duration::from_secs(60));
298        let data: ErrorData = err.into();
299        assert!(data.message.contains("timeout"));
300    }
301
302    #[test]
303    fn test_error_to_error_data_schema_denied() {
304        let err = Error::SchemaAccessDenied("PRIVATE".to_string());
305        let data: ErrorData = err.into();
306        assert!(data.message.contains("Schema access denied"));
307        assert!(data.message.contains("PRIVATE"));
308    }
309
310    #[test]
311    fn test_error_to_error_data_config() {
312        let err = Error::Config("missing URL".to_string());
313        let data: ErrorData = err.into();
314        assert!(data.message.contains("Configuration error"));
315    }
316
317    #[test]
318    fn test_error_to_error_data_pool_exhausted() {
319        let err = Error::PoolExhausted;
320        let data: ErrorData = err.into();
321        assert!(data.message.contains("pool exhausted"));
322    }
323
324    #[test]
325    fn test_error_to_error_data_transport() {
326        let err = Error::Transport("connection refused".to_string());
327        let data: ErrorData = err.into();
328        assert!(data.message.contains("Transport error"));
329    }
330
331    #[test]
332    fn test_error_to_error_data_query() {
333        let err = Error::Query("invalid SQL".to_string());
334        let data: ErrorData = err.into();
335        assert!(data.message.contains("Query error"));
336    }
337
338    // DML error tests
339    #[test]
340    fn test_dml_disabled_error() {
341        let err = Error::DmlDisabled;
342        assert!(err.is_dml_error());
343        assert!(err.to_string().contains("disabled"));
344    }
345
346    #[test]
347    fn test_dml_operation_not_allowed_error() {
348        let err = Error::DmlOperationNotAllowed(DmlOperation::Delete);
349        assert!(err.is_dml_error());
350        assert!(err.to_string().contains("DELETE"));
351    }
352
353    #[test]
354    fn test_dml_where_clause_required_error() {
355        let err = Error::DmlWhereClauseRequired(DmlOperation::Update);
356        assert!(err.is_dml_error());
357        assert!(err.to_string().contains("WHERE"));
358        assert!(err.to_string().contains("UPDATE"));
359    }
360
361    #[test]
362    fn test_dml_row_limit_exceeded_error() {
363        let err = Error::DmlRowLimitExceeded {
364            actual: 5000,
365            limit: 1000,
366        };
367        assert!(err.is_dml_error());
368        assert!(err.to_string().contains("5000"));
369        assert!(err.to_string().contains("1000"));
370    }
371
372    #[test]
373    fn test_dml_cancelled_error() {
374        let err = Error::DmlCancelled;
375        assert!(err.is_dml_error());
376        assert!(err.to_string().contains("cancelled"));
377    }
378
379    #[test]
380    fn test_dml_not_a_statement_error() {
381        let err = Error::DmlNotAStatement;
382        assert!(err.is_dml_error());
383        assert!(err.to_string().contains("INSERT"));
384    }
385
386    #[test]
387    fn test_error_to_error_data_dml_disabled() {
388        let err = Error::DmlDisabled;
389        let data: ErrorData = err.into();
390        assert!(data.message.contains("disabled"));
391        assert!(data.message.contains("allow_dml"));
392    }
393
394    #[test]
395    fn test_error_to_error_data_dml_row_limit() {
396        let err = Error::DmlRowLimitExceeded {
397            actual: 2000,
398            limit: 500,
399        };
400        let data: ErrorData = err.into();
401        assert!(data.message.contains("2000"));
402        assert!(data.message.contains("500"));
403    }
404
405    // Procedure error tests
406    #[test]
407    fn test_procedure_disabled_error() {
408        let err = Error::ProcedureDisabled;
409        assert!(err.is_procedure_error());
410        assert!(err.to_string().contains("disabled"));
411    }
412
413    #[test]
414    fn test_procedure_not_found_error() {
415        let err = Error::ProcedureNotFound {
416            schema: "APP".to_string(),
417            name: "MISSING_PROC".to_string(),
418        };
419        assert!(err.is_procedure_error());
420        assert!(err.to_string().contains("APP"));
421        assert!(err.to_string().contains("MISSING_PROC"));
422    }
423
424    #[test]
425    fn test_invalid_procedure_name_error() {
426        let err = Error::InvalidProcedureName("bad;name".to_string());
427        assert!(err.is_procedure_error());
428        assert!(err.to_string().contains("bad;name"));
429    }
430
431    #[test]
432    fn test_procedure_missing_parameter_error() {
433        let err = Error::ProcedureMissingParameter("USER_ID".to_string());
434        assert!(err.is_procedure_error());
435        assert!(err.to_string().contains("USER_ID"));
436    }
437
438    #[test]
439    fn test_procedure_cancelled_error() {
440        let err = Error::ProcedureCancelled;
441        assert!(err.is_procedure_error());
442        assert!(err.to_string().contains("cancelled"));
443    }
444
445    #[test]
446    fn test_procedure_result_set_limit_exceeded_error() {
447        let err = Error::ProcedureResultSetLimitExceeded {
448            actual: 15,
449            limit: 10,
450        };
451        assert!(err.is_procedure_error());
452        assert!(err.to_string().contains("15"));
453        assert!(err.to_string().contains("10"));
454    }
455
456    #[test]
457    fn test_procedure_execution_failed_error() {
458        let err = Error::ProcedureExecutionFailed("division by zero".to_string());
459        assert!(err.is_procedure_error());
460        assert!(err.to_string().contains("division by zero"));
461    }
462
463    #[test]
464    fn test_error_to_error_data_procedure_disabled() {
465        let err = Error::ProcedureDisabled;
466        let data: ErrorData = err.into();
467        assert!(data.message.contains("disabled"));
468        assert!(data.message.contains("allow_procedures"));
469    }
470
471    #[test]
472    fn test_error_to_error_data_procedure_not_found() {
473        let err = Error::ProcedureNotFound {
474            schema: "TEST".to_string(),
475            name: "MY_PROC".to_string(),
476        };
477        let data: ErrorData = err.into();
478        assert!(data.message.contains("TEST"));
479        assert!(data.message.contains("MY_PROC"));
480    }
481}