use crate::client::ChannelBuilder;
use crate::client::HeaderInterceptor;
use crate::client::SparkClient;
use crate::spark;
use crate::spark::spark_connect_service_client::SparkConnectServiceClient;
use crate::spark::expression::Literal;
use crate::query::SqlQueryBuilder;
use crate::{SparkError, error::SparkErrorKind};
use arrow::record_batch::RecordBatch;
use std::sync::Arc;
use tokio::sync::RwLock;
use tonic::transport::Channel;
#[cfg(feature = "tls")]
use tonic::transport::ClientTlsConfig;
use tower::ServiceBuilder;
#[derive(Clone, Debug)]
pub struct SparkSessionBuilder {
channel_builder: ChannelBuilder,
}
impl SparkSessionBuilder {
pub fn new(connection: &str) -> Self {
let channel_builder =
ChannelBuilder::new(connection).expect("Invalid Spark connection string");
Self { channel_builder }
}
pub async fn build(&self) -> Result<SparkSession, SparkError> {
let mut endpoint = Channel::from_shared(self.channel_builder.endpoint())
.map_err(|source| {
SparkError::new(SparkErrorKind::InvalidConnectionUri {
source, uri: self.channel_builder.endpoint()
})
})?;
#[cfg(feature = "tls")]
if self.channel_builder.use_ssl {
let tls_config = ClientTlsConfig::new()
.domain_name(&self.channel_builder.host)
.with_native_roots();
endpoint = endpoint.tls_config(tls_config).map_err(|source| {
SparkError::new(SparkErrorKind::Transport(source))
})?;
}
let channel = ServiceBuilder::new().service(
endpoint.connect().await.map_err(|source| {
SparkError::new(SparkErrorKind::Transport(source))
})?
);
let grpc_client = SparkConnectServiceClient::with_interceptor(
channel, HeaderInterceptor::new(
self.channel_builder.headers().unwrap_or_default()
)
);
let spark_client = SparkClient::new(
Arc::new(RwLock::new(grpc_client)),
self.channel_builder.clone(),
);
Ok(SparkSession::new(spark_client))
}
}
#[derive(Clone, Debug)]
pub struct SparkSession {
client: SparkClient,
session_id: String,
}
impl SparkSession {
pub(crate) fn new(client: SparkClient) -> Self {
let session_id = client.session_id().to_string();
Self { client, session_id }
}
pub fn session_id(&self) -> String {
self.session_id.to_string()
}
pub(crate) fn client(&self) -> SparkClient {
self.client.clone()
}
pub async fn sql(
&self,
query: &str,
params: Vec<Literal>
) -> Result<spark::Plan, SparkError> {
let sql_cmd = spark::command::CommandType::SqlCommand(
spark::SqlCommand {
sql: query.to_string(),
args: Default::default(),
pos_args: params,
},
);
let plan = spark::Plan {
op_type: Some(spark::plan::OpType::Command(spark::Command {
command_type: Some(sql_cmd),
})),
};
let mut client = self.client();
let result = client.execute_plan(plan).await?;
Ok(spark::Plan {
op_type: Some(spark::plan::OpType::Root(result.relation()?)),
})
}
pub fn query(
&self,
query: &str,
) -> SqlQueryBuilder<'_> {
SqlQueryBuilder::new(&self, query)
}
pub async fn collect(&self, plan: spark::Plan) -> Result<Vec<RecordBatch>, SparkError> {
let mut client = self.client();
Ok(client.execute_plan(plan).await?.batches())
}
pub async fn interrupt_all(&self) -> Result<Vec<String>, SparkError> {
Ok(
self.client().interrupt(
spark::interrupt_request::InterruptType::All,
None
).await?.interrupted_ids()
)
}
pub async fn interrupt_operation(&self, op_id: &str) -> Result<Vec<String>, SparkError> {
Ok(
self.client().interrupt(
spark::interrupt_request::InterruptType::OperationId,
Some(op_id.to_string()),
).await?.interrupted_ids()
)
}
pub async fn version(&self) -> Result<String, SparkError> {
let version = spark::analyze_plan_request::Analyze::SparkVersion(
spark::analyze_plan_request::SparkVersion {},
);
let mut client = self.client.clone();
Ok(client.analyze(version).await?.spark_version()?)
}
}
#[cfg(test)]
mod tests {
use crate::test_utils::test_utils::setup_session;
use crate::SparkError;
use arrow::array::{Int32Array, StringArray};
use regex::Regex;
#[tokio::test]
async fn test_session_create() {
let spark = setup_session().await;
assert!(spark.is_ok());
}
#[tokio::test]
async fn test_session_version() -> Result<(), SparkError> {
let spark = setup_session().await?;
let version = spark.version().await?;
let re = Regex::new(r"^\d+\.\d+\.\d+$").unwrap();
assert!(re.is_match(&version), "Version {} invalid", version);
Ok(())
}
#[tokio::test]
async fn test_sql() {
let session = setup_session().await.expect("Failed to create Spark session");
let lazy_plan = session
.sql("SELECT 1 AS id, 'hello' AS text", vec![])
.await
.expect("SQL query failed");
let batches = session
.collect(lazy_plan)
.await
.expect("Failed to collect batches");
assert_eq!(batches.len(), 1, "Expected exactly one RecordBatch");
let batch = &batches[0];
assert_eq!(batch.num_rows(), 1, "Expected one row");
assert_eq!(batch.num_columns(), 2, "Expected two columns");
let id_col = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.expect("Column 0 should be an Int32Array");
assert_eq!(id_col.value(0), 1);
}
#[tokio::test]
async fn test_sql_query_builder_bind() -> Result<(), SparkError> {
let session = setup_session().await?;
let batches = session
.query("SELECT ? AS id, ? AS text")
.bind(42_i32)
.bind("world")
.execute()
.await?;
assert_eq!(batches.len(), 1);
let batch = &batches[0];
assert_eq!(batch.num_rows(), 1);
assert_eq!(batch.num_columns(), 2);
let id_col = batch.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(id_col.value(0), 42);
let text_col = batch.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(text_col.value(0), "world");
Ok(())
}
}