use std::sync::Arc;
use tonic::transport::{Channel, Endpoint};
use tracing::{debug, info, warn};
use crate::client::error::{Error, ErrorKind, Result};
use super::config::GrpcConfig;
use super::error::from_grpc_status;
use super::executor::{GrpcChunkStream, GrpcQueryExecutor};
use super::params::{ParameterStyle, QueryParameters};
use super::proto::hyper_service::query_param::TransferMode;
use super::proto::{
AttachedDatabase, CancelQueryParam, HyperServiceClient, OutputFormat, QueryParam,
};
use super::result::GrpcQueryResult;
#[derive(Debug)]
pub struct GrpcClient {
channel: Channel,
config: GrpcConfig,
}
impl GrpcClient {
pub async fn connect(config: GrpcConfig) -> Result<Self> {
info!(endpoint = %config.endpoint, "Connecting to Hyper via gRPC");
let endpoint = Endpoint::from_shared(config.endpoint.clone())
.map_err(|e| Error::new(ErrorKind::Config, format!("Invalid gRPC endpoint: {e}")))?;
let endpoint = endpoint
.connect_timeout(config.connect_timeout)
.timeout(config.request_timeout);
let endpoint = if config.use_tls {
let tls_config = tonic::transport::ClientTlsConfig::new().with_enabled_roots();
endpoint.tls_config(tls_config).map_err(|e| {
Error::new(ErrorKind::Config, format!("TLS configuration error: {e}"))
})?
} else {
endpoint
};
let channel = endpoint.connect().await.map_err(|e| {
debug!("gRPC connection error details: {:?}", e);
Error::new(
ErrorKind::Connection,
format!("Failed to connect to gRPC endpoint: {e} (details: {e:?})"),
)
})?;
debug!("gRPC channel established");
Ok(GrpcClient { channel, config })
}
#[must_use]
pub fn channel(&self) -> &Channel {
&self.channel
}
pub fn config(&self) -> &GrpcConfig {
&self.config
}
pub async fn execute_query(&mut self, sql: &str) -> Result<GrpcQueryResult> {
self.execute_query_with_options(sql, OutputFormat::ArrowIpc, self.config.transfer_mode)
.await
}
pub async fn execute_query_to_arrow(&mut self, sql: &str) -> Result<bytes::Bytes> {
let result = self.execute_query(sql).await?;
Ok(result.into_arrow_data())
}
pub async fn execute_query_with_params(
&mut self,
sql: &str,
params: QueryParameters,
style: ParameterStyle,
) -> Result<GrpcQueryResult> {
self.execute_query_with_params_and_options(
sql,
params,
style,
OutputFormat::ArrowIpc,
self.config.transfer_mode,
)
.await
}
pub async fn execute_query_with_params_to_arrow(
&mut self,
sql: &str,
params: QueryParameters,
style: ParameterStyle,
) -> Result<bytes::Bytes> {
let result = self.execute_query_with_params(sql, params, style).await?;
Ok(result.into_arrow_data())
}
pub async fn execute_query_with_params_and_options(
&mut self,
sql: &str,
params: QueryParameters,
style: ParameterStyle,
output_format: OutputFormat,
transfer_mode: TransferMode,
) -> Result<GrpcQueryResult> {
debug!(
sql = %sql,
param_style = ?style,
format = ?output_format,
mode = ?transfer_mode,
"Executing parameterized query"
);
let query_param = QueryParam {
query: sql.to_string(),
databases: self.build_attached_databases(),
output_format: output_format.into(),
settings: self.config.settings.clone(),
transfer_mode: transfer_mode.into(),
param_style: i32::from(style),
parameters: Some(params.into_proto()),
result_range: None,
query_row_limit: None,
};
let headers = self.build_headers();
let client = HyperServiceClient::new(self.channel.clone())
.max_decoding_message_size(self.config.max_decoding_message_size)
.max_encoding_message_size(self.config.max_encoding_message_size);
let mut executor = GrpcQueryExecutor::new(client, headers, transfer_mode);
executor.execute(query_param).await?;
let mut final_result = GrpcQueryResult::default();
loop {
if let Some(mut partial_result) = executor.next_result().await? {
while let Some(chunk) = partial_result.take_chunk() {
final_result.chunks.push_back(chunk);
}
if partial_result.query_id.is_some() {
final_result.query_id = partial_result.query_id;
}
if partial_result.schema.is_some() {
final_result.schema = partial_result.schema;
}
if partial_result.rows_affected.is_some() {
final_result.rows_affected = partial_result.rows_affected;
}
if partial_result.is_complete {
final_result.is_complete = true;
break;
}
} else {
final_result.is_complete = true;
break;
}
}
if final_result.chunks.is_empty() && !final_result.is_complete {
return Err(Error::new(ErrorKind::Protocol, "No result from query"));
}
Ok(final_result)
}
pub async fn execute_query_with_options(
&mut self,
sql: &str,
output_format: OutputFormat,
transfer_mode: TransferMode,
) -> Result<GrpcQueryResult> {
debug!(sql = %sql, format = ?output_format, mode = ?transfer_mode, "Executing query");
let query_param = QueryParam {
query: sql.to_string(),
databases: self.build_attached_databases(),
output_format: output_format.into(),
settings: self.config.settings.clone(),
transfer_mode: transfer_mode.into(),
param_style: 0, parameters: None,
result_range: None,
query_row_limit: None,
};
let headers = self.build_headers();
let client = HyperServiceClient::new(self.channel.clone())
.max_decoding_message_size(self.config.max_decoding_message_size)
.max_encoding_message_size(self.config.max_encoding_message_size);
let mut executor = GrpcQueryExecutor::new(client, headers, transfer_mode);
executor.execute(query_param).await?;
let mut final_result = GrpcQueryResult::default();
loop {
if let Some(mut partial_result) = executor.next_result().await? {
while let Some(chunk) = partial_result.take_chunk() {
final_result.chunks.push_back(chunk);
}
if partial_result.query_id.is_some() {
final_result.query_id = partial_result.query_id;
}
if partial_result.schema.is_some() {
final_result.schema = partial_result.schema;
}
if partial_result.rows_affected.is_some() {
final_result.rows_affected = partial_result.rows_affected;
}
if partial_result.is_complete {
final_result.is_complete = true;
break;
}
} else {
final_result.is_complete = true;
break;
}
}
if final_result.chunks.is_empty() && !final_result.is_complete {
return Err(Error::new(ErrorKind::Protocol, "No result from query"));
}
Ok(final_result)
}
pub async fn execute_query_stream(&mut self, sql: &str) -> Result<GrpcChunkStream> {
self.execute_query_stream_with_options(
sql,
OutputFormat::ArrowIpc,
self.config.transfer_mode,
)
.await
}
pub async fn execute_query_stream_with_options(
&mut self,
sql: &str,
output_format: OutputFormat,
transfer_mode: TransferMode,
) -> Result<GrpcChunkStream> {
debug!(sql = %sql, format = ?output_format, mode = ?transfer_mode, "Executing streaming query");
let query_param = QueryParam {
query: sql.to_string(),
databases: self.build_attached_databases(),
output_format: output_format.into(),
settings: self.config.settings.clone(),
transfer_mode: transfer_mode.into(),
param_style: 0,
parameters: None,
result_range: None,
query_row_limit: None,
};
let headers = self.build_headers();
let client = HyperServiceClient::new(self.channel.clone())
.max_decoding_message_size(self.config.max_decoding_message_size)
.max_encoding_message_size(self.config.max_encoding_message_size);
let mut executor = GrpcQueryExecutor::new(client, headers, transfer_mode);
executor.execute(query_param).await?;
Ok(GrpcChunkStream::new(executor))
}
pub async fn cancel_query(&mut self, query_id: &str) -> Result<()> {
debug!(query_id = %query_id, "Cancelling gRPC query");
let param = CancelQueryParam {
query_id: query_id.to_string(),
};
let mut request = tonic::Request::new(param);
for (key, value) in self.build_headers() {
match (
key.parse::<tonic::metadata::MetadataKey<_>>(),
value.parse(),
) {
(Ok(k), Ok(v)) => {
request.metadata_mut().insert(k, v);
}
(key_res, value_res) => {
warn!(
target: "hyperdb_api_core::client",
query_id = %query_id,
header_key = %key,
key_parse_ok = key_res.is_ok(),
value_parse_ok = value_res.is_ok(),
"cancel: header parse failed, dropping header from cancel request",
);
}
}
}
match query_id.parse() {
Ok(value) => {
request.metadata_mut().insert("x-hyperdb-query-id", value);
}
Err(e) => {
warn!(
target: "hyperdb_api_core::client",
query_id = %query_id,
error = %e,
"cancel: x-hyperdb-query-id header parse failed; \
cancel routing may fall back to payload-based lookup",
);
}
}
let mut client = HyperServiceClient::new(self.channel.clone())
.max_decoding_message_size(self.config.max_decoding_message_size)
.max_encoding_message_size(self.config.max_encoding_message_size);
client
.cancel_query(request)
.await
.map_err(from_grpc_status)?;
info!(query_id = %query_id, "gRPC query cancelled");
Ok(())
}
#[expect(
clippy::unused_async,
reason = "async fn retained for API symmetry; callers await regardless of whether the current body is synchronous"
)]
pub async fn close(self) -> Result<()> {
debug!("Closing gRPC connection");
Ok(())
}
fn build_attached_databases(&self) -> Vec<AttachedDatabase> {
if let Some(db_path) = &self.config.database {
debug!(db_path = %db_path, "Attaching database for query");
if db_path.starts_with('[') {
vec![AttachedDatabase {
path: db_path.clone(),
alias: String::new(), }]
} else {
vec![AttachedDatabase {
path: db_path.clone(),
alias: String::new(), }]
}
} else {
debug!("No database configured on gRPC client — query will run without attachment");
vec![]
}
}
fn build_headers(&self) -> Vec<(String, String)> {
let mut headers: Vec<(String, String)> = self
.config
.headers
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
if let Some(ref db) = self.config.database {
headers.push(("x-hyper-database".to_string(), db.clone()));
}
headers
}
}
#[derive(Debug)]
pub struct GrpcClientSync {
inner: GrpcClient,
runtime: Arc<tokio::runtime::Runtime>,
}
impl GrpcClientSync {
pub fn connect(config: GrpcConfig) -> Result<Self> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Failed to create Tokio runtime: {e}"),
)
})?;
let inner = runtime.block_on(GrpcClient::connect(config))?;
Ok(GrpcClientSync {
inner,
runtime: Arc::new(runtime),
})
}
pub fn execute_query(&mut self, sql: &str) -> Result<GrpcQueryResult> {
self.runtime.block_on(self.inner.execute_query(sql))
}
pub fn execute_query_to_arrow(&mut self, sql: &str) -> Result<bytes::Bytes> {
self.runtime
.block_on(self.inner.execute_query_to_arrow(sql))
}
pub fn execute_query_stream(&mut self, sql: &str) -> Result<GrpcChunkStreamSync> {
let inner = self
.runtime
.block_on(self.inner.execute_query_stream(sql))?;
Ok(GrpcChunkStreamSync {
inner,
runtime: Arc::clone(&self.runtime),
})
}
pub fn execute_query_with_params(
&mut self,
sql: &str,
params: QueryParameters,
style: ParameterStyle,
) -> Result<GrpcQueryResult> {
self.runtime
.block_on(self.inner.execute_query_with_params(sql, params, style))
}
pub fn execute_query_with_params_to_arrow(
&mut self,
sql: &str,
params: QueryParameters,
style: ParameterStyle,
) -> Result<bytes::Bytes> {
self.runtime.block_on(
self.inner
.execute_query_with_params_to_arrow(sql, params, style),
)
}
pub fn cancel_query(&mut self, query_id: &str) -> Result<()> {
self.runtime.block_on(self.inner.cancel_query(query_id))
}
pub fn config(&self) -> &GrpcConfig {
self.inner.config()
}
pub fn close(self) -> Result<()> {
self.runtime.block_on(self.inner.close())
}
}
#[derive(Debug)]
pub struct GrpcChunkStreamSync {
inner: GrpcChunkStream,
runtime: Arc<tokio::runtime::Runtime>,
}
impl GrpcChunkStreamSync {
pub fn next_chunk(&mut self) -> Result<Option<bytes::Bytes>> {
self.runtime.block_on(self.inner.next_chunk())
}
pub fn schema(&self) -> Option<&super::proto::QueryResultSchema> {
self.inner.schema()
}
pub fn query_id(&self) -> Option<&str> {
self.inner.query_id()
}
pub fn rows_affected(&self) -> Option<u64> {
self.inner.rows_affected()
}
}