Skip to main content

lance_context_core/
api_impl.rs

1use chrono::Utc;
2use serde_json::Value;
3use uuid::Uuid;
4
5use lance_context_api::{
6    AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse,
7    ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RecordPatchDto,
8    RelationshipDto, RetrieveRequest, RetrieveResultDto, SearchRequest, SearchResultDto,
9    StateMetadataDto, UpdateRecordRequest, UpdateRecordResponse, UpsertRecordRequest,
10    UpsertRecordResponse, UpsertRecordsRequest, UpsertRecordsResponse, UpsertResultDto,
11};
12
13use crate::record::{
14    ContextRecord, LifecycleQueryOptions, RecordFilters, RecordPatch, Relationship, StateMetadata,
15    LIFECYCLE_ACTIVE,
16};
17use crate::store::{CompactionConfig, ContextStore};
18
19impl ContextStoreApi for ContextStore {
20    async fn add(&mut self, records: &[AddRecordRequest]) -> ContextResult<AddRecordsResponse> {
21        let run_id = Uuid::new_v4().to_string();
22        let mut ids = Vec::with_capacity(records.len());
23        let mut core_records = Vec::with_capacity(records.len());
24
25        for r in records {
26            let id = Uuid::new_v4().to_string();
27            ids.push(id.clone());
28            core_records.push(record_from_add_request(r, id, run_id.clone()));
29        }
30
31        let count = core_records.len();
32        let version = self.add(&core_records).await.map_err(to_ctx_err)?;
33        Ok(AddRecordsResponse {
34            version,
35            ids,
36            count,
37        })
38    }
39
40    async fn upsert(
41        &mut self,
42        request: &UpsertRecordRequest,
43    ) -> ContextResult<UpsertRecordResponse> {
44        if request.key != "external_id" {
45            return Err(ContextError::InvalidRequest(format!(
46                "upsert key '{}' is not supported; use 'external_id'",
47                request.key
48            )));
49        }
50        if request
51            .record
52            .external_id
53            .as_deref()
54            .is_none_or(str::is_empty)
55        {
56            return Err(ContextError::InvalidRequest(
57                "upsert requires record.external_id".to_string(),
58            ));
59        }
60
61        let record = record_from_add_request(
62            &request.record,
63            Uuid::new_v4().to_string(),
64            Uuid::new_v4().to_string(),
65        );
66        let result = ContextStore::upsert_by_external_id(self, record)
67            .await
68            .map_err(to_ctx_err)?;
69        Ok(UpsertRecordResponse {
70            version: result.version,
71            inserted: result.inserted,
72            replaced_id: result.replaced_id,
73            record: record_to_dto(result.record),
74        })
75    }
76
77    async fn upsert_many(
78        &mut self,
79        request: &UpsertRecordsRequest,
80    ) -> ContextResult<UpsertRecordsResponse> {
81        if request.key != "external_id" {
82            return Err(ContextError::InvalidRequest(format!(
83                "upsert key '{}' is not supported; use 'external_id'",
84                request.key
85            )));
86        }
87        if request.records.is_empty() {
88            return Err(ContextError::InvalidRequest(
89                "upsert_many requires at least one record".to_string(),
90            ));
91        }
92        for (index, record) in request.records.iter().enumerate() {
93            if record.external_id.as_deref().is_none_or(str::is_empty) {
94                return Err(ContextError::InvalidRequest(format!(
95                    "upsert_many requires record.external_id (records[{index}])"
96                )));
97            }
98        }
99
100        let core_records: Vec<ContextRecord> = request
101            .records
102            .iter()
103            .map(|r| {
104                record_from_add_request(r, Uuid::new_v4().to_string(), Uuid::new_v4().to_string())
105            })
106            .collect();
107
108        let results = ContextStore::upsert_many_by_external_id(self, core_records)
109            .await
110            .map_err(to_ctx_err)?;
111        let version = results
112            .last()
113            .map(|r| r.version)
114            .unwrap_or_else(|| ContextStore::version(self));
115        Ok(UpsertRecordsResponse {
116            version,
117            results: results
118                .into_iter()
119                .map(|r| UpsertResultDto {
120                    inserted: r.inserted,
121                    replaced_id: r.replaced_id,
122                    record: record_to_dto(r.record),
123                })
124                .collect(),
125        })
126    }
127
128    async fn update(
129        &mut self,
130        request: &UpdateRecordRequest,
131    ) -> ContextResult<UpdateRecordResponse> {
132        if request.patch.is_empty() {
133            return Err(ContextError::InvalidRequest(
134                "update requires at least one patch field".to_string(),
135            ));
136        }
137
138        let patch = patch_from_dto(&request.patch);
139        let result = match (&request.id, &request.external_id) {
140            (Some(id), None) => ContextStore::update_by_id(self, id, patch).await,
141            (None, Some(external_id)) => {
142                ContextStore::update_by_external_id(self, external_id, patch).await
143            }
144            (None, None) => {
145                return Err(ContextError::InvalidRequest(
146                    "update requires either id or external_id".to_string(),
147                ));
148            }
149            (Some(_), Some(_)) => {
150                return Err(ContextError::InvalidRequest(
151                    "update accepts only one of id or external_id".to_string(),
152                ));
153            }
154        }
155        .map_err(to_ctx_err)?;
156
157        Ok(match result {
158            Some(result) => UpdateRecordResponse {
159                version: result.version,
160                updated: true,
161                replaced_id: Some(result.replaced_id),
162                record: Some(record_to_dto(result.record)),
163            },
164            None => UpdateRecordResponse {
165                version: ContextStore::version(self),
166                updated: false,
167                replaced_id: None,
168                record: None,
169            },
170        })
171    }
172
173    async fn get(&self, id: &str) -> ContextResult<Option<RecordDto>> {
174        let record = ContextStore::get(self, id).await.map_err(to_ctx_err)?;
175        Ok(record.map(record_to_dto))
176    }
177
178    async fn get_by_external_id(&self, external_id: &str) -> ContextResult<Option<RecordDto>> {
179        let record = ContextStore::get_by_external_id(self, external_id)
180            .await
181            .map_err(to_ctx_err)?;
182        Ok(record.map(record_to_dto))
183    }
184
185    async fn delete_by_id(&mut self, id: &str) -> ContextResult<DeleteRecordResponse> {
186        let deleted = ContextStore::delete_by_id(self, id)
187            .await
188            .map_err(to_ctx_err)?;
189        Ok(DeleteRecordResponse {
190            deleted,
191            version: ContextStore::version(self),
192        })
193    }
194
195    async fn delete_by_external_id(
196        &mut self,
197        external_id: &str,
198    ) -> ContextResult<DeleteRecordResponse> {
199        let deleted = ContextStore::delete_by_external_id(self, external_id)
200            .await
201            .map_err(to_ctx_err)?;
202        Ok(DeleteRecordResponse {
203            deleted,
204            version: ContextStore::version(self),
205        })
206    }
207
208    async fn list(
209        &self,
210        limit: Option<usize>,
211        offset: Option<usize>,
212        filters: Option<Value>,
213        include_expired: bool,
214        include_retired: bool,
215    ) -> ContextResult<Vec<RecordDto>> {
216        let filters = filters
217            .map(RecordFilters::from_json_value)
218            .transpose()
219            .map_err(ContextError::InvalidRequest)?;
220        let options = LifecycleQueryOptions::new(include_expired, include_retired);
221        let records = ContextStore::list_filtered_with_options(
222            self,
223            limit,
224            offset,
225            filters.as_ref(),
226            options,
227        )
228        .await
229        .map_err(to_ctx_err)?;
230        Ok(records.into_iter().map(record_to_dto).collect())
231    }
232
233    async fn related(
234        &self,
235        target_id: &str,
236        relation: Option<&str>,
237        limit: Option<usize>,
238        include_expired: bool,
239        include_retired: bool,
240    ) -> ContextResult<Vec<RecordDto>> {
241        let options = LifecycleQueryOptions::new(include_expired, include_retired);
242        let records =
243            ContextStore::list_related_with_options(self, target_id, relation, limit, options)
244                .await
245                .map_err(to_ctx_err)?;
246        Ok(records.into_iter().map(record_to_dto).collect())
247    }
248
249    async fn search(&self, request: &SearchRequest) -> ContextResult<Vec<SearchResultDto>> {
250        let filters = request
251            .filters
252            .clone()
253            .map(RecordFilters::from_json_value)
254            .transpose()
255            .map_err(ContextError::InvalidRequest)?;
256        let options = LifecycleQueryOptions::new(request.include_expired, request.include_retired);
257        let results = ContextStore::search_filtered_with_options(
258            self,
259            &request.query,
260            Some(request.limit),
261            filters.as_ref(),
262            options,
263        )
264        .await
265        .map_err(to_ctx_err)?;
266        Ok(results
267            .into_iter()
268            .map(|mut sr| {
269                if !request.include_relationships {
270                    sr.record.relationships.clear();
271                }
272                SearchResultDto {
273                    record: record_to_dto(sr.record),
274                    distance: sr.distance,
275                }
276            })
277            .collect())
278    }
279
280    async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult<Vec<RetrieveResultDto>> {
281        if request.fusion != "rrf" {
282            return Err(ContextError::InvalidRequest(
283                "retrieve fusion currently supports only 'rrf'".to_string(),
284            ));
285        }
286
287        let filters = request
288            .filters
289            .clone()
290            .map(RecordFilters::from_json_value)
291            .transpose()
292            .map_err(ContextError::InvalidRequest)?;
293        let options = LifecycleQueryOptions::new(request.include_expired, request.include_retired);
294        let results = self
295            .retrieve_filtered_with_options(
296                request.text.as_deref(),
297                request.vector.as_deref(),
298                Some(request.limit),
299                filters.as_ref(),
300                options,
301            )
302            .await
303            .map_err(to_ctx_err)?;
304
305        Ok(results
306            .into_iter()
307            .map(|mut result| {
308                if !request.include_relationships {
309                    result.record.relationships.clear();
310                }
311                RetrieveResultDto {
312                    record: record_to_dto(result.record),
313                    score: result.score,
314                    vector_distance: result.vector_distance,
315                    text_score: result.text_score,
316                    matched_channels: result.matched_channels,
317                }
318            })
319            .collect())
320    }
321
322    fn version(&self) -> u64 {
323        ContextStore::version(self)
324    }
325
326    async fn checkout(&mut self, version: u64) -> ContextResult<()> {
327        ContextStore::checkout(self, version)
328            .await
329            .map_err(to_ctx_err)
330    }
331
332    async fn compact(&mut self, options: Option<CompactRequest>) -> ContextResult<CompactResponse> {
333        let config = options.map(|req| {
334            let mut c = CompactionConfig::default();
335            if let Some(v) = req.target_rows_per_fragment {
336                c.target_rows_per_fragment = v;
337            }
338            if let Some(v) = req.materialize_deletions {
339                c.materialize_deletions = v;
340            }
341            c
342        });
343
344        let metrics = ContextStore::compact(self, config)
345            .await
346            .map_err(to_ctx_err)?;
347        Ok(CompactResponse {
348            fragments_removed: metrics.fragments_removed,
349            fragments_added: metrics.fragments_added,
350            files_removed: metrics.files_removed,
351            files_added: metrics.files_added,
352        })
353    }
354
355    async fn compaction_stats(&self) -> ContextResult<CompactStatsResponse> {
356        let stats = ContextStore::compaction_stats(self)
357            .await
358            .map_err(to_ctx_err)?;
359        Ok(CompactStatsResponse {
360            total_fragments: stats.total_fragments,
361            is_compacting: stats.is_compacting,
362            last_compaction: stats.last_compaction,
363            last_error: stats.last_error,
364            total_compactions: stats.total_compactions,
365        })
366    }
367}
368
369fn dto_to_relationship(r: RelationshipDto) -> Relationship {
370    Relationship {
371        target_id: r.target_id,
372        relation: r.relation,
373        weight: r.weight,
374    }
375}
376
377fn relationship_to_dto(r: Relationship) -> RelationshipDto {
378    RelationshipDto {
379        target_id: r.target_id,
380        relation: r.relation,
381        weight: r.weight,
382    }
383}
384
385fn patch_from_dto(patch: &RecordPatchDto) -> RecordPatch {
386    RecordPatch {
387        bot_id: patch.bot_id.clone(),
388        session_id: patch.session_id.clone(),
389        tenant: patch.tenant.clone(),
390        source: patch.source.clone(),
391        state_metadata: patch.state_metadata.as_ref().map(|sm| StateMetadata {
392            step: sm.step,
393            active_plan_id: sm.active_plan_id.clone(),
394            tokens_used: sm.tokens_used,
395            custom: sm.custom.clone(),
396        }),
397        metadata: patch.metadata.clone(),
398        relationships: patch.relationships.as_ref().map(|relationships| {
399            relationships
400                .iter()
401                .cloned()
402                .map(dto_to_relationship)
403                .collect()
404        }),
405        expires_at: patch.expires_at,
406        retention_policy: patch.retention_policy.clone(),
407        lifecycle_status: patch.lifecycle_status.clone(),
408        retired_at: patch.retired_at,
409        retired_reason: patch.retired_reason.clone(),
410        embedding: patch.embedding.clone(),
411    }
412}
413
414fn record_from_add_request(r: &AddRecordRequest, id: String, run_id: String) -> ContextRecord {
415    ContextRecord {
416        id,
417        external_id: r.external_id.clone(),
418        run_id,
419        bot_id: r.bot_id.clone(),
420        session_id: r.session_id.clone(),
421        tenant: r.tenant.clone(),
422        source: r.source.clone(),
423        created_at: Utc::now(),
424        role: r.role.clone(),
425        state_metadata: r.state_metadata.as_ref().map(|sm| StateMetadata {
426            step: sm.step,
427            active_plan_id: sm.active_plan_id.clone(),
428            tokens_used: sm.tokens_used,
429            custom: sm.custom.clone(),
430        }),
431        metadata: r.metadata.clone(),
432        relationships: r
433            .relationships
434            .iter()
435            .cloned()
436            .map(dto_to_relationship)
437            .collect(),
438        expires_at: r.expires_at,
439        retention_policy: r.retention_policy.clone(),
440        lifecycle_status: LIFECYCLE_ACTIVE.to_string(),
441        retired_at: None,
442        retired_reason: None,
443        supersedes_id: r.supersedes_id.clone(),
444        superseded_by_id: None,
445        content_type: r.content_type.clone(),
446        text_payload: r.text_payload.clone(),
447        binary_payload: r.binary_payload.clone(),
448        embedding: r.embedding.clone(),
449    }
450}
451
452fn record_to_dto(r: ContextRecord) -> RecordDto {
453    RecordDto {
454        id: r.id,
455        external_id: r.external_id,
456        run_id: r.run_id,
457        bot_id: r.bot_id,
458        session_id: r.session_id,
459        tenant: r.tenant,
460        source: r.source,
461        created_at: r.created_at,
462        role: r.role,
463        content_type: r.content_type,
464        text_payload: r.text_payload,
465        binary_payload: r.binary_payload,
466        embedding: r.embedding,
467        state_metadata: r.state_metadata.map(|sm| StateMetadataDto {
468            step: sm.step,
469            active_plan_id: sm.active_plan_id,
470            tokens_used: sm.tokens_used,
471            custom: sm.custom,
472        }),
473        metadata: r.metadata,
474        relationships: r
475            .relationships
476            .into_iter()
477            .map(relationship_to_dto)
478            .collect(),
479        expires_at: r.expires_at,
480        retention_policy: r.retention_policy,
481        lifecycle_status: r.lifecycle_status,
482        retired_at: r.retired_at,
483        retired_reason: r.retired_reason,
484        supersedes_id: r.supersedes_id,
485        superseded_by_id: r.superseded_by_id,
486    }
487}
488
489fn to_ctx_err(err: lance::Error) -> ContextError {
490    let msg = err.to_string();
491    if msg.contains("already in progress") {
492        ContextError::CompactionInProgress
493    } else if msg.contains("not found") || msg.contains("DatasetNotFound") {
494        ContextError::NotFound(msg)
495    } else if msg.contains("Invalid") {
496        ContextError::InvalidRequest(msg)
497    } else {
498        ContextError::Internal(msg)
499    }
500}