Skip to main content

lance_context/
unified.rs

1use std::collections::HashSet;
2
3use lance_context_api::{
4    AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse,
5    ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RetrieveRequest,
6    RetrieveResultDto, SearchRequest, SearchResultDto, UpdateRecordRequest, UpdateRecordResponse,
7    UpsertRecordRequest, UpsertRecordResponse, UpsertRecordsRequest, UpsertRecordsResponse,
8};
9use lance_context_core::{
10    ContextStore as LocalStore, ContextStoreOptions, DistanceMetric, IdIndexType,
11};
12
13#[cfg(feature = "remote")]
14use lance_context_client::RemoteContextStore;
15
16pub enum ContextStore {
17    Local(Box<LocalStore>),
18    #[cfg(feature = "remote")]
19    Remote(RemoteContextStore),
20}
21
22impl ContextStore {
23    pub async fn open(uri: &str) -> Result<Self, ContextError> {
24        let store = LocalStore::open(uri)
25            .await
26            .map_err(|e| ContextError::Internal(e.to_string()))?;
27        Ok(Self::Local(Box::new(store)))
28    }
29
30    pub async fn open_with_options(
31        uri: &str,
32        storage_options: Option<std::collections::HashMap<String, String>>,
33        id_index_type: Option<&str>,
34        blob_columns: Option<Vec<String>>,
35        distance_metric: Option<&str>,
36    ) -> Result<Self, ContextError> {
37        let id_idx = match id_index_type {
38            Some("btree") => IdIndexType::BTree,
39            Some("zonemap") => IdIndexType::ZoneMap,
40            Some("none") | None => IdIndexType::None,
41            Some(other) => {
42                return Err(ContextError::InvalidRequest(format!(
43                    "Invalid id_index_type: '{other}'"
44                )));
45            }
46        };
47        let metric = match distance_metric {
48            Some(value) => Some(
49                DistanceMetric::parse(value)
50                    .map_err(|e| ContextError::InvalidRequest(e.to_string()))?,
51            ),
52            None => None,
53        };
54        let options = ContextStoreOptions {
55            storage_options,
56            blob_columns: blob_columns
57                .unwrap_or_default()
58                .into_iter()
59                .collect::<HashSet<_>>(),
60            id_index_type: id_idx,
61            distance_metric: metric,
62            ..Default::default()
63        };
64        let store = LocalStore::open_with_options(uri, options)
65            .await
66            .map_err(|e| ContextError::Internal(e.to_string()))?;
67        Ok(Self::Local(Box::new(store)))
68    }
69
70    #[cfg(feature = "remote")]
71    pub async fn connect(base_url: &str, context_name: &str) -> Result<Self, ContextError> {
72        let store = RemoteContextStore::connect(base_url, context_name)
73            .await
74            .map_err(|e| ContextError::Internal(e.to_string()))?;
75        Ok(Self::Remote(store))
76    }
77
78    #[cfg(feature = "remote")]
79    pub async fn connect_or_create(
80        base_url: &str,
81        req: &lance_context_api::CreateContextRequest,
82    ) -> Result<Self, ContextError> {
83        let store = RemoteContextStore::connect_or_create(base_url, req)
84            .await
85            .map_err(|e| ContextError::Internal(e.to_string()))?;
86        Ok(Self::Remote(store))
87    }
88}
89
90macro_rules! dispatch_mut {
91    ($self:expr, $method:ident $(, $arg:expr)*) => {
92        match $self {
93            ContextStore::Local(s) => ContextStoreApi::$method(s.as_mut() $(, $arg)*).await,
94            #[cfg(feature = "remote")]
95            ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*).await,
96        }
97    };
98}
99
100macro_rules! dispatch_ref {
101    ($self:expr, $method:ident $(, $arg:expr)*) => {
102        match $self {
103            ContextStore::Local(s) => ContextStoreApi::$method(s.as_ref() $(, $arg)*).await,
104            #[cfg(feature = "remote")]
105            ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*).await,
106        }
107    };
108}
109
110macro_rules! dispatch_sync {
111    ($self:expr, $method:ident $(, $arg:expr)*) => {
112        match $self {
113            ContextStore::Local(s) => ContextStoreApi::$method(s.as_ref() $(, $arg)*),
114            #[cfg(feature = "remote")]
115            ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*),
116        }
117    };
118}
119
120impl ContextStoreApi for ContextStore {
121    async fn add(&mut self, records: &[AddRecordRequest]) -> ContextResult<AddRecordsResponse> {
122        dispatch_mut!(self, add, records)
123    }
124
125    async fn upsert(
126        &mut self,
127        request: &UpsertRecordRequest,
128    ) -> ContextResult<UpsertRecordResponse> {
129        dispatch_mut!(self, upsert, request)
130    }
131
132    async fn upsert_many(
133        &mut self,
134        request: &UpsertRecordsRequest,
135    ) -> ContextResult<UpsertRecordsResponse> {
136        dispatch_mut!(self, upsert_many, request)
137    }
138
139    async fn update(
140        &mut self,
141        request: &UpdateRecordRequest,
142    ) -> ContextResult<UpdateRecordResponse> {
143        dispatch_mut!(self, update, request)
144    }
145
146    async fn get(&self, id: &str) -> ContextResult<Option<RecordDto>> {
147        dispatch_ref!(self, get, id)
148    }
149
150    async fn get_by_external_id(&self, external_id: &str) -> ContextResult<Option<RecordDto>> {
151        dispatch_ref!(self, get_by_external_id, external_id)
152    }
153
154    async fn delete_by_id(&mut self, id: &str) -> ContextResult<DeleteRecordResponse> {
155        dispatch_mut!(self, delete_by_id, id)
156    }
157
158    async fn delete_by_external_id(
159        &mut self,
160        external_id: &str,
161    ) -> ContextResult<DeleteRecordResponse> {
162        dispatch_mut!(self, delete_by_external_id, external_id)
163    }
164
165    async fn list(
166        &self,
167        limit: Option<usize>,
168        offset: Option<usize>,
169        filters: Option<serde_json::Value>,
170        include_expired: bool,
171        include_retired: bool,
172    ) -> ContextResult<Vec<RecordDto>> {
173        dispatch_ref!(
174            self,
175            list,
176            limit,
177            offset,
178            filters,
179            include_expired,
180            include_retired
181        )
182    }
183
184    async fn related(
185        &self,
186        target_id: &str,
187        relation: Option<&str>,
188        limit: Option<usize>,
189        include_expired: bool,
190        include_retired: bool,
191    ) -> ContextResult<Vec<RecordDto>> {
192        dispatch_ref!(
193            self,
194            related,
195            target_id,
196            relation,
197            limit,
198            include_expired,
199            include_retired
200        )
201    }
202
203    async fn search(&self, request: &SearchRequest) -> ContextResult<Vec<SearchResultDto>> {
204        dispatch_ref!(self, search, request)
205    }
206
207    async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult<Vec<RetrieveResultDto>> {
208        dispatch_ref!(self, retrieve, request)
209    }
210
211    fn version(&self) -> u64 {
212        dispatch_sync!(self, version)
213    }
214
215    async fn checkout(&mut self, version: u64) -> ContextResult<()> {
216        dispatch_mut!(self, checkout, version)
217    }
218
219    async fn compact(&mut self, options: Option<CompactRequest>) -> ContextResult<CompactResponse> {
220        dispatch_mut!(self, compact, options)
221    }
222
223    async fn compaction_stats(&self) -> ContextResult<CompactStatsResponse> {
224        dispatch_ref!(self, compaction_stats)
225    }
226}