Skip to main content

prax_mssql/
error.rs

1//! Error types for Microsoft SQL Server operations.
2
3use prax_query::QueryError;
4use thiserror::Error;
5
6/// Result type for MSSQL operations.
7pub type MssqlResult<T> = Result<T, MssqlError>;
8
9/// Errors that can occur during MSSQL operations.
10#[derive(Error, Debug)]
11pub enum MssqlError {
12    /// Connection pool error.
13    #[error("pool error: {0}")]
14    Pool(String),
15
16    /// Tiberius/SQL Server error.
17    #[error("sql server error: {0}")]
18    SqlServer(#[from] tiberius::error::Error),
19
20    /// Configuration error.
21    #[error("configuration error: {0}")]
22    Config(String),
23
24    /// Connection error.
25    #[error("connection error: {0}")]
26    Connection(String),
27
28    /// Query execution error.
29    #[error("query error: {0}")]
30    Query(String),
31
32    /// Row deserialization error.
33    #[error("deserialization error: {0}")]
34    Deserialization(String),
35
36    /// Type conversion error.
37    #[error("type conversion error: {0}")]
38    TypeConversion(String),
39
40    /// Timeout error.
41    #[error("operation timed out after {0}ms")]
42    Timeout(u64),
43
44    /// Internal error.
45    #[error("internal error: {0}")]
46    Internal(String),
47
48    /// RLS policy error.
49    #[error("rls policy error: {0}")]
50    RlsPolicy(String),
51}
52
53impl MssqlError {
54    /// Create a pool error.
55    pub fn pool(message: impl Into<String>) -> Self {
56        Self::Pool(message.into())
57    }
58
59    /// Create a configuration error.
60    pub fn config(message: impl Into<String>) -> Self {
61        Self::Config(message.into())
62    }
63
64    /// Create a connection error.
65    pub fn connection(message: impl Into<String>) -> Self {
66        Self::Connection(message.into())
67    }
68
69    /// Create a query error.
70    pub fn query(message: impl Into<String>) -> Self {
71        Self::Query(message.into())
72    }
73
74    /// Create a deserialization error.
75    pub fn deserialization(message: impl Into<String>) -> Self {
76        Self::Deserialization(message.into())
77    }
78
79    /// Create a type conversion error.
80    pub fn type_conversion(message: impl Into<String>) -> Self {
81        Self::TypeConversion(message.into())
82    }
83
84    /// Create an RLS policy error.
85    pub fn rls_policy(message: impl Into<String>) -> Self {
86        Self::RlsPolicy(message.into())
87    }
88
89    /// Check if this is a connection error.
90    pub fn is_connection_error(&self) -> bool {
91        matches!(self, Self::Pool(_) | Self::Connection(_))
92    }
93
94    /// Check if this is a timeout error.
95    pub fn is_timeout(&self) -> bool {
96        matches!(self, Self::Timeout(_))
97    }
98}
99
100impl<E> From<bb8::RunError<E>> for MssqlError
101where
102    E: std::error::Error,
103{
104    fn from(err: bb8::RunError<E>) -> Self {
105        match err {
106            bb8::RunError::User(e) => MssqlError::Pool(e.to_string()),
107            bb8::RunError::TimedOut => MssqlError::Timeout(30000), // Default 30s
108        }
109    }
110}
111
112impl From<MssqlError> for QueryError {
113    fn from(err: MssqlError) -> Self {
114        match err {
115            MssqlError::Pool(msg) => QueryError::connection(msg),
116            MssqlError::SqlServer(e) => {
117                // Try to categorize SQL Server errors by error number
118                let msg = e.to_string();
119
120                // Unique constraint violation (error 2627)
121                if msg.contains("2627") || msg.contains("unique") || msg.contains("duplicate") {
122                    return QueryError::constraint_violation("", msg);
123                }
124
125                // Foreign key violation (error 547)
126                if msg.contains("547") || msg.contains("foreign key") {
127                    return QueryError::constraint_violation("", msg);
128                }
129
130                // Not null violation (error 515)
131                if msg.contains("515") || msg.contains("cannot insert") {
132                    return QueryError::invalid_input("", msg);
133                }
134
135                QueryError::database(msg)
136            }
137            MssqlError::Config(msg) => QueryError::connection(msg),
138            MssqlError::Connection(msg) => QueryError::connection(msg),
139            MssqlError::Query(msg) => QueryError::database(msg),
140            MssqlError::Deserialization(msg) => QueryError::serialization(msg),
141            MssqlError::TypeConversion(msg) => QueryError::serialization(msg),
142            MssqlError::Timeout(ms) => QueryError::timeout(ms),
143            MssqlError::Internal(msg) => QueryError::internal(msg),
144            MssqlError::RlsPolicy(msg) => QueryError::database(msg),
145        }
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_error_creation() {
155        let err = MssqlError::config("invalid connection string");
156        assert!(matches!(err, MssqlError::Config(_)));
157
158        let err = MssqlError::connection("connection refused");
159        assert!(err.is_connection_error());
160
161        let err = MssqlError::Timeout(5000);
162        assert!(err.is_timeout());
163    }
164
165    #[test]
166    fn test_into_query_error() {
167        let mssql_err = MssqlError::Timeout(1000);
168        let query_err: QueryError = mssql_err.into();
169        assert!(query_err.is_timeout());
170    }
171
172    #[test]
173    fn test_error_display() {
174        let err = MssqlError::config("test error");
175        assert_eq!(err.to_string(), "configuration error: test error");
176
177        let err = MssqlError::Pool("pool exhausted".to_string());
178        assert_eq!(err.to_string(), "pool error: pool exhausted");
179    }
180}