1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use arrow_array::RecordBatch;
6use arrow_array::RecordBatchReader;
7use arrow_schema::{Schema as ArrowSchema, SchemaRef};
8use async_trait::async_trait;
9use chrono::{DateTime, TimeZone, Utc};
10use futures::StreamExt;
11use lance::Dataset as InnerLance;
12use lance::dataset::ProjectionRequest;
13use lance_index::DatasetIndexExt as _;
14
15use crate::Result;
16use crate::cli::LanceArgs;
17use crate::dataset::{
18 BatchStream, BranchInfo, Dataset, IndexInfo, LanceCapabilities, TagInfo, VersionInfo,
19};
20use crate::error::Error;
21
22const MAIN_BRANCH: &str = "main";
23
24#[derive(Debug)]
25pub struct LanceDataset {
26 inner: InnerLance,
27 origin: PathBuf,
28 arrow_schema: SchemaRef,
29}
30
31impl LanceDataset {
32 pub async fn open(path: &Path, lance: Option<&LanceArgs>) -> Result<Self> {
33 let uri = path.to_string_lossy().into_owned();
34 let inner = InnerLance::open(&uri).await.map_err(|e| Error::LanceOpen {
35 path: path.to_path_buf(),
36 source: Box::new(e),
37 })?;
38 let inner = apply_checkout(inner, lance).await?;
39 let arrow_schema: SchemaRef = Arc::new(ArrowSchema::from(inner.schema()));
40 Ok(Self {
41 inner,
42 origin: path.to_path_buf(),
43 arrow_schema,
44 })
45 }
46
47 fn projection_request(&self, projection: Option<&[String]>) -> ProjectionRequest {
48 match projection {
49 Some(cols) => ProjectionRequest::from_columns(cols.iter(), self.inner.schema()),
50 None => ProjectionRequest::from_schema(self.inner.schema().clone()),
51 }
52 }
53}
54
55async fn apply_checkout(mut ds: InnerLance, lance: Option<&LanceArgs>) -> Result<InnerLance> {
56 let Some(args) = lance else { return Ok(ds) };
57
58 if let Some(tag) = &args.tag {
59 if let Some(requested) = &args.branch {
62 let content = ds
63 .tags()
64 .get(tag)
65 .await
66 .map_err(|e| Error::Lance(Box::new(e)))?;
67 let tag_branch = content.branch.as_deref().unwrap_or(MAIN_BRANCH);
68 if tag_branch != requested.as_str() {
69 return Err(Error::TagBranchMismatch {
70 tag: tag.clone(),
71 tag_branch: tag_branch.to_string(),
72 requested_branch: requested.clone(),
73 });
74 }
75 }
76 ds = ds
78 .checkout_version(tag.as_str())
79 .await
80 .map_err(|e| Error::Lance(Box::new(e)))?;
81 return Ok(ds);
82 }
83
84 if let Some(branch) = &args.branch {
85 ds = ds
86 .checkout_branch(branch)
87 .await
88 .map_err(|e| Error::Lance(Box::new(e)))?;
89 }
90 if let Some(version) = args.version {
91 ds = ds
92 .checkout_version(version)
93 .await
94 .map_err(|e| Error::Lance(Box::new(e)))?;
95 }
96 Ok(ds)
97}
98
99#[async_trait]
100impl Dataset for LanceDataset {
101 fn origin(&self) -> &Path {
102 &self.origin
103 }
104
105 fn arrow_schema(&self) -> SchemaRef {
106 self.arrow_schema.clone()
107 }
108
109 fn physical_schema_debug(&self, projection: Option<&[String]>) -> Result<String> {
110 match projection {
111 None => Ok(format!("{:#?}", self.inner.schema())),
112 Some(cols) => {
113 let projected = self
114 .inner
115 .schema()
116 .project(cols)
117 .map_err(|e| Error::Lance(Box::new(e)))?;
118 Ok(format!("{projected:#?}"))
119 }
120 }
121 }
122
123 async fn count_rows(&self) -> Result<u64> {
124 let n = self
125 .inner
126 .count_rows(None)
127 .await
128 .map_err(|e| Error::Lance(Box::new(e)))?;
129 Ok(n as u64)
130 }
131
132 async fn scan(&self, projection: Option<&[String]>) -> Result<BatchStream> {
133 let mut scanner = self.inner.scan();
134 if let Some(cols) = projection {
135 scanner
136 .project(cols)
137 .map_err(|e| Error::Lance(Box::new(e)))?;
138 }
139 let stream = scanner
140 .try_into_stream()
141 .await
142 .map_err(|e| Error::Lance(Box::new(e)))?;
143 let stream = stream.map(|r| r.map_err(|e| Error::Lance(Box::new(e))));
144 Ok(Box::pin(stream))
145 }
146
147 async fn take(&self, indices: &[u64], projection: Option<&[String]>) -> Result<RecordBatch> {
148 let req = self.projection_request(projection);
149 self.inner
150 .take(indices, req)
151 .await
152 .map_err(|e| Error::Lance(Box::new(e)))
153 }
154
155 fn lance(&self) -> Option<&dyn LanceCapabilities> {
156 Some(self)
157 }
158}
159
160#[async_trait]
161impl LanceCapabilities for LanceDataset {
162 async fn list_versions(
163 &self,
164 branch: Option<&str>,
165 tagged_only: bool,
166 ) -> Result<Vec<VersionInfo>> {
167 let scoped = match branch {
169 Some(b) if b != MAIN_BRANCH => self
170 .inner
171 .clone()
172 .checkout_branch(b)
173 .await
174 .map_err(|e| Error::Lance(Box::new(e)))?,
175 _ => self.inner.clone(),
176 };
177 let target_branch = branch.unwrap_or(MAIN_BRANCH);
178
179 let versions = scoped
180 .versions()
181 .await
182 .map_err(|e| Error::Lance(Box::new(e)))?;
183
184 let tags = self
187 .inner
188 .tags()
189 .list()
190 .await
191 .map_err(|e| Error::Lance(Box::new(e)))?;
192 let mut tags_for_version: HashMap<u64, Vec<String>> = HashMap::new();
193 for (name, content) in tags {
194 let content_branch = content.branch.as_deref().unwrap_or(MAIN_BRANCH);
195 if content_branch == target_branch {
196 tags_for_version
197 .entry(content.version)
198 .or_default()
199 .push(name);
200 }
201 }
202
203 let mut out: Vec<VersionInfo> = versions
204 .into_iter()
205 .map(|v| {
206 let mut tag_names = tags_for_version.remove(&v.version).unwrap_or_default();
207 tag_names.sort();
208 let tag = if tag_names.is_empty() {
209 None
210 } else {
211 Some(tag_names.join(","))
212 };
213 let message = v.metadata.get("message").cloned();
214 VersionInfo {
215 version: v.version,
216 timestamp: v.timestamp,
217 tag,
218 message,
219 }
220 })
221 .collect();
222
223 if tagged_only {
224 out.retain(|v| v.tag.is_some());
225 }
226 Ok(out)
227 }
228
229 async fn list_branches(&self) -> Result<Vec<BranchInfo>> {
230 let map = self
231 .inner
232 .list_branches()
233 .await
234 .map_err(|e| Error::Lance(Box::new(e)))?;
235
236 let mut out: Vec<BranchInfo> = map
240 .into_iter()
241 .map(|(name, content)| BranchInfo {
242 name,
243 parent_branch: Some(
244 content
245 .parent_branch
246 .unwrap_or_else(|| MAIN_BRANCH.to_string()),
247 ),
248 parent_version: Some(content.parent_version),
249 created_at: unix_seconds_to_utc(content.create_at),
250 })
251 .collect();
252
253 if !out.iter().any(|b| b.name == MAIN_BRANCH) {
258 let main_inner = self
259 .inner
260 .clone()
261 .checkout_branch(MAIN_BRANCH)
262 .await
263 .map_err(|e| Error::Lance(Box::new(e)))?;
264 let main_created_at = main_inner
265 .versions()
266 .await
267 .map_err(|e| Error::Lance(Box::new(e)))?
268 .into_iter()
269 .next()
270 .map(|v| v.timestamp);
271 out.insert(
272 0,
273 BranchInfo {
274 name: MAIN_BRANCH.to_string(),
275 parent_branch: None,
276 parent_version: None,
277 created_at: main_created_at,
278 },
279 );
280 }
281 out.sort_by(|a, b| a.name.cmp(&b.name));
282 Ok(out)
283 }
284
285 async fn list_tags(&self) -> Result<Vec<TagInfo>> {
286 let tags = self
287 .inner
288 .tags()
289 .list()
290 .await
291 .map_err(|e| Error::Lance(Box::new(e)))?;
292 let mut out: Vec<TagInfo> = tags
293 .into_iter()
294 .map(|(name, content)| TagInfo {
295 name,
296 branch: content.branch.unwrap_or_else(|| MAIN_BRANCH.to_string()),
297 version: content.version,
298 })
299 .collect();
300 out.sort_by(|a, b| a.name.cmp(&b.name));
301 Ok(out)
302 }
303
304 async fn list_indices(&self) -> Result<Vec<IndexInfo>> {
305 let indices = self
306 .inner
307 .load_indices()
308 .await
309 .map_err(|e| Error::Lance(Box::new(e)))?;
310 let schema = self.inner.schema();
311
312 Ok(indices
313 .iter()
314 .map(|m| {
315 let columns = m
316 .fields
317 .iter()
318 .map(|id| {
319 schema
320 .field_by_id(*id)
321 .map(|f| f.name.clone())
322 .unwrap_or_else(|| format!("<field_id={id}>"))
323 })
324 .collect();
325 IndexInfo {
326 name: m.name.clone(),
327 uuid: m.uuid.to_string(),
328 columns,
329 dataset_version: m.dataset_version,
330 created_at: m.created_at,
331 }
332 })
333 .collect())
334 }
335}
336
337fn unix_seconds_to_utc(seconds: u64) -> Option<DateTime<Utc>> {
338 let secs = i64::try_from(seconds).ok()?;
339 Utc.timestamp_opt(secs, 0).single()
340}
341
342pub async fn write_dataset<R>(path: &Path, reader: R) -> Result<()>
346where
347 R: RecordBatchReader + Send + 'static,
348{
349 let uri = path.to_string_lossy().into_owned();
350 InnerLance::write(reader, uri.as_str(), None)
351 .await
352 .map_err(|e| Error::Lance(Box::new(e)))?;
353 Ok(())
354}