use std::{sync::Arc, time::Duration};
use arrow_flight::sql::client::FlightSqlServiceClient;
use datafusion::{
catalog::SchemaProvider,
error::{DataFusionError, Result},
execution::{
context::{SessionContext, SessionState},
options::CsvReadOptions,
},
};
use datafusion_federation::sql::{SQLFederationProvider, SQLSchemaProvider};
use datafusion_flight_sql_server::service::FlightSqlService;
use datafusion_flight_sql_table_provider::FlightSQLExecutor;
use tokio::time::sleep;
use tonic::transport::Endpoint;
#[tokio::main]
async fn main() -> Result<()> {
let dsn: String = "0.0.0.0:50051".to_string();
let remote_ctx = SessionContext::new();
let csv_path = format!("{}/examples/test.csv", env!("CARGO_MANIFEST_DIR"));
remote_ctx
.register_csv("test", &csv_path, CsvReadOptions::new())
.await?;
tokio::spawn(async move {
FlightSqlService::new(remote_ctx.state())
.serve(dsn.clone())
.await
.unwrap();
});
sleep(Duration::from_secs(3)).await;
let state = datafusion_federation::default_session_state();
let known_tables: Vec<String> = ["test"].iter().map(|&x| x.into()).collect();
let dsn: String = "http://localhost:50051".to_string();
let client = new_client(dsn.clone()).await?;
let executor = Arc::new(FlightSQLExecutor::new(dsn, client));
let provider = Arc::new(SQLFederationProvider::new(executor));
let schema_provider =
Arc::new(SQLSchemaProvider::new_with_tables(provider, known_tables).await?);
overwrite_default_schema(&state, schema_provider)?;
let ctx = SessionContext::new_with_state(state);
let query = r#"SELECT * from test"#;
let df = ctx.sql(query).await?;
df.show().await
}
fn overwrite_default_schema(state: &SessionState, schema: Arc<dyn SchemaProvider>) -> Result<()> {
let options = &state.config().options().catalog;
let catalog = state
.catalog_list()
.catalog(options.default_catalog.as_str())
.unwrap();
catalog.register_schema(options.default_schema.as_str(), schema)?;
Ok(())
}
async fn new_client(dsn: String) -> Result<FlightSqlServiceClient<tonic::transport::Channel>> {
let endpoint = Endpoint::new(dsn).map_err(tx_error_to_df)?;
let channel = endpoint.connect().await.map_err(tx_error_to_df)?;
Ok(FlightSqlServiceClient::new(channel))
}
fn tx_error_to_df(err: tonic::transport::Error) -> DataFusionError {
DataFusionError::External(format!("failed to connect: {err:?}").into())
}