1use crate::config::ConnectionConfig;
9use crate::error::ClientError;
10use crate::result::{Column, DataType, QueryResult, Row, Value};
11use reqwest::Client;
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::Instant;
15
16pub struct Connection {
22 id: u64,
23 config: ConnectionConfig,
24 http_client: Client,
25 base_url: String,
26 auth_token: std::sync::RwLock<Option<String>>,
27 connected: AtomicBool,
28 in_transaction: AtomicBool,
29 created_at: Instant,
30 last_used: std::sync::RwLock<Instant>,
31 queries_executed: AtomicU64,
32}
33
34impl Connection {
35 pub async fn new(config: ConnectionConfig) -> Result<Self, ClientError> {
37 static CONN_ID: AtomicU64 = AtomicU64::new(1);
38
39 let base_url = format!("http://{}:{}", config.host, config.port);
40
41 let http_client = Client::builder()
42 .timeout(std::time::Duration::from_secs(30))
43 .build()
44 .map_err(|e| ClientError::ConnectionFailed(e.to_string()))?;
45
46 let conn = Self {
47 id: CONN_ID.fetch_add(1, Ordering::SeqCst),
48 config,
49 http_client,
50 base_url,
51 auth_token: std::sync::RwLock::new(None),
52 connected: AtomicBool::new(false),
53 in_transaction: AtomicBool::new(false),
54 created_at: Instant::now(),
55 last_used: std::sync::RwLock::new(Instant::now()),
56 queries_executed: AtomicU64::new(0),
57 };
58
59 conn.connect().await?;
60 Ok(conn)
61 }
62
63 pub fn id(&self) -> u64 {
65 self.id
66 }
67
68 async fn connect(&self) -> Result<(), ClientError> {
70 let health_url = format!("{}/health", self.base_url);
72 let response = self
73 .http_client
74 .get(&health_url)
75 .send()
76 .await
77 .map_err(|e| ClientError::ConnectionFailed(format!("Failed to connect: {}", e)))?;
78
79 if !response.status().is_success() {
80 return Err(ClientError::ConnectionFailed(format!(
81 "Server returned status: {}",
82 response.status()
83 )));
84 }
85
86 if let (Some(ref username), Some(ref password)) =
88 (&self.config.username, &self.config.password)
89 {
90 let login_url = format!("{}/api/v1/auth/login", self.base_url);
91 let login_body = serde_json::json!({
92 "username": username,
93 "password": password
94 });
95
96 let response = self
97 .http_client
98 .post(&login_url)
99 .json(&login_body)
100 .send()
101 .await
102 .map_err(|e| ClientError::AuthenticationFailed(e.to_string()))?;
103
104 if response.status().is_success() {
105 let auth_response: serde_json::Value = response
106 .json()
107 .await
108 .map_err(|e| ClientError::AuthenticationFailed(e.to_string()))?;
109
110 if let Some(token) = auth_response.get("token").and_then(|t| t.as_str()) {
111 *self.auth_token.write().expect("auth_token RwLock poisoned") =
112 Some(token.to_string());
113 }
114 } else {
115 return Err(ClientError::AuthenticationFailed(
116 "Invalid credentials".to_string(),
117 ));
118 }
119 }
120
121 self.connected.store(true, Ordering::SeqCst);
122 Ok(())
123 }
124
125 pub fn is_connected(&self) -> bool {
127 self.connected.load(Ordering::SeqCst)
128 }
129
130 pub fn in_transaction(&self) -> bool {
132 self.in_transaction.load(Ordering::SeqCst)
133 }
134
135 pub fn age(&self) -> std::time::Duration {
137 self.created_at.elapsed()
138 }
139
140 pub fn idle_time(&self) -> std::time::Duration {
142 self.last_used
143 .read()
144 .expect("last_used RwLock poisoned")
145 .elapsed()
146 }
147
148 fn mark_used(&self) {
150 *self.last_used.write().expect("last_used RwLock poisoned") = Instant::now();
151 }
152
153 fn add_auth(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
155 if let Some(ref token) = *self.auth_token.read().expect("auth_token RwLock poisoned") {
156 request.header("Authorization", format!("Bearer {}", token))
157 } else {
158 request
159 }
160 }
161
162 pub async fn query(&self, sql: &str) -> Result<QueryResult, ClientError> {
164 self.query_with_params(sql, vec![]).await
165 }
166
167 pub async fn query_with_params(
169 &self,
170 sql: &str,
171 params: Vec<Value>,
172 ) -> Result<QueryResult, ClientError> {
173 if !self.is_connected() {
174 return Err(ClientError::NotConnected);
175 }
176
177 self.mark_used();
178 self.queries_executed.fetch_add(1, Ordering::SeqCst);
179
180 let url = format!("{}/api/v1/query", self.base_url);
181 let body = serde_json::json!({
182 "database": &self.config.database,
183 "sql": sql,
184 "params": params.iter().map(value_to_json).collect::<Vec<_>>()
185 });
186
187 let request = self.http_client.post(&url).json(&body);
188 let request = self.add_auth(request);
189
190 let response = request
191 .send()
192 .await
193 .map_err(|e| ClientError::QueryFailed(e.to_string()))?;
194
195 let status = response.status();
196 let response_body: serde_json::Value = response
197 .json()
198 .await
199 .map_err(|e| ClientError::QueryFailed(e.to_string()))?;
200
201 if !status.is_success() {
202 let error = response_body
203 .get("error")
204 .and_then(|e| e.as_str())
205 .unwrap_or("Unknown error");
206 return Err(ClientError::QueryFailed(error.to_string()));
207 }
208
209 let data = response_body.get("data");
211
212 let columns: Vec<Column> = data
213 .and_then(|d| d.get("columns"))
214 .and_then(|c| c.as_array())
215 .map(|cols| {
216 cols.iter()
217 .map(|c| {
218 Column::new(
219 c.as_str().unwrap_or(""),
220 DataType::Text, )
222 })
223 .collect()
224 })
225 .unwrap_or_default();
226
227 let column_names: Vec<String> = columns.iter().map(|c| c.name.clone()).collect();
228
229 let rows: Vec<Row> = data
230 .and_then(|d| d.get("rows"))
231 .and_then(|r| r.as_array())
232 .map(|rows| {
233 rows.iter()
234 .map(|row| {
235 let values: Vec<Value> = row
236 .as_array()
237 .map(|arr| arr.iter().map(json_to_value).collect())
238 .unwrap_or_default();
239 Row::new(column_names.clone(), values)
240 })
241 .collect()
242 })
243 .unwrap_or_default();
244
245 Ok(QueryResult::new(columns, rows))
246 }
247
248 pub async fn execute(&self, sql: &str) -> Result<u64, ClientError> {
250 self.execute_with_params(sql, vec![]).await
251 }
252
253 pub async fn execute_with_params(
255 &self,
256 sql: &str,
257 params: Vec<Value>,
258 ) -> Result<u64, ClientError> {
259 if !self.is_connected() {
260 return Err(ClientError::NotConnected);
261 }
262
263 self.mark_used();
264 self.queries_executed.fetch_add(1, Ordering::SeqCst);
265
266 let sql_upper = sql.trim().to_uppercase();
267
268 if sql_upper.starts_with("BEGIN") {
270 self.in_transaction.store(true, Ordering::SeqCst);
271 return Ok(0);
272 } else if sql_upper.starts_with("COMMIT") || sql_upper.starts_with("ROLLBACK") {
273 self.in_transaction.store(false, Ordering::SeqCst);
274 return Ok(0);
275 }
276
277 let url = format!("{}/api/v1/query", self.base_url);
278 let body = serde_json::json!({
279 "database": &self.config.database,
280 "sql": sql,
281 "params": params.iter().map(value_to_json).collect::<Vec<_>>()
282 });
283
284 let request = self.http_client.post(&url).json(&body);
285 let request = self.add_auth(request);
286
287 let response = request
288 .send()
289 .await
290 .map_err(|e| ClientError::QueryFailed(e.to_string()))?;
291
292 let status = response.status();
293 let response_body: serde_json::Value = response
294 .json()
295 .await
296 .map_err(|e| ClientError::QueryFailed(e.to_string()))?;
297
298 if !status.is_success() {
299 let error = response_body
300 .get("error")
301 .and_then(|e| e.as_str())
302 .unwrap_or("Unknown error");
303 return Err(ClientError::QueryFailed(error.to_string()));
304 }
305
306 let rows_affected = response_body
307 .get("data")
308 .and_then(|d| d.get("rows_affected"))
309 .and_then(|r| r.as_u64())
310 .unwrap_or(0);
311
312 Ok(rows_affected)
313 }
314
315 pub async fn begin_transaction(&self) -> Result<(), ClientError> {
317 if self.in_transaction() {
318 return Err(ClientError::TransactionAlreadyStarted);
319 }
320 self.execute("BEGIN").await?;
321 Ok(())
322 }
323
324 pub async fn commit(&self) -> Result<(), ClientError> {
326 if !self.in_transaction() {
327 return Err(ClientError::NoTransaction);
328 }
329 self.execute("COMMIT").await?;
330 Ok(())
331 }
332
333 pub async fn rollback(&self) -> Result<(), ClientError> {
335 if !self.in_transaction() {
336 return Err(ClientError::NoTransaction);
337 }
338 self.execute("ROLLBACK").await?;
339 Ok(())
340 }
341
342 pub async fn ping(&self) -> Result<(), ClientError> {
344 let health_url = format!("{}/health", self.base_url);
345 let response = self
346 .http_client
347 .get(&health_url)
348 .send()
349 .await
350 .map_err(|e| ClientError::ConnectionFailed(e.to_string()))?;
351
352 if response.status().is_success() {
353 self.mark_used();
354 Ok(())
355 } else {
356 self.connected.store(false, Ordering::SeqCst);
357 Err(ClientError::NotConnected)
358 }
359 }
360
361 pub async fn close(&self) {
363 let token = self
365 .auth_token
366 .read()
367 .expect("auth_token RwLock poisoned")
368 .clone();
369 if let Some(ref token) = token {
370 let logout_url = format!("{}/api/v1/auth/logout", self.base_url);
371 let body = serde_json::json!({ "token": token });
372 let _ = self.http_client.post(&logout_url).json(&body).send().await;
373 }
374 self.connected.store(false, Ordering::SeqCst);
375 }
376
377 pub fn stats(&self) -> ConnectionStats {
379 ConnectionStats {
380 id: self.id,
381 connected: self.is_connected(),
382 in_transaction: self.in_transaction(),
383 age_ms: self.age().as_millis() as u64,
384 idle_ms: self.idle_time().as_millis() as u64,
385 queries_executed: self.queries_executed.load(Ordering::SeqCst),
386 }
387 }
388
389 pub fn base_url(&self) -> &str {
391 &self.base_url
392 }
393}
394
395fn value_to_json(value: &Value) -> serde_json::Value {
400 match value {
401 Value::Null => serde_json::Value::Null,
402 Value::Bool(b) => serde_json::Value::Bool(*b),
403 Value::Int(i) => serde_json::Value::Number((*i).into()),
404 Value::Float(f) => serde_json::Number::from_f64(*f)
405 .map(serde_json::Value::Number)
406 .unwrap_or(serde_json::Value::Null),
407 Value::String(s) => serde_json::Value::String(s.clone()),
408 Value::Bytes(b) => serde_json::Value::String(base64_encode(b)),
409 Value::Timestamp(t) => serde_json::Value::Number((*t).into()),
410 Value::Array(arr) => serde_json::Value::Array(arr.iter().map(value_to_json).collect()),
411 Value::Object(obj) => {
412 let map: serde_json::Map<String, serde_json::Value> = obj
413 .iter()
414 .map(|(k, v)| (k.clone(), value_to_json(v)))
415 .collect();
416 serde_json::Value::Object(map)
417 }
418 }
419}
420
421fn json_to_value(json: &serde_json::Value) -> Value {
422 match json {
423 serde_json::Value::Null => Value::Null,
424 serde_json::Value::Bool(b) => Value::Bool(*b),
425 serde_json::Value::Number(n) => {
426 if let Some(i) = n.as_i64() {
427 Value::Int(i)
428 } else if let Some(f) = n.as_f64() {
429 Value::Float(f)
430 } else {
431 Value::Null
432 }
433 }
434 serde_json::Value::String(s) => Value::String(s.clone()),
435 serde_json::Value::Array(arr) => Value::Array(arr.iter().map(json_to_value).collect()),
436 serde_json::Value::Object(obj) => Value::Object(
437 obj.iter()
438 .map(|(k, v)| (k.clone(), json_to_value(v)))
439 .collect(),
440 ),
441 }
442}
443
444fn base64_encode(data: &[u8]) -> String {
445 const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
446 let mut result = String::new();
447
448 for chunk in data.chunks(3) {
449 let b0 = chunk[0] as usize;
450 let b1 = chunk.get(1).copied().unwrap_or(0) as usize;
451 let b2 = chunk.get(2).copied().unwrap_or(0) as usize;
452
453 result.push(CHARS[b0 >> 2] as char);
454 result.push(CHARS[((b0 & 0x03) << 4) | (b1 >> 4)] as char);
455
456 if chunk.len() > 1 {
457 result.push(CHARS[((b1 & 0x0f) << 2) | (b2 >> 6)] as char);
458 } else {
459 result.push('=');
460 }
461
462 if chunk.len() > 2 {
463 result.push(CHARS[b2 & 0x3f] as char);
464 } else {
465 result.push('=');
466 }
467 }
468
469 result
470}
471
472#[derive(Debug, Clone)]
478pub struct ConnectionStats {
479 pub id: u64,
480 pub connected: bool,
481 pub in_transaction: bool,
482 pub age_ms: u64,
483 pub idle_ms: u64,
484 pub queries_executed: u64,
485}
486
487pub struct PooledConnection {
495 connection: Arc<Connection>,
496 pool_return: std::sync::Mutex<Option<Box<dyn FnOnce(Arc<Connection>) + Send>>>,
497}
498
499impl PooledConnection {
500 pub fn new<F>(connection: Arc<Connection>, on_return: F) -> Self
502 where
503 F: FnOnce(Arc<Connection>) + Send + 'static,
504 {
505 Self {
506 connection,
507 pool_return: std::sync::Mutex::new(Some(Box::new(on_return))),
508 }
509 }
510
511 pub fn connection(&self) -> &Connection {
513 &self.connection
514 }
515
516 pub fn inner(&self) -> &Connection {
518 &self.connection
519 }
520
521 pub async fn query(&self, sql: &str) -> Result<QueryResult, ClientError> {
523 self.connection.query(sql).await
524 }
525
526 pub async fn execute(&self, sql: &str) -> Result<u64, ClientError> {
528 self.connection.execute(sql).await
529 }
530}
531
532impl std::ops::Deref for PooledConnection {
533 type Target = Connection;
534
535 fn deref(&self) -> &Self::Target {
536 &self.connection
537 }
538}
539
540impl Drop for PooledConnection {
541 fn drop(&mut self) {
542 if let Ok(mut guard) = self.pool_return.lock() {
543 if let Some(return_fn) = guard.take() {
544 return_fn(Arc::clone(&self.connection));
545 }
546 }
547 }
548}
549
550#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn test_connection_stats() {
560 let stats = ConnectionStats {
561 id: 1,
562 connected: true,
563 in_transaction: false,
564 age_ms: 1000,
565 idle_ms: 100,
566 queries_executed: 5,
567 };
568 assert_eq!(stats.id, 1);
569 assert!(stats.connected);
570 }
571
572 #[test]
573 fn test_json_to_value() {
574 let json = serde_json::json!({"name": "test", "count": 42});
575 let value = json_to_value(&json);
576 if let Value::Object(map) = value {
577 assert!(map.contains_key("name"));
578 assert!(map.contains_key("count"));
579 } else {
580 panic!("Expected Object");
581 }
582 }
583
584 #[test]
585 fn test_value_to_json() {
586 let value = Value::String("hello".to_string());
587 let json = value_to_json(&value);
588 assert_eq!(json, serde_json::Value::String("hello".to_string()));
589 }
590
591 #[tokio::test]
592 async fn test_connection_create() {
593 let config = ConnectionConfig {
595 host: "127.0.0.1".to_string(),
596 port: 7001,
597 ..Default::default()
598 };
599
600 match Connection::new(config).await {
601 Ok(conn) => {
602 assert!(conn.is_connected());
603 }
604 Err(_) => {
605 }
607 }
608 }
609}