use crate::error::{Result, SqlStreamError};
use datafusion::arrow::util::pretty::print_batches;
use datafusion::prelude::*;
use std::path::Path;
use tracing::{debug, info, instrument};
pub struct QueryEngine {
ctx: SessionContext,
}
impl QueryEngine {
#[instrument]
pub fn new() -> Result<Self> {
info!("Initializing query engine");
let ctx = SessionContext::new();
Ok(Self { ctx })
}
#[instrument(skip(self))]
pub async fn register_file(&mut self, file_path: &str, table_name: &str) -> Result<()> {
let path = Path::new(file_path);
if !path.exists() {
return Err(SqlStreamError::FileNotFound(path.to_path_buf()));
}
info!("Registering file: {} as table: {}", file_path, table_name);
let extension = path
.extension()
.and_then(|ext| ext.to_str())
.ok_or_else(|| SqlStreamError::UnsupportedFormat(path.to_string_lossy().to_string()))?;
match extension.to_lowercase().as_str() {
"csv" => {
debug!("Detected CSV format");
self.ctx
.register_csv(table_name, file_path, CsvReadOptions::new())
.await
.map_err(|e| {
SqlStreamError::TableRegistration(table_name.to_string(), e.to_string())
})?;
}
"json" => {
debug!("Detected JSON format");
self.ctx
.register_json(table_name, file_path, NdJsonReadOptions::default())
.await
.map_err(|e| {
SqlStreamError::TableRegistration(table_name.to_string(), e.to_string())
})?;
}
_ => {
return Err(SqlStreamError::UnsupportedFormat(extension.to_string()));
}
}
info!("Successfully registered table: {}", table_name);
Ok(())
}
#[instrument(skip(self))]
pub async fn execute_query(&self, sql: &str) -> Result<DataFrame> {
info!("Executing SQL query");
debug!("Query: {}", sql);
let df = self
.ctx
.sql(sql)
.await
.map_err(|e| SqlStreamError::QueryExecution(e.to_string()))?;
Ok(df)
}
#[instrument(skip(self, dataframe))]
pub async fn print_results(&self, dataframe: DataFrame) -> Result<()> {
info!("Collecting and printing results");
let batches = dataframe.collect().await?;
print_batches(&batches).map_err(|e| {
SqlStreamError::QueryExecution(format!("Failed to print results: {}", e))
})?;
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
info!("Query returned {} rows", total_rows);
Ok(())
}
}
impl Default for QueryEngine {
fn default() -> Self {
Self::new().expect("Failed to create default QueryEngine")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_engine_creation() {
let engine = QueryEngine::new();
assert!(engine.is_ok());
}
#[tokio::test]
async fn test_file_not_found() {
let mut engine = QueryEngine::new().unwrap();
let result = engine.register_file("nonexistent.csv", "test").await;
assert!(matches!(result, Err(SqlStreamError::FileNotFound(_))));
}
}