#![warn(clippy::all)]
#![warn(rust_2018_idioms)]
#[cfg(not(any(feature = "aws", feature = "azure", feature = "gcp", feature = "r2")))]
compile_error!(
"At least one of the following crate features `aws`, `azure`, `gcp`, or `r2` must be enabled \
for this crate to function properly."
);
use reqwest::header::{HeaderValue, InvalidHeaderValue, AUTHORIZATION};
use std::str::FromStr;
use crate::credential::{
AzureCliCredential, ClientSecretOAuthProvider, CredentialProvider, WorkspaceOAuthProvider,
};
use crate::models::{
ErrorResponse, GetSchemaResponse, GetTableResponse, ListCatalogsResponse, ListSchemasResponse,
ListTableSummariesResponse, TableTempCredentialsResponse, TemporaryTableCredentialsRequest,
};
use deltalake_core::data_catalog::DataCatalogResult;
use deltalake_core::{DataCatalog, DataCatalogError};
use crate::client::retry::*;
use deltalake_core::storage::str_is_truthy;
pub mod client;
pub mod credential;
#[cfg(feature = "datafusion")]
pub mod datafusion;
pub mod models;
pub mod prelude;
#[derive(thiserror::Error, Debug)]
pub enum UnityCatalogError {
#[error("GET request error: {source}")]
RequestError {
#[from]
source: reqwest::Error,
},
#[error("Error in middleware: {source}")]
RequestMiddlewareError {
#[from]
source: reqwest_middleware::Error,
},
#[error("Invalid table error: {error_code}: {message}")]
InvalidTable {
error_code: String,
message: String,
},
#[error("Invalid token for auth header: {header_error}")]
InvalidHeader {
#[from]
header_error: InvalidHeaderValue,
},
#[error("Missing configuration key: {0}")]
MissingConfiguration(String),
#[error("Failed to get a credential from UnityCatalog client configuration.")]
MissingCredential,
#[error("Azure CLI error: {message}")]
AzureCli {
message: String,
},
#[error("Missing or corrupted federated token file for WorkloadIdentity.")]
FederatedTokenFile,
#[cfg(feature = "datafusion")]
#[error("Datafusion error: {0}")]
DatafusionError(#[from] datafusion_common::DataFusionError),
}
impl From<ErrorResponse> for UnityCatalogError {
fn from(value: ErrorResponse) -> Self {
UnityCatalogError::InvalidTable {
error_code: value.error_code,
message: value.message,
}
}
}
impl From<UnityCatalogError> for DataCatalogError {
fn from(value: UnityCatalogError) -> Self {
DataCatalogError::Generic {
catalog: "Unity",
source: Box::new(value),
}
}
}
pub enum UnityCatalogConfigKey {
#[deprecated(since = "0.17.0", note = "Please use the DATABRICKS_HOST env variable")]
WorkspaceUrl,
Host,
#[deprecated(
since = "0.17.0",
note = "Please use the DATABRICKS_TOKEN env variable"
)]
AccessToken,
Token,
ClientId,
ClientSecret,
AuthorityId,
AuthorityHost,
MsiEndpoint,
ObjectId,
MsiResourceId,
FederatedTokenFile,
UseAzureCli,
}
impl FromStr for UnityCatalogConfigKey {
type Err = DataCatalogError;
#[allow(deprecated)]
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"access_token"
| "unity_access_token"
| "databricks_access_token"
| "databricks_token" => Ok(UnityCatalogConfigKey::AccessToken),
"authority_host" | "unity_authority_host" | "databricks_authority_host" => {
Ok(UnityCatalogConfigKey::AuthorityHost)
}
"authority_id" | "unity_authority_id" | "databricks_authority_id" => {
Ok(UnityCatalogConfigKey::AuthorityId)
}
"client_id" | "unity_client_id" | "databricks_client_id" => {
Ok(UnityCatalogConfigKey::ClientId)
}
"client_secret" | "unity_client_secret" | "databricks_client_secret" => {
Ok(UnityCatalogConfigKey::ClientSecret)
}
"federated_token_file"
| "unity_federated_token_file"
| "databricks_federated_token_file" => Ok(UnityCatalogConfigKey::FederatedTokenFile),
"host" => Ok(UnityCatalogConfigKey::Host),
"msi_endpoint" | "unity_msi_endpoint" | "databricks_msi_endpoint" => {
Ok(UnityCatalogConfigKey::MsiEndpoint)
}
"msi_resource_id" | "unity_msi_resource_id" | "databricks_msi_resource_id" => {
Ok(UnityCatalogConfigKey::MsiResourceId)
}
"object_id" | "unity_object_id" | "databricks_object_id" => {
Ok(UnityCatalogConfigKey::ObjectId)
}
"token" => Ok(UnityCatalogConfigKey::Token),
"use_azure_cli" | "unity_use_azure_cli" | "databricks_use_azure_cli" => {
Ok(UnityCatalogConfigKey::UseAzureCli)
}
"workspace_url"
| "unity_workspace_url"
| "databricks_workspace_url"
| "databricks_host" => Ok(UnityCatalogConfigKey::WorkspaceUrl),
_ => Err(DataCatalogError::UnknownConfigKey {
catalog: "unity",
key: s.to_string(),
}),
}
}
}
#[allow(deprecated)]
impl AsRef<str> for UnityCatalogConfigKey {
fn as_ref(&self) -> &str {
match self {
UnityCatalogConfigKey::AccessToken => "unity_access_token",
UnityCatalogConfigKey::AuthorityHost => "unity_authority_host",
UnityCatalogConfigKey::AuthorityId => "unity_authority_id",
UnityCatalogConfigKey::ClientId => "unity_client_id",
UnityCatalogConfigKey::ClientSecret => "unity_client_secret",
UnityCatalogConfigKey::FederatedTokenFile => "unity_federated_token_file",
UnityCatalogConfigKey::Host => "databricks_host",
UnityCatalogConfigKey::MsiEndpoint => "unity_msi_endpoint",
UnityCatalogConfigKey::MsiResourceId => "unity_msi_resource_id",
UnityCatalogConfigKey::ObjectId => "unity_object_id",
UnityCatalogConfigKey::UseAzureCli => "unity_use_azure_cli",
UnityCatalogConfigKey::Token => "databricks_token",
UnityCatalogConfigKey::WorkspaceUrl => "unity_workspace_url",
}
}
}
#[derive(Default)]
pub struct UnityCatalogBuilder {
workspace_url: Option<String>,
bearer_token: Option<String>,
client_id: Option<String>,
client_secret: Option<String>,
authority_id: Option<String>,
authority_host: Option<String>,
msi_endpoint: Option<String>,
object_id: Option<String>,
msi_resource_id: Option<String>,
federated_token_file: Option<String>,
use_azure_cli: bool,
retry_config: RetryConfig,
client_options: client::ClientOptions,
}
#[allow(deprecated)]
impl UnityCatalogBuilder {
pub fn new() -> Self {
Default::default()
}
pub fn try_with_option(
mut self,
key: impl AsRef<str>,
value: impl Into<String>,
) -> DataCatalogResult<Self> {
match UnityCatalogConfigKey::from_str(key.as_ref())? {
UnityCatalogConfigKey::AccessToken => self.bearer_token = Some(value.into()),
UnityCatalogConfigKey::ClientId => self.client_id = Some(value.into()),
UnityCatalogConfigKey::ClientSecret => self.client_secret = Some(value.into()),
UnityCatalogConfigKey::AuthorityId => self.authority_id = Some(value.into()),
UnityCatalogConfigKey::AuthorityHost => self.authority_host = Some(value.into()),
UnityCatalogConfigKey::Host => self.workspace_url = Some(value.into()),
UnityCatalogConfigKey::MsiEndpoint => self.msi_endpoint = Some(value.into()),
UnityCatalogConfigKey::ObjectId => self.object_id = Some(value.into()),
UnityCatalogConfigKey::MsiResourceId => self.msi_resource_id = Some(value.into()),
UnityCatalogConfigKey::FederatedTokenFile => {
self.federated_token_file = Some(value.into())
}
UnityCatalogConfigKey::Token => self.bearer_token = Some(value.into()),
UnityCatalogConfigKey::UseAzureCli => self.use_azure_cli = str_is_truthy(&value.into()),
UnityCatalogConfigKey::WorkspaceUrl => self.workspace_url = Some(value.into()),
};
Ok(self)
}
pub fn try_with_options<I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
mut self,
options: I,
) -> DataCatalogResult<Self> {
for (key, value) in options {
self = self.try_with_option(key, value)?;
}
Ok(self)
}
pub fn from_env() -> Self {
let mut builder = Self::default();
for (os_key, os_value) in std::env::vars_os() {
if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
if key.starts_with("UNITY_") || key.starts_with("DATABRICKS_") {
tracing::debug!("Found relevant env: {}", key);
if let Ok(config_key) =
UnityCatalogConfigKey::from_str(&key.to_ascii_lowercase())
{
tracing::debug!("Trying: {} with {}", key, value);
builder = builder.try_with_option(config_key, value).unwrap();
}
}
}
}
builder
}
pub fn with_workspace_url(mut self, url: impl Into<String>) -> Self {
self.workspace_url = Some(url.into());
self
}
pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
self.client_id = Some(client_id.into());
self
}
pub fn with_client_secret(mut self, client_secret: impl Into<String>) -> Self {
self.client_secret = Some(client_secret.into());
self
}
pub fn with_authority_id(mut self, tenant_id: impl Into<String>) -> Self {
self.authority_id = Some(tenant_id.into());
self
}
pub fn with_bearer_token(mut self, bearer_token: impl Into<String>) -> Self {
self.bearer_token = Some(bearer_token.into());
self
}
pub fn with_access_token(self, access_token: impl Into<String>) -> Self {
self.with_bearer_token(access_token)
}
pub fn with_client_options(mut self, options: client::ClientOptions) -> Self {
self.client_options = options;
self
}
pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
self.retry_config = config;
self
}
fn get_credential_provider(&self) -> Option<CredentialProvider> {
if let Some(token) = self.bearer_token.as_ref() {
return Some(CredentialProvider::BearerToken(token.clone()));
}
if let (Some(client_id), Some(client_secret), Some(workspace_host)) =
(&self.client_id, &self.client_secret, &self.workspace_url)
{
return Some(CredentialProvider::TokenCredential(
Default::default(),
Box::new(WorkspaceOAuthProvider::new(
client_id,
client_secret,
workspace_host,
)),
));
}
if let (Some(client_id), Some(client_secret), Some(authority_id)) = (
self.client_id.as_ref(),
self.client_secret.as_ref(),
self.authority_id.as_ref(),
) {
return Some(CredentialProvider::TokenCredential(
Default::default(),
Box::new(ClientSecretOAuthProvider::new(
client_id,
client_secret,
authority_id,
self.authority_host.as_ref(),
)),
));
}
if self.use_azure_cli {
return Some(CredentialProvider::TokenCredential(
Default::default(),
Box::new(AzureCliCredential::new()),
));
}
None
}
pub fn build(self) -> DataCatalogResult<UnityCatalog> {
let credential = self
.get_credential_provider()
.ok_or(UnityCatalogError::MissingCredential)?;
let workspace_url = self
.workspace_url
.ok_or(UnityCatalogError::MissingConfiguration(
"workspace_url".into(),
))?
.trim_end_matches('/')
.to_string();
let client = self.client_options.client()?;
Ok(UnityCatalog {
client,
workspace_url,
credential,
})
}
}
pub struct UnityCatalog {
client: reqwest_middleware::ClientWithMiddleware,
credential: CredentialProvider,
workspace_url: String,
}
impl UnityCatalog {
async fn get_credential(&self) -> Result<HeaderValue, UnityCatalogError> {
match &self.credential {
CredentialProvider::BearerToken(token) => {
Ok(HeaderValue::from_str(&format!("Bearer {token}"))?)
}
CredentialProvider::TokenCredential(cache, cred) => {
let token = cache
.get_or_insert_with(|| cred.fetch_token(&self.client))
.await?;
Ok(HeaderValue::from_str(&format!("Bearer {token}"))?)
}
}
}
fn catalog_url(&self) -> String {
format!("{}/api/2.1/unity-catalog", &self.workspace_url)
}
pub async fn list_catalogs(&self) -> Result<ListCatalogsResponse, UnityCatalogError> {
let token = self.get_credential().await?;
let resp = self
.client
.get(format!("{}/catalogs", self.catalog_url()))
.header(AUTHORIZATION, token)
.send()
.await?;
Ok(resp.json().await?)
}
pub async fn list_schemas(
&self,
catalog_name: impl AsRef<str>,
) -> Result<ListSchemasResponse, UnityCatalogError> {
let token = self.get_credential().await?;
let resp = self
.client
.get(format!("{}/schemas", self.catalog_url()))
.header(AUTHORIZATION, token)
.query(&[("catalog_name", catalog_name.as_ref())])
.send()
.await?;
Ok(resp.json().await?)
}
pub async fn get_schema(
&self,
catalog_name: impl AsRef<str>,
schema_name: impl AsRef<str>,
) -> Result<GetSchemaResponse, UnityCatalogError> {
let token = self.get_credential().await?;
let resp = self
.client
.get(format!(
"{}/schemas/{}.{}",
self.catalog_url(),
catalog_name.as_ref(),
schema_name.as_ref()
))
.header(AUTHORIZATION, token)
.send()
.await?;
Ok(resp.json().await?)
}
pub async fn list_table_summaries(
&self,
catalog_name: impl AsRef<str>,
schema_name_pattern: impl AsRef<str>,
) -> Result<ListTableSummariesResponse, UnityCatalogError> {
let token = self.get_credential().await?;
let resp = self
.client
.get(format!("{}/table-summaries", self.catalog_url()))
.query(&[
("catalog_name", catalog_name.as_ref()),
("schema_name_pattern", schema_name_pattern.as_ref()),
])
.header(AUTHORIZATION, token)
.send()
.await?;
Ok(resp.json().await?)
}
pub async fn get_table(
&self,
catalog_id: impl AsRef<str>,
database_name: impl AsRef<str>,
table_name: impl AsRef<str>,
) -> Result<GetTableResponse, UnityCatalogError> {
let token = self.get_credential().await?;
let resp = self
.client
.get(format!(
"{}/tables/{}.{}.{}",
self.catalog_url(),
catalog_id.as_ref(),
database_name.as_ref(),
table_name.as_ref()
))
.header(AUTHORIZATION, token)
.send()
.await?;
Ok(resp.json().await?)
}
pub async fn get_temp_table_credentials(
&self,
catalog_id: impl AsRef<str>,
database_name: impl AsRef<str>,
table_name: impl AsRef<str>,
) -> Result<TableTempCredentialsResponse, UnityCatalogError> {
let token = self.get_credential().await?;
let table_info = self
.get_table(catalog_id, database_name, table_name)
.await?;
let response = match table_info {
GetTableResponse::Success(table) => {
let request = TemporaryTableCredentialsRequest::new(&table.table_id, "READ");
Ok(self
.client
.post(format!(
"{}/temporary-table-credentials",
self.catalog_url()
))
.header(AUTHORIZATION, token)
.json(&request)
.send()
.await?)
}
GetTableResponse::Error(err) => Err(UnityCatalogError::InvalidTable {
error_code: err.error_code,
message: err.message,
}),
}?;
Ok(response.json().await?)
}
}
#[async_trait::async_trait]
impl DataCatalog for UnityCatalog {
type Error = UnityCatalogError;
async fn get_table_storage_location(
&self,
catalog_id: Option<String>,
database_name: &str,
table_name: &str,
) -> Result<String, UnityCatalogError> {
match self
.get_table(
catalog_id.unwrap_or("main".into()),
database_name,
table_name,
)
.await?
{
GetTableResponse::Success(table) => Ok(table.storage_location),
GetTableResponse::Error(err) => Err(UnityCatalogError::InvalidTable {
error_code: err.error_code,
message: err.message,
}),
}
}
}
impl std::fmt::Debug for UnityCatalog {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "UnityCatalog")
}
}
#[cfg(test)]
mod tests {
use crate::client::ClientOptions;
use crate::models::tests::{GET_SCHEMA_RESPONSE, GET_TABLE_RESPONSE, LIST_SCHEMAS_RESPONSE};
use crate::models::*;
use crate::UnityCatalogBuilder;
use httpmock::prelude::*;
#[tokio::test]
async fn test_unity_client() {
let server = MockServer::start_async().await;
let options = ClientOptions::default().with_allow_http(true);
let client = UnityCatalogBuilder::new()
.with_workspace_url(server.url(""))
.with_bearer_token("bearer_token")
.with_client_options(options)
.build()
.unwrap();
server
.mock_async(|when, then| {
when.path("/api/2.1/unity-catalog/schemas").method("GET");
then.body(LIST_SCHEMAS_RESPONSE);
})
.await;
server
.mock_async(|when, then| {
when.path("/api/2.1/unity-catalog/schemas/catalog_name.schema_name")
.method("GET");
then.body(GET_SCHEMA_RESPONSE);
})
.await;
server
.mock_async(|when, then| {
when.path("/api/2.1/unity-catalog/tables/catalog_name.schema_name.table_name")
.method("GET");
then.body(GET_TABLE_RESPONSE);
})
.await;
let list_schemas_response = client.list_schemas("catalog_name").await.unwrap();
assert!(matches!(
list_schemas_response,
ListSchemasResponse::Success { .. }
));
let get_schema_response = client
.get_schema("catalog_name", "schema_name")
.await
.unwrap();
assert!(matches!(get_schema_response, GetSchemaResponse::Success(_)));
let get_table_response = client
.get_table("catalog_name", "schema_name", "table_name")
.await;
assert!(matches!(
get_table_response.unwrap(),
GetTableResponse::Success(_)
));
}
}