Skip to main content

prax_mssql/
engine.rs

1//! Microsoft SQL Server query engine implementation.
2
3use std::marker::PhantomData;
4
5use prax_query::QueryResult;
6use prax_query::filter::FilterValue;
7use prax_query::traits::{BoxFuture, Model, QueryEngine};
8use tracing::debug;
9
10use crate::pool::MssqlPool;
11use crate::types::filter_value_to_sql;
12
13/// Microsoft SQL Server query engine that implements the Prax QueryEngine trait.
14#[derive(Clone)]
15pub struct MssqlEngine {
16    pool: MssqlPool,
17}
18
19impl MssqlEngine {
20    /// Create a new MSSQL engine with the given connection pool.
21    pub fn new(pool: MssqlPool) -> Self {
22        Self { pool }
23    }
24
25    /// Get a reference to the connection pool.
26    pub fn pool(&self) -> &MssqlPool {
27        &self.pool
28    }
29
30    /// Convert filter values to MSSQL parameters.
31    fn to_params(
32        values: &[FilterValue],
33    ) -> Result<Vec<Box<dyn tiberius::ToSql>>, prax_query::QueryError> {
34        values
35            .iter()
36            .map(|v| {
37                filter_value_to_sql(v).map_err(|e| prax_query::QueryError::database(e.to_string()))
38            })
39            .collect()
40    }
41
42    /// Convert PostgreSQL-style parameter placeholders ($1, $2) to MSSQL-style (@P1, @P2).
43    fn convert_params(sql: &str) -> String {
44        let mut result = sql.to_string();
45        let mut i = 1;
46
47        while result.contains(&format!("${}", i)) {
48            result = result.replace(&format!("${}", i), &format!("@P{}", i));
49            i += 1;
50        }
51
52        result
53    }
54}
55
56impl QueryEngine for MssqlEngine {
57    fn query_many<T: Model + Send + 'static>(
58        &self,
59        sql: &str,
60        params: Vec<FilterValue>,
61    ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
62        let sql = Self::convert_params(sql);
63        Box::pin(async move {
64            debug!(sql = %sql, "Executing query_many");
65
66            let mut conn = self
67                .pool
68                .get()
69                .await
70                .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
71
72            let mssql_params = Self::to_params(&params)?;
73            let param_refs: Vec<&dyn tiberius::ToSql> =
74                mssql_params.iter().map(|p| p.as_ref()).collect();
75
76            let _rows = conn
77                .query(&sql, &param_refs)
78                .await
79                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
80
81            // Placeholder - would deserialize rows into Vec<T>
82            Ok(Vec::new())
83        })
84    }
85
86    fn query_one<T: Model + Send + 'static>(
87        &self,
88        sql: &str,
89        params: Vec<FilterValue>,
90    ) -> BoxFuture<'_, QueryResult<T>> {
91        let sql = Self::convert_params(sql);
92        Box::pin(async move {
93            debug!(sql = %sql, "Executing query_one");
94
95            let mut conn = self
96                .pool
97                .get()
98                .await
99                .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
100
101            let mssql_params = Self::to_params(&params)?;
102            let param_refs: Vec<&dyn tiberius::ToSql> =
103                mssql_params.iter().map(|p| p.as_ref()).collect();
104
105            let _row = conn.query_one(&sql, &param_refs).await.map_err(|e| {
106                if e.to_string().contains("no rows") {
107                    prax_query::QueryError::not_found(T::MODEL_NAME)
108                } else {
109                    prax_query::QueryError::database(e.to_string())
110                }
111            })?;
112
113            // Placeholder - would deserialize row into T
114            Err(prax_query::QueryError::internal(
115                "deserialization not yet implemented".to_string(),
116            ))
117        })
118    }
119
120    fn query_optional<T: Model + Send + 'static>(
121        &self,
122        sql: &str,
123        params: Vec<FilterValue>,
124    ) -> BoxFuture<'_, QueryResult<Option<T>>> {
125        let sql = Self::convert_params(sql);
126        Box::pin(async move {
127            debug!(sql = %sql, "Executing query_optional");
128
129            let mut conn = self
130                .pool
131                .get()
132                .await
133                .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
134
135            let mssql_params = Self::to_params(&params)?;
136            let param_refs: Vec<&dyn tiberius::ToSql> =
137                mssql_params.iter().map(|p| p.as_ref()).collect();
138
139            let row = conn
140                .query_opt(&sql, &param_refs)
141                .await
142                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
143
144            match row {
145                Some(_row) => {
146                    // Placeholder - would deserialize row into T
147                    Err(prax_query::QueryError::internal(
148                        "deserialization not yet implemented".to_string(),
149                    ))
150                }
151                None => Ok(None),
152            }
153        })
154    }
155
156    fn execute_insert<T: Model + Send + 'static>(
157        &self,
158        sql: &str,
159        params: Vec<FilterValue>,
160    ) -> BoxFuture<'_, QueryResult<T>> {
161        let sql = Self::convert_params(sql);
162        Box::pin(async move {
163            debug!(sql = %sql, "Executing insert");
164
165            let mut conn = self
166                .pool
167                .get()
168                .await
169                .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
170
171            let mssql_params = Self::to_params(&params)?;
172            let param_refs: Vec<&dyn tiberius::ToSql> =
173                mssql_params.iter().map(|p| p.as_ref()).collect();
174
175            // For INSERT with RETURNING, MSSQL uses OUTPUT clause
176            let _row = conn
177                .query_one(&sql, &param_refs)
178                .await
179                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
180
181            // Placeholder - would deserialize row into T
182            Err(prax_query::QueryError::internal(
183                "deserialization not yet implemented".to_string(),
184            ))
185        })
186    }
187
188    fn execute_update<T: Model + Send + 'static>(
189        &self,
190        sql: &str,
191        params: Vec<FilterValue>,
192    ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
193        let sql = Self::convert_params(sql);
194        Box::pin(async move {
195            debug!(sql = %sql, "Executing update");
196
197            let mut conn = self
198                .pool
199                .get()
200                .await
201                .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
202
203            let mssql_params = Self::to_params(&params)?;
204            let param_refs: Vec<&dyn tiberius::ToSql> =
205                mssql_params.iter().map(|p| p.as_ref()).collect();
206
207            let _rows = conn
208                .query(&sql, &param_refs)
209                .await
210                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
211
212            // Placeholder - would deserialize rows into Vec<T>
213            Ok(Vec::new())
214        })
215    }
216
217    fn execute_delete(
218        &self,
219        sql: &str,
220        params: Vec<FilterValue>,
221    ) -> BoxFuture<'_, QueryResult<u64>> {
222        let sql = Self::convert_params(sql);
223        Box::pin(async move {
224            debug!(sql = %sql, "Executing delete");
225
226            let mut conn = self
227                .pool
228                .get()
229                .await
230                .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
231
232            let mssql_params = Self::to_params(&params)?;
233            let param_refs: Vec<&dyn tiberius::ToSql> =
234                mssql_params.iter().map(|p| p.as_ref()).collect();
235
236            let count = conn
237                .execute(&sql, &param_refs)
238                .await
239                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
240
241            Ok(count)
242        })
243    }
244
245    fn execute_raw(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
246        let sql = Self::convert_params(sql);
247        Box::pin(async move {
248            debug!(sql = %sql, "Executing raw SQL");
249
250            let mut conn = self
251                .pool
252                .get()
253                .await
254                .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
255
256            let mssql_params = Self::to_params(&params)?;
257            let param_refs: Vec<&dyn tiberius::ToSql> =
258                mssql_params.iter().map(|p| p.as_ref()).collect();
259
260            let count = conn
261                .execute(&sql, &param_refs)
262                .await
263                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
264
265            Ok(count)
266        })
267    }
268
269    fn count(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
270        let sql = Self::convert_params(sql);
271        Box::pin(async move {
272            debug!(sql = %sql, "Executing count");
273
274            let mut conn = self
275                .pool
276                .get()
277                .await
278                .map_err(|e| prax_query::QueryError::connection(e.to_string()))?;
279
280            let mssql_params = Self::to_params(&params)?;
281            let param_refs: Vec<&dyn tiberius::ToSql> =
282                mssql_params.iter().map(|p| p.as_ref()).collect();
283
284            let row = conn
285                .query_one(&sql, &param_refs)
286                .await
287                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
288
289            let count: i32 = row.get(0).unwrap_or(0);
290            Ok(count as u64)
291        })
292    }
293}
294
295/// A typed query builder that uses the MSSQL engine.
296pub struct MssqlQueryBuilder<T: Model> {
297    engine: MssqlEngine,
298    _marker: PhantomData<T>,
299}
300
301impl<T: Model> MssqlQueryBuilder<T> {
302    /// Create a new query builder.
303    pub fn new(engine: MssqlEngine) -> Self {
304        Self {
305            engine,
306            _marker: PhantomData,
307        }
308    }
309
310    /// Get the underlying engine.
311    pub fn engine(&self) -> &MssqlEngine {
312        &self.engine
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn test_convert_params() {
322        assert_eq!(
323            MssqlEngine::convert_params("SELECT * FROM users WHERE id = $1"),
324            "SELECT * FROM users WHERE id = @P1"
325        );
326
327        assert_eq!(
328            MssqlEngine::convert_params("SELECT * FROM users WHERE id = $1 AND name = $2"),
329            "SELECT * FROM users WHERE id = @P1 AND name = @P2"
330        );
331
332        assert_eq!(
333            MssqlEngine::convert_params("SELECT * FROM users"),
334            "SELECT * FROM users"
335        );
336    }
337}