1use crate::error::QailError;
7use crate::parser;
8use crate::transpiler::ToSql;
9
10use sqlx::postgres::{PgPool, PgPoolOptions, PgRow};
11use sqlx::{Column, Row, TypeInfo};
12use std::collections::HashMap;
13
14#[derive(Clone)]
16pub struct QailDB {
17 pool: PgPool,
18}
19
20impl QailDB {
21 pub async fn connect(url: &str) -> Result<Self, QailError> {
32 let pool = PgPoolOptions::new()
33 .max_connections(5)
34 .connect(url)
35 .await
36 .map_err(|e| QailError::Connection(e.to_string()))?;
37
38 Ok(Self { pool })
39 }
40
41 pub fn query(&self, qail: &str) -> QailQuery {
53 QailQuery::new(self.pool.clone(), qail.to_string())
54 }
55
56 pub fn raw(&self, sql: &str) -> QailQuery {
58 QailQuery::raw(self.pool.clone(), sql.to_string())
59 }
60
61 pub fn pool(&self) -> &PgPool {
63 &self.pool
64 }
65}
66
67pub struct QailQuery {
69 pool: PgPool,
70 qail: String,
71 sql: Option<String>,
72 bindings: Vec<QailValue>,
73 is_raw: bool,
74}
75
76#[derive(Debug, Clone)]
78pub enum QailValue {
79 Null,
80 Bool(bool),
81 Int(i64),
82 Float(f64),
83 String(String),
84}
85
86impl QailQuery {
87 fn new(pool: PgPool, qail: String) -> Self {
88 Self {
89 pool,
90 qail,
91 sql: None,
92 bindings: Vec::new(),
93 is_raw: false,
94 }
95 }
96
97 fn raw(pool: PgPool, sql: String) -> Self {
98 Self {
99 pool,
100 qail: String::new(),
101 sql: Some(sql),
102 bindings: Vec::new(),
103 is_raw: true,
104 }
105 }
106
107 pub fn bind_bool(mut self, value: bool) -> Self {
109 self.bindings.push(QailValue::Bool(value));
110 self
111 }
112
113 pub fn bind_int(mut self, value: i64) -> Self {
115 self.bindings.push(QailValue::Int(value));
116 self
117 }
118
119 pub fn bind_float(mut self, value: f64) -> Self {
121 self.bindings.push(QailValue::Float(value));
122 self
123 }
124
125 pub fn bind_str(mut self, value: &str) -> Self {
127 self.bindings.push(QailValue::String(value.to_string()));
128 self
129 }
130
131 pub fn bind<T: Into<QailValue>>(mut self, value: T) -> Self {
133 self.bindings.push(value.into());
134 self
135 }
136
137 pub fn sql(&self) -> Result<String, QailError> {
139 if self.is_raw {
140 return Ok(self.sql.clone().unwrap_or_default());
141 }
142 let cmd = parser::parse(&self.qail)?;
143 Ok(cmd.to_sql())
144 }
145
146 pub async fn fetch_all(&self) -> Result<Vec<HashMap<String, serde_json::Value>>, QailError> {
148 let sql = self.sql()?;
149 let mut query = sqlx::query(&sql);
150
151 for binding in &self.bindings {
153 query = match binding {
154 QailValue::Null => query,
155 QailValue::Bool(v) => query.bind(*v),
156 QailValue::Int(v) => query.bind(*v),
157 QailValue::Float(v) => query.bind(*v),
158 QailValue::String(v) => query.bind(v.as_str()),
159 };
160 }
161
162 let rows: Vec<PgRow> = query
163 .fetch_all(&self.pool)
164 .await
165 .map_err(|e| QailError::Execution(e.to_string()))?;
166
167 let results: Vec<HashMap<String, serde_json::Value>> = rows
169 .iter()
170 .map(|row| row_to_map(row))
171 .collect();
172
173 Ok(results)
174 }
175
176 pub async fn fetch_one(&self) -> Result<HashMap<String, serde_json::Value>, QailError> {
178 let sql = self.sql()?;
179 let mut query = sqlx::query(&sql);
180
181 for binding in &self.bindings {
182 query = match binding {
183 QailValue::Null => query,
184 QailValue::Bool(v) => query.bind(*v),
185 QailValue::Int(v) => query.bind(*v),
186 QailValue::Float(v) => query.bind(*v),
187 QailValue::String(v) => query.bind(v.as_str()),
188 };
189 }
190
191 let row: PgRow = query
192 .fetch_one(&self.pool)
193 .await
194 .map_err(|e| QailError::Execution(e.to_string()))?;
195
196 Ok(row_to_map(&row))
197 }
198
199 pub async fn execute(&self) -> Result<u64, QailError> {
202 let sql = self.sql()?;
203 let mut query = sqlx::query(&sql);
204
205 for binding in &self.bindings {
206 query = match binding {
207 QailValue::Null => query,
208 QailValue::Bool(v) => query.bind(*v),
209 QailValue::Int(v) => query.bind(*v),
210 QailValue::Float(v) => query.bind(*v),
211 QailValue::String(v) => query.bind(v.as_str()),
212 };
213 }
214
215 let result = query
216 .execute(&self.pool)
217 .await
218 .map_err(|e| QailError::Execution(e.to_string()))?;
219
220 Ok(result.rows_affected())
221 }
222}
223
224fn row_to_map(row: &PgRow) -> HashMap<String, serde_json::Value> {
226 use sqlx::ValueRef;
227
228 let mut map = HashMap::new();
229
230 for (i, column) in row.columns().iter().enumerate() {
231 let name = column.name().to_string();
232 let type_name = column.type_info().name();
233
234 let value_ref = row.try_get_raw(i);
236 if value_ref.is_err() || value_ref.as_ref().map(|v| v.is_null()).unwrap_or(true) {
237 map.insert(name, serde_json::Value::Null);
238 continue;
239 }
240
241 let value: serde_json::Value = match type_name {
242 "BOOL" => row
243 .try_get::<bool, _>(i)
244 .map(serde_json::Value::Bool)
245 .unwrap_or(serde_json::Value::Null),
246 "INT2" | "INT4" => row
247 .try_get::<i32, _>(i)
248 .map(|v| serde_json::Value::Number(v.into()))
249 .unwrap_or(serde_json::Value::Null),
250 "INT8" => row
251 .try_get::<i64, _>(i)
252 .map(|v| serde_json::Value::Number(v.into()))
253 .unwrap_or(serde_json::Value::Null),
254 "FLOAT4" => row
255 .try_get::<f32, _>(i)
256 .ok()
257 .and_then(|v| serde_json::Number::from_f64(v as f64))
258 .map(serde_json::Value::Number)
259 .unwrap_or(serde_json::Value::Null),
260 "FLOAT8" => row
261 .try_get::<f64, _>(i)
262 .ok()
263 .and_then(|v| serde_json::Number::from_f64(v))
264 .map(serde_json::Value::Number)
265 .unwrap_or(serde_json::Value::Null),
266 "UUID" => row
267 .try_get::<sqlx::types::Uuid, _>(i)
268 .map(|v| serde_json::Value::String(v.to_string()))
269 .unwrap_or(serde_json::Value::Null),
270 "TIMESTAMPTZ" | "TIMESTAMP" => row
271 .try_get::<chrono::DateTime<chrono::Utc>, _>(i)
272 .map(|v| serde_json::Value::String(v.to_rfc3339()))
273 .or_else(|_| {
274 row.try_get::<chrono::NaiveDateTime, _>(i)
275 .map(|v| serde_json::Value::String(v.to_string()))
276 })
277 .unwrap_or(serde_json::Value::Null),
278 "DATE" => row
279 .try_get::<chrono::NaiveDate, _>(i)
280 .map(|v| serde_json::Value::String(v.to_string()))
281 .unwrap_or(serde_json::Value::Null),
282 "TEXT" | "VARCHAR" | "CHAR" | "NAME" => row
283 .try_get::<String, _>(i)
284 .map(serde_json::Value::String)
285 .unwrap_or(serde_json::Value::Null),
286 "JSONB" | "JSON" => row
287 .try_get::<serde_json::Value, _>(i)
288 .unwrap_or(serde_json::Value::Null),
289 _ => {
290 row.try_get::<String, _>(i)
292 .map(serde_json::Value::String)
293 .unwrap_or_else(|_| serde_json::Value::String(format!("<{}>", type_name)))
294 }
295 };
296
297 map.insert(name, value);
298 }
299
300 map
301}
302
303impl From<bool> for QailValue {
305 fn from(v: bool) -> Self {
306 QailValue::Bool(v)
307 }
308}
309
310impl From<i32> for QailValue {
311 fn from(v: i32) -> Self {
312 QailValue::Int(v as i64)
313 }
314}
315
316impl From<i64> for QailValue {
317 fn from(v: i64) -> Self {
318 QailValue::Int(v)
319 }
320}
321
322impl From<f64> for QailValue {
323 fn from(v: f64) -> Self {
324 QailValue::Float(v)
325 }
326}
327
328impl From<&str> for QailValue {
329 fn from(v: &str) -> Self {
330 QailValue::String(v.to_string())
331 }
332}
333
334impl From<String> for QailValue {
335 fn from(v: String) -> Self {
336 QailValue::String(v)
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_qail_value_from() {
346 let _b: QailValue = true.into();
347 let _i: QailValue = 42i32.into();
348 let _f: QailValue = 3.14f64.into();
349 let _s: QailValue = "hello".into();
350 }
351}