use std::collections::BTreeSet;
use std::ops::ControlFlow;
use roaring::RoaringBitmap;
use super::interner::{Interned, MappedInterner};
use super::logger::SearchLogger;
use super::query_graph::QueryNode;
use super::ranking_rule_graph::{
ConditionDocIdsCache, DeadEndsCache, ExactnessGraph, FidGraph, PositionGraph, ProximityGraph,
RankingRuleGraph, RankingRuleGraphTrait, TypoGraph, WordsGraph,
};
use super::small_bitmap::SmallBitmap;
use super::{QueryGraph, RankingRule, RankingRuleOutput, SearchContext};
use crate::score_details::Rank;
use crate::search::new::query_term::LocatedQueryTermSubset;
use crate::search::new::ranking_rule_graph::PathVisitor;
use crate::{Result, TermsMatchingStrategy};
pub type Words = GraphBasedRankingRule<WordsGraph>;
impl GraphBasedRankingRule<WordsGraph> {
pub fn new(terms_matching_strategy: TermsMatchingStrategy) -> Self {
Self::new_with_id("words".to_owned(), Some(terms_matching_strategy))
}
}
pub type Proximity = GraphBasedRankingRule<ProximityGraph>;
impl GraphBasedRankingRule<ProximityGraph> {
pub fn new(terms_matching_strategy: Option<TermsMatchingStrategy>) -> Self {
Self::new_with_id("proximity".to_owned(), terms_matching_strategy)
}
}
pub type Fid = GraphBasedRankingRule<FidGraph>;
impl GraphBasedRankingRule<FidGraph> {
pub fn new(terms_matching_strategy: Option<TermsMatchingStrategy>) -> Self {
Self::new_with_id("fid".to_owned(), terms_matching_strategy)
}
}
pub type Position = GraphBasedRankingRule<PositionGraph>;
impl GraphBasedRankingRule<PositionGraph> {
pub fn new(terms_matching_strategy: Option<TermsMatchingStrategy>) -> Self {
Self::new_with_id("position".to_owned(), terms_matching_strategy)
}
}
pub type Typo = GraphBasedRankingRule<TypoGraph>;
impl GraphBasedRankingRule<TypoGraph> {
pub fn new(terms_matching_strategy: Option<TermsMatchingStrategy>) -> Self {
Self::new_with_id("typo".to_owned(), terms_matching_strategy)
}
}
pub type Exactness = GraphBasedRankingRule<ExactnessGraph>;
impl GraphBasedRankingRule<ExactnessGraph> {
pub fn new() -> Self {
Self::new_with_id("exactness".to_owned(), None)
}
}
pub struct GraphBasedRankingRule<G: RankingRuleGraphTrait> {
id: String,
terms_matching_strategy: Option<TermsMatchingStrategy>,
state: Option<GraphBasedRankingRuleState<G>>,
}
impl<G: RankingRuleGraphTrait> GraphBasedRankingRule<G> {
pub fn new_with_id(id: String, terms_matching_strategy: Option<TermsMatchingStrategy>) -> Self {
Self { id, terms_matching_strategy, state: None }
}
}
pub struct GraphBasedRankingRuleState<G: RankingRuleGraphTrait> {
graph: RankingRuleGraph<G>,
conditions_cache: ConditionDocIdsCache<G>,
dead_ends_cache: DeadEndsCache<G::Condition>,
all_costs: MappedInterner<QueryNode, Vec<u64>>,
cur_cost: u64,
next_max_cost: u64,
}
impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBasedRankingRule<G> {
fn id(&self) -> String {
self.id.clone()
}
#[tracing::instrument(level = "trace", skip_all, target = "search::graph_based")]
fn start_iteration(
&mut self,
ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<QueryGraph>,
_universe: &RoaringBitmap,
query_graph: &QueryGraph,
) -> Result<()> {
let mut next_max_cost = 1;
let removal_cost = if let Some(terms_matching_strategy) = self.terms_matching_strategy {
next_max_cost += query_graph
.words_in_phrases_count(ctx)
.saturating_sub(1) as u64;
match terms_matching_strategy {
TermsMatchingStrategy::Last => {
let removal_order =
query_graph.removal_order_for_terms_matching_strategy_last(ctx);
let mut forbidden_nodes =
SmallBitmap::for_interned_values_in(&query_graph.nodes);
let mut costs = query_graph.nodes.map(|_| None);
for ns in removal_order {
for n in ns.iter() {
*costs.get_mut(n) = Some((1, forbidden_nodes.clone()));
}
forbidden_nodes.union(&ns);
}
costs
}
TermsMatchingStrategy::Frequency => {
let removal_order =
query_graph.removal_order_for_terms_matching_strategy_frequency(ctx)?;
let mut forbidden_nodes =
SmallBitmap::for_interned_values_in(&query_graph.nodes);
let mut costs = query_graph.nodes.map(|_| None);
for ns in removal_order {
for n in ns.iter() {
*costs.get_mut(n) = Some((1, forbidden_nodes.clone()));
}
forbidden_nodes.union(&ns);
}
costs
}
TermsMatchingStrategy::All => query_graph.nodes.map(|_| None),
}
} else {
query_graph.nodes.map(|_| None)
};
let graph = RankingRuleGraph::build(ctx, query_graph.clone(), removal_cost)?;
let condition_docids_cache = ConditionDocIdsCache::default();
let dead_ends_cache = DeadEndsCache::new(&graph.conditions_interner);
let all_costs = graph.find_all_costs_to_end();
next_max_cost +=
all_costs.get(graph.query_graph.root_node).iter().copied().max().unwrap_or(0);
let state = GraphBasedRankingRuleState {
graph,
conditions_cache: condition_docids_cache,
dead_ends_cache,
all_costs,
cur_cost: 0,
next_max_cost,
};
self.state = Some(state);
Ok(())
}
#[tracing::instrument(level = "trace", skip_all, target = "search::graph_based")]
fn next_bucket(
&mut self,
ctx: &mut SearchContext<'ctx>,
logger: &mut dyn SearchLogger<QueryGraph>,
universe: &RoaringBitmap,
) -> Result<Option<RankingRuleOutput<QueryGraph>>> {
let mut state = self.state.take().unwrap();
let all_costs = state.all_costs.get(state.graph.query_graph.root_node);
let Some(&cost) = all_costs.iter().find(|c| **c >= state.cur_cost) else {
self.state = None;
return Ok(None);
};
state.cur_cost = cost + 1;
let mut bucket = RoaringBitmap::new();
let GraphBasedRankingRuleState {
graph,
conditions_cache: condition_docids_cache,
dead_ends_cache,
all_costs,
cur_cost: _,
next_max_cost,
} = &mut state;
let rank = *next_max_cost - cost;
let score = G::rank_to_score(Rank { rank: rank as u32, max_rank: *next_max_cost as u32 });
let mut universe = universe.clone();
let mut used_conditions = SmallBitmap::for_interned_values_in(&graph.conditions_interner);
let mut good_paths = vec![];
let mut considered_paths = vec![];
let mut subpaths_docids: Vec<(Interned<G::Condition>, RoaringBitmap)> = vec![];
let mut nodes_with_removed_outgoing_conditions = BTreeSet::new();
let visitor = PathVisitor::new(cost, graph, all_costs, dead_ends_cache);
visitor.visit_paths(&mut |path, graph, dead_ends_cache| {
considered_paths.push(path.to_vec());
if universe.is_empty() {
return Ok(ControlFlow::Break(()));
}
let idx_of_first_different_condition = {
let mut idx = 0;
for (&last_c, cur_c) in path.iter().zip(subpaths_docids.iter().map(|x| x.0)) {
if last_c == cur_c {
idx += 1;
} else {
break;
}
}
subpaths_docids.truncate(idx);
idx
};
for latest_condition in path[idx_of_first_different_condition..].iter().copied() {
let success = visit_path_condition(
ctx,
graph,
&universe,
dead_ends_cache,
condition_docids_cache,
&mut subpaths_docids,
&mut nodes_with_removed_outgoing_conditions,
latest_condition,
)?;
if !success {
return Ok(ControlFlow::Continue(()));
}
}
assert!(subpaths_docids.iter().map(|x| x.0).eq(path.iter().copied()));
let path_docids =
subpaths_docids.pop().map(|x| x.1).unwrap_or_else(|| universe.clone());
assert!(!path_docids.is_empty());
good_paths.push(path.to_vec());
for &condition in path {
used_conditions.insert(condition);
}
bucket |= &path_docids;
universe -= &path_docids;
for (_, docids) in subpaths_docids.iter_mut() {
*docids -= &path_docids;
}
if universe.is_empty() {
Ok(ControlFlow::Break(()))
} else {
Ok(ControlFlow::Continue(()))
}
})?;
logger.log_internal_state(graph);
logger.log_internal_state(&good_paths);
let paths: Vec<Vec<(Option<LocatedQueryTermSubset>, LocatedQueryTermSubset)>> = good_paths
.into_iter()
.map(|path| {
path.into_iter()
.map(|condition| {
let (a, b) =
condition_docids_cache.get_subsets_used_by_condition(condition);
(a.clone(), b.clone())
})
.collect()
})
.collect();
let next_query_graph = QueryGraph::build_from_paths(paths);
#[allow(clippy::comparison_chain)]
if nodes_with_removed_outgoing_conditions.len() == 1 {
graph.update_all_costs_before_node(
*nodes_with_removed_outgoing_conditions.first().unwrap(),
all_costs,
);
} else if nodes_with_removed_outgoing_conditions.len() > 1 {
*all_costs = graph.find_all_costs_to_end();
}
self.state = Some(state);
Ok(Some(RankingRuleOutput { query: next_query_graph, candidates: bucket, score }))
}
#[tracing::instrument(level = "trace", skip_all, target = "search::graph_based")]
fn end_iteration(
&mut self,
_ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<QueryGraph>,
) {
self.state = None;
}
}
#[allow(clippy::too_many_arguments)]
fn visit_path_condition<G: RankingRuleGraphTrait>(
ctx: &mut SearchContext<'_>,
graph: &mut RankingRuleGraph<G>,
universe: &RoaringBitmap,
dead_ends_cache: &mut DeadEndsCache<G::Condition>,
condition_docids_cache: &mut ConditionDocIdsCache<G>,
subpath: &mut Vec<(Interned<G::Condition>, RoaringBitmap)>,
nodes_with_removed_outgoing_conditions: &mut BTreeSet<Interned<QueryNode>>,
latest_condition: Interned<G::Condition>,
) -> Result<bool> {
let condition_docids = &condition_docids_cache
.get_computed_condition(ctx, latest_condition, graph, universe)?
.docids;
if condition_docids.is_empty() {
dead_ends_cache.forbid_condition(latest_condition);
let source_nodes = graph.remove_edges_with_condition(latest_condition);
nodes_with_removed_outgoing_conditions.extend(source_nodes);
return Ok(false);
}
let latest_path_docids = if let Some((_, prev_docids)) = subpath.last() {
prev_docids & condition_docids
} else {
condition_docids.clone()
};
if !latest_path_docids.is_empty() {
subpath.push((latest_condition, latest_path_docids));
return Ok(true);
}
dead_ends_cache.forbid_condition_after_prefix(subpath.iter().map(|x| x.0), latest_condition);
if subpath.len() <= 1 {
return Ok(false);
}
let mut subprefix = vec![];
for (past_condition, sp_docids) in subpath[..subpath.len() - 1].iter() {
subprefix.push(*past_condition);
if condition_docids.is_disjoint(sp_docids) {
dead_ends_cache
.forbid_condition_after_prefix(subprefix.iter().copied(), latest_condition);
}
}
Ok(false)
}