Skip to main content

aegis_client/
connection.rs

1//! Aegis Client Connection
2//!
3//! Real HTTP-based database connection to Aegis server.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use crate::config::ConnectionConfig;
9use crate::error::ClientError;
10use crate::result::{Column, DataType, QueryResult, Row, Value};
11use reqwest::Client;
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::Instant;
15
16// =============================================================================
17// Connection
18// =============================================================================
19
20/// A real database connection to an Aegis server.
21pub struct Connection {
22    id: u64,
23    config: ConnectionConfig,
24    http_client: Client,
25    base_url: String,
26    auth_token: std::sync::RwLock<Option<String>>,
27    connected: AtomicBool,
28    in_transaction: AtomicBool,
29    created_at: Instant,
30    last_used: std::sync::RwLock<Instant>,
31    queries_executed: AtomicU64,
32}
33
34impl Connection {
35    /// Create a new connection.
36    pub async fn new(config: ConnectionConfig) -> Result<Self, ClientError> {
37        static CONN_ID: AtomicU64 = AtomicU64::new(1);
38
39        let base_url = format!("http://{}:{}", config.host, config.port);
40
41        let http_client = Client::builder()
42            .timeout(std::time::Duration::from_secs(30))
43            .build()
44            .map_err(|e| ClientError::ConnectionFailed(e.to_string()))?;
45
46        let conn = Self {
47            id: CONN_ID.fetch_add(1, Ordering::SeqCst),
48            config,
49            http_client,
50            base_url,
51            auth_token: std::sync::RwLock::new(None),
52            connected: AtomicBool::new(false),
53            in_transaction: AtomicBool::new(false),
54            created_at: Instant::now(),
55            last_used: std::sync::RwLock::new(Instant::now()),
56            queries_executed: AtomicU64::new(0),
57        };
58
59        conn.connect().await?;
60        Ok(conn)
61    }
62
63    /// Get the connection ID.
64    pub fn id(&self) -> u64 {
65        self.id
66    }
67
68    /// Connect to the database server.
69    async fn connect(&self) -> Result<(), ClientError> {
70        // Check server health
71        let health_url = format!("{}/health", self.base_url);
72        let response = self
73            .http_client
74            .get(&health_url)
75            .send()
76            .await
77            .map_err(|e| ClientError::ConnectionFailed(format!("Failed to connect: {}", e)))?;
78
79        if !response.status().is_success() {
80            return Err(ClientError::ConnectionFailed(format!(
81                "Server returned status: {}",
82                response.status()
83            )));
84        }
85
86        // Authenticate if credentials provided
87        if let (Some(ref username), Some(ref password)) =
88            (&self.config.username, &self.config.password)
89        {
90            let login_url = format!("{}/api/v1/auth/login", self.base_url);
91            let login_body = serde_json::json!({
92                "username": username,
93                "password": password
94            });
95
96            let response = self
97                .http_client
98                .post(&login_url)
99                .json(&login_body)
100                .send()
101                .await
102                .map_err(|e| ClientError::AuthenticationFailed(e.to_string()))?;
103
104            if response.status().is_success() {
105                let auth_response: serde_json::Value = response
106                    .json()
107                    .await
108                    .map_err(|e| ClientError::AuthenticationFailed(e.to_string()))?;
109
110                if let Some(token) = auth_response.get("token").and_then(|t| t.as_str()) {
111                    *self.auth_token.write().expect("auth_token RwLock poisoned") =
112                        Some(token.to_string());
113                }
114            } else {
115                return Err(ClientError::AuthenticationFailed(
116                    "Invalid credentials".to_string(),
117                ));
118            }
119        }
120
121        self.connected.store(true, Ordering::SeqCst);
122        Ok(())
123    }
124
125    /// Check if connected.
126    pub fn is_connected(&self) -> bool {
127        self.connected.load(Ordering::SeqCst)
128    }
129
130    /// Check if in a transaction.
131    pub fn in_transaction(&self) -> bool {
132        self.in_transaction.load(Ordering::SeqCst)
133    }
134
135    /// Get connection age.
136    pub fn age(&self) -> std::time::Duration {
137        self.created_at.elapsed()
138    }
139
140    /// Get idle time.
141    pub fn idle_time(&self) -> std::time::Duration {
142        self.last_used
143            .read()
144            .expect("last_used RwLock poisoned")
145            .elapsed()
146    }
147
148    /// Mark as used.
149    fn mark_used(&self) {
150        *self.last_used.write().expect("last_used RwLock poisoned") = Instant::now();
151    }
152
153    /// Add auth header to request if we have a token.
154    fn add_auth(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
155        if let Some(ref token) = *self.auth_token.read().expect("auth_token RwLock poisoned") {
156            request.header("Authorization", format!("Bearer {}", token))
157        } else {
158            request
159        }
160    }
161
162    /// Execute a query.
163    pub async fn query(&self, sql: &str) -> Result<QueryResult, ClientError> {
164        self.query_with_params(sql, vec![]).await
165    }
166
167    /// Execute a query with parameters.
168    pub async fn query_with_params(
169        &self,
170        sql: &str,
171        params: Vec<Value>,
172    ) -> Result<QueryResult, ClientError> {
173        if !self.is_connected() {
174            return Err(ClientError::NotConnected);
175        }
176
177        self.mark_used();
178        self.queries_executed.fetch_add(1, Ordering::SeqCst);
179
180        let url = format!("{}/api/v1/query", self.base_url);
181        let body = serde_json::json!({
182            "database": &self.config.database,
183            "sql": sql,
184            "params": params.iter().map(value_to_json).collect::<Vec<_>>()
185        });
186
187        let request = self.http_client.post(&url).json(&body);
188        let request = self.add_auth(request);
189
190        let response = request
191            .send()
192            .await
193            .map_err(|e| ClientError::QueryFailed(e.to_string()))?;
194
195        let status = response.status();
196        let response_body: serde_json::Value = response
197            .json()
198            .await
199            .map_err(|e| ClientError::QueryFailed(e.to_string()))?;
200
201        if !status.is_success() {
202            let error = response_body
203                .get("error")
204                .and_then(|e| e.as_str())
205                .unwrap_or("Unknown error");
206            return Err(ClientError::QueryFailed(error.to_string()));
207        }
208
209        // Parse the response into QueryResult
210        let data = response_body.get("data");
211
212        let columns: Vec<Column> = data
213            .and_then(|d| d.get("columns"))
214            .and_then(|c| c.as_array())
215            .map(|cols| {
216                cols.iter()
217                    .map(|c| {
218                        Column::new(
219                            c.as_str().unwrap_or(""),
220                            DataType::Text, // Default to text, server doesn't send types
221                        )
222                    })
223                    .collect()
224            })
225            .unwrap_or_default();
226
227        let column_names: Vec<String> = columns.iter().map(|c| c.name.clone()).collect();
228
229        let rows: Vec<Row> = data
230            .and_then(|d| d.get("rows"))
231            .and_then(|r| r.as_array())
232            .map(|rows| {
233                rows.iter()
234                    .map(|row| {
235                        let values: Vec<Value> = row
236                            .as_array()
237                            .map(|arr| arr.iter().map(json_to_value).collect())
238                            .unwrap_or_default();
239                        Row::new(column_names.clone(), values)
240                    })
241                    .collect()
242            })
243            .unwrap_or_default();
244
245        Ok(QueryResult::new(columns, rows))
246    }
247
248    /// Execute a statement (INSERT, UPDATE, DELETE).
249    pub async fn execute(&self, sql: &str) -> Result<u64, ClientError> {
250        self.execute_with_params(sql, vec![]).await
251    }
252
253    /// Execute a statement with parameters.
254    pub async fn execute_with_params(
255        &self,
256        sql: &str,
257        params: Vec<Value>,
258    ) -> Result<u64, ClientError> {
259        if !self.is_connected() {
260            return Err(ClientError::NotConnected);
261        }
262
263        self.mark_used();
264        self.queries_executed.fetch_add(1, Ordering::SeqCst);
265
266        let sql_upper = sql.trim().to_uppercase();
267
268        // Handle transaction commands locally
269        if sql_upper.starts_with("BEGIN") {
270            self.in_transaction.store(true, Ordering::SeqCst);
271            return Ok(0);
272        } else if sql_upper.starts_with("COMMIT") || sql_upper.starts_with("ROLLBACK") {
273            self.in_transaction.store(false, Ordering::SeqCst);
274            return Ok(0);
275        }
276
277        let url = format!("{}/api/v1/query", self.base_url);
278        let body = serde_json::json!({
279            "database": &self.config.database,
280            "sql": sql,
281            "params": params.iter().map(value_to_json).collect::<Vec<_>>()
282        });
283
284        let request = self.http_client.post(&url).json(&body);
285        let request = self.add_auth(request);
286
287        let response = request
288            .send()
289            .await
290            .map_err(|e| ClientError::QueryFailed(e.to_string()))?;
291
292        let status = response.status();
293        let response_body: serde_json::Value = response
294            .json()
295            .await
296            .map_err(|e| ClientError::QueryFailed(e.to_string()))?;
297
298        if !status.is_success() {
299            let error = response_body
300                .get("error")
301                .and_then(|e| e.as_str())
302                .unwrap_or("Unknown error");
303            return Err(ClientError::QueryFailed(error.to_string()));
304        }
305
306        let rows_affected = response_body
307            .get("data")
308            .and_then(|d| d.get("rows_affected"))
309            .and_then(|r| r.as_u64())
310            .unwrap_or(0);
311
312        Ok(rows_affected)
313    }
314
315    /// Begin a transaction.
316    pub async fn begin_transaction(&self) -> Result<(), ClientError> {
317        if self.in_transaction() {
318            return Err(ClientError::TransactionAlreadyStarted);
319        }
320        self.execute("BEGIN").await?;
321        Ok(())
322    }
323
324    /// Commit a transaction.
325    pub async fn commit(&self) -> Result<(), ClientError> {
326        if !self.in_transaction() {
327            return Err(ClientError::NoTransaction);
328        }
329        self.execute("COMMIT").await?;
330        Ok(())
331    }
332
333    /// Rollback a transaction.
334    pub async fn rollback(&self) -> Result<(), ClientError> {
335        if !self.in_transaction() {
336            return Err(ClientError::NoTransaction);
337        }
338        self.execute("ROLLBACK").await?;
339        Ok(())
340    }
341
342    /// Ping the connection.
343    pub async fn ping(&self) -> Result<(), ClientError> {
344        let health_url = format!("{}/health", self.base_url);
345        let response = self
346            .http_client
347            .get(&health_url)
348            .send()
349            .await
350            .map_err(|e| ClientError::ConnectionFailed(e.to_string()))?;
351
352        if response.status().is_success() {
353            self.mark_used();
354            Ok(())
355        } else {
356            self.connected.store(false, Ordering::SeqCst);
357            Err(ClientError::NotConnected)
358        }
359    }
360
361    /// Close the connection.
362    pub async fn close(&self) {
363        // Clone token before await to avoid holding lock across await point
364        let token = self
365            .auth_token
366            .read()
367            .expect("auth_token RwLock poisoned")
368            .clone();
369        if let Some(ref token) = token {
370            let logout_url = format!("{}/api/v1/auth/logout", self.base_url);
371            let body = serde_json::json!({ "token": token });
372            let _ = self.http_client.post(&logout_url).json(&body).send().await;
373        }
374        self.connected.store(false, Ordering::SeqCst);
375    }
376
377    /// Get connection statistics.
378    pub fn stats(&self) -> ConnectionStats {
379        ConnectionStats {
380            id: self.id,
381            connected: self.is_connected(),
382            in_transaction: self.in_transaction(),
383            age_ms: self.age().as_millis() as u64,
384            idle_ms: self.idle_time().as_millis() as u64,
385            queries_executed: self.queries_executed.load(Ordering::SeqCst),
386        }
387    }
388
389    /// Get the base URL of the server.
390    pub fn base_url(&self) -> &str {
391        &self.base_url
392    }
393}
394
395// =============================================================================
396// Value Conversion
397// =============================================================================
398
399fn value_to_json(value: &Value) -> serde_json::Value {
400    match value {
401        Value::Null => serde_json::Value::Null,
402        Value::Bool(b) => serde_json::Value::Bool(*b),
403        Value::Int(i) => serde_json::Value::Number((*i).into()),
404        Value::Float(f) => serde_json::Number::from_f64(*f)
405            .map(serde_json::Value::Number)
406            .unwrap_or(serde_json::Value::Null),
407        Value::String(s) => serde_json::Value::String(s.clone()),
408        Value::Bytes(b) => serde_json::Value::String(base64_encode(b)),
409        Value::Timestamp(t) => serde_json::Value::Number((*t).into()),
410        Value::Array(arr) => serde_json::Value::Array(arr.iter().map(value_to_json).collect()),
411        Value::Object(obj) => {
412            let map: serde_json::Map<String, serde_json::Value> = obj
413                .iter()
414                .map(|(k, v)| (k.clone(), value_to_json(v)))
415                .collect();
416            serde_json::Value::Object(map)
417        }
418    }
419}
420
421fn json_to_value(json: &serde_json::Value) -> Value {
422    match json {
423        serde_json::Value::Null => Value::Null,
424        serde_json::Value::Bool(b) => Value::Bool(*b),
425        serde_json::Value::Number(n) => {
426            if let Some(i) = n.as_i64() {
427                Value::Int(i)
428            } else if let Some(f) = n.as_f64() {
429                Value::Float(f)
430            } else {
431                Value::Null
432            }
433        }
434        serde_json::Value::String(s) => Value::String(s.clone()),
435        serde_json::Value::Array(arr) => Value::Array(arr.iter().map(json_to_value).collect()),
436        serde_json::Value::Object(obj) => Value::Object(
437            obj.iter()
438                .map(|(k, v)| (k.clone(), json_to_value(v)))
439                .collect(),
440        ),
441    }
442}
443
444fn base64_encode(data: &[u8]) -> String {
445    const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
446    let mut result = String::new();
447
448    for chunk in data.chunks(3) {
449        let b0 = chunk[0] as usize;
450        let b1 = chunk.get(1).copied().unwrap_or(0) as usize;
451        let b2 = chunk.get(2).copied().unwrap_or(0) as usize;
452
453        result.push(CHARS[b0 >> 2] as char);
454        result.push(CHARS[((b0 & 0x03) << 4) | (b1 >> 4)] as char);
455
456        if chunk.len() > 1 {
457            result.push(CHARS[((b1 & 0x0f) << 2) | (b2 >> 6)] as char);
458        } else {
459            result.push('=');
460        }
461
462        if chunk.len() > 2 {
463            result.push(CHARS[b2 & 0x3f] as char);
464        } else {
465            result.push('=');
466        }
467    }
468
469    result
470}
471
472// =============================================================================
473// Connection Statistics
474// =============================================================================
475
476/// Statistics for a connection.
477#[derive(Debug, Clone)]
478pub struct ConnectionStats {
479    pub id: u64,
480    pub connected: bool,
481    pub in_transaction: bool,
482    pub age_ms: u64,
483    pub idle_ms: u64,
484    pub queries_executed: u64,
485}
486
487// =============================================================================
488// Pooled Connection
489// =============================================================================
490
491/// A connection managed by a pool.
492///
493/// This struct is thread-safe (`Sync`) because the return callback is protected by a Mutex.
494pub struct PooledConnection {
495    connection: Arc<Connection>,
496    pool_return: std::sync::Mutex<Option<Box<dyn FnOnce(Arc<Connection>) + Send>>>,
497}
498
499impl PooledConnection {
500    /// Create a new pooled connection.
501    pub fn new<F>(connection: Arc<Connection>, on_return: F) -> Self
502    where
503        F: FnOnce(Arc<Connection>) + Send + 'static,
504    {
505        Self {
506            connection,
507            pool_return: std::sync::Mutex::new(Some(Box::new(on_return))),
508        }
509    }
510
511    /// Get a reference to the underlying connection.
512    pub fn connection(&self) -> &Connection {
513        &self.connection
514    }
515
516    /// Get the underlying connection (alias for connection()).
517    pub fn inner(&self) -> &Connection {
518        &self.connection
519    }
520
521    /// Execute a query.
522    pub async fn query(&self, sql: &str) -> Result<QueryResult, ClientError> {
523        self.connection.query(sql).await
524    }
525
526    /// Execute a statement.
527    pub async fn execute(&self, sql: &str) -> Result<u64, ClientError> {
528        self.connection.execute(sql).await
529    }
530}
531
532impl std::ops::Deref for PooledConnection {
533    type Target = Connection;
534
535    fn deref(&self) -> &Self::Target {
536        &self.connection
537    }
538}
539
540impl Drop for PooledConnection {
541    fn drop(&mut self) {
542        if let Ok(mut guard) = self.pool_return.lock() {
543            if let Some(return_fn) = guard.take() {
544                return_fn(Arc::clone(&self.connection));
545            }
546        }
547    }
548}
549
550// =============================================================================
551// Tests
552// =============================================================================
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557
558    #[test]
559    fn test_connection_stats() {
560        let stats = ConnectionStats {
561            id: 1,
562            connected: true,
563            in_transaction: false,
564            age_ms: 1000,
565            idle_ms: 100,
566            queries_executed: 5,
567        };
568        assert_eq!(stats.id, 1);
569        assert!(stats.connected);
570    }
571
572    #[test]
573    fn test_json_to_value() {
574        let json = serde_json::json!({"name": "test", "count": 42});
575        let value = json_to_value(&json);
576        if let Value::Object(map) = value {
577            assert!(map.contains_key("name"));
578            assert!(map.contains_key("count"));
579        } else {
580            panic!("Expected Object");
581        }
582    }
583
584    #[test]
585    fn test_value_to_json() {
586        let value = Value::String("hello".to_string());
587        let json = value_to_json(&value);
588        assert_eq!(json, serde_json::Value::String("hello".to_string()));
589    }
590
591    #[tokio::test]
592    async fn test_connection_create() {
593        // This test requires a running server, skip if not available
594        let config = ConnectionConfig {
595            host: "127.0.0.1".to_string(),
596            port: 7001,
597            ..Default::default()
598        };
599
600        match Connection::new(config).await {
601            Ok(conn) => {
602                assert!(conn.is_connected());
603            }
604            Err(_) => {
605                // Server not running, skip test
606            }
607        }
608    }
609}