Skip to main content

lance_context_api/
lib.rs

1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::future::Future;
6
7// ---------------------------------------------------------------------------
8// Unified error
9// ---------------------------------------------------------------------------
10
11#[derive(Debug, thiserror::Error)]
12pub enum ContextError {
13    #[error("{0}")]
14    NotFound(String),
15    #[error("{0}")]
16    AlreadyExists(String),
17    #[error("{0}")]
18    InvalidRequest(String),
19    #[error("{0}")]
20    Internal(String),
21    #[error("Compaction already in progress")]
22    CompactionInProgress,
23}
24
25pub type ContextResult<T> = Result<T, ContextError>;
26
27// ---------------------------------------------------------------------------
28// Unified trait
29// ---------------------------------------------------------------------------
30
31pub trait ContextStoreApi {
32    fn add(
33        &mut self,
34        records: &[AddRecordRequest],
35    ) -> impl Future<Output = ContextResult<AddRecordsResponse>> + Send;
36
37    fn upsert(
38        &mut self,
39        request: &UpsertRecordRequest,
40    ) -> impl Future<Output = ContextResult<UpsertRecordResponse>> + Send;
41
42    fn update(
43        &mut self,
44        request: &UpdateRecordRequest,
45    ) -> impl Future<Output = ContextResult<UpdateRecordResponse>> + Send;
46
47    fn get(&self, id: &str) -> impl Future<Output = ContextResult<Option<RecordDto>>> + Send;
48
49    fn get_by_external_id(
50        &self,
51        external_id: &str,
52    ) -> impl Future<Output = ContextResult<Option<RecordDto>>> + Send;
53
54    fn delete_by_id(
55        &mut self,
56        id: &str,
57    ) -> impl Future<Output = ContextResult<DeleteRecordResponse>> + Send;
58
59    fn delete_by_external_id(
60        &mut self,
61        external_id: &str,
62    ) -> impl Future<Output = ContextResult<DeleteRecordResponse>> + Send;
63
64    fn list(
65        &self,
66        limit: Option<usize>,
67        offset: Option<usize>,
68        filters: Option<Value>,
69        include_expired: bool,
70        include_retired: bool,
71    ) -> impl Future<Output = ContextResult<Vec<RecordDto>>> + Send;
72
73    fn related(
74        &self,
75        target_id: &str,
76        relation: Option<&str>,
77        limit: Option<usize>,
78        include_expired: bool,
79        include_retired: bool,
80    ) -> impl Future<Output = ContextResult<Vec<RecordDto>>> + Send;
81
82    fn search(
83        &self,
84        request: &SearchRequest,
85    ) -> impl Future<Output = ContextResult<Vec<SearchResultDto>>> + Send;
86
87    fn retrieve(
88        &self,
89        request: &RetrieveRequest,
90    ) -> impl Future<Output = ContextResult<Vec<RetrieveResultDto>>> + Send;
91
92    fn version(&self) -> u64;
93
94    fn checkout(&mut self, version: u64) -> impl Future<Output = ContextResult<()>> + Send;
95
96    fn compact(
97        &mut self,
98        options: Option<CompactRequest>,
99    ) -> impl Future<Output = ContextResult<CompactResponse>> + Send;
100
101    fn compaction_stats(&self) -> impl Future<Output = ContextResult<CompactStatsResponse>> + Send;
102}
103
104// ---------------------------------------------------------------------------
105// Context lifecycle
106// ---------------------------------------------------------------------------
107
108#[derive(Debug, Serialize, Deserialize)]
109pub struct CreateContextRequest {
110    pub name: String,
111    #[serde(default)]
112    pub storage_options: Option<std::collections::HashMap<String, String>>,
113    #[serde(default)]
114    pub id_index_type: Option<String>,
115    #[serde(default)]
116    pub blob_columns: Option<Vec<String>>,
117    #[serde(default)]
118    pub embedding_dim: Option<i32>,
119    #[serde(default)]
120    pub distance_metric: Option<String>,
121}
122
123#[derive(Debug, Serialize, Deserialize)]
124pub struct ContextInfo {
125    pub name: String,
126    pub uri: String,
127    pub version: u64,
128}
129
130#[derive(Debug, Serialize, Deserialize)]
131pub struct ListContextsResponse {
132    pub contexts: Vec<ContextInfo>,
133}
134
135// ---------------------------------------------------------------------------
136// Records
137// ---------------------------------------------------------------------------
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct StateMetadataDto {
141    #[serde(default, skip_serializing_if = "Option::is_none")]
142    pub step: Option<i32>,
143    #[serde(default, skip_serializing_if = "Option::is_none")]
144    pub active_plan_id: Option<String>,
145    #[serde(default, skip_serializing_if = "Option::is_none")]
146    pub tokens_used: Option<i32>,
147    #[serde(default, skip_serializing_if = "Option::is_none")]
148    pub custom: Option<String>,
149}
150
151#[derive(Debug, Clone, Default, Serialize, Deserialize)]
152pub struct RelationshipDto {
153    pub target_id: String,
154    pub relation: String,
155    #[serde(default, skip_serializing_if = "Option::is_none")]
156    pub weight: Option<f32>,
157}
158
159#[derive(Debug, Clone, Default, Serialize, Deserialize)]
160pub struct AddRecordRequest {
161    #[serde(default = "default_role")]
162    pub role: String,
163    #[serde(default = "default_content_type")]
164    pub content_type: String,
165    #[serde(default, skip_serializing_if = "Option::is_none")]
166    pub text_payload: Option<String>,
167    #[serde(
168        default,
169        skip_serializing_if = "Option::is_none",
170        serialize_with = "serialize_base64_opt",
171        deserialize_with = "deserialize_base64_opt"
172    )]
173    pub binary_payload: Option<Vec<u8>>,
174    #[serde(default, skip_serializing_if = "Option::is_none")]
175    pub embedding: Option<Vec<f32>>,
176    #[serde(default, skip_serializing_if = "Option::is_none")]
177    pub bot_id: Option<String>,
178    #[serde(default, skip_serializing_if = "Option::is_none")]
179    pub session_id: Option<String>,
180    #[serde(default, skip_serializing_if = "Option::is_none")]
181    pub tenant: Option<String>,
182    #[serde(default, skip_serializing_if = "Option::is_none")]
183    pub source: Option<String>,
184    #[serde(default, skip_serializing_if = "Option::is_none")]
185    pub external_id: Option<String>,
186    #[serde(default, skip_serializing_if = "Option::is_none")]
187    pub state_metadata: Option<StateMetadataDto>,
188    #[serde(default, skip_serializing_if = "Option::is_none")]
189    pub metadata: Option<Value>,
190    #[serde(default, skip_serializing_if = "Vec::is_empty")]
191    pub relationships: Vec<RelationshipDto>,
192    #[serde(default, skip_serializing_if = "Option::is_none")]
193    pub expires_at: Option<DateTime<Utc>>,
194    #[serde(default, skip_serializing_if = "Option::is_none")]
195    pub retention_policy: Option<String>,
196    #[serde(default, skip_serializing_if = "Option::is_none")]
197    pub supersedes_id: Option<String>,
198}
199
200#[derive(Debug, Serialize, Deserialize)]
201pub struct AddRecordsRequest {
202    pub records: Vec<AddRecordRequest>,
203}
204
205#[derive(Debug, Serialize, Deserialize)]
206pub struct AddRecordsResponse {
207    pub version: u64,
208    pub ids: Vec<String>,
209    pub count: usize,
210}
211
212#[derive(Debug, Serialize, Deserialize)]
213pub struct UpsertRecordRequest {
214    pub record: AddRecordRequest,
215    #[serde(default = "default_upsert_key")]
216    pub key: String,
217}
218
219#[derive(Debug, Serialize, Deserialize)]
220pub struct UpsertRecordResponse {
221    pub version: u64,
222    pub inserted: bool,
223    #[serde(default, skip_serializing_if = "Option::is_none")]
224    pub replaced_id: Option<String>,
225    pub record: RecordDto,
226}
227
228#[derive(Debug, Clone, Default, Serialize, Deserialize)]
229pub struct RecordPatchDto {
230    #[serde(default, skip_serializing_if = "Option::is_none")]
231    pub bot_id: Option<String>,
232    #[serde(default, skip_serializing_if = "Option::is_none")]
233    pub session_id: Option<String>,
234    #[serde(default, skip_serializing_if = "Option::is_none")]
235    pub tenant: Option<String>,
236    #[serde(default, skip_serializing_if = "Option::is_none")]
237    pub source: Option<String>,
238    #[serde(default, skip_serializing_if = "Option::is_none")]
239    pub state_metadata: Option<StateMetadataDto>,
240    #[serde(default, skip_serializing_if = "Option::is_none")]
241    pub metadata: Option<Value>,
242    #[serde(default, skip_serializing_if = "Option::is_none")]
243    pub relationships: Option<Vec<RelationshipDto>>,
244    #[serde(default, skip_serializing_if = "Option::is_none")]
245    pub expires_at: Option<DateTime<Utc>>,
246    #[serde(default, skip_serializing_if = "Option::is_none")]
247    pub retention_policy: Option<String>,
248    #[serde(default, skip_serializing_if = "Option::is_none")]
249    pub lifecycle_status: Option<String>,
250    #[serde(default, skip_serializing_if = "Option::is_none")]
251    pub retired_at: Option<DateTime<Utc>>,
252    #[serde(default, skip_serializing_if = "Option::is_none")]
253    pub retired_reason: Option<String>,
254    #[serde(default, skip_serializing_if = "Option::is_none")]
255    pub embedding: Option<Vec<f32>>,
256}
257
258impl RecordPatchDto {
259    #[must_use]
260    pub fn is_empty(&self) -> bool {
261        self.bot_id.is_none()
262            && self.session_id.is_none()
263            && self.tenant.is_none()
264            && self.source.is_none()
265            && self.state_metadata.is_none()
266            && self.metadata.is_none()
267            && self.relationships.is_none()
268            && self.expires_at.is_none()
269            && self.retention_policy.is_none()
270            && self.lifecycle_status.is_none()
271            && self.retired_at.is_none()
272            && self.retired_reason.is_none()
273            && self.embedding.is_none()
274    }
275}
276
277#[derive(Debug, Serialize, Deserialize)]
278pub struct UpdateRecordRequest {
279    #[serde(default, skip_serializing_if = "Option::is_none")]
280    pub id: Option<String>,
281    #[serde(default, skip_serializing_if = "Option::is_none")]
282    pub external_id: Option<String>,
283    #[serde(default)]
284    pub patch: RecordPatchDto,
285}
286
287#[derive(Debug, Serialize, Deserialize)]
288pub struct UpdateRecordResponse {
289    pub version: u64,
290    pub updated: bool,
291    #[serde(default, skip_serializing_if = "Option::is_none")]
292    pub replaced_id: Option<String>,
293    #[serde(default, skip_serializing_if = "Option::is_none")]
294    pub record: Option<RecordDto>,
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct RecordDto {
299    pub id: String,
300    #[serde(default, skip_serializing_if = "Option::is_none")]
301    pub external_id: Option<String>,
302    pub run_id: String,
303    #[serde(default, skip_serializing_if = "Option::is_none")]
304    pub bot_id: Option<String>,
305    #[serde(default, skip_serializing_if = "Option::is_none")]
306    pub session_id: Option<String>,
307    #[serde(default, skip_serializing_if = "Option::is_none")]
308    pub tenant: Option<String>,
309    #[serde(default, skip_serializing_if = "Option::is_none")]
310    pub source: Option<String>,
311    pub created_at: DateTime<Utc>,
312    pub role: String,
313    pub content_type: String,
314    #[serde(default, skip_serializing_if = "Option::is_none")]
315    pub text_payload: Option<String>,
316    #[serde(
317        default,
318        skip_serializing_if = "Option::is_none",
319        serialize_with = "serialize_base64_opt",
320        deserialize_with = "deserialize_base64_opt"
321    )]
322    pub binary_payload: Option<Vec<u8>>,
323    #[serde(default, skip_serializing_if = "Option::is_none")]
324    pub embedding: Option<Vec<f32>>,
325    #[serde(default, skip_serializing_if = "Option::is_none")]
326    pub state_metadata: Option<StateMetadataDto>,
327    #[serde(default, skip_serializing_if = "Option::is_none")]
328    pub metadata: Option<Value>,
329    #[serde(default, skip_serializing_if = "Vec::is_empty")]
330    pub relationships: Vec<RelationshipDto>,
331    #[serde(default, skip_serializing_if = "Option::is_none")]
332    pub expires_at: Option<DateTime<Utc>>,
333    #[serde(default, skip_serializing_if = "Option::is_none")]
334    pub retention_policy: Option<String>,
335    pub lifecycle_status: String,
336    #[serde(default, skip_serializing_if = "Option::is_none")]
337    pub retired_at: Option<DateTime<Utc>>,
338    #[serde(default, skip_serializing_if = "Option::is_none")]
339    pub retired_reason: Option<String>,
340    #[serde(default, skip_serializing_if = "Option::is_none")]
341    pub supersedes_id: Option<String>,
342    #[serde(default, skip_serializing_if = "Option::is_none")]
343    pub superseded_by_id: Option<String>,
344}
345
346#[derive(Debug, Serialize, Deserialize)]
347pub struct ListRecordsResponse {
348    pub records: Vec<RecordDto>,
349}
350
351// ---------------------------------------------------------------------------
352// Single record lookup
353// ---------------------------------------------------------------------------
354
355#[derive(Debug, Serialize, Deserialize)]
356pub struct GetRecordResponse {
357    pub record: Option<RecordDto>,
358}
359
360#[derive(Debug, Serialize, Deserialize)]
361pub struct DeleteRecordResponse {
362    pub deleted: bool,
363    pub version: u64,
364}
365
366// ---------------------------------------------------------------------------
367// Search
368// ---------------------------------------------------------------------------
369
370#[derive(Debug, Serialize, Deserialize)]
371pub struct SearchRequest {
372    pub query: Vec<f32>,
373    #[serde(default = "default_search_limit")]
374    pub limit: usize,
375    #[serde(default, skip_serializing_if = "Option::is_none")]
376    pub filters: Option<Value>,
377    #[serde(default)]
378    pub include_expired: bool,
379    #[serde(default)]
380    pub include_retired: bool,
381    #[serde(default)]
382    pub include_relationships: bool,
383}
384
385#[derive(Debug, Serialize, Deserialize)]
386pub struct SearchResultDto {
387    pub record: RecordDto,
388    pub distance: f32,
389}
390
391#[derive(Debug, Serialize, Deserialize)]
392pub struct SearchResponse {
393    pub results: Vec<SearchResultDto>,
394}
395
396// ---------------------------------------------------------------------------
397// Hybrid retrieval
398// ---------------------------------------------------------------------------
399
400#[derive(Debug, Serialize, Deserialize)]
401pub struct RetrieveRequest {
402    #[serde(default, skip_serializing_if = "Option::is_none")]
403    pub text: Option<String>,
404    #[serde(default, skip_serializing_if = "Option::is_none")]
405    pub vector: Option<Vec<f32>>,
406    #[serde(default = "default_search_limit")]
407    pub limit: usize,
408    #[serde(default, skip_serializing_if = "Option::is_none")]
409    pub filters: Option<Value>,
410    #[serde(default)]
411    pub include_expired: bool,
412    #[serde(default)]
413    pub include_retired: bool,
414    #[serde(default)]
415    pub include_relationships: bool,
416    #[serde(default = "default_retrieve_fusion")]
417    pub fusion: String,
418}
419
420#[derive(Debug, Serialize, Deserialize)]
421pub struct RetrieveResultDto {
422    pub record: RecordDto,
423    pub score: f32,
424    #[serde(default, skip_serializing_if = "Option::is_none")]
425    pub vector_distance: Option<f32>,
426    #[serde(default, skip_serializing_if = "Option::is_none")]
427    pub text_score: Option<f32>,
428    #[serde(default, skip_serializing_if = "Vec::is_empty")]
429    pub matched_channels: Vec<String>,
430}
431
432#[derive(Debug, Serialize, Deserialize)]
433pub struct RetrieveResponse {
434    pub results: Vec<RetrieveResultDto>,
435}
436
437// ---------------------------------------------------------------------------
438// Versioning
439// ---------------------------------------------------------------------------
440
441#[derive(Debug, Serialize, Deserialize)]
442pub struct VersionResponse {
443    pub version: u64,
444}
445
446#[derive(Debug, Serialize, Deserialize)]
447pub struct CheckoutRequest {
448    pub version: u64,
449}
450
451// ---------------------------------------------------------------------------
452// Compaction
453// ---------------------------------------------------------------------------
454
455#[derive(Debug, Default, Serialize, Deserialize)]
456pub struct CompactRequest {
457    #[serde(default, skip_serializing_if = "Option::is_none")]
458    pub target_rows_per_fragment: Option<usize>,
459    #[serde(default, skip_serializing_if = "Option::is_none")]
460    pub materialize_deletions: Option<bool>,
461}
462
463#[derive(Debug, Serialize, Deserialize)]
464pub struct CompactResponse {
465    pub fragments_removed: usize,
466    pub fragments_added: usize,
467    pub files_removed: usize,
468    pub files_added: usize,
469}
470
471#[derive(Debug, Serialize, Deserialize)]
472pub struct CompactStatsResponse {
473    pub total_fragments: usize,
474    pub is_compacting: bool,
475    #[serde(default, skip_serializing_if = "Option::is_none")]
476    pub last_compaction: Option<DateTime<Utc>>,
477    #[serde(default, skip_serializing_if = "Option::is_none")]
478    pub last_error: Option<String>,
479    pub total_compactions: u64,
480}
481
482// ---------------------------------------------------------------------------
483// Error
484// ---------------------------------------------------------------------------
485
486#[derive(Debug, Serialize, Deserialize)]
487pub struct ErrorBody {
488    pub code: String,
489    pub message: String,
490}
491
492#[derive(Debug, Serialize, Deserialize)]
493pub struct ErrorResponse {
494    pub error: ErrorBody,
495}
496
497// ---------------------------------------------------------------------------
498// Helpers
499// ---------------------------------------------------------------------------
500
501fn default_content_type() -> String {
502    "text/plain".to_string()
503}
504
505fn default_role() -> String {
506    "user".to_string()
507}
508
509fn default_upsert_key() -> String {
510    "external_id".to_string()
511}
512
513fn default_search_limit() -> usize {
514    10
515}
516
517fn default_retrieve_fusion() -> String {
518    "rrf".to_string()
519}
520
521fn serialize_base64_opt<S>(data: &Option<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
522where
523    S: serde::Serializer,
524{
525    match data {
526        Some(bytes) => serializer.serialize_some(&BASE64.encode(bytes)),
527        None => serializer.serialize_none(),
528    }
529}
530
531fn deserialize_base64_opt<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
532where
533    D: serde::Deserializer<'de>,
534{
535    let opt: Option<String> = Option::deserialize(deserializer)?;
536    match opt {
537        Some(s) => BASE64
538            .decode(&s)
539            .map(Some)
540            .map_err(serde::de::Error::custom),
541        None => Ok(None),
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548
549    #[test]
550    fn search_request_legacy_payload_defaults_filters_and_lifecycle() {
551        // Clients written against the pre-#89 shape send only query/limit.
552        let req: SearchRequest =
553            serde_json::from_str(r#"{"query": [0.1, 0.2], "limit": 5}"#).unwrap();
554        assert_eq!(req.query, vec![0.1, 0.2]);
555        assert_eq!(req.limit, 5);
556        assert!(req.filters.is_none());
557        assert!(!req.include_expired);
558        assert!(!req.include_retired);
559        assert!(!req.include_relationships);
560    }
561
562    #[test]
563    fn search_request_defaults_limit_when_omitted() {
564        let req: SearchRequest = serde_json::from_str(r#"{"query": [1.0]}"#).unwrap();
565        assert_eq!(req.limit, default_search_limit());
566    }
567
568    #[test]
569    fn search_request_parses_filters_and_lifecycle() {
570        let req: SearchRequest = serde_json::from_str(
571            r#"{"query": [1.0], "filters": {"tenant": "acme"}, "include_expired": true, "include_retired": true}"#,
572        )
573        .unwrap();
574        assert_eq!(req.filters, Some(serde_json::json!({"tenant": "acme"})));
575        assert!(req.include_expired);
576        assert!(req.include_retired);
577    }
578}