1use std::marker::PhantomData;
4
5use prax_query::QueryResult;
6use prax_query::filter::FilterValue;
7use prax_query::traits::{BoxFuture, Model, QueryEngine};
8use tracing::debug;
9
10use crate::pool::MssqlPool;
11use crate::types::filter_value_to_sql;
12
13#[derive(Clone)]
15pub struct MssqlEngine {
16 pool: MssqlPool,
17}
18
19impl MssqlEngine {
20 pub fn new(pool: MssqlPool) -> Self {
22 Self { pool }
23 }
24
25 pub fn pool(&self) -> &MssqlPool {
27 &self.pool
28 }
29
30 fn to_params(
32 values: &[FilterValue],
33 ) -> Result<Vec<Box<dyn tiberius::ToSql>>, prax_query::QueryError> {
34 values
35 .iter()
36 .map(|v| {
37 filter_value_to_sql(v).map_err(|e| prax_query::QueryError::database(e.to_string()))
38 })
39 .collect()
40 }
41
42 fn convert_params(sql: &str) -> String {
44 let mut result = sql.to_string();
45 let mut i = 1;
46
47 while result.contains(&format!("${}", i)) {
48 result = result.replace(&format!("${}", i), &format!("@P{}", i));
49 i += 1;
50 }
51
52 result
53 }
54}
55
56impl QueryEngine for MssqlEngine {
57 fn query_many<T: Model + Send + 'static>(
58 &self,
59 sql: &str,
60 params: Vec<FilterValue>,
61 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
62 let sql = Self::convert_params(sql);
63 Box::pin(async move {
64 debug!(sql = %sql, "Executing query_many");
65
66 let mut conn = self
67 .pool
68 .get()
69 .await
70 .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
71
72 let mssql_params = Self::to_params(¶ms)?;
73 let param_refs: Vec<&dyn tiberius::ToSql> =
74 mssql_params.iter().map(|p| p.as_ref()).collect();
75
76 let _rows = conn
77 .query(&sql, ¶m_refs)
78 .await
79 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
80
81 Ok(Vec::new())
83 })
84 }
85
86 fn query_one<T: Model + Send + 'static>(
87 &self,
88 sql: &str,
89 params: Vec<FilterValue>,
90 ) -> BoxFuture<'_, QueryResult<T>> {
91 let sql = Self::convert_params(sql);
92 Box::pin(async move {
93 debug!(sql = %sql, "Executing query_one");
94
95 let mut conn = self
96 .pool
97 .get()
98 .await
99 .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
100
101 let mssql_params = Self::to_params(¶ms)?;
102 let param_refs: Vec<&dyn tiberius::ToSql> =
103 mssql_params.iter().map(|p| p.as_ref()).collect();
104
105 let _row = conn.query_one(&sql, ¶m_refs).await.map_err(|e| {
106 if e.to_string().contains("no rows") {
107 prax_query::QueryError::not_found(T::MODEL_NAME)
108 } else {
109 prax_query::QueryError::database(e.to_string())
110 }
111 })?;
112
113 Err(prax_query::QueryError::internal(
115 "deserialization not yet implemented".to_string(),
116 ))
117 })
118 }
119
120 fn query_optional<T: Model + Send + 'static>(
121 &self,
122 sql: &str,
123 params: Vec<FilterValue>,
124 ) -> BoxFuture<'_, QueryResult<Option<T>>> {
125 let sql = Self::convert_params(sql);
126 Box::pin(async move {
127 debug!(sql = %sql, "Executing query_optional");
128
129 let mut conn = self
130 .pool
131 .get()
132 .await
133 .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
134
135 let mssql_params = Self::to_params(¶ms)?;
136 let param_refs: Vec<&dyn tiberius::ToSql> =
137 mssql_params.iter().map(|p| p.as_ref()).collect();
138
139 let row = conn
140 .query_opt(&sql, ¶m_refs)
141 .await
142 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
143
144 match row {
145 Some(_row) => {
146 Err(prax_query::QueryError::internal(
148 "deserialization not yet implemented".to_string(),
149 ))
150 }
151 None => Ok(None),
152 }
153 })
154 }
155
156 fn execute_insert<T: Model + Send + 'static>(
157 &self,
158 sql: &str,
159 params: Vec<FilterValue>,
160 ) -> BoxFuture<'_, QueryResult<T>> {
161 let sql = Self::convert_params(sql);
162 Box::pin(async move {
163 debug!(sql = %sql, "Executing insert");
164
165 let mut conn = self
166 .pool
167 .get()
168 .await
169 .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
170
171 let mssql_params = Self::to_params(¶ms)?;
172 let param_refs: Vec<&dyn tiberius::ToSql> =
173 mssql_params.iter().map(|p| p.as_ref()).collect();
174
175 let _row = conn
177 .query_one(&sql, ¶m_refs)
178 .await
179 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
180
181 Err(prax_query::QueryError::internal(
183 "deserialization not yet implemented".to_string(),
184 ))
185 })
186 }
187
188 fn execute_update<T: Model + Send + 'static>(
189 &self,
190 sql: &str,
191 params: Vec<FilterValue>,
192 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
193 let sql = Self::convert_params(sql);
194 Box::pin(async move {
195 debug!(sql = %sql, "Executing update");
196
197 let mut conn = self
198 .pool
199 .get()
200 .await
201 .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
202
203 let mssql_params = Self::to_params(¶ms)?;
204 let param_refs: Vec<&dyn tiberius::ToSql> =
205 mssql_params.iter().map(|p| p.as_ref()).collect();
206
207 let _rows = conn
208 .query(&sql, ¶m_refs)
209 .await
210 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
211
212 Ok(Vec::new())
214 })
215 }
216
217 fn execute_delete(
218 &self,
219 sql: &str,
220 params: Vec<FilterValue>,
221 ) -> BoxFuture<'_, QueryResult<u64>> {
222 let sql = Self::convert_params(sql);
223 Box::pin(async move {
224 debug!(sql = %sql, "Executing delete");
225
226 let mut conn = self
227 .pool
228 .get()
229 .await
230 .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
231
232 let mssql_params = Self::to_params(¶ms)?;
233 let param_refs: Vec<&dyn tiberius::ToSql> =
234 mssql_params.iter().map(|p| p.as_ref()).collect();
235
236 let count = conn
237 .execute(&sql, ¶m_refs)
238 .await
239 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
240
241 Ok(count)
242 })
243 }
244
245 fn execute_raw(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
246 let sql = Self::convert_params(sql);
247 Box::pin(async move {
248 debug!(sql = %sql, "Executing raw SQL");
249
250 let mut conn = self
251 .pool
252 .get()
253 .await
254 .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
255
256 let mssql_params = Self::to_params(¶ms)?;
257 let param_refs: Vec<&dyn tiberius::ToSql> =
258 mssql_params.iter().map(|p| p.as_ref()).collect();
259
260 let count = conn
261 .execute(&sql, ¶m_refs)
262 .await
263 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
264
265 Ok(count)
266 })
267 }
268
269 fn count(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
270 let sql = Self::convert_params(sql);
271 Box::pin(async move {
272 debug!(sql = %sql, "Executing count");
273
274 let mut conn = self
275 .pool
276 .get()
277 .await
278 .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
279
280 let mssql_params = Self::to_params(¶ms)?;
281 let param_refs: Vec<&dyn tiberius::ToSql> =
282 mssql_params.iter().map(|p| p.as_ref()).collect();
283
284 let row = conn
285 .query_one(&sql, ¶m_refs)
286 .await
287 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
288
289 let count: i32 = row.get(0).unwrap_or(0);
290 Ok(count as u64)
291 })
292 }
293}
294
295pub struct MssqlQueryBuilder<T: Model> {
297 engine: MssqlEngine,
298 _marker: PhantomData<T>,
299}
300
301impl<T: Model> MssqlQueryBuilder<T> {
302 pub fn new(engine: MssqlEngine) -> Self {
304 Self {
305 engine,
306 _marker: PhantomData,
307 }
308 }
309
310 pub fn engine(&self) -> &MssqlEngine {
312 &self.engine
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn test_convert_params() {
322 assert_eq!(
323 MssqlEngine::convert_params("SELECT * FROM users WHERE id = $1"),
324 "SELECT * FROM users WHERE id = @P1"
325 );
326
327 assert_eq!(
328 MssqlEngine::convert_params("SELECT * FROM users WHERE id = $1 AND name = $2"),
329 "SELECT * FROM users WHERE id = @P1 AND name = @P2"
330 );
331
332 assert_eq!(
333 MssqlEngine::convert_params("SELECT * FROM users"),
334 "SELECT * FROM users"
335 );
336 }
337}