Skip to main content

lance_graph_catalog/
unity_catalog.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Unity Catalog REST API client implementing the [`CatalogProvider`] trait.
5//!
6//! Connects to an OSS Unity Catalog server and provides catalog browsing
7//! capabilities. This is the first `CatalogProvider` implementation.
8
9use std::collections::HashMap;
10
11use async_trait::async_trait;
12use reqwest::Client;
13use serde::Deserialize;
14
15use crate::catalog_provider::*;
16
17/// Configuration for connecting to a Unity Catalog server.
18#[derive(Debug, Clone)]
19pub struct UnityCatalogConfig {
20    /// Base URL of the UC server (e.g., `http://localhost:8080/api/2.1/unity-catalog`).
21    pub base_url: String,
22    /// Optional bearer token for authenticated access.
23    pub bearer_token: Option<String>,
24    /// Optional request timeout in seconds (default: 30).
25    pub timeout_secs: Option<u64>,
26}
27
28impl UnityCatalogConfig {
29    pub fn new(base_url: impl Into<String>) -> Self {
30        Self {
31            base_url: base_url.into().trim_end_matches('/').to_string(),
32            bearer_token: None,
33            timeout_secs: None,
34        }
35    }
36
37    pub fn with_token(mut self, token: impl Into<String>) -> Self {
38        self.bearer_token = Some(token.into());
39        self
40    }
41
42    pub fn with_timeout(mut self, secs: u64) -> Self {
43        self.timeout_secs = Some(secs);
44        self
45    }
46}
47
48/// Unity Catalog REST API client.
49pub struct UnityCatalogProvider {
50    config: UnityCatalogConfig,
51    client: Client,
52}
53
54impl UnityCatalogProvider {
55    pub fn new(config: UnityCatalogConfig) -> CatalogResult<Self> {
56        let mut builder = Client::builder();
57        if let Some(timeout) = config.timeout_secs {
58            builder = builder.timeout(std::time::Duration::from_secs(timeout));
59        }
60        let client = builder.build().map_err(|e| {
61            CatalogError::ConnectionError(format!("Failed to build HTTP client: {}", e))
62        })?;
63
64        Ok(Self { config, client })
65    }
66
67    fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
68        let url = format!("{}{}", self.config.base_url, path);
69        let mut req = self.client.request(method, &url);
70        if let Some(ref token) = self.config.bearer_token {
71            req = req.bearer_auth(token);
72        }
73        req
74    }
75
76    async fn handle_response<T: serde::de::DeserializeOwned>(
77        &self,
78        resp: reqwest::Response,
79        resource_name: &str,
80    ) -> CatalogResult<T> {
81        let status = resp.status();
82
83        if status == reqwest::StatusCode::NOT_FOUND {
84            return Err(CatalogError::NotFound(format!(
85                "{} not found",
86                resource_name
87            )));
88        }
89        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
90            let body = resp.text().await.unwrap_or_default();
91            return Err(CatalogError::AuthError(format!(
92                "HTTP {}: {}",
93                status, body
94            )));
95        }
96        if !status.is_success() {
97            let body = resp.text().await.unwrap_or_default();
98            return Err(CatalogError::ConnectionError(format!(
99                "HTTP {}: {}",
100                status, body
101            )));
102        }
103
104        resp.json::<T>()
105            .await
106            .map_err(|e| CatalogError::InvalidResponse(e.to_string()))
107    }
108}
109
110// ---- Serde models for UC REST API JSON responses ----
111
112#[derive(Deserialize)]
113struct ListCatalogsResponse {
114    #[serde(default)]
115    catalogs: Vec<UcCatalog>,
116}
117
118#[derive(Deserialize)]
119struct UcCatalog {
120    name: String,
121    comment: Option<String>,
122    #[serde(default)]
123    properties: HashMap<String, String>,
124    created_at: Option<i64>,
125    updated_at: Option<i64>,
126}
127
128#[derive(Deserialize)]
129struct ListSchemasResponse {
130    #[serde(default)]
131    schemas: Vec<UcSchema>,
132}
133
134#[derive(Deserialize)]
135struct UcSchema {
136    name: String,
137    catalog_name: String,
138    comment: Option<String>,
139    #[serde(default)]
140    properties: HashMap<String, String>,
141    created_at: Option<i64>,
142    updated_at: Option<i64>,
143}
144
145#[derive(Deserialize)]
146struct ListTablesResponse {
147    #[serde(default)]
148    tables: Vec<UcTable>,
149}
150
151#[derive(Deserialize)]
152struct UcTable {
153    name: String,
154    catalog_name: String,
155    schema_name: String,
156    table_type: String,
157    data_source_format: Option<String>,
158    #[serde(default)]
159    columns: Vec<UcColumn>,
160    storage_location: Option<String>,
161    comment: Option<String>,
162    #[serde(default)]
163    properties: HashMap<String, String>,
164    created_at: Option<i64>,
165    updated_at: Option<i64>,
166}
167
168#[derive(Deserialize)]
169struct UcColumn {
170    name: String,
171    type_text: String,
172    type_name: String,
173    position: i32,
174    #[serde(default = "default_nullable")]
175    nullable: bool,
176    comment: Option<String>,
177}
178
179fn default_nullable() -> bool {
180    true
181}
182
183// ---- Conversion helpers ----
184
185impl From<UcCatalog> for CatalogInfo {
186    fn from(uc: UcCatalog) -> Self {
187        Self {
188            name: uc.name,
189            comment: uc.comment,
190            properties: uc.properties,
191            created_at: uc.created_at,
192            updated_at: uc.updated_at,
193        }
194    }
195}
196
197impl From<UcSchema> for SchemaInfo {
198    fn from(uc: UcSchema) -> Self {
199        Self {
200            name: uc.name,
201            catalog_name: uc.catalog_name,
202            comment: uc.comment,
203            properties: uc.properties,
204            created_at: uc.created_at,
205            updated_at: uc.updated_at,
206        }
207    }
208}
209
210impl From<UcTable> for TableInfo {
211    fn from(uc: UcTable) -> Self {
212        Self {
213            name: uc.name,
214            catalog_name: uc.catalog_name,
215            schema_name: uc.schema_name,
216            table_type: match uc.table_type.as_str() {
217                "EXTERNAL" => TableType::External,
218                _ => TableType::Managed,
219            },
220            data_source_format: match uc.data_source_format.as_deref() {
221                Some("DELTA") => DataSourceFormat::Delta,
222                Some("PARQUET") => DataSourceFormat::Parquet,
223                Some("CSV") => DataSourceFormat::Csv,
224                Some("JSON") => DataSourceFormat::Json,
225                Some("AVRO") => DataSourceFormat::Avro,
226                Some("ORC") => DataSourceFormat::Orc,
227                Some("TEXT") => DataSourceFormat::Text,
228                Some(other) => DataSourceFormat::Other(other.to_string()),
229                None => DataSourceFormat::Other("UNKNOWN".to_string()),
230            },
231            columns: uc.columns.into_iter().map(Into::into).collect(),
232            storage_location: uc.storage_location,
233            comment: uc.comment,
234            properties: uc.properties,
235            created_at: uc.created_at,
236            updated_at: uc.updated_at,
237        }
238    }
239}
240
241impl From<UcColumn> for ColumnInfo {
242    fn from(uc: UcColumn) -> Self {
243        Self {
244            name: uc.name,
245            type_text: uc.type_text,
246            type_name: uc.type_name,
247            position: uc.position,
248            nullable: uc.nullable,
249            comment: uc.comment,
250        }
251    }
252}
253
254// ---- CatalogProvider implementation ----
255
256#[async_trait]
257impl CatalogProvider for UnityCatalogProvider {
258    fn name(&self) -> &str {
259        "unity-catalog"
260    }
261
262    async fn list_catalogs(&self) -> CatalogResult<Vec<CatalogInfo>> {
263        let resp = self
264            .request(reqwest::Method::GET, "/catalogs")
265            .send()
266            .await
267            .map_err(|e| CatalogError::ConnectionError(e.to_string()))?;
268
269        let body: ListCatalogsResponse = self.handle_response(resp, "catalogs").await?;
270        Ok(body.catalogs.into_iter().map(Into::into).collect())
271    }
272
273    async fn get_catalog(&self, name: &str) -> CatalogResult<CatalogInfo> {
274        let resp = self
275            .request(reqwest::Method::GET, &format!("/catalogs/{}", name))
276            .send()
277            .await
278            .map_err(|e| CatalogError::ConnectionError(e.to_string()))?;
279
280        let body: UcCatalog = self
281            .handle_response(resp, &format!("catalog '{}'", name))
282            .await?;
283        Ok(body.into())
284    }
285
286    async fn list_schemas(&self, catalog_name: &str) -> CatalogResult<Vec<SchemaInfo>> {
287        let resp = self
288            .request(reqwest::Method::GET, "/schemas")
289            .query(&[("catalog_name", catalog_name)])
290            .send()
291            .await
292            .map_err(|e| CatalogError::ConnectionError(e.to_string()))?;
293
294        let body: ListSchemasResponse = self
295            .handle_response(resp, &format!("schemas in '{}'", catalog_name))
296            .await?;
297        Ok(body.schemas.into_iter().map(Into::into).collect())
298    }
299
300    async fn get_schema(&self, catalog_name: &str, schema_name: &str) -> CatalogResult<SchemaInfo> {
301        let full_name = format!("{}.{}", catalog_name, schema_name);
302        let resp = self
303            .request(reqwest::Method::GET, &format!("/schemas/{}", full_name))
304            .send()
305            .await
306            .map_err(|e| CatalogError::ConnectionError(e.to_string()))?;
307
308        let body: UcSchema = self
309            .handle_response(resp, &format!("schema '{}'", full_name))
310            .await?;
311        Ok(body.into())
312    }
313
314    async fn list_tables(
315        &self,
316        catalog_name: &str,
317        schema_name: &str,
318    ) -> CatalogResult<Vec<TableInfo>> {
319        let resp = self
320            .request(reqwest::Method::GET, "/tables")
321            .query(&[("catalog_name", catalog_name), ("schema_name", schema_name)])
322            .send()
323            .await
324            .map_err(|e| CatalogError::ConnectionError(e.to_string()))?;
325
326        let body: ListTablesResponse = self
327            .handle_response(
328                resp,
329                &format!("tables in '{}.{}'", catalog_name, schema_name),
330            )
331            .await?;
332        Ok(body.tables.into_iter().map(Into::into).collect())
333    }
334
335    async fn get_table(
336        &self,
337        catalog_name: &str,
338        schema_name: &str,
339        table_name: &str,
340    ) -> CatalogResult<TableInfo> {
341        let full_name = format!("{}.{}.{}", catalog_name, schema_name, table_name);
342        let resp = self
343            .request(reqwest::Method::GET, &format!("/tables/{}", full_name))
344            .send()
345            .await
346            .map_err(|e| CatalogError::ConnectionError(e.to_string()))?;
347
348        let body: UcTable = self
349            .handle_response(resp, &format!("table '{}'", full_name))
350            .await?;
351        Ok(body.into())
352    }
353}