use crate::{recordset::ColumnInfo, DatabaseError, DatabaseResult, Recordset};
use odbc_api::{
buffers::{BufferDesc, ColumnarAnyBuffer, TextRowSet},
Connection, ConnectionOptions, Cursor, Environment, IntoParameter,
};
use serde_json::Value as JsonValue;
use std::sync::Arc;
#[derive(Debug)]
pub struct ODBCConnection {
connection: Option<Connection<'static>>,
environment: Option<Arc<Environment>>,
}
impl ODBCConnection {
pub async fn connect(connection_string: &str) -> DatabaseResult<Self> {
let environment = Environment::new().map_err(|e| {
DatabaseError::connection_error(format!("Failed to create ODBC environment: {}", e))
})?;
let environment = Arc::new(environment);
let env_clone = environment.clone();
let connection = tokio::task::spawn_blocking(move || {
env_clone
.connect_with_connection_string(connection_string, ConnectionOptions::default())
})
.await
.map_err(|e| DatabaseError::connection_error(format!("Task join error: {}", e)))?
.map_err(|e| DatabaseError::connection_error(format!("ODBC connection failed: {}", e)))?;
let connection_static =
unsafe { std::mem::transmute::<Connection<'_>, Connection<'static>>(connection) };
Ok(Self {
connection: Some(connection_static),
environment: Some(environment),
})
}
pub async fn execute(&self, query: &str) -> DatabaseResult<Recordset> {
let conn = self
.connection
.as_ref()
.ok_or(DatabaseError::NotConnected)?;
let query_owned = query.to_string();
let result =
tokio::task::spawn_blocking(move || Self::execute_query_sync(conn, &query_owned))
.await
.map_err(|e| DatabaseError::query_error(format!("Task join error: {}", e)))??;
Ok(result)
}
fn execute_query_sync(conn: &Connection<'static>, query: &str) -> DatabaseResult<Recordset> {
let cursor = conn
.execute(query, ())
.map_err(|e| DatabaseError::query_error(format!("Query execution failed: {}", e)))?
.ok_or_else(|| DatabaseError::query_error("No cursor returned from query"))?;
let num_cols = cursor.num_result_cols().map_err(|e| {
DatabaseError::query_error(format!("Failed to get column count: {}", e))
})?;
let mut columns = Vec::new();
for i in 1..=num_cols {
let mut name_buffer = vec![0u8; 256];
let col_info = cursor
.describe_col(i as u16, &mut name_buffer)
.map_err(|e| {
DatabaseError::query_error(format!("Failed to describe column {}: {}", i, e))
})?;
let name = String::from_utf8_lossy(&name_buffer[..col_info.name_length]).to_string();
columns.push(ColumnInfo {
name,
data_type: format!("{:?}", col_info.data_type),
ordinal: (i - 1) as usize,
});
}
let mut rows_data = Vec::new();
let buffer_desc: Vec<BufferDesc> = (0..num_cols)
.map(|_| BufferDesc::Text { max_str_len: 1024 })
.collect();
let row_set_buffer = ColumnarAnyBuffer::try_from_descs(100, buffer_desc.iter())
.map_err(|e| DatabaseError::query_error(format!("Failed to create buffer: {}", e)))?;
let mut cursor = cursor
.bind_buffer(row_set_buffer)
.map_err(|e| DatabaseError::query_error(format!("Failed to bind buffer: {}", e)))?;
while let Some(row_set) = cursor
.fetch()
.map_err(|e| DatabaseError::query_error(format!("Failed to fetch rows: {}", e)))?
{
for row_idx in 0..row_set.num_rows() {
let mut row_data = Vec::new();
for col_idx in 0..num_cols {
let col_view = row_set.column(col_idx);
let value = match col_view {
odbc_api::buffers::AnyColumnView::Text(text_col) => {
match text_col.get(row_idx) {
Some(text_bytes) => {
let text = String::from_utf8_lossy(text_bytes).to_string();
if text.is_empty() {
JsonValue::Null
} else if let Ok(num) = text.parse::<i64>() {
JsonValue::from(num)
} else if let Ok(num) = text.parse::<f64>() {
JsonValue::from(num)
} else {
JsonValue::from(text)
}
}
None => JsonValue::Null,
}
}
_ => JsonValue::String("(binary data)".to_string()),
};
row_data.push(value);
}
rows_data.push(row_data);
}
}
Ok(Recordset::new(columns, rows_data))
}
pub async fn execute_command(&self, command: &str) -> DatabaseResult<u64> {
let conn = self
.connection
.as_ref()
.ok_or(DatabaseError::NotConnected)?;
let command_owned = command.to_string();
let rows_affected = tokio::task::spawn_blocking(move || {
conn.execute(&command_owned, ())
.map_err(|e| DatabaseError::query_error(format!("Command execution failed: {}", e)))
})
.await
.map_err(|e| DatabaseError::query_error(format!("Task join error: {}", e)))??;
match rows_affected {
Some(_cursor) => Ok(0), None => Ok(0), }
}
pub async fn close(&mut self) -> DatabaseResult<()> {
self.connection = None;
self.environment = None;
Ok(())
}
pub async fn is_connected(&self) -> bool {
if let Some(conn) = &self.connection {
tokio::task::spawn_blocking(move || conn.execute("SELECT 1", ()))
.await
.is_ok()
} else {
false
}
}
}
impl Drop for ODBCConnection {
fn drop(&mut self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_odbc_connection_string_parsing() {
let conn_str = "DRIVER={SQLite3};Database=test.db";
let result = ODBCConnection::connect(conn_str).await;
println!("Connection attempt result: {:?}", result.is_ok());
}
}