use crate::config::ConnectionConfig;
use crate::error::ClientError;
use crate::result::{QueryResult, Row, Value};
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug, Serialize)]
struct QueryRequest {
sql: String,
#[serde(default)]
params: Vec<serde_json::Value>,
}
#[derive(Debug, Deserialize)]
struct QueryResponse {
success: bool,
#[serde(default)]
data: Option<QueryResultData>,
error: Option<String>,
execution_time_ms: u64,
}
#[derive(Debug, Deserialize)]
struct QueryResultData {
columns: Vec<String>,
rows: Vec<Vec<serde_json::Value>>,
rows_affected: u64,
}
pub struct Connection {
id: u64,
config: ConnectionConfig,
connected: AtomicBool,
in_transaction: AtomicBool,
created_at: Instant,
last_used: std::sync::RwLock<Instant>,
queries_executed: AtomicU64,
http_client: reqwest::Client,
base_url: String,
}
impl Connection {
pub async fn new(config: ConnectionConfig) -> Result<Self, ClientError> {
static CONN_ID: AtomicU64 = AtomicU64::new(1);
let base_url = format!("http://{}:{}", config.host, config.port);
let http_client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| ClientError::ConnectionFailed(e.to_string()))?;
let conn = Self {
id: CONN_ID.fetch_add(1, Ordering::SeqCst),
config,
connected: AtomicBool::new(false),
in_transaction: AtomicBool::new(false),
created_at: Instant::now(),
last_used: std::sync::RwLock::new(Instant::now()),
queries_executed: AtomicU64::new(0),
http_client,
base_url,
};
conn.connect().await?;
Ok(conn)
}
pub fn id(&self) -> u64 {
self.id
}
async fn connect(&self) -> Result<(), ClientError> {
let health_url = format!("{}/health", self.base_url);
let response = self.http_client
.get(&health_url)
.send()
.await
.map_err(|e| ClientError::ConnectionFailed(format!("Failed to connect to {}: {}", self.base_url, e)))?;
if !response.status().is_success() {
return Err(ClientError::ConnectionFailed(
format!("Server returned status: {}", response.status())
));
}
self.connected.store(true, Ordering::SeqCst);
Ok(())
}
pub fn is_connected(&self) -> bool {
self.connected.load(Ordering::SeqCst)
}
pub fn in_transaction(&self) -> bool {
self.in_transaction.load(Ordering::SeqCst)
}
pub fn age(&self) -> std::time::Duration {
self.created_at.elapsed()
}
pub fn idle_time(&self) -> std::time::Duration {
self.last_used.read().unwrap().elapsed()
}
fn mark_used(&self) {
*self.last_used.write().unwrap() = Instant::now();
}
pub async fn query(&self, sql: &str) -> Result<QueryResult, ClientError> {
self.query_with_params(sql, vec![]).await
}
pub async fn query_with_params(
&self,
sql: &str,
params: Vec<Value>,
) -> Result<QueryResult, ClientError> {
if !self.is_connected() {
return Err(ClientError::NotConnected);
}
self.mark_used();
self.queries_executed.fetch_add(1, Ordering::SeqCst);
let query_url = format!("{}/api/v1/query", self.base_url);
let json_params: Vec<serde_json::Value> = params.into_iter()
.map(|v| value_to_json(&v))
.collect();
let request = QueryRequest {
sql: sql.to_string(),
params: json_params,
};
let response = self.http_client
.post(&query_url)
.json(&request)
.send()
.await
.map_err(|e| ClientError::QueryFailed(format!("HTTP request failed: {}", e)))?;
let status = response.status();
let query_response: QueryResponse = response
.json()
.await
.map_err(|e| ClientError::QueryFailed(format!("Failed to parse response: {}", e)))?;
if !query_response.success {
return Err(ClientError::QueryFailed(
query_response.error.unwrap_or_else(|| format!("Query failed with status {}", status))
));
}
let data = query_response.data
.ok_or_else(|| ClientError::QueryFailed("No data in successful response".to_string()))?;
let mut rows = Vec::new();
for row_data in data.rows {
let values: Vec<Value> = row_data.into_iter()
.map(|v| json_to_value(v))
.collect();
rows.push(Row::new(data.columns.clone(), values));
}
Ok(QueryResult::with_rows_affected(rows, data.columns, data.rows_affected as usize))
}
pub async fn execute(&self, sql: &str) -> Result<u64, ClientError> {
self.execute_with_params(sql, vec![]).await
}
pub async fn execute_with_params(
&self,
sql: &str,
params: Vec<Value>,
) -> Result<u64, ClientError> {
if !self.is_connected() {
return Err(ClientError::NotConnected);
}
self.mark_used();
self.queries_executed.fetch_add(1, Ordering::SeqCst);
let sql_upper = sql.trim().to_uppercase();
if sql_upper.starts_with("BEGIN") {
self.in_transaction.store(true, Ordering::SeqCst);
return Ok(0);
} else if sql_upper.starts_with("COMMIT") || sql_upper.starts_with("ROLLBACK") {
self.in_transaction.store(false, Ordering::SeqCst);
return Ok(0);
}
let query_url = format!("{}/api/v1/query", self.base_url);
let json_params: Vec<serde_json::Value> = params.into_iter()
.map(|v| value_to_json(&v))
.collect();
let request = QueryRequest {
sql: sql.to_string(),
params: json_params,
};
let response = self.http_client
.post(&query_url)
.json(&request)
.send()
.await
.map_err(|e| ClientError::QueryFailed(format!("HTTP request failed: {}", e)))?;
let query_response: QueryResponse = response
.json()
.await
.map_err(|e| ClientError::QueryFailed(format!("Failed to parse response: {}", e)))?;
if !query_response.success {
return Err(ClientError::QueryFailed(
query_response.error.unwrap_or_else(|| "Query failed".to_string())
));
}
Ok(query_response.data
.map(|d| d.rows_affected)
.unwrap_or(0))
}
pub async fn begin_transaction(&self) -> Result<(), ClientError> {
if self.in_transaction() {
return Err(ClientError::TransactionAlreadyStarted);
}
self.execute("BEGIN").await?;
Ok(())
}
pub async fn commit(&self) -> Result<(), ClientError> {
if !self.in_transaction() {
return Err(ClientError::NoTransaction);
}
self.execute("COMMIT").await?;
Ok(())
}
pub async fn rollback(&self) -> Result<(), ClientError> {
if !self.in_transaction() {
return Err(ClientError::NoTransaction);
}
self.execute("ROLLBACK").await?;
Ok(())
}
pub async fn ping(&self) -> Result<(), ClientError> {
if !self.is_connected() {
return Err(ClientError::NotConnected);
}
let health_url = format!("{}/health", self.base_url);
self.http_client
.get(&health_url)
.send()
.await
.map_err(|e| ClientError::ConnectionFailed(format!("Ping failed: {}", e)))?;
self.mark_used();
Ok(())
}
pub async fn close(&self) {
self.connected.store(false, Ordering::SeqCst);
}
pub fn stats(&self) -> ConnectionStats {
ConnectionStats {
id: self.id,
connected: self.is_connected(),
in_transaction: self.in_transaction(),
age_ms: self.age().as_millis() as u64,
idle_ms: self.idle_time().as_millis() as u64,
queries_executed: self.queries_executed.load(Ordering::SeqCst),
}
}
}
fn value_to_json(value: &Value) -> serde_json::Value {
match value {
Value::Null => serde_json::Value::Null,
Value::Bool(b) => serde_json::Value::Bool(*b),
Value::Int(i) => serde_json::json!(i),
Value::Float(f) => serde_json::json!(f),
Value::String(s) => serde_json::Value::String(s.clone()),
Value::Bytes(b) => serde_json::json!(b),
Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(value_to_json).collect())
}
Value::Object(obj) => {
let map: serde_json::Map<String, serde_json::Value> = obj.iter()
.map(|(k, v)| (k.clone(), value_to_json(v)))
.collect();
serde_json::Value::Object(map)
}
Value::Timestamp(ts) => serde_json::json!(ts),
}
}
fn json_to_value(json: serde_json::Value) -> Value {
match json {
serde_json::Value::Null => Value::Null,
serde_json::Value::Bool(b) => Value::Bool(b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Value::Int(i)
} else if let Some(f) = n.as_f64() {
Value::Float(f)
} else {
Value::Null
}
}
serde_json::Value::String(s) => Value::String(s),
serde_json::Value::Array(arr) => {
Value::Array(arr.into_iter().map(json_to_value).collect())
}
serde_json::Value::Object(obj) => {
let map: std::collections::HashMap<String, Value> = obj.into_iter()
.map(|(k, v)| (k, json_to_value(v)))
.collect();
Value::Object(map)
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionStats {
pub id: u64,
pub connected: bool,
pub in_transaction: bool,
pub age_ms: u64,
pub idle_ms: u64,
pub queries_executed: u64,
}
pub struct PooledConnection {
connection: Arc<Connection>,
pool_return: Option<Box<dyn FnOnce(Arc<Connection>) + Send>>,
}
impl PooledConnection {
pub fn new<F>(connection: Arc<Connection>, on_return: F) -> Self
where
F: FnOnce(Arc<Connection>) + Send + 'static,
{
Self {
connection,
pool_return: Some(Box::new(on_return)),
}
}
pub async fn query(&self, sql: &str) -> Result<QueryResult, ClientError> {
self.connection.query(sql).await
}
pub async fn query_with_params(
&self,
sql: &str,
params: Vec<Value>,
) -> Result<QueryResult, ClientError> {
self.connection.query_with_params(sql, params).await
}
pub async fn execute(&self, sql: &str) -> Result<u64, ClientError> {
self.connection.execute(sql).await
}
pub async fn execute_with_params(
&self,
sql: &str,
params: Vec<Value>,
) -> Result<u64, ClientError> {
self.connection.execute_with_params(sql, params).await
}
pub fn inner(&self) -> &Connection {
&self.connection
}
}
impl Drop for PooledConnection {
fn drop(&mut self) {
if let Some(return_fn) = self.pool_return.take() {
return_fn(Arc::clone(&self.connection));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_value_conversion() {
let val = Value::Int(42);
let json = value_to_json(&val);
assert_eq!(json, serde_json::json!(42));
let back = json_to_value(json);
assert!(matches!(back, Value::Int(42)));
}
}