use std::time::Duration;
use tonic::transport::{Channel, Endpoint};
use tracing::{debug, info, warn};
use std::collections::HashMap;
use crate::client::error::{Error, ErrorKind, Result};
use hyperdb_api_salesforce::{DataCloudToken, SharedTokenProvider};
use super::error::from_grpc_status;
use super::executor::GrpcQueryExecutor;
use super::params::{ParameterStyle, QueryParameters};
use super::proto::hyper_service::query_param::TransferMode;
use super::proto::{
AttachedDatabase, CancelQueryParam, HyperServiceClient, OutputFormat, QueryParam,
};
use super::result::GrpcQueryResult;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TableInfo {
pub schema: String,
pub name: String,
pub table_type: String,
pub display_name: Option<String>,
}
impl TableInfo {
#[must_use]
pub fn full_name(&self) -> String {
format!("{}.{}", self.schema, self.name)
}
#[must_use]
pub fn display_name(&self) -> &str {
self.display_name.as_deref().unwrap_or(&self.name)
}
}
const DEFAULT_DC_JWT_EXPIRY_THRESHOLD_SECS: i64 = 300;
const DEFAULT_DC_JWT_MAX_AGE_SECS: i64 = 900;
const MAX_AUTH_RETRIES: u32 = 1;
pub struct AuthenticatedGrpcClient {
token_provider: SharedTokenProvider,
dataspace: Option<String>,
channel: Option<Channel>,
current_token: Option<DataCloudToken>,
dc_jwt_expiry_threshold_secs: i64,
dc_jwt_max_age_secs: i64,
transfer_mode: TransferMode,
connect_timeout: Duration,
request_timeout: Duration,
max_decoding_message_size: usize,
max_encoding_message_size: usize,
}
impl AuthenticatedGrpcClient {
pub async fn connect(
token_provider: SharedTokenProvider,
dataspace: Option<String>,
) -> Result<Self> {
use super::config::DEFAULT_MAX_MESSAGE_SIZE;
let mut client = AuthenticatedGrpcClient {
token_provider,
dataspace,
channel: None,
current_token: None,
dc_jwt_expiry_threshold_secs: DEFAULT_DC_JWT_EXPIRY_THRESHOLD_SECS,
dc_jwt_max_age_secs: DEFAULT_DC_JWT_MAX_AGE_SECS,
transfer_mode: TransferMode::Adaptive,
connect_timeout: Duration::from_secs(30),
request_timeout: Duration::from_secs(300),
max_decoding_message_size: DEFAULT_MAX_MESSAGE_SIZE,
max_encoding_message_size: DEFAULT_MAX_MESSAGE_SIZE,
};
client.ensure_connected().await?;
Ok(client)
}
#[must_use]
pub fn with_dc_jwt_expiry_threshold(mut self, secs: i64) -> Self {
self.dc_jwt_expiry_threshold_secs = secs;
self
}
#[must_use]
pub fn with_dc_jwt_max_age(mut self, secs: i64) -> Self {
self.dc_jwt_max_age_secs = secs;
self
}
#[must_use]
pub fn with_transfer_mode(mut self, mode: TransferMode) -> Self {
self.transfer_mode = mode;
self
}
#[must_use]
pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
#[must_use]
pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
#[must_use]
pub fn with_max_message_size(mut self, size: usize) -> Self {
self.max_decoding_message_size = size;
self.max_encoding_message_size = size;
self
}
#[must_use]
pub fn with_max_decoding_message_size(mut self, size: usize) -> Self {
self.max_decoding_message_size = size;
self
}
#[must_use]
pub fn with_max_encoding_message_size(mut self, size: usize) -> Self {
self.max_encoding_message_size = size;
self
}
#[must_use]
pub fn current_token(&self) -> Option<&DataCloudToken> {
self.current_token.as_ref()
}
#[must_use]
pub fn tenant_url(&self) -> Option<&str> {
self.current_token
.as_ref()
.map(hyperdb_api_salesforce::DataCloudToken::tenant_url_str)
}
pub async fn execute_query(&mut self, sql: &str) -> Result<GrpcQueryResult> {
self.execute_query_with_options(sql, OutputFormat::ArrowIpc, self.transfer_mode)
.await
}
pub async fn execute_query_to_arrow(&mut self, sql: &str) -> Result<bytes::Bytes> {
let result = self.execute_query(sql).await?;
Ok(result.into_arrow_data())
}
pub async fn execute_query_with_params(
&mut self,
sql: &str,
params: QueryParameters,
style: ParameterStyle,
) -> Result<GrpcQueryResult> {
self.execute_query_with_params_and_options(
sql,
params,
style,
OutputFormat::ArrowIpc,
self.transfer_mode,
)
.await
}
pub async fn execute_query_with_params_to_arrow(
&mut self,
sql: &str,
params: QueryParameters,
style: ParameterStyle,
) -> Result<bytes::Bytes> {
let result = self.execute_query_with_params(sql, params, style).await?;
Ok(result.into_arrow_data())
}
pub async fn execute_query_with_params_and_options(
&mut self,
sql: &str,
params: QueryParameters,
style: ParameterStyle,
output_format: OutputFormat,
transfer_mode: TransferMode,
) -> Result<GrpcQueryResult> {
self.ensure_token_valid().await?;
let mut last_error = None;
for attempt in 0..=MAX_AUTH_RETRIES {
if attempt > 0 {
info!(
"Retrying parameterized query after token refresh (attempt {})",
attempt + 1
);
}
let params_clone = params.clone();
match self
.execute_query_with_params_internal(
sql,
params_clone,
style,
output_format,
transfer_mode,
)
.await
{
Ok(result) => return Ok(result),
Err(e) => {
if Self::is_auth_error(&e) && attempt < MAX_AUTH_RETRIES {
warn!(
error = %e,
"Authentication error, refreshing token and retrying"
);
if let Err(refresh_err) = self.force_refresh_and_reconnect().await {
warn!(error = %refresh_err, "Failed to refresh token");
return Err(e);
}
last_error = Some(e);
continue;
}
return Err(e);
}
}
}
Err(last_error.unwrap_or_else(|| {
Error::new(
ErrorKind::Authentication,
"Parameterized query failed after token refresh",
)
}))
}
pub async fn execute_query_with_options(
&mut self,
sql: &str,
output_format: OutputFormat,
transfer_mode: TransferMode,
) -> Result<GrpcQueryResult> {
self.ensure_token_valid().await?;
let mut last_error = None;
for attempt in 0..=MAX_AUTH_RETRIES {
if attempt > 0 {
info!(
"Retrying query after token refresh (attempt {})",
attempt + 1
);
}
match self
.execute_query_internal(sql, output_format, transfer_mode)
.await
{
Ok(result) => return Ok(result),
Err(e) => {
if Self::is_auth_error(&e) && attempt < MAX_AUTH_RETRIES {
warn!(
error = %e,
"Authentication error, refreshing token and retrying"
);
if let Err(refresh_err) = self.force_refresh_and_reconnect().await {
warn!(error = %refresh_err, "Failed to refresh token");
return Err(e);
}
last_error = Some(e);
continue;
}
return Err(e);
}
}
}
Err(last_error.unwrap_or_else(|| {
Error::new(
ErrorKind::Authentication,
"Query failed after token refresh",
)
}))
}
pub async fn refresh_token(&mut self) -> Result<()> {
self.force_refresh_and_reconnect().await
}
pub async fn cancel_query(&mut self, query_id: &str) -> Result<()> {
self.ensure_token_valid().await?;
self.ensure_connected().await?;
for attempt in 0..=MAX_AUTH_RETRIES {
if attempt > 0 {
info!(
"Retrying cancel after token refresh (attempt {})",
attempt + 1
);
}
match self.cancel_query_internal(query_id).await {
Ok(()) => return Ok(()),
Err(e) => {
if Self::is_auth_error(&e) && attempt < MAX_AUTH_RETRIES {
warn!(error = %e, "Auth error during cancel, refreshing token");
if let Err(refresh_err) = self.force_refresh_and_reconnect().await {
warn!(error = %refresh_err, "Failed to refresh token");
return Err(e);
}
continue;
}
return Err(e);
}
}
}
Err(Error::new(
ErrorKind::Other,
"Cancel failed after DC JWT refresh",
))
}
#[expect(
clippy::unused_async,
reason = "async fn retained for API symmetry; callers await regardless of whether the current body is synchronous"
)]
pub async fn close(self) -> Result<()> {
debug!("Closing authenticated gRPC connection");
Ok(())
}
pub async fn list_schemas(&mut self) -> Result<Vec<String>> {
let query = r"
SELECT nspname
FROM pg_catalog.pg_namespace
WHERE nspname NOT IN ('pg_catalog', 'pg_temp', 'pg_toast', 'information_schema')
ORDER BY nspname
";
let result = self.execute_query(query).await?;
Self::extract_string_column(&result, 0)
}
pub async fn list_tables(&mut self) -> Result<Vec<TableInfo>> {
self.list_tables_with_limit(None).await
}
pub async fn list_tables_with_limit(&mut self, limit: Option<u32>) -> Result<Vec<TableInfo>> {
let limit_clause = limit.map(|l| format!("LIMIT {l}")).unwrap_or_default();
let query = format!(
r"
SELECT
n.nspname AS table_schema,
c.relname AS table_name,
CASE c.relkind
WHEN 'r' THEN 'TABLE'
WHEN 'v' THEN 'VIEW'
WHEN 'm' THEN 'MATERIALIZED VIEW'
ELSE 'OTHER'
END AS table_type
FROM
pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
WHERE
c.relkind IN ('r', 'v', 'm')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
ORDER BY
n.nspname, c.relname
{limit_clause}
"
);
let result = self.execute_query(&query).await?;
Self::extract_table_info(&result)
}
pub async fn list_tables_in_schema(&mut self, schema: &str) -> Result<Vec<String>> {
let query = format!(
r"
SELECT c.relname AS table_name
FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
WHERE c.relkind IN ('r', 'v', 'm')
AND n.nspname = '{}'
ORDER BY c.relname
",
schema.replace('\'', "''")
);
let result = self.execute_query(&query).await?;
Self::extract_string_column(&result, 0)
}
pub async fn has_table(&mut self, schema: &str, table: &str) -> Result<bool> {
let query = format!(
r"
SELECT 1
FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname = '{}' AND c.relname = '{}'
AND c.relkind IN ('r', 'v', 'm')
",
schema.replace('\'', "''"),
table.replace('\'', "''")
);
let result = self.execute_query(&query).await?;
Ok(!result.arrow_data().is_empty() && result.arrow_data().len() > 8)
}
fn extract_string_column(result: &GrpcQueryResult, column_idx: usize) -> Result<Vec<String>> {
use arrow::array::Array;
use arrow::ipc::reader::StreamReader;
use std::io::Cursor;
let arrow_data = result.arrow_data();
if arrow_data.is_empty() {
return Ok(Vec::new());
}
let reader = StreamReader::try_new(Cursor::new(arrow_data), None).map_err(|e| {
Error::new(
ErrorKind::Protocol,
format!("Failed to parse Arrow data: {e}"),
)
})?;
let mut values = Vec::new();
for batch_result in reader {
let batch = batch_result.map_err(|e| {
Error::new(
ErrorKind::Protocol,
format!("Failed to read Arrow batch: {e}"),
)
})?;
if let Some(arr) = batch
.column(column_idx)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
{
for i in 0..arr.len() {
if !arr.is_null(i) {
values.push(arr.value(i).to_string());
}
}
}
}
Ok(values)
}
fn extract_table_info(result: &GrpcQueryResult) -> Result<Vec<TableInfo>> {
use arrow::array::Array;
use arrow::ipc::reader::StreamReader;
use std::io::Cursor;
let arrow_data = result.arrow_data();
if arrow_data.is_empty() {
return Ok(Vec::new());
}
let reader = StreamReader::try_new(Cursor::new(arrow_data), None).map_err(|e| {
Error::new(
ErrorKind::Protocol,
format!("Failed to parse Arrow data: {e}"),
)
})?;
let mut tables = Vec::new();
for batch_result in reader {
let batch = batch_result.map_err(|e| {
Error::new(
ErrorKind::Protocol,
format!("Failed to read Arrow batch: {e}"),
)
})?;
let schema_col = batch
.column(0)
.as_any()
.downcast_ref::<arrow::array::StringArray>();
let name_col = batch
.column(1)
.as_any()
.downcast_ref::<arrow::array::StringArray>();
let type_col = batch
.column(2)
.as_any()
.downcast_ref::<arrow::array::StringArray>();
if let (Some(schemas), Some(names), Some(types)) = (schema_col, name_col, type_col) {
for i in 0..batch.num_rows() {
if !schemas.is_null(i) && !names.is_null(i) && !types.is_null(i) {
tables.push(TableInfo {
schema: schemas.value(i).to_string(),
name: names.value(i).to_string(),
table_type: types.value(i).to_string(),
display_name: None, });
}
}
}
}
Ok(tables)
}
pub async fn get_table_labels(
&mut self,
schema: &str,
) -> Result<std::collections::HashMap<String, String>> {
let mut labels = std::collections::HashMap::new();
let query = format!(
r"SELECT c.relname as table_name,
COALESCE(d.description, c.relname) as label
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_description d ON d.objoid = c.oid AND d.objsubid = 0
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname = '{schema}' AND c.relkind IN ('r', 'v', 'm')
ORDER BY c.relname"
);
let result = self.execute_query(&query).await?;
let reader = arrow::ipc::reader::StreamReader::try_new(
std::io::Cursor::new(result.arrow_data()),
None,
)
.map_err(|e| crate::client::Error::other(format!("Failed to parse Arrow data: {e}")))?;
#[expect(
clippy::manual_flatten,
reason = "explicit if let Ok matches the rest of the Arrow-IPC stream consumers in this module; refactoring to .flatten() on StreamReader would hide the error discard"
)]
for batch_result in reader {
if let Ok(batch) = batch_result {
if let (Some(name_arr), Some(label_arr)) = (
batch
.column(0)
.as_any()
.downcast_ref::<arrow::array::StringArray>(),
batch
.column(1)
.as_any()
.downcast_ref::<arrow::array::StringArray>(),
) {
for i in 0..batch.num_rows() {
use arrow::array::Array;
if !name_arr.is_null(i) && !label_arr.is_null(i) {
let table_name = name_arr.value(i).to_string();
let label_raw = label_arr.value(i);
let label = if label_raw.starts_with('{') {
if let Ok(value) =
serde_json::from_str::<serde_json::Value>(label_raw)
{
value
.get("displayName")
.and_then(|v| v.as_str())
.map_or_else(
|| label_raw.to_string(),
std::string::ToString::to_string,
)
} else {
label_raw.to_string()
}
} else {
label_raw.to_string()
};
labels.insert(table_name, label);
}
}
}
}
}
Ok(labels)
}
pub async fn get_column_labels(
&mut self,
schema: &str,
table: &str,
) -> Result<std::collections::HashMap<String, String>> {
let mut labels = std::collections::HashMap::new();
let query = format!(
r"SELECT a.attname as column_name,
COALESCE(d.description, a.attname) as label
FROM pg_catalog.pg_attribute a
LEFT JOIN pg_catalog.pg_description d ON d.objoid = a.attrelid AND d.objsubid = a.attnum
JOIN pg_catalog.pg_class c ON a.attrelid = c.oid
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname = '{schema}' AND c.relname = '{table}' AND a.attnum > 0 AND NOT a.attisdropped
ORDER BY a.attnum"
);
let result = self.execute_query(&query).await?;
let reader = arrow::ipc::reader::StreamReader::try_new(
std::io::Cursor::new(result.arrow_data()),
None,
)
.map_err(|e| crate::client::Error::other(format!("Failed to parse Arrow data: {e}")))?;
#[expect(
clippy::manual_flatten,
reason = "explicit if let Ok matches the rest of the Arrow-IPC stream consumers in this module; refactoring to .flatten() on StreamReader would hide the error discard"
)]
for batch_result in reader {
if let Ok(batch) = batch_result {
if let (Some(name_arr), Some(label_arr)) = (
batch
.column(0)
.as_any()
.downcast_ref::<arrow::array::StringArray>(),
batch
.column(1)
.as_any()
.downcast_ref::<arrow::array::StringArray>(),
) {
for i in 0..batch.num_rows() {
use arrow::array::Array;
if !name_arr.is_null(i) && !label_arr.is_null(i) {
let col_name = name_arr.value(i).to_string();
let label_raw = label_arr.value(i);
let label = if label_raw.starts_with('{') {
if let Ok(value) =
serde_json::from_str::<serde_json::Value>(label_raw)
{
value
.get("displayName")
.and_then(|v| v.as_str())
.map_or_else(
|| label_raw.to_string(),
std::string::ToString::to_string,
)
} else {
label_raw.to_string()
}
} else {
label_raw.to_string()
};
labels.insert(col_name, label);
}
}
}
}
}
Ok(labels)
}
async fn ensure_token_valid(&mut self) -> Result<()> {
let needs_refresh = match &self.current_token {
Some(token) => {
let needs = token
.needs_refresh(self.dc_jwt_expiry_threshold_secs, self.dc_jwt_max_age_secs);
if needs {
let age_secs = token.age().num_seconds();
let remaining_secs = token.remaining_lifetime().num_seconds();
debug!(
age_secs,
remaining_secs,
max_age_secs = self.dc_jwt_max_age_secs,
threshold_secs = self.dc_jwt_expiry_threshold_secs,
"DC JWT needs proactive refresh (expiring or too old)"
);
}
needs
}
None => true,
};
if needs_refresh {
self.force_refresh_and_reconnect().await?;
}
Ok(())
}
async fn ensure_connected(&mut self) -> Result<()> {
if self.channel.is_some() && self.current_token.is_some() {
return Ok(());
}
let token = self.token_provider.get_token().await.map_err(|e| {
Error::new(
ErrorKind::Authentication,
format!("Failed to get DC JWT: {e}"),
)
})?;
self.connect_to_tenant(&token).await?;
self.current_token = Some(token);
Ok(())
}
async fn force_refresh_and_reconnect(&mut self) -> Result<()> {
info!("Refreshing DC JWT");
let token = self.token_provider.refresh_token().await.map_err(|e| {
Error::new(
ErrorKind::Authentication,
format!("Failed to refresh DC JWT: {e}"),
)
})?;
self.connect_to_tenant(&token).await?;
self.current_token = Some(token);
info!("DC JWT refreshed and reconnected successfully");
Ok(())
}
async fn connect_to_tenant(&mut self, token: &DataCloudToken) -> Result<()> {
let tenant_url = token.tenant_url();
let hostname = tenant_url
.host_str()
.ok_or_else(|| Error::new(ErrorKind::Config, "No hostname in tenant URL"))?;
let grpc_endpoint = format!("https://{hostname}:443");
info!(endpoint = %grpc_endpoint, "Connecting to Data Cloud");
let endpoint = Endpoint::from_shared(grpc_endpoint.clone())
.map_err(|e| Error::new(ErrorKind::Config, format!("Invalid gRPC endpoint: {e}")))?;
let endpoint = endpoint
.connect_timeout(self.connect_timeout)
.timeout(self.request_timeout);
let tls_config = tonic::transport::ClientTlsConfig::new().with_enabled_roots();
let endpoint = endpoint
.tls_config(tls_config)
.map_err(|e| Error::new(ErrorKind::Config, format!("TLS configuration error: {e}")))?;
let channel = endpoint.connect().await.map_err(|e| {
Error::new(
ErrorKind::Connection,
format!("Failed to connect to {grpc_endpoint}: {e}"),
)
})?;
self.channel = Some(channel);
debug!("gRPC channel established");
Ok(())
}
async fn execute_query_internal(
&mut self,
sql: &str,
output_format: OutputFormat,
transfer_mode: TransferMode,
) -> Result<GrpcQueryResult> {
self.execute_query_with_params_internal(
sql,
None,
ParameterStyle::default(),
output_format,
transfer_mode,
)
.await
}
async fn execute_query_with_params_internal(
&mut self,
sql: &str,
params: impl Into<Option<QueryParameters>>,
style: ParameterStyle,
output_format: OutputFormat,
transfer_mode: TransferMode,
) -> Result<GrpcQueryResult> {
let channel = self
.channel
.as_ref()
.ok_or_else(|| Error::new(ErrorKind::Connection, "Not connected"))?;
let token = self
.current_token
.as_ref()
.ok_or_else(|| Error::new(ErrorKind::Authentication, "No token available"))?;
let params = params.into();
debug!(
sql = %sql,
has_params = params.is_some(),
format = ?output_format,
mode = ?transfer_mode,
"Executing query"
);
let lakehouse = token
.lakehouse_name(self.dataspace.as_deref())
.map_err(|e| {
Error::new(
ErrorKind::Authentication,
format!("Failed to get lakehouse name: {e}"),
)
})?;
let query_param = QueryParam {
query: sql.to_string(),
databases: vec![AttachedDatabase {
path: lakehouse,
alias: String::new(),
}],
output_format: output_format.into(),
settings: HashMap::new(),
transfer_mode: transfer_mode.into(),
param_style: i32::from(style),
parameters: params.map(super::params::QueryParameters::into_proto),
result_range: None,
query_row_limit: None,
};
let headers = vec![
("Authorization".to_string(), token.bearer_token()),
("audience".to_string(), token.tenant_url_str().to_string()),
];
let client = HyperServiceClient::new(channel.clone())
.max_decoding_message_size(self.max_decoding_message_size)
.max_encoding_message_size(self.max_encoding_message_size);
let mut executor = GrpcQueryExecutor::new(client, headers, transfer_mode);
executor.execute(query_param).await?;
let mut final_result = GrpcQueryResult::default();
loop {
if let Some(mut partial_result) = executor.next_result().await? {
while let Some(chunk) = partial_result.take_chunk() {
final_result.chunks.push_back(chunk);
}
if partial_result.query_id.is_some() {
final_result.query_id = partial_result.query_id;
}
if partial_result.schema.is_some() {
final_result.schema = partial_result.schema;
}
if partial_result.rows_affected.is_some() {
final_result.rows_affected = partial_result.rows_affected;
}
if partial_result.is_complete {
final_result.is_complete = true;
break;
}
} else {
final_result.is_complete = true;
break;
}
}
Ok(final_result)
}
async fn cancel_query_internal(&self, query_id: &str) -> Result<()> {
let channel = self
.channel
.as_ref()
.ok_or_else(|| Error::new(ErrorKind::Connection, "Not connected"))?;
let token = self
.current_token
.as_ref()
.ok_or_else(|| Error::new(ErrorKind::Authentication, "No token available"))?;
debug!(query_id = %query_id, "Cancelling query");
let param = CancelQueryParam {
query_id: query_id.to_string(),
};
let mut request = tonic::Request::new(param);
if let Ok(value) = query_id.parse() {
request.metadata_mut().insert("x-hyperdb-query-id", value);
}
request.metadata_mut().insert(
"authorization",
token
.bearer_token()
.parse()
.map_err(|_| Error::new(ErrorKind::Authentication, "Invalid token format"))?,
);
request.metadata_mut().insert(
"audience",
token
.tenant_url_str()
.parse()
.map_err(|_| Error::new(ErrorKind::Config, "Invalid tenant URL"))?,
);
let mut client = HyperServiceClient::new(channel.clone())
.max_decoding_message_size(self.max_decoding_message_size)
.max_encoding_message_size(self.max_encoding_message_size);
client
.cancel_query(request)
.await
.map_err(from_grpc_status)?;
info!(query_id = %query_id, "Query cancelled successfully");
Ok(())
}
fn is_auth_error(error: &Error) -> bool {
if matches!(error.kind(), ErrorKind::Authentication) {
return true;
}
let msg = error.to_string().to_lowercase();
msg.contains("unauthenticated") || msg.contains("unauthorized") || msg.contains("401")
}
}
impl std::fmt::Debug for AuthenticatedGrpcClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthenticatedGrpcClient")
.field("dataspace", &self.dataspace)
.field("has_channel", &self.channel.is_some())
.field("has_dc_jwt", &self.current_token.is_some())
.field(
"dc_jwt_expiry_threshold_secs",
&self.dc_jwt_expiry_threshold_secs,
)
.field("dc_jwt_max_age_secs", &self.dc_jwt_max_age_secs)
.finish_non_exhaustive()
}
}
#[derive(Debug)]
pub struct AuthenticatedGrpcClientSync {
inner: AuthenticatedGrpcClient,
runtime: tokio::runtime::Runtime,
}
impl AuthenticatedGrpcClientSync {
pub fn connect(token_provider: SharedTokenProvider, dataspace: Option<String>) -> Result<Self> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| Error::new(ErrorKind::Other, format!("Failed to create runtime: {e}")))?;
let inner =
runtime.block_on(AuthenticatedGrpcClient::connect(token_provider, dataspace))?;
Ok(AuthenticatedGrpcClientSync { inner, runtime })
}
pub fn execute_query(&mut self, sql: &str) -> Result<GrpcQueryResult> {
self.runtime.block_on(self.inner.execute_query(sql))
}
pub fn execute_query_to_arrow(&mut self, sql: &str) -> Result<bytes::Bytes> {
self.runtime
.block_on(self.inner.execute_query_to_arrow(sql))
}
pub fn execute_query_with_params(
&mut self,
sql: &str,
params: QueryParameters,
style: ParameterStyle,
) -> Result<GrpcQueryResult> {
self.runtime
.block_on(self.inner.execute_query_with_params(sql, params, style))
}
pub fn execute_query_with_params_to_arrow(
&mut self,
sql: &str,
params: QueryParameters,
style: ParameterStyle,
) -> Result<bytes::Bytes> {
self.runtime.block_on(
self.inner
.execute_query_with_params_to_arrow(sql, params, style),
)
}
pub fn refresh_token(&mut self) -> Result<()> {
self.runtime.block_on(self.inner.refresh_token())
}
pub fn cancel_query(&mut self, query_id: &str) -> Result<()> {
self.runtime.block_on(self.inner.cancel_query(query_id))
}
pub fn current_token(&self) -> Option<&DataCloudToken> {
self.inner.current_token()
}
pub fn close(self) -> Result<()> {
self.runtime.block_on(self.inner.close())
}
pub fn list_schemas(&mut self) -> Result<Vec<String>> {
self.runtime.block_on(self.inner.list_schemas())
}
pub fn list_tables(&mut self) -> Result<Vec<TableInfo>> {
self.runtime.block_on(self.inner.list_tables())
}
pub fn list_tables_with_limit(&mut self, limit: Option<u32>) -> Result<Vec<TableInfo>> {
self.runtime
.block_on(self.inner.list_tables_with_limit(limit))
}
pub fn list_tables_in_schema(&mut self, schema: &str) -> Result<Vec<String>> {
self.runtime
.block_on(self.inner.list_tables_in_schema(schema))
}
pub fn has_table(&mut self, schema: &str, table: &str) -> Result<bool> {
self.runtime.block_on(self.inner.has_table(schema, table))
}
pub fn get_table_labels(
&mut self,
schema: &str,
) -> Result<std::collections::HashMap<String, String>> {
self.runtime.block_on(self.inner.get_table_labels(schema))
}
pub fn get_column_labels(
&mut self,
schema: &str,
table: &str,
) -> Result<std::collections::HashMap<String, String>> {
self.runtime
.block_on(self.inner.get_column_labels(schema, table))
}
}