use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use futures_util::Stream;
use gcp_auth::TokenProvider;
use log::info;
use thiserror::Error;
use tokio::net::UnixStream;
use tonic::metadata::MetadataValue;
use tonic::transport::Endpoint;
use tonic::IntoRequest;
use tonic::{
codec::Streaming,
transport::{channel::Change, Channel, ClientTlsConfig},
Response,
};
use tower::ServiceBuilder;
use crate::auth_service::AuthSvc;
use crate::bigtable::read_rows::{decode_read_rows_response, decode_read_rows_response_stream};
use crate::{root_ca_certificate, util::get_row_range_from_prefix};
use googleapis_tonic_google_bigtable_v2::google::bigtable::v2::{
bigtable_client::BigtableClient, MutateRowRequest, MutateRowResponse, MutateRowsRequest,
MutateRowsResponse, ReadRowsRequest, RowSet, SampleRowKeysRequest, SampleRowKeysResponse,
};
use googleapis_tonic_google_bigtable_v2::google::bigtable::v2::{
CheckAndMutateRowRequest, CheckAndMutateRowResponse, ExecuteQueryRequest, ExecuteQueryResponse,
};
pub mod read_rows;
type RowKey = Vec<u8>;
type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub struct RowCell {
pub family_name: String,
pub qualifier: Vec<u8>,
pub value: Vec<u8>,
pub timestamp_micros: i64,
pub labels: Vec<String>,
}
#[derive(Debug, Error)]
pub enum Error {
#[error("AccessToken error: {0}")]
AccessTokenError(String),
#[error("Certificate error: {0}")]
CertificateError(String),
#[error("I/O Error: {0}")]
IoError(std::io::Error),
#[error("Transport error: {0}")]
TransportError(tonic::transport::Error),
#[error("Chunk error")]
ChunkError(String),
#[error("Row not found")]
RowNotFound,
#[error("Row write failed")]
RowWriteFailed,
#[error("Object not found: {0}")]
ObjectNotFound(String),
#[error("Object is corrupt: {0}")]
ObjectCorrupt(String),
#[error("RPC error: {0}")]
RpcError(tonic::Status),
#[error("Timeout error after {0} seconds")]
TimeoutError(u64),
#[error("GCPAuthError error: {0}")]
GCPAuthError(#[from] gcp_auth::Error),
#[error("Invalid metadata")]
MetadataError(tonic::metadata::errors::InvalidMetadataValue),
}
impl std::convert::From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Self {
Self::IoError(err)
}
}
impl std::convert::From<tonic::transport::Error> for Error {
fn from(err: tonic::transport::Error) -> Self {
Self::TransportError(err)
}
}
impl std::convert::From<tonic::Status> for Error {
fn from(err: tonic::Status) -> Self {
Self::RpcError(err)
}
}
#[derive(Clone)]
pub struct BigTableConnection {
client: BigtableClient<AuthSvc>,
table_prefix: Arc<String>,
instance_prefix: Arc<String>,
timeout: Arc<Option<Duration>>,
}
impl BigTableConnection {
pub async fn new(
project_id: &str,
instance_name: &str,
is_read_only: bool,
channel_size: usize,
timeout: Option<Duration>,
) -> Result<Self> {
match std::env::var("BIGTABLE_EMULATOR_HOST") {
Ok(endpoint) => Self::new_with_emulator(
endpoint.as_str(),
project_id,
instance_name,
is_read_only,
channel_size,
timeout,
),
Err(_) => {
let token_provider = gcp_auth::provider().await?;
Self::new_with_token_provider(
project_id,
instance_name,
is_read_only,
channel_size,
timeout,
token_provider,
)
}
}
}
pub fn new_with_token_provider(
project_id: &str,
instance_name: &str,
is_read_only: bool,
channel_size: usize,
timeout: Option<Duration>,
token_provider: Arc<dyn TokenProvider>,
) -> Result<Self> {
match std::env::var("BIGTABLE_EMULATOR_HOST") {
Ok(endpoint) => Self::new_with_emulator(
endpoint.as_str(),
project_id,
instance_name,
is_read_only,
channel_size,
timeout,
),
Err(_) => {
let instance_prefix = format!("projects/{project_id}/instances/{instance_name}");
let table_prefix = format!("{instance_prefix}/tables/");
let channel_size = channel_size.max(1);
let (channel, tx) = Channel::balance_channel(channel_size);
for i in 0..channel_size {
let endpoint = Channel::from_static("https://bigtable.googleapis.com")
.tls_config(
ClientTlsConfig::new()
.ca_certificate(
root_ca_certificate::load()
.map_err(Error::CertificateError)
.expect("root certificate error"),
)
.domain_name("bigtable.googleapis.com"),
)
.map_err(Error::TransportError)?
.http2_keep_alive_interval(Duration::from_secs(60))
.keep_alive_while_idle(true);
let endpoint = if let Some(timeout) = timeout {
endpoint.timeout(timeout)
} else {
endpoint
};
tx.try_send(Change::Insert(i, endpoint)).unwrap();
}
let token_provider = Some(token_provider);
Ok(Self {
client: create_client(channel, token_provider, is_read_only),
table_prefix: Arc::new(table_prefix),
instance_prefix: Arc::new(instance_prefix),
timeout: Arc::new(timeout),
})
}
}
}
pub fn new_with_emulator(
emulator_endpoint: &str,
project_id: &str,
instance_name: &str,
is_read_only: bool,
channel_size: usize,
timeout: Option<Duration>,
) -> Result<Self> {
info!("Connecting to bigtable emulator at {}", emulator_endpoint);
fn configure_endpoint(endpoint: Endpoint, timeout: Option<Duration>) -> Endpoint {
let endpoint = endpoint
.http2_keep_alive_interval(Duration::from_secs(60))
.keep_alive_while_idle(true);
if let Some(timeout) = timeout {
endpoint.timeout(timeout)
} else {
endpoint
}
}
let channel = if let Some(path) = emulator_endpoint.strip_prefix("unix://") {
let endpoint = Endpoint::from_static("http://[::]:50051");
let endpoint = configure_endpoint(endpoint, timeout);
let path: String = path.to_string();
let connector = tower::service_fn({
move |_: tonic::transport::Uri| {
let path = path.clone();
async move {
let stream = UnixStream::connect(path).await?;
Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new(stream))
}
}
});
endpoint.connect_with_connector_lazy(connector)
} else {
let channel_size = channel_size.max(1);
let (channel, tx) = Channel::balance_channel(channel_size);
for i in 0..channel_size {
let endpoint = Channel::from_shared(format!("http://{}", emulator_endpoint))
.expect("invalid connection emulator uri");
let endpoint = configure_endpoint(endpoint, timeout);
tx.try_send(Change::Insert(i, endpoint)).unwrap();
}
channel
};
Ok(Self {
client: create_client(channel, None, is_read_only),
table_prefix: Arc::new(format!(
"projects/{}/instances/{}/tables/",
project_id, instance_name
)),
instance_prefix: Arc::new(format!(
"projects/{}/instances/{}",
project_id, instance_name
)),
timeout: Arc::new(timeout),
})
}
pub fn client(&self) -> BigTable {
BigTable {
client: self.client.clone(),
instance_prefix: self.instance_prefix.clone(),
table_prefix: self.table_prefix.clone(),
timeout: self.timeout.clone(),
}
}
pub fn configure_inner_client(
&mut self,
config_fn: fn(BigtableClient<AuthSvc>) -> BigtableClient<AuthSvc>,
) {
self.client = config_fn(self.client.clone());
}
}
fn create_client(
channel: Channel,
token_provider: Option<Arc<dyn TokenProvider>>,
read_only: bool,
) -> BigtableClient<AuthSvc> {
let scopes = if read_only {
"https://www.googleapis.com/auth/bigtable.data.readonly"
} else {
"https://www.googleapis.com/auth/bigtable.data"
};
let auth_svc = ServiceBuilder::new()
.layer_fn(|c| AuthSvc::new(c, token_provider.clone(), scopes.to_string()))
.service(channel);
return BigtableClient::new(auth_svc);
}
#[derive(Clone)]
pub struct BigTable {
client: BigtableClient<AuthSvc>,
instance_prefix: Arc<String>,
table_prefix: Arc<String>,
timeout: Arc<Option<Duration>>,
}
impl BigTable {
pub async fn check_and_mutate_row(
&mut self,
request: CheckAndMutateRowRequest,
) -> Result<CheckAndMutateRowResponse> {
let response = self
.client
.check_and_mutate_row(request)
.await?
.into_inner();
Ok(response)
}
pub async fn read_rows(
&mut self,
request: ReadRowsRequest,
) -> Result<Vec<(RowKey, Vec<RowCell>)>> {
let response = self.client.read_rows(request).await?.into_inner();
decode_read_rows_response(self.timeout.as_ref(), response).await
}
pub async fn read_rows_with_prefix(
&mut self,
mut request: ReadRowsRequest,
prefix: Vec<u8>,
) -> Result<Vec<(RowKey, Vec<RowCell>)>> {
let row_range = get_row_range_from_prefix(prefix);
request.rows = Some(RowSet {
row_keys: vec![], row_ranges: vec![row_range],
});
let response = self.client.read_rows(request).await?.into_inner();
decode_read_rows_response(self.timeout.as_ref(), response).await
}
pub async fn stream_rows(
&mut self,
request: ReadRowsRequest,
) -> Result<impl Stream<Item = Result<(RowKey, Vec<RowCell>)>>> {
let response = self.client.read_rows(request).await?.into_inner();
let stream = decode_read_rows_response_stream(response).await;
Ok(stream)
}
pub async fn stream_rows_with_prefix(
&mut self,
mut request: ReadRowsRequest,
prefix: Vec<u8>,
) -> Result<impl Stream<Item = Result<(RowKey, Vec<RowCell>)>>> {
let row_range = get_row_range_from_prefix(prefix);
request.rows = Some(RowSet {
row_keys: vec![],
row_ranges: vec![row_range],
});
let response = self.client.read_rows(request).await?.into_inner();
let stream = decode_read_rows_response_stream(response).await;
Ok(stream)
}
pub async fn sample_row_keys(
&mut self,
request: SampleRowKeysRequest,
) -> Result<Streaming<SampleRowKeysResponse>> {
let response = self.client.sample_row_keys(request).await?.into_inner();
Ok(response)
}
pub async fn mutate_row(
&mut self,
request: MutateRowRequest,
) -> Result<Response<MutateRowResponse>> {
let response = self.client.mutate_row(request).await?;
Ok(response)
}
pub async fn mutate_rows(
&mut self,
request: MutateRowsRequest,
) -> Result<Streaming<MutateRowsResponse>> {
let response = self.client.mutate_rows(request).await?.into_inner();
Ok(response)
}
pub async fn execute_query(
&mut self,
request: ExecuteQueryRequest,
) -> Result<Streaming<ExecuteQueryResponse>> {
let app_profile_id = request.app_profile_id.clone();
let mut tonic_req: tonic::Request<_> = request.into_request();
tonic_req.metadata_mut().insert(
"x-goog-request-params",
MetadataValue::from_str(&format!(
"name={}&app_profile_id={}",
self.instance_prefix, app_profile_id
))
.map_err(Error::MetadataError)?,
);
let response = self.client.execute_query(tonic_req).await?.into_inner();
Ok(response)
}
pub fn get_client(&mut self) -> &mut BigtableClient<AuthSvc> {
&mut self.client
}
pub fn configure_inner_client(
&mut self,
config_fn: fn(BigtableClient<AuthSvc>) -> BigtableClient<AuthSvc>,
) {
self.client = config_fn(self.client.clone());
}
pub fn get_full_table_name(&self, table_name: &str) -> String {
[&self.table_prefix, table_name].concat()
}
}