mod builder;
mod error;
mod handlers;
mod middleware;
pub use self::builder::ChannelBuilder;
pub(crate) use self::error::{ClientError, ClientErrorKind};
pub use self::middleware::HeaderInterceptor;
use self::handlers::{AnalyzeHandler, ExecuteHandler, InterruptHandler};
use crate::spark;
use crate::spark::spark_connect_service_client::SparkConnectServiceClient;
use crate::spark::execute_plan_response::ResponseType;
use arrow::array::RecordBatch;
use std::sync::Arc;
use tokio::sync::RwLock;
use tonic::codec::Streaming;
use tonic::transport::Channel;
use uuid;
type InterceptedChannel = tonic::service::interceptor::InterceptedService<Channel, HeaderInterceptor>;
#[derive(Clone, Debug)]
pub struct SparkClient {
pub(crate) builder: ChannelBuilder,
stub: Arc<RwLock<SparkConnectServiceClient<InterceptedChannel>>>,
user_context: Option<spark::UserContext>,
use_reattachable_execute: bool,
session_id: String,
operation_id: Option<String>,
response_id: Option<String>,
handler_analyze: AnalyzeHandler,
handler_execute: ExecuteHandler,
handler_interrupt: InterruptHandler,
}
impl SparkClient {
pub(crate) fn new(
stub: Arc<RwLock<SparkConnectServiceClient<InterceptedChannel>>>,
builder: ChannelBuilder,
) -> Self {
let user_ref = builder.user_id.clone().unwrap_or("".to_string());
let session_id = builder.session_id.to_string();
Self {
stub,
builder,
user_context: Some(spark::UserContext {
user_id: user_ref.clone(),
user_name: user_ref,
extensions: vec![],
}),
session_id,
operation_id: None,
response_id: None,
handler_analyze: AnalyzeHandler::default(),
handler_execute: ExecuteHandler::default(),
handler_interrupt: InterruptHandler::default(),
use_reattachable_execute: true,
}
}
pub(crate) fn session_id(&self) -> String {
self.session_id.to_string()
}
pub(crate) fn spark_version(&self) -> Result<String, ClientError> {
self.handler_analyze
.spark_version
.to_owned()
.ok_or_else(|| ClientError::new(ClientErrorKind::AnalyzeResponseNotFound(
"Spark version response is empty".to_string()
)))
}
pub(crate) fn interrupted_ids(&self) -> Vec<String> {
self.handler_interrupt.interrupted_ids.to_owned()
}
pub(crate) fn relation(&self) -> Result<spark::Relation, ClientError> {
self.handler_execute
.relation
.to_owned()
.ok_or_else(|| ClientError::new(ClientErrorKind::AnalyzeResponseNotFound(
"relation response is empty".to_string()
)))
}
pub(crate) fn batches(&self) -> Vec<RecordBatch> {
self.handler_execute.batches.to_owned()
}
fn validate_session(&self, session_id: &str) -> Result<(), ClientError> {
if self.session_id() != session_id {
return Err(ClientError::new(ClientErrorKind::SessionIDMismatch {
client_session_id: self.builder.session_id.to_string(),
request_session_id: session_id.to_string()
}));
}
Ok(())
}
pub(crate) async fn analyze(
&mut self,
analyze: spark::analyze_plan_request::Analyze,
) -> Result<&mut Self, ClientError> {
let request = spark::AnalyzePlanRequest {
session_id: self.session_id(),
user_context: self.user_context.clone(),
client_type: self.builder.user_agent.clone(),
analyze: Some(analyze),
};
let mut client = self.stub.write().await;
let resp = client.analyze_plan(request.clone())
.await
.map_err(|status| {
ClientError::new(ClientErrorKind::AnalyzeRequest { status, request })
})?
.into_inner();
drop(client);
self.handle_analyze_response(resp)?;
Ok(self)
}
fn handle_analyze_response(
&mut self,
resp: spark::AnalyzePlanResponse,
) -> Result<(), ClientError> {
self.validate_session(&resp.session_id)?;
self.handler_analyze = AnalyzeHandler::default();
if let Some(result) = resp.result {
match result {
spark::analyze_plan_response::Result::Schema(schema) => {
self.handler_analyze.schema = schema.schema
}
spark::analyze_plan_response::Result::SparkVersion(spark_version) => {
self.handler_analyze.spark_version = Some(spark_version.version)
}
_ => return Err(ClientError::new(ClientErrorKind::Unimplemented(format!(
"Handling of analyze response {result:?} not implemented!"
))))
}
}
Ok(())
}
pub(crate) async fn interrupt(
&mut self,
interrupt_type: spark::interrupt_request::InterruptType,
id_or_tag: Option<String>,
) -> Result<&mut Self, ClientError> {
let mut request = spark::InterruptRequest {
session_id: self.session_id(),
user_context: self.user_context.clone(),
client_type: self.builder.user_agent.clone(),
interrupt_type: 0,
interrupt: None,
};
match interrupt_type {
spark::interrupt_request::InterruptType::All => {
request.interrupt_type = interrupt_type.into();
}
spark::interrupt_request::InterruptType::Tag => {
return Err(ClientError::new(ClientErrorKind::Unimplemented(
"Tag interrupts are not implemented!".to_string()
)))
}
spark::interrupt_request::InterruptType::OperationId => {
let op_id = id_or_tag.expect("Operation ID can not be empty");
let interrupt = spark::interrupt_request::Interrupt::OperationId(op_id);
request.interrupt_type = interrupt_type.into();
request.interrupt = Some(interrupt);
}
spark::interrupt_request::InterruptType::Unspecified => {
return Err(ClientError::new(ClientErrorKind::UnspecifiedInterruptRequest))
}
};
let mut client = self.stub.write().await;
let resp = client
.interrupt(request.clone())
.await
.map_err(|status| {
ClientError::new(ClientErrorKind::InterruptRequest { status, request })
})?
.into_inner();
drop(client);
self.handler_interrupt = InterruptHandler::default();
self.handler_interrupt.interrupted_ids = resp.interrupted_ids;
Ok(self)
}
pub(crate) async fn execute_plan(
&mut self,
plan: spark::Plan
) -> Result<&mut Self, ClientError> {
let mut request = self.new_execute_plan_request();
request.plan = Some(plan);
let mut client = self.stub.write().await;
let mut stream = client
.execute_plan(request.clone())
.await
.map_err(|status| {
ClientError::new(ClientErrorKind::ExecutePlanRequest { status, request })
})?
.into_inner();
drop(client);
self.handler_execute = ExecuteHandler::default();
self.process_stream(&mut stream).await?;
if self.use_reattachable_execute && self.handler_execute.result_complete {
self.release_all().await?;
}
Ok(self)
}
fn new_execute_plan_request(&mut self) -> spark::ExecutePlanRequest {
let operation_id = uuid::Uuid::new_v4().to_string();
self.operation_id = Some(operation_id.clone());
spark::ExecutePlanRequest {
session_id: self.session_id(),
user_context: self.user_context.clone(),
operation_id: Some(operation_id),
plan: None,
client_type: self.builder.user_agent.clone(),
request_options: vec![spark::execute_plan_request::RequestOption {
request_option: Some(
spark::execute_plan_request::request_option::RequestOption::ReattachOptions(
spark::ReattachOptions { reattachable: self.use_reattachable_execute },
),
),
}],
tags: vec![],
}
}
fn handle_execute_response(
&mut self,
resp: spark::ExecutePlanResponse
) -> Result<(), ClientError> {
self.validate_session(&resp.session_id)?;
self.operation_id = Some(resp.operation_id);
self.response_id = Some(resp.response_id);
if let Some(data) = resp.response_type {
match data {
ResponseType::ArrowBatch(res) => {
let (batches, total_count) = crate::io::deserialize(
res.data.as_slice(), res.row_count
)?;
self.handler_execute.batches.extend(batches);
self.handler_execute.total_count += total_count;
}
ResponseType::SqlCommandResult(sql_cmd) => {
self.handler_execute.relation = sql_cmd.clone().relation
}
ResponseType::ResultComplete(_) => self.handler_execute.result_complete = true,
_ => return Err(ClientError::new(ClientErrorKind::Unimplemented(
format!("Handling of plan response {data:?} not implemented!")
)))
}
}
Ok(())
}
async fn reattach(&mut self) -> Result<(), ClientError> {
let request = spark::ReattachExecuteRequest {
session_id: self.session_id(),
user_context: self.user_context.clone(),
operation_id: self.operation_id.clone().unwrap(),
client_type: self.builder.user_agent.clone(),
last_response_id: self.response_id.clone(),
};
let mut client = self.stub.write().await;
let mut stream = client
.reattach_execute(request.clone())
.await
.map_err(|status| {
ClientError::new(ClientErrorKind::ReattachExecuteRequest { status, request })
})?
.into_inner();
drop(client);
self.process_stream(&mut stream).await?;
if self.use_reattachable_execute && self.handler_execute.result_complete {
self.release_all().await?;
}
Ok(())
}
async fn process_stream(
&mut self, stream: &mut Streaming<spark::ExecutePlanResponse>
) -> Result<(), ClientError> {
while let Some(_resp) = match stream.message().await {
Ok(Some(msg)) => {
self.handle_execute_response(msg.clone())?;
Some(msg)
}
Ok(None) => {
if self.use_reattachable_execute && !self.handler_execute.result_complete {
Box::pin(self.reattach()).await?;
}
None
}
Err(status) => {
if self.use_reattachable_execute && self.response_id.is_some() {
self.release_until().await?;
}
return Err(ClientError::new(ClientErrorKind::Stream(status)));
}
} {}
Ok(())
}
async fn release_until(&mut self) -> Result<(), ClientError> {
let release_until = spark::release_execute_request::ReleaseUntil {
response_id: self.response_id.clone().unwrap(),
};
self.release_execute(spark::release_execute_request::Release::ReleaseUntil(
release_until,
)).await
}
async fn release_all(&mut self) -> Result<(), ClientError> {
let release_all = spark::release_execute_request::ReleaseAll {};
self.release_execute(spark::release_execute_request::Release::ReleaseAll(
release_all,
)).await
}
async fn release_execute(
&mut self,
release: spark::release_execute_request::Release,
) -> Result<(), ClientError> {
let mut client = self.stub.write().await;
let request = spark::ReleaseExecuteRequest {
session_id: self.session_id(),
user_context: self.user_context.clone(),
operation_id: self.operation_id.clone().unwrap(),
client_type: self.builder.user_agent.clone(),
release: Some(release),
};
let _resp = client
.release_execute(request.clone())
.await
.map_err(|status| {
ClientError::new(ClientErrorKind::ReleaseExecuteRequest { status, request })
})?
.into_inner();
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::test_utils::test_utils::setup_session;
use crate::spark;
#[tokio::test]
async fn test_validate_session_error() {
let session = setup_session().await.expect("Failed to create Spark session");
let mut client_with_bad_session = session.client().clone();
client_with_bad_session.session_id = "invalid-session-id".to_string();
let result = client_with_bad_session
.analyze(spark::analyze_plan_request::Analyze::SparkVersion(
spark::analyze_plan_request::SparkVersion {},
))
.await;
assert!(
result.is_err(),
"Expected an error due to invalid session ID"
);
}
#[tokio::test]
async fn test_interrupt_all_request() {
let session = setup_session().await.expect("Failed to create Spark session");
let mut client = session.client();
let result = client
.interrupt(spark::interrupt_request::InterruptType::All, None)
.await
.unwrap();
assert_eq!(result.session_id(), session.session_id());
}
}