Skip to main content

lance_context_core/
api_impl.rs

1use chrono::Utc;
2use serde_json::Value;
3use uuid::Uuid;
4
5use lance_context_api::{
6    AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse,
7    ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RecordPatchDto,
8    RelationshipDto, RetrieveRequest, RetrieveResultDto, SearchRequest, SearchResultDto,
9    StateMetadataDto, UpdateRecordRequest, UpdateRecordResponse, UpsertRecordRequest,
10    UpsertRecordResponse,
11};
12
13use crate::record::{
14    ContextRecord, LifecycleQueryOptions, RecordFilters, RecordPatch, Relationship, StateMetadata,
15    LIFECYCLE_ACTIVE,
16};
17use crate::store::{CompactionConfig, ContextStore};
18
19impl ContextStoreApi for ContextStore {
20    async fn add(&mut self, records: &[AddRecordRequest]) -> ContextResult<AddRecordsResponse> {
21        let run_id = Uuid::new_v4().to_string();
22        let mut ids = Vec::with_capacity(records.len());
23        let mut core_records = Vec::with_capacity(records.len());
24
25        for r in records {
26            let id = Uuid::new_v4().to_string();
27            ids.push(id.clone());
28            core_records.push(record_from_add_request(r, id, run_id.clone()));
29        }
30
31        let count = core_records.len();
32        let version = self.add(&core_records).await.map_err(to_ctx_err)?;
33        Ok(AddRecordsResponse {
34            version,
35            ids,
36            count,
37        })
38    }
39
40    async fn upsert(
41        &mut self,
42        request: &UpsertRecordRequest,
43    ) -> ContextResult<UpsertRecordResponse> {
44        if request.key != "external_id" {
45            return Err(ContextError::InvalidRequest(format!(
46                "upsert key '{}' is not supported; use 'external_id'",
47                request.key
48            )));
49        }
50        if request
51            .record
52            .external_id
53            .as_deref()
54            .is_none_or(str::is_empty)
55        {
56            return Err(ContextError::InvalidRequest(
57                "upsert requires record.external_id".to_string(),
58            ));
59        }
60
61        let record = record_from_add_request(
62            &request.record,
63            Uuid::new_v4().to_string(),
64            Uuid::new_v4().to_string(),
65        );
66        let result = ContextStore::upsert_by_external_id(self, record)
67            .await
68            .map_err(to_ctx_err)?;
69        Ok(UpsertRecordResponse {
70            version: result.version,
71            inserted: result.inserted,
72            replaced_id: result.replaced_id,
73            record: record_to_dto(result.record),
74        })
75    }
76
77    async fn update(
78        &mut self,
79        request: &UpdateRecordRequest,
80    ) -> ContextResult<UpdateRecordResponse> {
81        if request.patch.is_empty() {
82            return Err(ContextError::InvalidRequest(
83                "update requires at least one patch field".to_string(),
84            ));
85        }
86
87        let patch = patch_from_dto(&request.patch);
88        let result = match (&request.id, &request.external_id) {
89            (Some(id), None) => ContextStore::update_by_id(self, id, patch).await,
90            (None, Some(external_id)) => {
91                ContextStore::update_by_external_id(self, external_id, patch).await
92            }
93            (None, None) => {
94                return Err(ContextError::InvalidRequest(
95                    "update requires either id or external_id".to_string(),
96                ));
97            }
98            (Some(_), Some(_)) => {
99                return Err(ContextError::InvalidRequest(
100                    "update accepts only one of id or external_id".to_string(),
101                ));
102            }
103        }
104        .map_err(to_ctx_err)?;
105
106        Ok(match result {
107            Some(result) => UpdateRecordResponse {
108                version: result.version,
109                updated: true,
110                replaced_id: Some(result.replaced_id),
111                record: Some(record_to_dto(result.record)),
112            },
113            None => UpdateRecordResponse {
114                version: ContextStore::version(self),
115                updated: false,
116                replaced_id: None,
117                record: None,
118            },
119        })
120    }
121
122    async fn get(&self, id: &str) -> ContextResult<Option<RecordDto>> {
123        let record = ContextStore::get(self, id).await.map_err(to_ctx_err)?;
124        Ok(record.map(record_to_dto))
125    }
126
127    async fn get_by_external_id(&self, external_id: &str) -> ContextResult<Option<RecordDto>> {
128        let record = ContextStore::get_by_external_id(self, external_id)
129            .await
130            .map_err(to_ctx_err)?;
131        Ok(record.map(record_to_dto))
132    }
133
134    async fn delete_by_id(&mut self, id: &str) -> ContextResult<DeleteRecordResponse> {
135        let deleted = ContextStore::delete_by_id(self, id)
136            .await
137            .map_err(to_ctx_err)?;
138        Ok(DeleteRecordResponse {
139            deleted,
140            version: ContextStore::version(self),
141        })
142    }
143
144    async fn delete_by_external_id(
145        &mut self,
146        external_id: &str,
147    ) -> ContextResult<DeleteRecordResponse> {
148        let deleted = ContextStore::delete_by_external_id(self, external_id)
149            .await
150            .map_err(to_ctx_err)?;
151        Ok(DeleteRecordResponse {
152            deleted,
153            version: ContextStore::version(self),
154        })
155    }
156
157    async fn list(
158        &self,
159        limit: Option<usize>,
160        offset: Option<usize>,
161        filters: Option<Value>,
162        include_expired: bool,
163        include_retired: bool,
164    ) -> ContextResult<Vec<RecordDto>> {
165        let filters = filters
166            .map(RecordFilters::from_json_value)
167            .transpose()
168            .map_err(ContextError::InvalidRequest)?;
169        let options = LifecycleQueryOptions::new(include_expired, include_retired);
170        let records = ContextStore::list_filtered_with_options(
171            self,
172            limit,
173            offset,
174            filters.as_ref(),
175            options,
176        )
177        .await
178        .map_err(to_ctx_err)?;
179        Ok(records.into_iter().map(record_to_dto).collect())
180    }
181
182    async fn related(
183        &self,
184        target_id: &str,
185        relation: Option<&str>,
186        limit: Option<usize>,
187        include_expired: bool,
188        include_retired: bool,
189    ) -> ContextResult<Vec<RecordDto>> {
190        let options = LifecycleQueryOptions::new(include_expired, include_retired);
191        let records =
192            ContextStore::list_related_with_options(self, target_id, relation, limit, options)
193                .await
194                .map_err(to_ctx_err)?;
195        Ok(records.into_iter().map(record_to_dto).collect())
196    }
197
198    async fn search(&self, request: &SearchRequest) -> ContextResult<Vec<SearchResultDto>> {
199        let filters = request
200            .filters
201            .clone()
202            .map(RecordFilters::from_json_value)
203            .transpose()
204            .map_err(ContextError::InvalidRequest)?;
205        let options = LifecycleQueryOptions::new(request.include_expired, request.include_retired);
206        let results = ContextStore::search_filtered_with_options(
207            self,
208            &request.query,
209            Some(request.limit),
210            filters.as_ref(),
211            options,
212        )
213        .await
214        .map_err(to_ctx_err)?;
215        Ok(results
216            .into_iter()
217            .map(|mut sr| {
218                if !request.include_relationships {
219                    sr.record.relationships.clear();
220                }
221                SearchResultDto {
222                    record: record_to_dto(sr.record),
223                    distance: sr.distance,
224                }
225            })
226            .collect())
227    }
228
229    async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult<Vec<RetrieveResultDto>> {
230        if request.fusion != "rrf" {
231            return Err(ContextError::InvalidRequest(
232                "retrieve fusion currently supports only 'rrf'".to_string(),
233            ));
234        }
235
236        let filters = request
237            .filters
238            .clone()
239            .map(RecordFilters::from_json_value)
240            .transpose()
241            .map_err(ContextError::InvalidRequest)?;
242        let options = LifecycleQueryOptions::new(request.include_expired, request.include_retired);
243        let results = self
244            .retrieve_filtered_with_options(
245                request.text.as_deref(),
246                request.vector.as_deref(),
247                Some(request.limit),
248                filters.as_ref(),
249                options,
250            )
251            .await
252            .map_err(to_ctx_err)?;
253
254        Ok(results
255            .into_iter()
256            .map(|mut result| {
257                if !request.include_relationships {
258                    result.record.relationships.clear();
259                }
260                RetrieveResultDto {
261                    record: record_to_dto(result.record),
262                    score: result.score,
263                    vector_distance: result.vector_distance,
264                    text_score: result.text_score,
265                    matched_channels: result.matched_channels,
266                }
267            })
268            .collect())
269    }
270
271    fn version(&self) -> u64 {
272        ContextStore::version(self)
273    }
274
275    async fn checkout(&mut self, version: u64) -> ContextResult<()> {
276        ContextStore::checkout(self, version)
277            .await
278            .map_err(to_ctx_err)
279    }
280
281    async fn compact(&mut self, options: Option<CompactRequest>) -> ContextResult<CompactResponse> {
282        let config = options.map(|req| {
283            let mut c = CompactionConfig::default();
284            if let Some(v) = req.target_rows_per_fragment {
285                c.target_rows_per_fragment = v;
286            }
287            if let Some(v) = req.materialize_deletions {
288                c.materialize_deletions = v;
289            }
290            c
291        });
292
293        let metrics = ContextStore::compact(self, config)
294            .await
295            .map_err(to_ctx_err)?;
296        Ok(CompactResponse {
297            fragments_removed: metrics.fragments_removed,
298            fragments_added: metrics.fragments_added,
299            files_removed: metrics.files_removed,
300            files_added: metrics.files_added,
301        })
302    }
303
304    async fn compaction_stats(&self) -> ContextResult<CompactStatsResponse> {
305        let stats = ContextStore::compaction_stats(self)
306            .await
307            .map_err(to_ctx_err)?;
308        Ok(CompactStatsResponse {
309            total_fragments: stats.total_fragments,
310            is_compacting: stats.is_compacting,
311            last_compaction: stats.last_compaction,
312            last_error: stats.last_error,
313            total_compactions: stats.total_compactions,
314        })
315    }
316}
317
318fn dto_to_relationship(r: RelationshipDto) -> Relationship {
319    Relationship {
320        target_id: r.target_id,
321        relation: r.relation,
322        weight: r.weight,
323    }
324}
325
326fn relationship_to_dto(r: Relationship) -> RelationshipDto {
327    RelationshipDto {
328        target_id: r.target_id,
329        relation: r.relation,
330        weight: r.weight,
331    }
332}
333
334fn patch_from_dto(patch: &RecordPatchDto) -> RecordPatch {
335    RecordPatch {
336        bot_id: patch.bot_id.clone(),
337        session_id: patch.session_id.clone(),
338        tenant: patch.tenant.clone(),
339        source: patch.source.clone(),
340        state_metadata: patch.state_metadata.as_ref().map(|sm| StateMetadata {
341            step: sm.step,
342            active_plan_id: sm.active_plan_id.clone(),
343            tokens_used: sm.tokens_used,
344            custom: sm.custom.clone(),
345        }),
346        metadata: patch.metadata.clone(),
347        relationships: patch.relationships.as_ref().map(|relationships| {
348            relationships
349                .iter()
350                .cloned()
351                .map(dto_to_relationship)
352                .collect()
353        }),
354        expires_at: patch.expires_at,
355        retention_policy: patch.retention_policy.clone(),
356        lifecycle_status: patch.lifecycle_status.clone(),
357        retired_at: patch.retired_at,
358        retired_reason: patch.retired_reason.clone(),
359        embedding: patch.embedding.clone(),
360    }
361}
362
363fn record_from_add_request(r: &AddRecordRequest, id: String, run_id: String) -> ContextRecord {
364    ContextRecord {
365        id,
366        external_id: r.external_id.clone(),
367        run_id,
368        bot_id: r.bot_id.clone(),
369        session_id: r.session_id.clone(),
370        tenant: r.tenant.clone(),
371        source: r.source.clone(),
372        created_at: Utc::now(),
373        role: r.role.clone(),
374        state_metadata: r.state_metadata.as_ref().map(|sm| StateMetadata {
375            step: sm.step,
376            active_plan_id: sm.active_plan_id.clone(),
377            tokens_used: sm.tokens_used,
378            custom: sm.custom.clone(),
379        }),
380        metadata: r.metadata.clone(),
381        relationships: r
382            .relationships
383            .iter()
384            .cloned()
385            .map(dto_to_relationship)
386            .collect(),
387        expires_at: r.expires_at,
388        retention_policy: r.retention_policy.clone(),
389        lifecycle_status: LIFECYCLE_ACTIVE.to_string(),
390        retired_at: None,
391        retired_reason: None,
392        supersedes_id: r.supersedes_id.clone(),
393        superseded_by_id: None,
394        content_type: r.content_type.clone(),
395        text_payload: r.text_payload.clone(),
396        binary_payload: r.binary_payload.clone(),
397        embedding: r.embedding.clone(),
398    }
399}
400
401fn record_to_dto(r: ContextRecord) -> RecordDto {
402    RecordDto {
403        id: r.id,
404        external_id: r.external_id,
405        run_id: r.run_id,
406        bot_id: r.bot_id,
407        session_id: r.session_id,
408        tenant: r.tenant,
409        source: r.source,
410        created_at: r.created_at,
411        role: r.role,
412        content_type: r.content_type,
413        text_payload: r.text_payload,
414        binary_payload: r.binary_payload,
415        embedding: r.embedding,
416        state_metadata: r.state_metadata.map(|sm| StateMetadataDto {
417            step: sm.step,
418            active_plan_id: sm.active_plan_id,
419            tokens_used: sm.tokens_used,
420            custom: sm.custom,
421        }),
422        metadata: r.metadata,
423        relationships: r
424            .relationships
425            .into_iter()
426            .map(relationship_to_dto)
427            .collect(),
428        expires_at: r.expires_at,
429        retention_policy: r.retention_policy,
430        lifecycle_status: r.lifecycle_status,
431        retired_at: r.retired_at,
432        retired_reason: r.retired_reason,
433        supersedes_id: r.supersedes_id,
434        superseded_by_id: r.superseded_by_id,
435    }
436}
437
438fn to_ctx_err(err: lance::Error) -> ContextError {
439    let msg = err.to_string();
440    if msg.contains("already in progress") {
441        ContextError::CompactionInProgress
442    } else if msg.contains("not found") || msg.contains("DatasetNotFound") {
443        ContextError::NotFound(msg)
444    } else if msg.contains("Invalid") {
445        ContextError::InvalidRequest(msg)
446    } else {
447        ContextError::Internal(msg)
448    }
449}