use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use manifoldb_core::{CollectionId, EntityId, Value};
use manifoldb_vector::{Embedding, SearchResult, VectorData, VectorError};
pub trait VectorIndexProvider: Send + Sync {
fn search(
&self,
index_name: &str,
query: &Embedding,
k: usize,
ef_search: Option<usize>,
) -> Result<Vec<SearchResult>, VectorError>;
fn has_index(&self, index_name: &str) -> bool;
fn dimension(&self, index_name: &str) -> Option<usize>;
}
pub trait CollectionVectorProvider: Send + Sync {
fn upsert_vector(
&self,
collection_id: CollectionId,
entity_id: EntityId,
collection_name: &str,
vector_name: &str,
data: &VectorData,
) -> Result<(), VectorError>;
fn delete_vector(
&self,
collection_id: CollectionId,
entity_id: EntityId,
collection_name: &str,
vector_name: &str,
) -> Result<bool, VectorError>;
fn delete_entity_vectors(
&self,
collection_id: CollectionId,
entity_id: EntityId,
collection_name: &str,
) -> Result<usize, VectorError>;
fn get_vector(
&self,
collection_id: CollectionId,
entity_id: EntityId,
vector_name: &str,
) -> Result<Option<VectorData>, VectorError>;
fn get_all_vectors(
&self,
collection_id: CollectionId,
entity_id: EntityId,
) -> Result<std::collections::HashMap<String, VectorData>, VectorError>;
fn search(
&self,
collection_name: &str,
vector_name: &str,
query: &Embedding,
k: usize,
ef_search: Option<usize>,
) -> Result<Vec<SearchResult>, VectorError>;
}
use super::graph_accessor::{GraphAccessor, NullGraphAccessor};
pub struct ExecutionContext {
parameters: HashMap<u32, Value>,
cancelled: AtomicBool,
stats: ExecutionStats,
config: ExecutionConfig,
graph: Arc<dyn GraphAccessor>,
vector_index_provider: Option<Arc<dyn VectorIndexProvider>>,
collection_vector_provider: Option<Arc<dyn CollectionVectorProvider>>,
}
impl ExecutionContext {
#[must_use]
pub fn new() -> Self {
Self {
parameters: HashMap::new(),
cancelled: AtomicBool::new(false),
stats: ExecutionStats::new(),
config: ExecutionConfig::default(),
graph: Arc::new(NullGraphAccessor),
vector_index_provider: None,
collection_vector_provider: None,
}
}
#[must_use]
pub fn with_parameters(parameters: HashMap<u32, Value>) -> Self {
Self {
parameters,
cancelled: AtomicBool::new(false),
stats: ExecutionStats::new(),
config: ExecutionConfig::default(),
graph: Arc::new(NullGraphAccessor),
vector_index_provider: None,
collection_vector_provider: None,
}
}
#[must_use]
pub fn with_graph(mut self, graph: Arc<dyn GraphAccessor>) -> Self {
self.graph = graph;
self
}
#[inline]
#[must_use]
pub fn graph(&self) -> &dyn GraphAccessor {
self.graph.as_ref()
}
#[inline]
#[must_use]
pub fn graph_arc(&self) -> Arc<dyn GraphAccessor> {
Arc::clone(&self.graph)
}
#[must_use]
pub fn with_vector_index_provider(mut self, provider: Arc<dyn VectorIndexProvider>) -> Self {
self.vector_index_provider = Some(provider);
self
}
pub fn set_vector_index_provider(&mut self, provider: Arc<dyn VectorIndexProvider>) {
self.vector_index_provider = Some(provider);
}
#[must_use]
pub fn vector_index_provider(&self) -> Option<&dyn VectorIndexProvider> {
self.vector_index_provider.as_deref()
}
#[must_use]
pub fn vector_index_provider_arc(&self) -> Option<Arc<dyn VectorIndexProvider>> {
self.vector_index_provider.clone()
}
#[must_use]
pub fn with_collection_vector_provider(
mut self,
provider: Arc<dyn CollectionVectorProvider>,
) -> Self {
self.collection_vector_provider = Some(provider);
self
}
pub fn set_collection_vector_provider(&mut self, provider: Arc<dyn CollectionVectorProvider>) {
self.collection_vector_provider = Some(provider);
}
#[must_use]
pub fn collection_vector_provider(&self) -> Option<&dyn CollectionVectorProvider> {
self.collection_vector_provider.as_deref()
}
#[must_use]
pub fn collection_vector_provider_arc(&self) -> Option<Arc<dyn CollectionVectorProvider>> {
self.collection_vector_provider.clone()
}
pub fn set_parameter(&mut self, index: u32, value: Value) {
self.parameters.insert(index, value);
}
#[inline]
#[must_use]
pub fn get_parameter(&self, index: u32) -> Option<&Value> {
self.parameters.get(&index)
}
#[inline]
#[must_use]
pub fn parameters(&self) -> &HashMap<u32, Value> {
&self.parameters
}
#[inline]
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::SeqCst);
}
#[inline]
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::SeqCst)
}
#[inline]
#[must_use]
pub fn stats(&self) -> &ExecutionStats {
&self.stats
}
#[inline]
#[must_use]
pub fn config(&self) -> &ExecutionConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut ExecutionConfig {
&mut self.config
}
#[inline]
pub fn record_rows_read(&self, count: u64) {
self.stats.rows_read.fetch_add(count, Ordering::Relaxed);
}
#[inline]
pub fn record_rows_produced(&self, count: u64) {
self.stats.rows_produced.fetch_add(count, Ordering::Relaxed);
}
#[inline]
pub fn record_rows_filtered(&self, count: u64) {
self.stats.rows_filtered.fetch_add(count, Ordering::Relaxed);
}
#[must_use]
pub fn with_config(mut self, config: ExecutionConfig) -> Self {
self.config = config;
self
}
#[inline]
#[must_use]
pub fn max_rows_in_memory(&self) -> usize {
self.config.max_rows_in_memory
}
}
impl Default for ExecutionContext {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for ExecutionContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExecutionContext")
.field("parameters", &self.parameters)
.field("cancelled", &self.cancelled)
.field("stats", &self.stats)
.field("config", &self.config)
.field("graph", &"<GraphAccessor>")
.field("vector_index_provider", &self.vector_index_provider.is_some())
.finish_non_exhaustive()
}
}
#[derive(Debug)]
pub struct ExecutionStats {
start_time: Instant,
rows_read: AtomicU64,
rows_produced: AtomicU64,
rows_filtered: AtomicU64,
}
impl ExecutionStats {
#[must_use]
pub fn new() -> Self {
Self {
start_time: Instant::now(),
rows_read: AtomicU64::new(0),
rows_produced: AtomicU64::new(0),
rows_filtered: AtomicU64::new(0),
}
}
#[inline]
#[must_use]
pub fn rows_read(&self) -> u64 {
self.rows_read.load(Ordering::Relaxed)
}
#[inline]
#[must_use]
pub fn rows_produced(&self) -> u64 {
self.rows_produced.load(Ordering::Relaxed)
}
#[inline]
#[must_use]
pub fn rows_filtered(&self) -> u64 {
self.rows_filtered.load(Ordering::Relaxed)
}
#[inline]
#[must_use]
pub fn elapsed(&self) -> std::time::Duration {
self.start_time.elapsed()
}
}
impl Default for ExecutionStats {
fn default() -> Self {
Self::new()
}
}
pub const DEFAULT_MAX_ROWS_IN_MEMORY: usize = 1_000_000;
#[derive(Debug, Clone)]
pub struct ExecutionConfig {
pub max_batch_size: usize,
pub collect_stats: bool,
pub memory_limit: usize,
pub max_rows_in_memory: usize,
}
impl ExecutionConfig {
#[must_use]
pub const fn new() -> Self {
Self {
max_batch_size: 1024,
collect_stats: false,
memory_limit: 0,
max_rows_in_memory: DEFAULT_MAX_ROWS_IN_MEMORY,
}
}
#[must_use]
pub const fn with_batch_size(mut self, size: usize) -> Self {
self.max_batch_size = size;
self
}
#[must_use]
pub const fn with_stats(mut self) -> Self {
self.collect_stats = true;
self
}
#[must_use]
pub const fn with_memory_limit(mut self, limit: usize) -> Self {
self.memory_limit = limit;
self
}
#[must_use]
pub const fn with_max_rows_in_memory(mut self, limit: usize) -> Self {
self.max_rows_in_memory = limit;
self
}
}
impl Default for ExecutionConfig {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CancellationToken {
cancelled: Arc<AtomicBool>,
}
impl CancellationToken {
#[must_use]
pub fn new() -> Self {
Self { cancelled: Arc::new(AtomicBool::new(false)) }
}
#[inline]
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::SeqCst);
}
#[inline]
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::SeqCst)
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn context_parameters() {
let mut ctx = ExecutionContext::new();
ctx.set_parameter(1, Value::Int(42));
ctx.set_parameter(2, Value::from("hello"));
assert_eq!(ctx.get_parameter(1), Some(&Value::Int(42)));
assert_eq!(ctx.get_parameter(2), Some(&Value::from("hello")));
assert_eq!(ctx.get_parameter(3), None);
}
#[test]
fn context_cancellation() {
let ctx = ExecutionContext::new();
assert!(!ctx.is_cancelled());
ctx.cancel();
assert!(ctx.is_cancelled());
}
#[test]
fn context_stats() {
let ctx = ExecutionContext::new();
ctx.record_rows_read(100);
ctx.record_rows_produced(50);
ctx.record_rows_filtered(50);
assert_eq!(ctx.stats().rows_read(), 100);
assert_eq!(ctx.stats().rows_produced(), 50);
assert_eq!(ctx.stats().rows_filtered(), 50);
}
#[test]
fn cancellation_token() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
let token2 = token.clone();
token.cancel();
assert!(token.is_cancelled());
assert!(token2.is_cancelled());
}
}