use chrono::Utc;
use serde_json::Value;
use uuid::Uuid;
use lance_context_api::{
AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse,
ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RecordPatchDto,
RelationshipDto, RetrieveRequest, RetrieveResultDto, SearchRequest, SearchResultDto,
StateMetadataDto, UpdateRecordRequest, UpdateRecordResponse, UpsertRecordRequest,
UpsertRecordResponse, UpsertRecordsRequest, UpsertRecordsResponse, UpsertResultDto,
};
use crate::record::{
ContextRecord, LifecycleQueryOptions, RecordFilters, RecordPatch, Relationship, StateMetadata,
LIFECYCLE_ACTIVE,
};
use crate::store::{CompactionConfig, ContextStore};
impl ContextStoreApi for ContextStore {
async fn add(&mut self, records: &[AddRecordRequest]) -> ContextResult<AddRecordsResponse> {
let run_id = Uuid::new_v4().to_string();
let mut ids = Vec::with_capacity(records.len());
let mut core_records = Vec::with_capacity(records.len());
for r in records {
let id = Uuid::new_v4().to_string();
ids.push(id.clone());
core_records.push(record_from_add_request(r, id, run_id.clone()));
}
let count = core_records.len();
let version = self.add(&core_records).await.map_err(to_ctx_err)?;
Ok(AddRecordsResponse {
version,
ids,
count,
})
}
async fn upsert(
&mut self,
request: &UpsertRecordRequest,
) -> ContextResult<UpsertRecordResponse> {
if request.key != "external_id" {
return Err(ContextError::InvalidRequest(format!(
"upsert key '{}' is not supported; use 'external_id'",
request.key
)));
}
if request
.record
.external_id
.as_deref()
.is_none_or(str::is_empty)
{
return Err(ContextError::InvalidRequest(
"upsert requires record.external_id".to_string(),
));
}
let record = record_from_add_request(
&request.record,
Uuid::new_v4().to_string(),
Uuid::new_v4().to_string(),
);
let result = ContextStore::upsert_by_external_id(self, record)
.await
.map_err(to_ctx_err)?;
Ok(UpsertRecordResponse {
version: result.version,
inserted: result.inserted,
replaced_id: result.replaced_id,
record: record_to_dto(result.record),
})
}
async fn upsert_many(
&mut self,
request: &UpsertRecordsRequest,
) -> ContextResult<UpsertRecordsResponse> {
if request.key != "external_id" {
return Err(ContextError::InvalidRequest(format!(
"upsert key '{}' is not supported; use 'external_id'",
request.key
)));
}
if request.records.is_empty() {
return Err(ContextError::InvalidRequest(
"upsert_many requires at least one record".to_string(),
));
}
for (index, record) in request.records.iter().enumerate() {
if record.external_id.as_deref().is_none_or(str::is_empty) {
return Err(ContextError::InvalidRequest(format!(
"upsert_many requires record.external_id (records[{index}])"
)));
}
}
let core_records: Vec<ContextRecord> = request
.records
.iter()
.map(|r| {
record_from_add_request(r, Uuid::new_v4().to_string(), Uuid::new_v4().to_string())
})
.collect();
let results = ContextStore::upsert_many_by_external_id(self, core_records)
.await
.map_err(to_ctx_err)?;
let version = results
.last()
.map(|r| r.version)
.unwrap_or_else(|| ContextStore::version(self));
Ok(UpsertRecordsResponse {
version,
results: results
.into_iter()
.map(|r| UpsertResultDto {
inserted: r.inserted,
replaced_id: r.replaced_id,
record: record_to_dto(r.record),
})
.collect(),
})
}
async fn update(
&mut self,
request: &UpdateRecordRequest,
) -> ContextResult<UpdateRecordResponse> {
if request.patch.is_empty() {
return Err(ContextError::InvalidRequest(
"update requires at least one patch field".to_string(),
));
}
let patch = patch_from_dto(&request.patch);
let result = match (&request.id, &request.external_id) {
(Some(id), None) => ContextStore::update_by_id(self, id, patch).await,
(None, Some(external_id)) => {
ContextStore::update_by_external_id(self, external_id, patch).await
}
(None, None) => {
return Err(ContextError::InvalidRequest(
"update requires either id or external_id".to_string(),
));
}
(Some(_), Some(_)) => {
return Err(ContextError::InvalidRequest(
"update accepts only one of id or external_id".to_string(),
));
}
}
.map_err(to_ctx_err)?;
Ok(match result {
Some(result) => UpdateRecordResponse {
version: result.version,
updated: true,
replaced_id: Some(result.replaced_id),
record: Some(record_to_dto(result.record)),
},
None => UpdateRecordResponse {
version: ContextStore::version(self),
updated: false,
replaced_id: None,
record: None,
},
})
}
async fn get(&self, id: &str) -> ContextResult<Option<RecordDto>> {
let record = ContextStore::get(self, id).await.map_err(to_ctx_err)?;
Ok(record.map(record_to_dto))
}
async fn get_by_external_id(&self, external_id: &str) -> ContextResult<Option<RecordDto>> {
let record = ContextStore::get_by_external_id(self, external_id)
.await
.map_err(to_ctx_err)?;
Ok(record.map(record_to_dto))
}
async fn delete_by_id(&mut self, id: &str) -> ContextResult<DeleteRecordResponse> {
let deleted = ContextStore::delete_by_id(self, id)
.await
.map_err(to_ctx_err)?;
Ok(DeleteRecordResponse {
deleted,
version: ContextStore::version(self),
})
}
async fn delete_by_external_id(
&mut self,
external_id: &str,
) -> ContextResult<DeleteRecordResponse> {
let deleted = ContextStore::delete_by_external_id(self, external_id)
.await
.map_err(to_ctx_err)?;
Ok(DeleteRecordResponse {
deleted,
version: ContextStore::version(self),
})
}
async fn list(
&self,
limit: Option<usize>,
offset: Option<usize>,
filters: Option<Value>,
include_expired: bool,
include_retired: bool,
) -> ContextResult<Vec<RecordDto>> {
let filters = filters
.map(RecordFilters::from_json_value)
.transpose()
.map_err(ContextError::InvalidRequest)?;
let options = LifecycleQueryOptions::new(include_expired, include_retired);
let records = ContextStore::list_filtered_with_options(
self,
limit,
offset,
filters.as_ref(),
options,
)
.await
.map_err(to_ctx_err)?;
Ok(records.into_iter().map(record_to_dto).collect())
}
async fn related(
&self,
target_id: &str,
relation: Option<&str>,
limit: Option<usize>,
include_expired: bool,
include_retired: bool,
) -> ContextResult<Vec<RecordDto>> {
let options = LifecycleQueryOptions::new(include_expired, include_retired);
let records =
ContextStore::list_related_with_options(self, target_id, relation, limit, options)
.await
.map_err(to_ctx_err)?;
Ok(records.into_iter().map(record_to_dto).collect())
}
async fn search(&self, request: &SearchRequest) -> ContextResult<Vec<SearchResultDto>> {
let filters = request
.filters
.clone()
.map(RecordFilters::from_json_value)
.transpose()
.map_err(ContextError::InvalidRequest)?;
let options = LifecycleQueryOptions::new(request.include_expired, request.include_retired);
let results = ContextStore::search_filtered_with_options(
self,
&request.query,
Some(request.limit),
filters.as_ref(),
options,
)
.await
.map_err(to_ctx_err)?;
Ok(results
.into_iter()
.map(|mut sr| {
if !request.include_relationships {
sr.record.relationships.clear();
}
SearchResultDto {
record: record_to_dto(sr.record),
distance: sr.distance,
}
})
.collect())
}
async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult<Vec<RetrieveResultDto>> {
if request.fusion != "rrf" {
return Err(ContextError::InvalidRequest(
"retrieve fusion currently supports only 'rrf'".to_string(),
));
}
let filters = request
.filters
.clone()
.map(RecordFilters::from_json_value)
.transpose()
.map_err(ContextError::InvalidRequest)?;
let options = LifecycleQueryOptions::new(request.include_expired, request.include_retired);
let results = self
.retrieve_filtered_with_options(
request.text.as_deref(),
request.vector.as_deref(),
Some(request.limit),
filters.as_ref(),
options,
)
.await
.map_err(to_ctx_err)?;
Ok(results
.into_iter()
.map(|mut result| {
if !request.include_relationships {
result.record.relationships.clear();
}
RetrieveResultDto {
record: record_to_dto(result.record),
score: result.score,
vector_distance: result.vector_distance,
text_score: result.text_score,
matched_channels: result.matched_channels,
}
})
.collect())
}
fn version(&self) -> u64 {
ContextStore::version(self)
}
async fn checkout(&mut self, version: u64) -> ContextResult<()> {
ContextStore::checkout(self, version)
.await
.map_err(to_ctx_err)
}
async fn compact(&mut self, options: Option<CompactRequest>) -> ContextResult<CompactResponse> {
let config = options.map(|req| {
let mut c = CompactionConfig::default();
if let Some(v) = req.target_rows_per_fragment {
c.target_rows_per_fragment = v;
}
if let Some(v) = req.materialize_deletions {
c.materialize_deletions = v;
}
c
});
let metrics = ContextStore::compact(self, config)
.await
.map_err(to_ctx_err)?;
Ok(CompactResponse {
fragments_removed: metrics.fragments_removed,
fragments_added: metrics.fragments_added,
files_removed: metrics.files_removed,
files_added: metrics.files_added,
})
}
async fn compaction_stats(&self) -> ContextResult<CompactStatsResponse> {
let stats = ContextStore::compaction_stats(self)
.await
.map_err(to_ctx_err)?;
Ok(CompactStatsResponse {
total_fragments: stats.total_fragments,
is_compacting: stats.is_compacting,
last_compaction: stats.last_compaction,
last_error: stats.last_error,
total_compactions: stats.total_compactions,
})
}
}
fn dto_to_relationship(r: RelationshipDto) -> Relationship {
Relationship {
target_id: r.target_id,
relation: r.relation,
weight: r.weight,
}
}
fn relationship_to_dto(r: Relationship) -> RelationshipDto {
RelationshipDto {
target_id: r.target_id,
relation: r.relation,
weight: r.weight,
}
}
fn patch_from_dto(patch: &RecordPatchDto) -> RecordPatch {
RecordPatch {
bot_id: patch.bot_id.clone(),
session_id: patch.session_id.clone(),
tenant: patch.tenant.clone(),
source: patch.source.clone(),
state_metadata: patch.state_metadata.as_ref().map(|sm| StateMetadata {
step: sm.step,
active_plan_id: sm.active_plan_id.clone(),
tokens_used: sm.tokens_used,
custom: sm.custom.clone(),
}),
metadata: patch.metadata.clone(),
relationships: patch.relationships.as_ref().map(|relationships| {
relationships
.iter()
.cloned()
.map(dto_to_relationship)
.collect()
}),
expires_at: patch.expires_at,
retention_policy: patch.retention_policy.clone(),
lifecycle_status: patch.lifecycle_status.clone(),
retired_at: patch.retired_at,
retired_reason: patch.retired_reason.clone(),
embedding: patch.embedding.clone(),
payload_uri: patch.payload_uri.clone(),
payload_size: patch.payload_size,
payload_checksum: patch.payload_checksum.clone(),
}
}
fn record_from_add_request(r: &AddRecordRequest, id: String, run_id: String) -> ContextRecord {
ContextRecord {
id,
external_id: r.external_id.clone(),
run_id,
bot_id: r.bot_id.clone(),
session_id: r.session_id.clone(),
tenant: r.tenant.clone(),
source: r.source.clone(),
created_at: Utc::now(),
role: r.role.clone(),
state_metadata: r.state_metadata.as_ref().map(|sm| StateMetadata {
step: sm.step,
active_plan_id: sm.active_plan_id.clone(),
tokens_used: sm.tokens_used,
custom: sm.custom.clone(),
}),
metadata: r.metadata.clone(),
relationships: r
.relationships
.iter()
.cloned()
.map(dto_to_relationship)
.collect(),
expires_at: r.expires_at,
retention_policy: r.retention_policy.clone(),
lifecycle_status: LIFECYCLE_ACTIVE.to_string(),
retired_at: None,
retired_reason: None,
supersedes_id: r.supersedes_id.clone(),
superseded_by_id: None,
content_type: r.content_type.clone(),
text_payload: r.text_payload.clone(),
binary_payload: r.binary_payload.clone(),
payload_uri: r.payload_uri.clone(),
payload_size: r.payload_size,
payload_checksum: r.payload_checksum.clone(),
embedding: r.embedding.clone(),
}
}
fn record_to_dto(r: ContextRecord) -> RecordDto {
RecordDto {
id: r.id,
external_id: r.external_id,
run_id: r.run_id,
bot_id: r.bot_id,
session_id: r.session_id,
tenant: r.tenant,
source: r.source,
created_at: r.created_at,
role: r.role,
content_type: r.content_type,
text_payload: r.text_payload,
binary_payload: r.binary_payload,
payload_uri: r.payload_uri,
payload_size: r.payload_size,
payload_checksum: r.payload_checksum,
embedding: r.embedding,
state_metadata: r.state_metadata.map(|sm| StateMetadataDto {
step: sm.step,
active_plan_id: sm.active_plan_id,
tokens_used: sm.tokens_used,
custom: sm.custom,
}),
metadata: r.metadata,
relationships: r
.relationships
.into_iter()
.map(relationship_to_dto)
.collect(),
expires_at: r.expires_at,
retention_policy: r.retention_policy,
lifecycle_status: r.lifecycle_status,
retired_at: r.retired_at,
retired_reason: r.retired_reason,
supersedes_id: r.supersedes_id,
superseded_by_id: r.superseded_by_id,
}
}
fn to_ctx_err(err: lance::Error) -> ContextError {
let msg = err.to_string();
if msg.contains("already in progress") {
ContextError::CompactionInProgress
} else if msg.contains("not found") || msg.contains("DatasetNotFound") {
ContextError::NotFound(msg)
} else if msg.contains("Invalid") {
ContextError::InvalidRequest(msg)
} else {
ContextError::Internal(msg)
}
}