use std::collections::HashMap;
use crate::api::rest_client::HttpClient;
use crate::catalog::Identifier;
use crate::common::{CatalogOptions, Options};
use crate::spec::{PartitionStatistics, Schema, Snapshot};
use crate::Result;
use super::api_request::{
AlterDatabaseRequest, CreateDatabaseRequest, CreateTableRequest, RenameTableRequest,
};
use super::api_response::{
ConfigResponse, GetDatabaseResponse, GetTableResponse, ListDatabasesResponse,
ListTablesResponse, PagedList,
};
use super::auth::{AuthProviderFactory, RESTAuthFunction};
use super::resource_paths::ResourcePaths;
use super::rest_util::RESTUtil;
fn validate_non_empty(value: &str, field_name: &str) -> Result<()> {
if value.trim().is_empty() {
return Err(crate::Error::ConfigInvalid {
message: format!("{field_name} cannot be empty"),
});
}
Ok(())
}
fn validate_non_empty_multi(values: &[(&str, &str)]) -> Result<()> {
for (value, field_name) in values {
validate_non_empty(value, field_name)?;
}
Ok(())
}
pub struct RESTApi {
client: HttpClient,
resource_paths: ResourcePaths,
options: Options,
}
impl RESTApi {
pub const HEADER_PREFIX: &'static str = "header.";
pub const MAX_RESULTS: &'static str = "maxResults";
pub const PAGE_TOKEN: &'static str = "pageToken";
pub const DATABASE_NAME_PATTERN: &'static str = "databaseNamePattern";
pub const TABLE_NAME_PATTERN: &'static str = "tableNamePattern";
pub const TABLE_TYPE: &'static str = "tableType";
pub async fn new(options: Options, config_required: bool) -> Result<Self> {
let uri = options
.get(CatalogOptions::URI)
.ok_or_else(|| crate::Error::ConfigInvalid {
message: "URI cannot be empty".to_string(),
})?;
if uri.trim().is_empty() {
return Err(crate::Error::ConfigInvalid {
message: "URI cannot be empty".to_string(),
});
}
let auth_provider = AuthProviderFactory::create_auth_provider(&options)?;
let mut base_headers: HashMap<String, String> =
RESTUtil::extract_prefix_map(&options, Self::HEADER_PREFIX);
let rest_auth_function = RESTAuthFunction::new(base_headers.clone(), auth_provider);
let mut client = HttpClient::new(uri, Some(rest_auth_function))?;
let options = if config_required {
let warehouse = options.get(CatalogOptions::WAREHOUSE).ok_or_else(|| {
crate::Error::ConfigInvalid {
message: "Warehouse name cannot be empty".to_string(),
}
})?;
if warehouse.trim().is_empty() {
return Err(crate::Error::ConfigInvalid {
message: "Warehouse name cannot be empty".to_string(),
});
}
let query_params: Vec<(&str, String)> = vec![(
CatalogOptions::WAREHOUSE,
RESTUtil::encode_string(warehouse),
)];
let config_response: ConfigResponse = client
.get(&ResourcePaths::config(), Some(&query_params))
.await?;
let merged = config_response.merge_options(&options);
base_headers.extend(RESTUtil::extract_prefix_map(&merged, Self::HEADER_PREFIX));
let auth_provider = AuthProviderFactory::create_auth_provider(&merged)?;
let rest_auth_function = RESTAuthFunction::new(base_headers, auth_provider);
client.set_auth_function(rest_auth_function);
merged
} else {
options
};
let resource_paths = ResourcePaths::for_catalog_properties(&options);
Ok(RESTApi {
client,
resource_paths,
options,
})
}
pub fn options(&self) -> &Options {
&self.options
}
pub async fn list_databases(&self) -> Result<Vec<String>> {
let mut results = Vec::new();
let mut page_token: Option<String> = None;
loop {
let paged = self
.list_databases_paged(None, page_token.as_deref(), None)
.await?;
let is_empty = paged.elements.is_empty();
results.extend(paged.elements);
page_token = paged.next_page_token;
if page_token.is_none() || is_empty {
break;
}
}
Ok(results)
}
pub async fn list_databases_paged(
&self,
max_results: Option<u32>,
page_token: Option<&str>,
database_name_pattern: Option<&str>,
) -> Result<PagedList<String>> {
let path = self.resource_paths.databases();
let mut params: Vec<(&str, String)> = Vec::new();
if let Some(max) = max_results {
params.push((Self::MAX_RESULTS, max.to_string()));
}
if let Some(token) = page_token {
params.push((Self::PAGE_TOKEN, token.to_string()));
}
if let Some(pattern) = database_name_pattern {
params.push((Self::DATABASE_NAME_PATTERN, pattern.to_string()));
}
let response: ListDatabasesResponse = if params.is_empty() {
self.client.get(&path, None::<&[(&str, &str)]>).await?
} else {
self.client.get(&path, Some(¶ms)).await?
};
Ok(PagedList::new(response.databases, response.next_page_token))
}
pub async fn create_database(
&self,
name: &str,
options: Option<HashMap<String, String>>,
) -> Result<()> {
validate_non_empty(name, "database name")?;
let path = self.resource_paths.databases();
let request = CreateDatabaseRequest::new(name.to_string(), options.unwrap_or_default());
let _resp: serde_json::Value = self.client.post(&path, &request).await?;
Ok(())
}
pub async fn get_database(&self, name: &str) -> Result<GetDatabaseResponse> {
validate_non_empty(name, "database name")?;
let path = self.resource_paths.database(name);
self.client.get(&path, None::<&[(&str, &str)]>).await
}
pub async fn alter_database(
&self,
name: &str,
removals: Vec<String>,
updates: HashMap<String, String>,
) -> Result<()> {
validate_non_empty(name, "database name")?;
let path = self.resource_paths.database(name);
let request = AlterDatabaseRequest::new(removals, updates);
let _resp: serde_json::Value = self.client.post(&path, &request).await?;
Ok(())
}
pub async fn drop_database(&self, name: &str) -> Result<()> {
validate_non_empty(name, "database name")?;
let path = self.resource_paths.database(name);
let _resp: serde_json::Value = self.client.delete(&path, None::<&[(&str, &str)]>).await?;
Ok(())
}
pub async fn list_tables(&self, database: &str) -> Result<Vec<String>> {
validate_non_empty(database, "database name")?;
let mut results = Vec::new();
let mut page_token: Option<String> = None;
loop {
let paged = self
.list_tables_paged(database, None, page_token.as_deref(), None, None)
.await?;
let is_empty = paged.elements.is_empty();
results.extend(paged.elements);
page_token = paged.next_page_token;
if page_token.is_none() || is_empty {
break;
}
}
Ok(results)
}
pub async fn list_tables_paged(
&self,
database: &str,
max_results: Option<u32>,
page_token: Option<&str>,
table_name_pattern: Option<&str>,
table_type: Option<&str>,
) -> Result<PagedList<String>> {
validate_non_empty(database, "database name")?;
let path = self.resource_paths.tables(Some(database));
let mut params: Vec<(&str, String)> = Vec::new();
if let Some(max) = max_results {
params.push((Self::MAX_RESULTS, max.to_string()));
}
if let Some(token) = page_token {
params.push((Self::PAGE_TOKEN, token.to_string()));
}
if let Some(pattern) = table_name_pattern {
params.push((Self::TABLE_NAME_PATTERN, pattern.to_string()));
}
if let Some(ttype) = table_type {
params.push((Self::TABLE_TYPE, ttype.to_string()));
}
let response: ListTablesResponse = if params.is_empty() {
self.client.get(&path, None::<&[(&str, &str)]>).await?
} else {
self.client.get(&path, Some(¶ms)).await?
};
Ok(PagedList::new(
response.tables.unwrap_or_default(),
response.next_page_token,
))
}
pub async fn create_table(&self, identifier: &Identifier, schema: Schema) -> Result<()> {
let database = identifier.database();
let table = identifier.object();
validate_non_empty_multi(&[(database, "database name"), (table, "table name")])?;
let path = self.resource_paths.tables(Some(database));
let request = CreateTableRequest::new(identifier.clone(), schema);
let _resp: serde_json::Value = self.client.post(&path, &request).await?;
Ok(())
}
pub async fn get_table(&self, identifier: &Identifier) -> Result<GetTableResponse> {
let database = identifier.database();
let table = identifier.object();
validate_non_empty_multi(&[(database, "database name"), (table, "table name")])?;
let path = self.resource_paths.table(database, table);
self.client.get(&path, None::<&[(&str, &str)]>).await
}
pub async fn rename_table(&self, source: &Identifier, destination: &Identifier) -> Result<()> {
validate_non_empty_multi(&[
(source.database(), "source database name"),
(source.object(), "source table name"),
(destination.database(), "destination database name"),
(destination.object(), "destination table name"),
])?;
let path = self.resource_paths.rename_table();
let request = RenameTableRequest::new(source.clone(), destination.clone());
let _resp: serde_json::Value = self.client.post(&path, &request).await?;
Ok(())
}
pub async fn drop_table(&self, identifier: &Identifier) -> Result<()> {
let database = identifier.database();
let table = identifier.object();
validate_non_empty_multi(&[(database, "database name"), (table, "table name")])?;
let path = self.resource_paths.table(database, table);
let _resp: serde_json::Value = self.client.delete(&path, None::<&[(&str, &str)]>).await?;
Ok(())
}
pub async fn load_table_token(
&self,
identifier: &Identifier,
) -> Result<super::api_response::GetTableTokenResponse> {
let database = identifier.database();
let table = identifier.object();
validate_non_empty_multi(&[(database, "database name"), (table, "table name")])?;
let path = self.resource_paths.table_token(database, table);
self.client.get(&path, None::<&[(&str, &str)]>).await
}
pub async fn commit_snapshot(
&self,
identifier: &Identifier,
table_uuid: &str,
snapshot: &Snapshot,
statistics: &[PartitionStatistics],
) -> Result<bool> {
let database = identifier.database();
let table = identifier.object();
validate_non_empty_multi(&[(database, "database name"), (table, "table name")])?;
let path = self.resource_paths.commit_table(database, table);
let request = serde_json::json!({
"tableUuid": table_uuid,
"snapshot": snapshot,
"statistics": statistics,
});
let resp: serde_json::Value = self.client.post(&path, &request).await?;
Ok(resp
.get("success")
.and_then(|v| v.as_bool())
.unwrap_or(false))
}
}