1use std::collections::HashMap;
10
11use async_trait::async_trait;
12use reqwest::Client;
13use serde::Deserialize;
14
15use crate::catalog_provider::*;
16
17#[derive(Debug, Clone)]
19pub struct UnityCatalogConfig {
20 pub base_url: String,
22 pub bearer_token: Option<String>,
24 pub timeout_secs: Option<u64>,
26}
27
28impl UnityCatalogConfig {
29 pub fn new(base_url: impl Into<String>) -> Self {
30 Self {
31 base_url: base_url.into().trim_end_matches('/').to_string(),
32 bearer_token: None,
33 timeout_secs: None,
34 }
35 }
36
37 pub fn with_token(mut self, token: impl Into<String>) -> Self {
38 self.bearer_token = Some(token.into());
39 self
40 }
41
42 pub fn with_timeout(mut self, secs: u64) -> Self {
43 self.timeout_secs = Some(secs);
44 self
45 }
46}
47
48pub struct UnityCatalogProvider {
50 config: UnityCatalogConfig,
51 client: Client,
52}
53
54impl UnityCatalogProvider {
55 pub fn new(config: UnityCatalogConfig) -> CatalogResult<Self> {
56 let mut builder = Client::builder();
57 if let Some(timeout) = config.timeout_secs {
58 builder = builder.timeout(std::time::Duration::from_secs(timeout));
59 }
60 let client = builder.build().map_err(|e| {
61 CatalogError::ConnectionError(format!("Failed to build HTTP client: {}", e))
62 })?;
63
64 Ok(Self { config, client })
65 }
66
67 fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
68 let url = format!("{}{}", self.config.base_url, path);
69 let mut req = self.client.request(method, &url);
70 if let Some(ref token) = self.config.bearer_token {
71 req = req.bearer_auth(token);
72 }
73 req
74 }
75
76 async fn handle_response<T: serde::de::DeserializeOwned>(
77 &self,
78 resp: reqwest::Response,
79 resource_name: &str,
80 ) -> CatalogResult<T> {
81 let status = resp.status();
82
83 if status == reqwest::StatusCode::NOT_FOUND {
84 return Err(CatalogError::NotFound(format!(
85 "{} not found",
86 resource_name
87 )));
88 }
89 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
90 let body = resp.text().await.unwrap_or_default();
91 return Err(CatalogError::AuthError(format!(
92 "HTTP {}: {}",
93 status, body
94 )));
95 }
96 if !status.is_success() {
97 let body = resp.text().await.unwrap_or_default();
98 return Err(CatalogError::ConnectionError(format!(
99 "HTTP {}: {}",
100 status, body
101 )));
102 }
103
104 resp.json::<T>()
105 .await
106 .map_err(|e| CatalogError::InvalidResponse(e.to_string()))
107 }
108}
109
110#[derive(Deserialize)]
113struct ListCatalogsResponse {
114 #[serde(default)]
115 catalogs: Vec<UcCatalog>,
116}
117
118#[derive(Deserialize)]
119struct UcCatalog {
120 name: String,
121 comment: Option<String>,
122 #[serde(default)]
123 properties: HashMap<String, String>,
124 created_at: Option<i64>,
125 updated_at: Option<i64>,
126}
127
128#[derive(Deserialize)]
129struct ListSchemasResponse {
130 #[serde(default)]
131 schemas: Vec<UcSchema>,
132}
133
134#[derive(Deserialize)]
135struct UcSchema {
136 name: String,
137 catalog_name: String,
138 comment: Option<String>,
139 #[serde(default)]
140 properties: HashMap<String, String>,
141 created_at: Option<i64>,
142 updated_at: Option<i64>,
143}
144
145#[derive(Deserialize)]
146struct ListTablesResponse {
147 #[serde(default)]
148 tables: Vec<UcTable>,
149}
150
151#[derive(Deserialize)]
152struct UcTable {
153 name: String,
154 catalog_name: String,
155 schema_name: String,
156 table_type: String,
157 data_source_format: Option<String>,
158 #[serde(default)]
159 columns: Vec<UcColumn>,
160 storage_location: Option<String>,
161 comment: Option<String>,
162 #[serde(default)]
163 properties: HashMap<String, String>,
164 created_at: Option<i64>,
165 updated_at: Option<i64>,
166}
167
168#[derive(Deserialize)]
169struct UcColumn {
170 name: String,
171 type_text: String,
172 type_name: String,
173 position: i32,
174 #[serde(default = "default_nullable")]
175 nullable: bool,
176 comment: Option<String>,
177}
178
179fn default_nullable() -> bool {
180 true
181}
182
183impl From<UcCatalog> for CatalogInfo {
186 fn from(uc: UcCatalog) -> Self {
187 Self {
188 name: uc.name,
189 comment: uc.comment,
190 properties: uc.properties,
191 created_at: uc.created_at,
192 updated_at: uc.updated_at,
193 }
194 }
195}
196
197impl From<UcSchema> for SchemaInfo {
198 fn from(uc: UcSchema) -> Self {
199 Self {
200 name: uc.name,
201 catalog_name: uc.catalog_name,
202 comment: uc.comment,
203 properties: uc.properties,
204 created_at: uc.created_at,
205 updated_at: uc.updated_at,
206 }
207 }
208}
209
210impl From<UcTable> for TableInfo {
211 fn from(uc: UcTable) -> Self {
212 Self {
213 name: uc.name,
214 catalog_name: uc.catalog_name,
215 schema_name: uc.schema_name,
216 table_type: match uc.table_type.as_str() {
217 "EXTERNAL" => TableType::External,
218 _ => TableType::Managed,
219 },
220 data_source_format: match uc.data_source_format.as_deref() {
221 Some("DELTA") => DataSourceFormat::Delta,
222 Some("PARQUET") => DataSourceFormat::Parquet,
223 Some("CSV") => DataSourceFormat::Csv,
224 Some("JSON") => DataSourceFormat::Json,
225 Some("AVRO") => DataSourceFormat::Avro,
226 Some("ORC") => DataSourceFormat::Orc,
227 Some("TEXT") => DataSourceFormat::Text,
228 Some(other) => DataSourceFormat::Other(other.to_string()),
229 None => DataSourceFormat::Other("UNKNOWN".to_string()),
230 },
231 columns: uc.columns.into_iter().map(Into::into).collect(),
232 storage_location: uc.storage_location,
233 comment: uc.comment,
234 properties: uc.properties,
235 created_at: uc.created_at,
236 updated_at: uc.updated_at,
237 }
238 }
239}
240
241impl From<UcColumn> for ColumnInfo {
242 fn from(uc: UcColumn) -> Self {
243 Self {
244 name: uc.name,
245 type_text: uc.type_text,
246 type_name: uc.type_name,
247 position: uc.position,
248 nullable: uc.nullable,
249 comment: uc.comment,
250 }
251 }
252}
253
254#[async_trait]
257impl CatalogProvider for UnityCatalogProvider {
258 fn name(&self) -> &str {
259 "unity-catalog"
260 }
261
262 async fn list_catalogs(&self) -> CatalogResult<Vec<CatalogInfo>> {
263 let resp = self
264 .request(reqwest::Method::GET, "/catalogs")
265 .send()
266 .await
267 .map_err(|e| CatalogError::ConnectionError(e.to_string()))?;
268
269 let body: ListCatalogsResponse = self.handle_response(resp, "catalogs").await?;
270 Ok(body.catalogs.into_iter().map(Into::into).collect())
271 }
272
273 async fn get_catalog(&self, name: &str) -> CatalogResult<CatalogInfo> {
274 let resp = self
275 .request(reqwest::Method::GET, &format!("/catalogs/{}", name))
276 .send()
277 .await
278 .map_err(|e| CatalogError::ConnectionError(e.to_string()))?;
279
280 let body: UcCatalog = self
281 .handle_response(resp, &format!("catalog '{}'", name))
282 .await?;
283 Ok(body.into())
284 }
285
286 async fn list_schemas(&self, catalog_name: &str) -> CatalogResult<Vec<SchemaInfo>> {
287 let resp = self
288 .request(reqwest::Method::GET, "/schemas")
289 .query(&[("catalog_name", catalog_name)])
290 .send()
291 .await
292 .map_err(|e| CatalogError::ConnectionError(e.to_string()))?;
293
294 let body: ListSchemasResponse = self
295 .handle_response(resp, &format!("schemas in '{}'", catalog_name))
296 .await?;
297 Ok(body.schemas.into_iter().map(Into::into).collect())
298 }
299
300 async fn get_schema(&self, catalog_name: &str, schema_name: &str) -> CatalogResult<SchemaInfo> {
301 let full_name = format!("{}.{}", catalog_name, schema_name);
302 let resp = self
303 .request(reqwest::Method::GET, &format!("/schemas/{}", full_name))
304 .send()
305 .await
306 .map_err(|e| CatalogError::ConnectionError(e.to_string()))?;
307
308 let body: UcSchema = self
309 .handle_response(resp, &format!("schema '{}'", full_name))
310 .await?;
311 Ok(body.into())
312 }
313
314 async fn list_tables(
315 &self,
316 catalog_name: &str,
317 schema_name: &str,
318 ) -> CatalogResult<Vec<TableInfo>> {
319 let resp = self
320 .request(reqwest::Method::GET, "/tables")
321 .query(&[("catalog_name", catalog_name), ("schema_name", schema_name)])
322 .send()
323 .await
324 .map_err(|e| CatalogError::ConnectionError(e.to_string()))?;
325
326 let body: ListTablesResponse = self
327 .handle_response(
328 resp,
329 &format!("tables in '{}.{}'", catalog_name, schema_name),
330 )
331 .await?;
332 Ok(body.tables.into_iter().map(Into::into).collect())
333 }
334
335 async fn get_table(
336 &self,
337 catalog_name: &str,
338 schema_name: &str,
339 table_name: &str,
340 ) -> CatalogResult<TableInfo> {
341 let full_name = format!("{}.{}.{}", catalog_name, schema_name, table_name);
342 let resp = self
343 .request(reqwest::Method::GET, &format!("/tables/{}", full_name))
344 .send()
345 .await
346 .map_err(|e| CatalogError::ConnectionError(e.to_string()))?;
347
348 let body: UcTable = self
349 .handle_response(resp, &format!("table '{}'", full_name))
350 .await?;
351 Ok(body.into())
352 }
353}