use std::time::Duration;
use arrow_flight::flight_service_server::FlightServiceServer;
use arrow_flight::sql::client::FlightSqlServiceClient;
use arrow_flight::sql::CommandGetTables;
use async_trait::async_trait;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::SessionState; use datafusion::prelude::*; use datafusion_flight_sql_server::service::FlightSqlService;
use datafusion_flight_sql_server::session::SessionStateProvider;
use tokio::time::sleep;
use tonic::transport::{Channel, Endpoint, Server};
use tonic::{Request, Status};
#[derive(Clone, Debug)]
pub struct UserData {
pub user_id: u32,
}
async fn bearer_auth_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
let auth_header = req
.metadata()
.get("authorization")
.ok_or_else(|| Status::unauthenticated("no authorization provided"))?;
let auth_str = auth_header
.to_str()
.map_err(|_| Status::unauthenticated("invalid authorization header encoding"))?;
if !auth_str.starts_with("Bearer ") {
return Err(Status::unauthenticated(
"invalid authorization header format",
));
}
let token = &auth_str["Bearer ".len()..];
let user_data = match token {
"token1" => UserData { user_id: 1 },
"token2" => UserData { user_id: 2 },
_ => return Err(Status::unauthenticated("invalid token")),
};
req.extensions_mut().insert(user_data);
Ok(req)
}
pub struct MySessionStateProvider {
base_context: SessionContext,
}
impl MySessionStateProvider {
async fn try_new() -> Result<Self> {
let ctx = SessionContext::new();
let csv_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/test.csv");
ctx.register_csv("test", csv_path, CsvReadOptions::new())
.await?;
Ok(Self { base_context: ctx })
}
}
#[async_trait]
impl SessionStateProvider for MySessionStateProvider {
async fn new_context(&self, request: &Request<()>) -> Result<SessionState, Status> {
if let Some(user_data) = request.extensions().get::<UserData>() {
println!(
"Session context for user_id: {}. Cloning base context.",
user_data.user_id
);
let state = self.base_context.state().clone();
Ok(state)
} else {
Err(Status::unauthenticated(
"User data not found in request extensions (MySessionStateProvider)",
))
}
}
}
async fn new_client_with_auth(
dsn: String,
token: Option<String>,
) -> Result<FlightSqlServiceClient<Channel>> {
let endpoint = Endpoint::from_shared(dsn.clone())
.map_err(|e| DataFusionError::External(format!("Invalid DSN {}: {}", dsn, e).into()))?
.connect_timeout(std::time::Duration::from_secs(10));
let channel = endpoint.connect().await.map_err(|e| {
DataFusionError::External(format!("Failed to connect to {}: {}", dsn, e).into())
})?;
let mut service_client = FlightSqlServiceClient::new(channel);
if let Some(token_str) = token.clone() {
service_client.set_header("authorization", format!("Bearer {}", token_str));
}
Ok(service_client)
}
#[tokio::main]
async fn main() -> Result<()> {
let dsn: String = "0.0.0.0:50051".to_string();
let state_provider = Box::new(MySessionStateProvider::try_new().await?);
let base_service = FlightSqlService::new_with_provider(state_provider);
let svc: FlightServiceServer<FlightSqlService> = FlightServiceServer::new(base_service);
let addr: std::net::SocketAddr = dsn.parse().map_err(|e| {
DataFusionError::External(format!("Invalid address format {}: {}", dsn, e).into())
})?;
tokio::spawn(async move {
println!(
"Bearer Authentication Flight SQL server listening on {}",
addr
);
if let Err(e) = Server::builder()
.layer(tonic_async_interceptor::async_interceptor(
bearer_auth_interceptor,
))
.add_service(svc)
.serve(addr)
.await
{
eprintln!("Server error: {}", e);
}
});
sleep(Duration::from_secs(3)).await;
let client_dsn = "http://localhost:50051".to_string();
println!("\nAttempting GetTables with valid token (token1)...");
match new_client_with_auth(client_dsn.clone(), Some("token1".to_string())).await {
Ok(mut client) => {
let request = CommandGetTables {
catalog: None,
db_schema_filter_pattern: None,
table_name_filter_pattern: None,
table_types: vec![],
include_schema: false,
};
match client.get_tables(request).await {
Ok(response) => {
println!("GetTables with token1 SUCCEEDED. Response: {:?}", response)
}
Err(e) => eprintln!("GetTables with token1 FAILED: {}", e),
}
}
Err(e) => eprintln!("Failed to create client with token1: {}", e),
}
println!("\nAttempting GetTables with invalid token (invalidtoken)...");
match new_client_with_auth(client_dsn.clone(), Some("invalidtoken".to_string())).await {
Ok(mut client) => {
let request = CommandGetTables {
catalog: None,
db_schema_filter_pattern: None,
table_name_filter_pattern: None,
table_types: vec![],
include_schema: false,
};
match client.get_tables(request).await {
Ok(response) => println!(
"GetTables with invalidtoken SUCCEEDED (unexpected). Response: {:?}",
response
),
Err(e) => eprintln!("GetTables with invalidtoken FAILED (as expected): {:?}", e),
}
}
Err(e) => eprintln!("Failed to create client with invalidtoken: {}", e),
}
println!("\nAttempting GetTables with no token...");
match new_client_with_auth(client_dsn.clone(), None).await {
Ok(mut client) => {
let request = CommandGetTables {
catalog: None,
db_schema_filter_pattern: None,
table_name_filter_pattern: None,
table_types: vec![],
include_schema: false,
};
match client.get_tables(request).await {
Ok(response) => println!(
"GetTables with no token SUCCEEDED (unexpected). Response: {:?}",
response
),
Err(e) => eprintln!("GetTables with no token FAILED (as expected): {:?}", e),
}
}
Err(e) => eprintln!("Failed to create client with no token: {}", e),
}
Ok(())
}