1use prax_query::QueryError;
4use thiserror::Error;
5
6pub type MssqlResult<T> = Result<T, MssqlError>;
8
9#[derive(Error, Debug)]
11pub enum MssqlError {
12 #[error("pool error: {0}")]
14 Pool(String),
15
16 #[error("sql server error: {0}")]
18 SqlServer(#[from] tiberius::error::Error),
19
20 #[error("configuration error: {0}")]
22 Config(String),
23
24 #[error("connection error: {0}")]
26 Connection(String),
27
28 #[error("query error: {0}")]
30 Query(String),
31
32 #[error("deserialization error: {0}")]
34 Deserialization(String),
35
36 #[error("type conversion error: {0}")]
38 TypeConversion(String),
39
40 #[error("operation timed out after {0}ms")]
42 Timeout(u64),
43
44 #[error("internal error: {0}")]
46 Internal(String),
47
48 #[error("rls policy error: {0}")]
50 RlsPolicy(String),
51}
52
53impl MssqlError {
54 pub fn pool(message: impl Into<String>) -> Self {
56 Self::Pool(message.into())
57 }
58
59 pub fn config(message: impl Into<String>) -> Self {
61 Self::Config(message.into())
62 }
63
64 pub fn connection(message: impl Into<String>) -> Self {
66 Self::Connection(message.into())
67 }
68
69 pub fn query(message: impl Into<String>) -> Self {
71 Self::Query(message.into())
72 }
73
74 pub fn deserialization(message: impl Into<String>) -> Self {
76 Self::Deserialization(message.into())
77 }
78
79 pub fn type_conversion(message: impl Into<String>) -> Self {
81 Self::TypeConversion(message.into())
82 }
83
84 pub fn rls_policy(message: impl Into<String>) -> Self {
86 Self::RlsPolicy(message.into())
87 }
88
89 pub fn is_connection_error(&self) -> bool {
91 matches!(self, Self::Pool(_) | Self::Connection(_))
92 }
93
94 pub fn is_timeout(&self) -> bool {
96 matches!(self, Self::Timeout(_))
97 }
98}
99
100impl<E> From<bb8::RunError<E>> for MssqlError
101where
102 E: std::error::Error,
103{
104 fn from(err: bb8::RunError<E>) -> Self {
105 match err {
106 bb8::RunError::User(e) => MssqlError::Pool(e.to_string()),
107 bb8::RunError::TimedOut => MssqlError::Timeout(30000), }
109 }
110}
111
112impl From<MssqlError> for QueryError {
113 fn from(err: MssqlError) -> Self {
114 match err {
115 MssqlError::Pool(msg) => QueryError::connection(msg),
116 MssqlError::SqlServer(e) => {
117 let msg = e.to_string();
119
120 if msg.contains("2627") || msg.contains("unique") || msg.contains("duplicate") {
122 return QueryError::constraint_violation("", msg);
123 }
124
125 if msg.contains("547") || msg.contains("foreign key") {
127 return QueryError::constraint_violation("", msg);
128 }
129
130 if msg.contains("515") || msg.contains("cannot insert") {
132 return QueryError::invalid_input("", msg);
133 }
134
135 QueryError::database(msg)
136 }
137 MssqlError::Config(msg) => QueryError::connection(msg),
138 MssqlError::Connection(msg) => QueryError::connection(msg),
139 MssqlError::Query(msg) => QueryError::database(msg),
140 MssqlError::Deserialization(msg) => QueryError::serialization(msg),
141 MssqlError::TypeConversion(msg) => QueryError::serialization(msg),
142 MssqlError::Timeout(ms) => QueryError::timeout(ms),
143 MssqlError::Internal(msg) => QueryError::internal(msg),
144 MssqlError::RlsPolicy(msg) => QueryError::database(msg),
145 }
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn test_error_creation() {
155 let err = MssqlError::config("invalid connection string");
156 assert!(matches!(err, MssqlError::Config(_)));
157
158 let err = MssqlError::connection("connection refused");
159 assert!(err.is_connection_error());
160
161 let err = MssqlError::Timeout(5000);
162 assert!(err.is_timeout());
163 }
164
165 #[test]
166 fn test_into_query_error() {
167 let mssql_err = MssqlError::Timeout(1000);
168 let query_err: QueryError = mssql_err.into();
169 assert!(query_err.is_timeout());
170 }
171
172 #[test]
173 fn test_error_display() {
174 let err = MssqlError::config("test error");
175 assert_eq!(err.to_string(), "configuration error: test error");
176
177 let err = MssqlError::Pool("pool exhausted".to_string());
178 assert_eq!(err.to_string(), "pool error: pool exhausted");
179 }
180}