1use std::collections::HashSet;
2
3use lance_context_api::{
4 AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse,
5 ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RetrieveRequest,
6 RetrieveResultDto, SearchRequest, SearchResultDto, UpdateRecordRequest, UpdateRecordResponse,
7 UpsertRecordRequest, UpsertRecordResponse,
8};
9use lance_context_core::{
10 ContextStore as LocalStore, ContextStoreOptions, DistanceMetric, IdIndexType,
11};
12
13#[cfg(feature = "remote")]
14use lance_context_client::RemoteContextStore;
15
16pub enum ContextStore {
17 Local(Box<LocalStore>),
18 #[cfg(feature = "remote")]
19 Remote(RemoteContextStore),
20}
21
22impl ContextStore {
23 pub async fn open(uri: &str) -> Result<Self, ContextError> {
24 let store = LocalStore::open(uri)
25 .await
26 .map_err(|e| ContextError::Internal(e.to_string()))?;
27 Ok(Self::Local(Box::new(store)))
28 }
29
30 pub async fn open_with_options(
31 uri: &str,
32 storage_options: Option<std::collections::HashMap<String, String>>,
33 id_index_type: Option<&str>,
34 blob_columns: Option<Vec<String>>,
35 distance_metric: Option<&str>,
36 ) -> Result<Self, ContextError> {
37 let id_idx = match id_index_type {
38 Some("btree") => IdIndexType::BTree,
39 Some("zonemap") => IdIndexType::ZoneMap,
40 Some("none") | None => IdIndexType::None,
41 Some(other) => {
42 return Err(ContextError::InvalidRequest(format!(
43 "Invalid id_index_type: '{other}'"
44 )));
45 }
46 };
47 let metric = match distance_metric {
48 Some(value) => Some(
49 DistanceMetric::parse(value)
50 .map_err(|e| ContextError::InvalidRequest(e.to_string()))?,
51 ),
52 None => None,
53 };
54 let options = ContextStoreOptions {
55 storage_options,
56 blob_columns: blob_columns
57 .unwrap_or_default()
58 .into_iter()
59 .collect::<HashSet<_>>(),
60 id_index_type: id_idx,
61 distance_metric: metric,
62 ..Default::default()
63 };
64 let store = LocalStore::open_with_options(uri, options)
65 .await
66 .map_err(|e| ContextError::Internal(e.to_string()))?;
67 Ok(Self::Local(Box::new(store)))
68 }
69
70 #[cfg(feature = "remote")]
71 pub async fn connect(base_url: &str, context_name: &str) -> Result<Self, ContextError> {
72 let store = RemoteContextStore::connect(base_url, context_name)
73 .await
74 .map_err(|e| ContextError::Internal(e.to_string()))?;
75 Ok(Self::Remote(store))
76 }
77
78 #[cfg(feature = "remote")]
79 pub async fn connect_or_create(
80 base_url: &str,
81 req: &lance_context_api::CreateContextRequest,
82 ) -> Result<Self, ContextError> {
83 let store = RemoteContextStore::connect_or_create(base_url, req)
84 .await
85 .map_err(|e| ContextError::Internal(e.to_string()))?;
86 Ok(Self::Remote(store))
87 }
88}
89
90macro_rules! dispatch_mut {
91 ($self:expr, $method:ident $(, $arg:expr)*) => {
92 match $self {
93 ContextStore::Local(s) => ContextStoreApi::$method(s.as_mut() $(, $arg)*).await,
94 #[cfg(feature = "remote")]
95 ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*).await,
96 }
97 };
98}
99
100macro_rules! dispatch_ref {
101 ($self:expr, $method:ident $(, $arg:expr)*) => {
102 match $self {
103 ContextStore::Local(s) => ContextStoreApi::$method(s.as_ref() $(, $arg)*).await,
104 #[cfg(feature = "remote")]
105 ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*).await,
106 }
107 };
108}
109
110macro_rules! dispatch_sync {
111 ($self:expr, $method:ident $(, $arg:expr)*) => {
112 match $self {
113 ContextStore::Local(s) => ContextStoreApi::$method(s.as_ref() $(, $arg)*),
114 #[cfg(feature = "remote")]
115 ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*),
116 }
117 };
118}
119
120impl ContextStoreApi for ContextStore {
121 async fn add(&mut self, records: &[AddRecordRequest]) -> ContextResult<AddRecordsResponse> {
122 dispatch_mut!(self, add, records)
123 }
124
125 async fn upsert(
126 &mut self,
127 request: &UpsertRecordRequest,
128 ) -> ContextResult<UpsertRecordResponse> {
129 dispatch_mut!(self, upsert, request)
130 }
131
132 async fn update(
133 &mut self,
134 request: &UpdateRecordRequest,
135 ) -> ContextResult<UpdateRecordResponse> {
136 dispatch_mut!(self, update, request)
137 }
138
139 async fn get(&self, id: &str) -> ContextResult<Option<RecordDto>> {
140 dispatch_ref!(self, get, id)
141 }
142
143 async fn get_by_external_id(&self, external_id: &str) -> ContextResult<Option<RecordDto>> {
144 dispatch_ref!(self, get_by_external_id, external_id)
145 }
146
147 async fn delete_by_id(&mut self, id: &str) -> ContextResult<DeleteRecordResponse> {
148 dispatch_mut!(self, delete_by_id, id)
149 }
150
151 async fn delete_by_external_id(
152 &mut self,
153 external_id: &str,
154 ) -> ContextResult<DeleteRecordResponse> {
155 dispatch_mut!(self, delete_by_external_id, external_id)
156 }
157
158 async fn list(
159 &self,
160 limit: Option<usize>,
161 offset: Option<usize>,
162 filters: Option<serde_json::Value>,
163 include_expired: bool,
164 include_retired: bool,
165 ) -> ContextResult<Vec<RecordDto>> {
166 dispatch_ref!(
167 self,
168 list,
169 limit,
170 offset,
171 filters,
172 include_expired,
173 include_retired
174 )
175 }
176
177 async fn related(
178 &self,
179 target_id: &str,
180 relation: Option<&str>,
181 limit: Option<usize>,
182 include_expired: bool,
183 include_retired: bool,
184 ) -> ContextResult<Vec<RecordDto>> {
185 dispatch_ref!(
186 self,
187 related,
188 target_id,
189 relation,
190 limit,
191 include_expired,
192 include_retired
193 )
194 }
195
196 async fn search(&self, request: &SearchRequest) -> ContextResult<Vec<SearchResultDto>> {
197 dispatch_ref!(self, search, request)
198 }
199
200 async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult<Vec<RetrieveResultDto>> {
201 dispatch_ref!(self, retrieve, request)
202 }
203
204 fn version(&self) -> u64 {
205 dispatch_sync!(self, version)
206 }
207
208 async fn checkout(&mut self, version: u64) -> ContextResult<()> {
209 dispatch_mut!(self, checkout, version)
210 }
211
212 async fn compact(&mut self, options: Option<CompactRequest>) -> ContextResult<CompactResponse> {
213 dispatch_mut!(self, compact, options)
214 }
215
216 async fn compaction_stats(&self) -> ContextResult<CompactStatsResponse> {
217 dispatch_ref!(self, compaction_stats)
218 }
219}