1use chrono::Utc;
2use serde_json::Value;
3use uuid::Uuid;
4
5use lance_context_api::{
6 AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse,
7 ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RecordPatchDto,
8 RelationshipDto, RetrieveRequest, RetrieveResultDto, SearchRequest, SearchResultDto,
9 StateMetadataDto, UpdateRecordRequest, UpdateRecordResponse, UpsertRecordRequest,
10 UpsertRecordResponse, UpsertRecordsRequest, UpsertRecordsResponse, UpsertResultDto,
11};
12
13use crate::record::{
14 ContextRecord, LifecycleQueryOptions, RecordFilters, RecordPatch, Relationship, StateMetadata,
15 LIFECYCLE_ACTIVE,
16};
17use crate::store::{CompactionConfig, ContextStore};
18
19impl ContextStoreApi for ContextStore {
20 async fn add(&mut self, records: &[AddRecordRequest]) -> ContextResult<AddRecordsResponse> {
21 let run_id = Uuid::new_v4().to_string();
22 let mut ids = Vec::with_capacity(records.len());
23 let mut core_records = Vec::with_capacity(records.len());
24
25 for r in records {
26 let id = Uuid::new_v4().to_string();
27 ids.push(id.clone());
28 core_records.push(record_from_add_request(r, id, run_id.clone()));
29 }
30
31 let count = core_records.len();
32 let version = self.add(&core_records).await.map_err(to_ctx_err)?;
33 Ok(AddRecordsResponse {
34 version,
35 ids,
36 count,
37 })
38 }
39
40 async fn upsert(
41 &mut self,
42 request: &UpsertRecordRequest,
43 ) -> ContextResult<UpsertRecordResponse> {
44 if request.key != "external_id" {
45 return Err(ContextError::InvalidRequest(format!(
46 "upsert key '{}' is not supported; use 'external_id'",
47 request.key
48 )));
49 }
50 if request
51 .record
52 .external_id
53 .as_deref()
54 .is_none_or(str::is_empty)
55 {
56 return Err(ContextError::InvalidRequest(
57 "upsert requires record.external_id".to_string(),
58 ));
59 }
60
61 let record = record_from_add_request(
62 &request.record,
63 Uuid::new_v4().to_string(),
64 Uuid::new_v4().to_string(),
65 );
66 let result = ContextStore::upsert_by_external_id(self, record)
67 .await
68 .map_err(to_ctx_err)?;
69 Ok(UpsertRecordResponse {
70 version: result.version,
71 inserted: result.inserted,
72 replaced_id: result.replaced_id,
73 record: record_to_dto(result.record),
74 })
75 }
76
77 async fn upsert_many(
78 &mut self,
79 request: &UpsertRecordsRequest,
80 ) -> ContextResult<UpsertRecordsResponse> {
81 if request.key != "external_id" {
82 return Err(ContextError::InvalidRequest(format!(
83 "upsert key '{}' is not supported; use 'external_id'",
84 request.key
85 )));
86 }
87 if request.records.is_empty() {
88 return Err(ContextError::InvalidRequest(
89 "upsert_many requires at least one record".to_string(),
90 ));
91 }
92 for (index, record) in request.records.iter().enumerate() {
93 if record.external_id.as_deref().is_none_or(str::is_empty) {
94 return Err(ContextError::InvalidRequest(format!(
95 "upsert_many requires record.external_id (records[{index}])"
96 )));
97 }
98 }
99
100 let core_records: Vec<ContextRecord> = request
101 .records
102 .iter()
103 .map(|r| {
104 record_from_add_request(r, Uuid::new_v4().to_string(), Uuid::new_v4().to_string())
105 })
106 .collect();
107
108 let results = ContextStore::upsert_many_by_external_id(self, core_records)
109 .await
110 .map_err(to_ctx_err)?;
111 let version = results
112 .last()
113 .map(|r| r.version)
114 .unwrap_or_else(|| ContextStore::version(self));
115 Ok(UpsertRecordsResponse {
116 version,
117 results: results
118 .into_iter()
119 .map(|r| UpsertResultDto {
120 inserted: r.inserted,
121 replaced_id: r.replaced_id,
122 record: record_to_dto(r.record),
123 })
124 .collect(),
125 })
126 }
127
128 async fn update(
129 &mut self,
130 request: &UpdateRecordRequest,
131 ) -> ContextResult<UpdateRecordResponse> {
132 if request.patch.is_empty() {
133 return Err(ContextError::InvalidRequest(
134 "update requires at least one patch field".to_string(),
135 ));
136 }
137
138 let patch = patch_from_dto(&request.patch);
139 let result = match (&request.id, &request.external_id) {
140 (Some(id), None) => ContextStore::update_by_id(self, id, patch).await,
141 (None, Some(external_id)) => {
142 ContextStore::update_by_external_id(self, external_id, patch).await
143 }
144 (None, None) => {
145 return Err(ContextError::InvalidRequest(
146 "update requires either id or external_id".to_string(),
147 ));
148 }
149 (Some(_), Some(_)) => {
150 return Err(ContextError::InvalidRequest(
151 "update accepts only one of id or external_id".to_string(),
152 ));
153 }
154 }
155 .map_err(to_ctx_err)?;
156
157 Ok(match result {
158 Some(result) => UpdateRecordResponse {
159 version: result.version,
160 updated: true,
161 replaced_id: Some(result.replaced_id),
162 record: Some(record_to_dto(result.record)),
163 },
164 None => UpdateRecordResponse {
165 version: ContextStore::version(self),
166 updated: false,
167 replaced_id: None,
168 record: None,
169 },
170 })
171 }
172
173 async fn get(&self, id: &str) -> ContextResult<Option<RecordDto>> {
174 let record = ContextStore::get(self, id).await.map_err(to_ctx_err)?;
175 Ok(record.map(record_to_dto))
176 }
177
178 async fn get_by_external_id(&self, external_id: &str) -> ContextResult<Option<RecordDto>> {
179 let record = ContextStore::get_by_external_id(self, external_id)
180 .await
181 .map_err(to_ctx_err)?;
182 Ok(record.map(record_to_dto))
183 }
184
185 async fn delete_by_id(&mut self, id: &str) -> ContextResult<DeleteRecordResponse> {
186 let deleted = ContextStore::delete_by_id(self, id)
187 .await
188 .map_err(to_ctx_err)?;
189 Ok(DeleteRecordResponse {
190 deleted,
191 version: ContextStore::version(self),
192 })
193 }
194
195 async fn delete_by_external_id(
196 &mut self,
197 external_id: &str,
198 ) -> ContextResult<DeleteRecordResponse> {
199 let deleted = ContextStore::delete_by_external_id(self, external_id)
200 .await
201 .map_err(to_ctx_err)?;
202 Ok(DeleteRecordResponse {
203 deleted,
204 version: ContextStore::version(self),
205 })
206 }
207
208 async fn list(
209 &self,
210 limit: Option<usize>,
211 offset: Option<usize>,
212 filters: Option<Value>,
213 include_expired: bool,
214 include_retired: bool,
215 ) -> ContextResult<Vec<RecordDto>> {
216 let filters = filters
217 .map(RecordFilters::from_json_value)
218 .transpose()
219 .map_err(ContextError::InvalidRequest)?;
220 let options = LifecycleQueryOptions::new(include_expired, include_retired);
221 let records = ContextStore::list_filtered_with_options(
222 self,
223 limit,
224 offset,
225 filters.as_ref(),
226 options,
227 )
228 .await
229 .map_err(to_ctx_err)?;
230 Ok(records.into_iter().map(record_to_dto).collect())
231 }
232
233 async fn related(
234 &self,
235 target_id: &str,
236 relation: Option<&str>,
237 limit: Option<usize>,
238 include_expired: bool,
239 include_retired: bool,
240 ) -> ContextResult<Vec<RecordDto>> {
241 let options = LifecycleQueryOptions::new(include_expired, include_retired);
242 let records =
243 ContextStore::list_related_with_options(self, target_id, relation, limit, options)
244 .await
245 .map_err(to_ctx_err)?;
246 Ok(records.into_iter().map(record_to_dto).collect())
247 }
248
249 async fn search(&self, request: &SearchRequest) -> ContextResult<Vec<SearchResultDto>> {
250 let filters = request
251 .filters
252 .clone()
253 .map(RecordFilters::from_json_value)
254 .transpose()
255 .map_err(ContextError::InvalidRequest)?;
256 let options = LifecycleQueryOptions::new(request.include_expired, request.include_retired);
257 let results = ContextStore::search_filtered_with_options(
258 self,
259 &request.query,
260 Some(request.limit),
261 filters.as_ref(),
262 options,
263 )
264 .await
265 .map_err(to_ctx_err)?;
266 Ok(results
267 .into_iter()
268 .map(|mut sr| {
269 if !request.include_relationships {
270 sr.record.relationships.clear();
271 }
272 SearchResultDto {
273 record: record_to_dto(sr.record),
274 distance: sr.distance,
275 }
276 })
277 .collect())
278 }
279
280 async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult<Vec<RetrieveResultDto>> {
281 if request.fusion != "rrf" {
282 return Err(ContextError::InvalidRequest(
283 "retrieve fusion currently supports only 'rrf'".to_string(),
284 ));
285 }
286
287 let filters = request
288 .filters
289 .clone()
290 .map(RecordFilters::from_json_value)
291 .transpose()
292 .map_err(ContextError::InvalidRequest)?;
293 let options = LifecycleQueryOptions::new(request.include_expired, request.include_retired);
294 let results = self
295 .retrieve_filtered_with_options(
296 request.text.as_deref(),
297 request.vector.as_deref(),
298 Some(request.limit),
299 filters.as_ref(),
300 options,
301 )
302 .await
303 .map_err(to_ctx_err)?;
304
305 Ok(results
306 .into_iter()
307 .map(|mut result| {
308 if !request.include_relationships {
309 result.record.relationships.clear();
310 }
311 RetrieveResultDto {
312 record: record_to_dto(result.record),
313 score: result.score,
314 vector_distance: result.vector_distance,
315 text_score: result.text_score,
316 matched_channels: result.matched_channels,
317 }
318 })
319 .collect())
320 }
321
322 fn version(&self) -> u64 {
323 ContextStore::version(self)
324 }
325
326 async fn checkout(&mut self, version: u64) -> ContextResult<()> {
327 ContextStore::checkout(self, version)
328 .await
329 .map_err(to_ctx_err)
330 }
331
332 async fn compact(&mut self, options: Option<CompactRequest>) -> ContextResult<CompactResponse> {
333 let config = options.map(|req| {
334 let mut c = CompactionConfig::default();
335 if let Some(v) = req.target_rows_per_fragment {
336 c.target_rows_per_fragment = v;
337 }
338 if let Some(v) = req.materialize_deletions {
339 c.materialize_deletions = v;
340 }
341 c
342 });
343
344 let metrics = ContextStore::compact(self, config)
345 .await
346 .map_err(to_ctx_err)?;
347 Ok(CompactResponse {
348 fragments_removed: metrics.fragments_removed,
349 fragments_added: metrics.fragments_added,
350 files_removed: metrics.files_removed,
351 files_added: metrics.files_added,
352 })
353 }
354
355 async fn compaction_stats(&self) -> ContextResult<CompactStatsResponse> {
356 let stats = ContextStore::compaction_stats(self)
357 .await
358 .map_err(to_ctx_err)?;
359 Ok(CompactStatsResponse {
360 total_fragments: stats.total_fragments,
361 is_compacting: stats.is_compacting,
362 last_compaction: stats.last_compaction,
363 last_error: stats.last_error,
364 total_compactions: stats.total_compactions,
365 })
366 }
367}
368
369fn dto_to_relationship(r: RelationshipDto) -> Relationship {
370 Relationship {
371 target_id: r.target_id,
372 relation: r.relation,
373 weight: r.weight,
374 }
375}
376
377fn relationship_to_dto(r: Relationship) -> RelationshipDto {
378 RelationshipDto {
379 target_id: r.target_id,
380 relation: r.relation,
381 weight: r.weight,
382 }
383}
384
385fn patch_from_dto(patch: &RecordPatchDto) -> RecordPatch {
386 RecordPatch {
387 bot_id: patch.bot_id.clone(),
388 session_id: patch.session_id.clone(),
389 tenant: patch.tenant.clone(),
390 source: patch.source.clone(),
391 state_metadata: patch.state_metadata.as_ref().map(|sm| StateMetadata {
392 step: sm.step,
393 active_plan_id: sm.active_plan_id.clone(),
394 tokens_used: sm.tokens_used,
395 custom: sm.custom.clone(),
396 }),
397 metadata: patch.metadata.clone(),
398 relationships: patch.relationships.as_ref().map(|relationships| {
399 relationships
400 .iter()
401 .cloned()
402 .map(dto_to_relationship)
403 .collect()
404 }),
405 expires_at: patch.expires_at,
406 retention_policy: patch.retention_policy.clone(),
407 lifecycle_status: patch.lifecycle_status.clone(),
408 retired_at: patch.retired_at,
409 retired_reason: patch.retired_reason.clone(),
410 embedding: patch.embedding.clone(),
411 }
412}
413
414fn record_from_add_request(r: &AddRecordRequest, id: String, run_id: String) -> ContextRecord {
415 ContextRecord {
416 id,
417 external_id: r.external_id.clone(),
418 run_id,
419 bot_id: r.bot_id.clone(),
420 session_id: r.session_id.clone(),
421 tenant: r.tenant.clone(),
422 source: r.source.clone(),
423 created_at: Utc::now(),
424 role: r.role.clone(),
425 state_metadata: r.state_metadata.as_ref().map(|sm| StateMetadata {
426 step: sm.step,
427 active_plan_id: sm.active_plan_id.clone(),
428 tokens_used: sm.tokens_used,
429 custom: sm.custom.clone(),
430 }),
431 metadata: r.metadata.clone(),
432 relationships: r
433 .relationships
434 .iter()
435 .cloned()
436 .map(dto_to_relationship)
437 .collect(),
438 expires_at: r.expires_at,
439 retention_policy: r.retention_policy.clone(),
440 lifecycle_status: LIFECYCLE_ACTIVE.to_string(),
441 retired_at: None,
442 retired_reason: None,
443 supersedes_id: r.supersedes_id.clone(),
444 superseded_by_id: None,
445 content_type: r.content_type.clone(),
446 text_payload: r.text_payload.clone(),
447 binary_payload: r.binary_payload.clone(),
448 embedding: r.embedding.clone(),
449 }
450}
451
452fn record_to_dto(r: ContextRecord) -> RecordDto {
453 RecordDto {
454 id: r.id,
455 external_id: r.external_id,
456 run_id: r.run_id,
457 bot_id: r.bot_id,
458 session_id: r.session_id,
459 tenant: r.tenant,
460 source: r.source,
461 created_at: r.created_at,
462 role: r.role,
463 content_type: r.content_type,
464 text_payload: r.text_payload,
465 binary_payload: r.binary_payload,
466 embedding: r.embedding,
467 state_metadata: r.state_metadata.map(|sm| StateMetadataDto {
468 step: sm.step,
469 active_plan_id: sm.active_plan_id,
470 tokens_used: sm.tokens_used,
471 custom: sm.custom,
472 }),
473 metadata: r.metadata,
474 relationships: r
475 .relationships
476 .into_iter()
477 .map(relationship_to_dto)
478 .collect(),
479 expires_at: r.expires_at,
480 retention_policy: r.retention_policy,
481 lifecycle_status: r.lifecycle_status,
482 retired_at: r.retired_at,
483 retired_reason: r.retired_reason,
484 supersedes_id: r.supersedes_id,
485 superseded_by_id: r.superseded_by_id,
486 }
487}
488
489fn to_ctx_err(err: lance::Error) -> ContextError {
490 let msg = err.to_string();
491 if msg.contains("already in progress") {
492 ContextError::CompactionInProgress
493 } else if msg.contains("not found") || msg.contains("DatasetNotFound") {
494 ContextError::NotFound(msg)
495 } else if msg.contains("Invalid") {
496 ContextError::InvalidRequest(msg)
497 } else {
498 ContextError::Internal(msg)
499 }
500}