Skip to main content

lance_context/
unified.rs

1use std::collections::HashSet;
2
3use lance_context_api::{
4    AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse,
5    ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RetrieveRequest,
6    RetrieveResultDto, SearchRequest, SearchResultDto, UpdateRecordRequest, UpdateRecordResponse,
7    UpsertRecordRequest, UpsertRecordResponse,
8};
9use lance_context_core::{
10    ContextStore as LocalStore, ContextStoreOptions, DistanceMetric, IdIndexType,
11};
12
13#[cfg(feature = "remote")]
14use lance_context_client::RemoteContextStore;
15
16pub enum ContextStore {
17    Local(Box<LocalStore>),
18    #[cfg(feature = "remote")]
19    Remote(RemoteContextStore),
20}
21
22impl ContextStore {
23    pub async fn open(uri: &str) -> Result<Self, ContextError> {
24        let store = LocalStore::open(uri)
25            .await
26            .map_err(|e| ContextError::Internal(e.to_string()))?;
27        Ok(Self::Local(Box::new(store)))
28    }
29
30    pub async fn open_with_options(
31        uri: &str,
32        storage_options: Option<std::collections::HashMap<String, String>>,
33        id_index_type: Option<&str>,
34        blob_columns: Option<Vec<String>>,
35        distance_metric: Option<&str>,
36    ) -> Result<Self, ContextError> {
37        let id_idx = match id_index_type {
38            Some("btree") => IdIndexType::BTree,
39            Some("zonemap") => IdIndexType::ZoneMap,
40            Some("none") | None => IdIndexType::None,
41            Some(other) => {
42                return Err(ContextError::InvalidRequest(format!(
43                    "Invalid id_index_type: '{other}'"
44                )));
45            }
46        };
47        let metric = match distance_metric {
48            Some(value) => Some(
49                DistanceMetric::parse(value)
50                    .map_err(|e| ContextError::InvalidRequest(e.to_string()))?,
51            ),
52            None => None,
53        };
54        let options = ContextStoreOptions {
55            storage_options,
56            blob_columns: blob_columns
57                .unwrap_or_default()
58                .into_iter()
59                .collect::<HashSet<_>>(),
60            id_index_type: id_idx,
61            distance_metric: metric,
62            ..Default::default()
63        };
64        let store = LocalStore::open_with_options(uri, options)
65            .await
66            .map_err(|e| ContextError::Internal(e.to_string()))?;
67        Ok(Self::Local(Box::new(store)))
68    }
69
70    #[cfg(feature = "remote")]
71    pub async fn connect(base_url: &str, context_name: &str) -> Result<Self, ContextError> {
72        let store = RemoteContextStore::connect(base_url, context_name)
73            .await
74            .map_err(|e| ContextError::Internal(e.to_string()))?;
75        Ok(Self::Remote(store))
76    }
77
78    #[cfg(feature = "remote")]
79    pub async fn connect_or_create(
80        base_url: &str,
81        req: &lance_context_api::CreateContextRequest,
82    ) -> Result<Self, ContextError> {
83        let store = RemoteContextStore::connect_or_create(base_url, req)
84            .await
85            .map_err(|e| ContextError::Internal(e.to_string()))?;
86        Ok(Self::Remote(store))
87    }
88}
89
90macro_rules! dispatch_mut {
91    ($self:expr, $method:ident $(, $arg:expr)*) => {
92        match $self {
93            ContextStore::Local(s) => ContextStoreApi::$method(s.as_mut() $(, $arg)*).await,
94            #[cfg(feature = "remote")]
95            ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*).await,
96        }
97    };
98}
99
100macro_rules! dispatch_ref {
101    ($self:expr, $method:ident $(, $arg:expr)*) => {
102        match $self {
103            ContextStore::Local(s) => ContextStoreApi::$method(s.as_ref() $(, $arg)*).await,
104            #[cfg(feature = "remote")]
105            ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*).await,
106        }
107    };
108}
109
110macro_rules! dispatch_sync {
111    ($self:expr, $method:ident $(, $arg:expr)*) => {
112        match $self {
113            ContextStore::Local(s) => ContextStoreApi::$method(s.as_ref() $(, $arg)*),
114            #[cfg(feature = "remote")]
115            ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*),
116        }
117    };
118}
119
120impl ContextStoreApi for ContextStore {
121    async fn add(&mut self, records: &[AddRecordRequest]) -> ContextResult<AddRecordsResponse> {
122        dispatch_mut!(self, add, records)
123    }
124
125    async fn upsert(
126        &mut self,
127        request: &UpsertRecordRequest,
128    ) -> ContextResult<UpsertRecordResponse> {
129        dispatch_mut!(self, upsert, request)
130    }
131
132    async fn update(
133        &mut self,
134        request: &UpdateRecordRequest,
135    ) -> ContextResult<UpdateRecordResponse> {
136        dispatch_mut!(self, update, request)
137    }
138
139    async fn get(&self, id: &str) -> ContextResult<Option<RecordDto>> {
140        dispatch_ref!(self, get, id)
141    }
142
143    async fn get_by_external_id(&self, external_id: &str) -> ContextResult<Option<RecordDto>> {
144        dispatch_ref!(self, get_by_external_id, external_id)
145    }
146
147    async fn delete_by_id(&mut self, id: &str) -> ContextResult<DeleteRecordResponse> {
148        dispatch_mut!(self, delete_by_id, id)
149    }
150
151    async fn delete_by_external_id(
152        &mut self,
153        external_id: &str,
154    ) -> ContextResult<DeleteRecordResponse> {
155        dispatch_mut!(self, delete_by_external_id, external_id)
156    }
157
158    async fn list(
159        &self,
160        limit: Option<usize>,
161        offset: Option<usize>,
162        filters: Option<serde_json::Value>,
163        include_expired: bool,
164        include_retired: bool,
165    ) -> ContextResult<Vec<RecordDto>> {
166        dispatch_ref!(
167            self,
168            list,
169            limit,
170            offset,
171            filters,
172            include_expired,
173            include_retired
174        )
175    }
176
177    async fn related(
178        &self,
179        target_id: &str,
180        relation: Option<&str>,
181        limit: Option<usize>,
182        include_expired: bool,
183        include_retired: bool,
184    ) -> ContextResult<Vec<RecordDto>> {
185        dispatch_ref!(
186            self,
187            related,
188            target_id,
189            relation,
190            limit,
191            include_expired,
192            include_retired
193        )
194    }
195
196    async fn search(&self, request: &SearchRequest) -> ContextResult<Vec<SearchResultDto>> {
197        dispatch_ref!(self, search, request)
198    }
199
200    async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult<Vec<RetrieveResultDto>> {
201        dispatch_ref!(self, retrieve, request)
202    }
203
204    fn version(&self) -> u64 {
205        dispatch_sync!(self, version)
206    }
207
208    async fn checkout(&mut self, version: u64) -> ContextResult<()> {
209        dispatch_mut!(self, checkout, version)
210    }
211
212    async fn compact(&mut self, options: Option<CompactRequest>) -> ContextResult<CompactResponse> {
213        dispatch_mut!(self, compact, options)
214    }
215
216    async fn compaction_stats(&self) -> ContextResult<CompactStatsResponse> {
217        dispatch_ref!(self, compaction_stats)
218    }
219}