1use std::collections::HashSet;
2
3use lance_context_api::{
4 AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse,
5 ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RetrieveRequest,
6 RetrieveResultDto, SearchRequest, SearchResultDto, UpdateRecordRequest, UpdateRecordResponse,
7 UpsertRecordRequest, UpsertRecordResponse, UpsertRecordsRequest, UpsertRecordsResponse,
8};
9use lance_context_core::{
10 ContextStore as LocalStore, ContextStoreOptions, DistanceMetric, IdIndexType,
11};
12
13#[cfg(feature = "remote")]
14use lance_context_client::RemoteContextStore;
15
16pub enum ContextStore {
17 Local(Box<LocalStore>),
18 #[cfg(feature = "remote")]
19 Remote(RemoteContextStore),
20}
21
22impl ContextStore {
23 pub async fn open(uri: &str) -> Result<Self, ContextError> {
24 let store = LocalStore::open(uri)
25 .await
26 .map_err(|e| ContextError::Internal(e.to_string()))?;
27 Ok(Self::Local(Box::new(store)))
28 }
29
30 pub async fn open_with_options(
31 uri: &str,
32 storage_options: Option<std::collections::HashMap<String, String>>,
33 id_index_type: Option<&str>,
34 blob_columns: Option<Vec<String>>,
35 distance_metric: Option<&str>,
36 ) -> Result<Self, ContextError> {
37 let id_idx = match id_index_type {
38 Some("btree") => IdIndexType::BTree,
39 Some("zonemap") => IdIndexType::ZoneMap,
40 Some("none") | None => IdIndexType::None,
41 Some(other) => {
42 return Err(ContextError::InvalidRequest(format!(
43 "Invalid id_index_type: '{other}'"
44 )));
45 }
46 };
47 let metric = match distance_metric {
48 Some(value) => Some(
49 DistanceMetric::parse(value)
50 .map_err(|e| ContextError::InvalidRequest(e.to_string()))?,
51 ),
52 None => None,
53 };
54 let options = ContextStoreOptions {
55 storage_options,
56 blob_columns: blob_columns
57 .unwrap_or_default()
58 .into_iter()
59 .collect::<HashSet<_>>(),
60 id_index_type: id_idx,
61 distance_metric: metric,
62 ..Default::default()
63 };
64 let store = LocalStore::open_with_options(uri, options)
65 .await
66 .map_err(|e| ContextError::Internal(e.to_string()))?;
67 Ok(Self::Local(Box::new(store)))
68 }
69
70 #[cfg(feature = "remote")]
71 pub async fn connect(base_url: &str, context_name: &str) -> Result<Self, ContextError> {
72 let store = RemoteContextStore::connect(base_url, context_name)
73 .await
74 .map_err(|e| ContextError::Internal(e.to_string()))?;
75 Ok(Self::Remote(store))
76 }
77
78 #[cfg(feature = "remote")]
79 pub async fn connect_or_create(
80 base_url: &str,
81 req: &lance_context_api::CreateContextRequest,
82 ) -> Result<Self, ContextError> {
83 let store = RemoteContextStore::connect_or_create(base_url, req)
84 .await
85 .map_err(|e| ContextError::Internal(e.to_string()))?;
86 Ok(Self::Remote(store))
87 }
88}
89
90macro_rules! dispatch_mut {
91 ($self:expr, $method:ident $(, $arg:expr)*) => {
92 match $self {
93 ContextStore::Local(s) => ContextStoreApi::$method(s.as_mut() $(, $arg)*).await,
94 #[cfg(feature = "remote")]
95 ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*).await,
96 }
97 };
98}
99
100macro_rules! dispatch_ref {
101 ($self:expr, $method:ident $(, $arg:expr)*) => {
102 match $self {
103 ContextStore::Local(s) => ContextStoreApi::$method(s.as_ref() $(, $arg)*).await,
104 #[cfg(feature = "remote")]
105 ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*).await,
106 }
107 };
108}
109
110macro_rules! dispatch_sync {
111 ($self:expr, $method:ident $(, $arg:expr)*) => {
112 match $self {
113 ContextStore::Local(s) => ContextStoreApi::$method(s.as_ref() $(, $arg)*),
114 #[cfg(feature = "remote")]
115 ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*),
116 }
117 };
118}
119
120impl ContextStoreApi for ContextStore {
121 async fn add(&mut self, records: &[AddRecordRequest]) -> ContextResult<AddRecordsResponse> {
122 dispatch_mut!(self, add, records)
123 }
124
125 async fn upsert(
126 &mut self,
127 request: &UpsertRecordRequest,
128 ) -> ContextResult<UpsertRecordResponse> {
129 dispatch_mut!(self, upsert, request)
130 }
131
132 async fn upsert_many(
133 &mut self,
134 request: &UpsertRecordsRequest,
135 ) -> ContextResult<UpsertRecordsResponse> {
136 dispatch_mut!(self, upsert_many, request)
137 }
138
139 async fn update(
140 &mut self,
141 request: &UpdateRecordRequest,
142 ) -> ContextResult<UpdateRecordResponse> {
143 dispatch_mut!(self, update, request)
144 }
145
146 async fn get(&self, id: &str) -> ContextResult<Option<RecordDto>> {
147 dispatch_ref!(self, get, id)
148 }
149
150 async fn get_by_external_id(&self, external_id: &str) -> ContextResult<Option<RecordDto>> {
151 dispatch_ref!(self, get_by_external_id, external_id)
152 }
153
154 async fn delete_by_id(&mut self, id: &str) -> ContextResult<DeleteRecordResponse> {
155 dispatch_mut!(self, delete_by_id, id)
156 }
157
158 async fn delete_by_external_id(
159 &mut self,
160 external_id: &str,
161 ) -> ContextResult<DeleteRecordResponse> {
162 dispatch_mut!(self, delete_by_external_id, external_id)
163 }
164
165 async fn list(
166 &self,
167 limit: Option<usize>,
168 offset: Option<usize>,
169 filters: Option<serde_json::Value>,
170 include_expired: bool,
171 include_retired: bool,
172 ) -> ContextResult<Vec<RecordDto>> {
173 dispatch_ref!(
174 self,
175 list,
176 limit,
177 offset,
178 filters,
179 include_expired,
180 include_retired
181 )
182 }
183
184 async fn related(
185 &self,
186 target_id: &str,
187 relation: Option<&str>,
188 limit: Option<usize>,
189 include_expired: bool,
190 include_retired: bool,
191 ) -> ContextResult<Vec<RecordDto>> {
192 dispatch_ref!(
193 self,
194 related,
195 target_id,
196 relation,
197 limit,
198 include_expired,
199 include_retired
200 )
201 }
202
203 async fn search(&self, request: &SearchRequest) -> ContextResult<Vec<SearchResultDto>> {
204 dispatch_ref!(self, search, request)
205 }
206
207 async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult<Vec<RetrieveResultDto>> {
208 dispatch_ref!(self, retrieve, request)
209 }
210
211 fn version(&self) -> u64 {
212 dispatch_sync!(self, version)
213 }
214
215 async fn checkout(&mut self, version: u64) -> ContextResult<()> {
216 dispatch_mut!(self, checkout, version)
217 }
218
219 async fn compact(&mut self, options: Option<CompactRequest>) -> ContextResult<CompactResponse> {
220 dispatch_mut!(self, compact, options)
221 }
222
223 async fn compaction_stats(&self) -> ContextResult<CompactStatsResponse> {
224 dispatch_ref!(self, compaction_stats)
225 }
226}