use std::{rc::Rc, sync::Arc};
use selene_core::{BindingTableId, CancellationChecker, DbString};
use selene_graph::{
CANDIDATE_STATE_PROVIDER_TAG, CompactionReport, CompactionStats, GraphResult, IndexProvider,
Mutator, ProviderError, ProviderTag, SeleneGraph, SharedGraph, VectorCandidateSet,
VectorCandidateStateInfo, VectorIndexMaintenancePolicy, VectorIndexRebuildReport,
};
use crate::{BindingTable, BindingTableRegistry, ImplDefinedCaps, ProcedureTier};
pub struct GraphContext<'a> {
snapshot: &'a SeleneGraph,
caps: &'a ImplDefinedCaps,
providers: &'a [Arc<dyn IndexProvider>],
cancellation: CancellationChecker<'a>,
binding_tables: Rc<BindingTableRegistry>,
}
impl<'a> GraphContext<'a> {
pub(crate) fn new(
snapshot: &'a SeleneGraph,
caps: &'a ImplDefinedCaps,
providers: &'a [Arc<dyn IndexProvider>],
cancellation: CancellationChecker<'a>,
binding_tables: Rc<BindingTableRegistry>,
) -> Self {
Self {
snapshot,
caps,
providers,
cancellation,
binding_tables,
}
}
#[must_use]
pub const fn snapshot(&self) -> &'a SeleneGraph {
self.snapshot
}
#[must_use]
pub const fn impl_defined_caps(&self) -> &'a ImplDefinedCaps {
self.caps
}
#[must_use]
pub fn index_provider_by_tag(&self, tag: ProviderTag) -> Option<Arc<dyn IndexProvider>> {
self.providers
.iter()
.find(|provider| provider.provider_tag() == tag)
.map(Arc::clone)
}
pub fn vector_candidate_set(
&self,
name: &DbString,
) -> Result<Option<VectorCandidateSet>, ProviderError> {
let Some(provider) = self.index_provider_by_tag(ProviderTag(CANDIDATE_STATE_PROVIDER_TAG))
else {
return Ok(None);
};
provider.vector_candidate_set(name, self.snapshot.meta.generation)
}
pub fn vector_candidate_state_infos(
&self,
) -> Result<Vec<VectorCandidateStateInfo>, ProviderError> {
let Some(provider) = self.index_provider_by_tag(ProviderTag(CANDIDATE_STATE_PROVIDER_TAG))
else {
return Ok(Vec::new());
};
provider.vector_candidate_state_infos(self.snapshot.meta.generation)
}
#[must_use]
pub const fn cancellation_checker(&self) -> CancellationChecker<'a> {
self.cancellation
}
pub fn register_binding_table(&self, table: Arc<BindingTable>) -> BindingTableId {
self.binding_tables.register(table)
}
}
pub struct MutationContext<'a, 'g> {
mutator: Mutator<'a, 'g>,
caps: &'a ImplDefinedCaps,
cancellation: CancellationChecker<'a>,
binding_tables: Rc<BindingTableRegistry>,
}
impl<'a, 'g> MutationContext<'a, 'g> {
pub(crate) fn new(
mutator: Mutator<'a, 'g>,
caps: &'a ImplDefinedCaps,
cancellation: CancellationChecker<'a>,
binding_tables: Rc<BindingTableRegistry>,
) -> Self {
Self {
mutator,
caps,
cancellation,
binding_tables,
}
}
#[cfg(any(test, feature = "test-harness"))]
#[must_use]
pub fn for_test(mutator: Mutator<'a, 'g>, caps: &'a ImplDefinedCaps) -> Self {
Self::new(
mutator,
caps,
CancellationChecker::disabled(),
Rc::new(BindingTableRegistry::new()),
)
}
#[must_use]
pub fn snapshot(&self) -> &SeleneGraph {
self.mutator.read()
}
pub fn mutator(&mut self) -> &mut Mutator<'a, 'g> {
&mut self.mutator
}
#[must_use]
pub fn index_provider_by_tag(&self, tag: ProviderTag) -> Option<Arc<dyn IndexProvider>> {
self.mutator.index_provider_by_tag(tag)
}
#[must_use]
pub const fn impl_defined_caps(&self) -> &'a ImplDefinedCaps {
self.caps
}
#[must_use]
pub const fn cancellation_checker(&self) -> CancellationChecker<'a> {
self.cancellation
}
pub fn register_binding_table(&self, table: Arc<BindingTable>) -> BindingTableId {
self.binding_tables.register(table)
}
}
pub struct MaintenanceContext<'a, 'g> {
graph: &'g SharedGraph,
caps: &'a ImplDefinedCaps,
cancellation: CancellationChecker<'a>,
binding_tables: Rc<BindingTableRegistry>,
}
impl<'a, 'g> MaintenanceContext<'a, 'g> {
pub(crate) fn new(
graph: &'g SharedGraph,
caps: &'a ImplDefinedCaps,
cancellation: CancellationChecker<'a>,
binding_tables: Rc<BindingTableRegistry>,
) -> Self {
Self {
graph,
caps,
cancellation,
binding_tables,
}
}
#[must_use]
pub const fn impl_defined_caps(&self) -> &'a ImplDefinedCaps {
self.caps
}
#[must_use]
pub const fn cancellation_checker(&self) -> CancellationChecker<'a> {
self.cancellation
}
pub fn rebuild_vector_indexes(&self) -> GraphResult<VectorIndexRebuildReport> {
self.graph.rebuild_vector_indexes()
}
pub fn rebuild_recommended_vector_indexes(&self) -> GraphResult<VectorIndexRebuildReport> {
self.graph.rebuild_recommended_vector_indexes()
}
pub fn maintain_vector_indexes(
&self,
policy: VectorIndexMaintenancePolicy,
) -> GraphResult<VectorIndexRebuildReport> {
self.graph.maintain_vector_indexes(policy)
}
#[must_use]
pub fn compaction_stats(&self) -> CompactionStats {
self.graph.compaction_stats()
}
pub fn compact(&self) -> GraphResult<CompactionReport> {
self.graph.compact()
}
pub fn register_binding_table(&self, table: Arc<BindingTable>) -> BindingTableId {
self.binding_tables.register(table)
}
}
#[non_exhaustive]
pub enum ProcedureContext<'a, 'g> {
Graph(GraphContext<'a>),
Mutation(MutationContext<'a, 'g>),
Maintenance(MaintenanceContext<'a, 'g>),
}
impl ProcedureContext<'_, '_> {
#[must_use]
pub const fn tier(&self) -> ProcedureTier {
match self {
Self::Graph(_) => ProcedureTier::Graph,
Self::Mutation(_) => ProcedureTier::Mutation,
Self::Maintenance(_) => ProcedureTier::Maintenance,
}
}
pub fn register_binding_table(&self, table: Arc<BindingTable>) -> BindingTableId {
match self {
Self::Graph(ctx) => ctx.register_binding_table(table),
Self::Mutation(ctx) => ctx.register_binding_table(table),
Self::Maintenance(ctx) => ctx.register_binding_table(table),
}
}
}