use crate::index::{
Document, FieldType, IndexArc, Shard, hash64, object_values_to_string_vec_recursive,
};
use crate::min_heap::{self, MinHeap};
use aho_corasick::{AhoCorasick, MatchKind};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use utoipa::ToSchema;
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct Highlight {
pub field: String,
#[serde(default)]
#[serde(skip_serializing_if = "String::is_empty")]
pub name: String,
#[serde(default)]
pub fragment_number: usize,
#[serde(default)]
pub fragment_size: usize,
#[serde(default)]
pub highlight_markup: bool,
#[serde(default = "default_pre_tag")]
pub pre_tags: String,
#[serde(default = "default_post_tag")]
pub post_tags: String,
}
impl Default for Highlight {
fn default() -> Self {
Highlight {
field: String::new(),
name: String::new(),
fragment_number: 1,
fragment_size: usize::MAX,
highlight_markup: true,
pre_tags: default_pre_tag(),
post_tags: default_post_tag(),
}
}
}
fn default_pre_tag() -> String {
"<b>".into()
}
fn default_post_tag() -> String {
"</b>".into()
}
#[derive(Debug)]
pub struct Highlighter {
pub(crate) highlights: Vec<Highlight>,
pub(crate) query_terms_ac: AhoCorasick,
}
pub async fn highlighter(
index_arc: &IndexArc,
highlights: Vec<Highlight>,
query_terms_vec: Vec<String>,
) -> Highlighter {
let index_ref = index_arc.read().await;
let query_terms = if !index_ref.synonyms_map.is_empty() {
let mut query_terms_vec_mut = query_terms_vec.clone();
for query_term in query_terms_vec.iter() {
let term_hash = hash64(query_term.to_lowercase().as_bytes());
if let Some(synonyms) = index_ref.synonyms_map.get(&term_hash) {
for synonym in synonyms.iter() {
query_terms_vec_mut.push(synonym.0.clone());
}
}
}
query_terms_vec_mut
} else {
query_terms_vec
};
let query_terms_ac = AhoCorasick::builder()
.ascii_case_insensitive(true)
.match_kind(MatchKind::LeftmostLongest)
.build(query_terms)
.unwrap();
Highlighter {
highlights,
query_terms_ac,
}
}
pub(crate) fn add_fragment<'a>(
no_score_no_highlight: bool,
mut fragment: Fragment<'a>,
query_terms_ac: &AhoCorasick,
fragments: &mut Vec<Fragment<'a>>,
topk_candidates: &mut MinHeap,
fragment_number: usize,
fragment_size: usize,
) {
let mut score = 0.0;
let mut expected_pattern = usize::MAX;
let mut expected_index = usize::MAX;
let mut first_end = 0;
let mut set = vec![0; query_terms_ac.patterns_len()];
let mut sequence_length = 1;
if no_score_no_highlight {
score = 1.0;
} else {
for mat in query_terms_ac.find_iter(fragment.text) {
if first_end == 0 {
first_end = mat.end();
}
let id = mat.pattern().as_usize();
score += if id == expected_pattern && expected_index == mat.start() {
sequence_length += 1;
set[id] = 1;
sequence_length as f32 * 5.0
} else if set[id] == 0 {
sequence_length = 1;
set[id] = 1;
1.0
} else {
sequence_length = 1;
0.3
};
expected_pattern = id + 1;
expected_index = mat.end() + 1;
}
}
if first_end > fragment_size {
let mut idx = fragment.text.len() - fragment_size;
while !fragment.text.is_char_boundary(idx) {
idx -= 1;
}
match fragment.text[idx..].find(' ') {
None => idx = 0,
Some(value) => idx += value,
}
let adjusted_fragment = &fragment.text[idx..];
fragment.text = adjusted_fragment;
fragment.trim_left = true;
} else if fragment.text.len() > fragment_size {
let mut idx = fragment_size;
while !fragment.text.is_char_boundary(idx) {
idx -= 1;
}
match fragment.text[idx..].find(' ') {
None => idx = fragment.text.len(),
Some(value) => idx += value,
}
let adjusted_fragment = &fragment.text[..idx];
fragment.text = adjusted_fragment;
fragment.trim_right = true;
}
let section_index = fragments.len();
let mut added = false;
if score > 0.0 {
added = topk_candidates.add_topk(
min_heap::Result {
doc_id: section_index,
score,
..Default::default()
},
fragment_number,
);
}
if fragments.is_empty() || added {
fragments.push(fragment);
}
}
const SENTENCE_BOUNDARY_CHARS: [char; 11] =
['!', '?', '.', '¿', '¡', '。', '、', '!', '?', '︒', '。'];
pub(crate) struct Fragment<'a> {
text: &'a str,
trim_left: bool,
trim_right: bool,
}
pub(crate) fn top_fragments_from_field(
shard: &Shard,
document: &Document,
query_terms_ac: &AhoCorasick,
highlight: &Highlight,
) -> Result<String, String> {
match document.get(&highlight.field) {
None => Ok("".to_string()),
Some(value) => {
let no_score_no_highlight =
query_terms_ac.patterns_len() == 1 && query_terms_ac.max_pattern_len() == 1;
let no_fragmentation = highlight.fragment_number == 0;
let fragment_number = if no_fragmentation {
1
} else {
highlight.fragment_number
};
let result_sort = Vec::new();
let mut topk_candidates = MinHeap::new(fragment_number, shard, false, &result_sort);
if let Some(schema_field) = shard.schema_map.get(&highlight.field) {
let text = match schema_field.field_type {
FieldType::Json => {
if matches!(value, Value::Object { .. }) {
let mut strings_vec: Vec<String> = Vec::new();
object_values_to_string_vec_recursive(value, &mut strings_vec);
strings_vec.join(" ")
} else {
serde_json::from_value::<String>(value.clone())
.unwrap_or(value.to_string())
}
}
FieldType::Text | FieldType::String16 | FieldType::String32 => {
serde_json::from_value::<String>(value.clone()).unwrap_or(value.to_string())
}
_ => value.to_string(),
};
let mut fragments: Vec<Fragment> = Vec::new();
let mut last = 0;
if !no_fragmentation {
for (character_index, matched) in
text.match_indices(&SENTENCE_BOUNDARY_CHARS[..])
{
if last != character_index {
let section = Fragment {
text: &text[last..character_index + matched.len()],
trim_left: false,
trim_right: false,
};
add_fragment(
no_score_no_highlight,
section,
query_terms_ac,
&mut fragments,
&mut topk_candidates,
fragment_number,
highlight.fragment_size,
);
if no_score_no_highlight
&& topk_candidates.current_heap_size == fragment_number
{
break;
}
}
last = character_index + matched.len();
}
}
if last + 1 < text.len() {
let section = Fragment {
text: &text[last..],
trim_left: false,
trim_right: false,
};
add_fragment(
no_score_no_highlight,
section,
query_terms_ac,
&mut fragments,
&mut topk_candidates,
fragment_number,
highlight.fragment_size,
);
}
let mut combined_string = String::with_capacity(text.len());
if !fragments.is_empty() {
if topk_candidates.current_heap_size > 0 {
if topk_candidates.current_heap_size < fragment_number {
topk_candidates
._elements
.truncate(topk_candidates.current_heap_size);
}
topk_candidates
._elements
.sort_by(|a, b| a.doc_id.partial_cmp(&b.doc_id).unwrap());
let mut previous_docid = 0;
for candidate in topk_candidates._elements {
if (!combined_string.is_empty()
&& !combined_string.ends_with("...")
&& candidate.doc_id != previous_docid + 1)
|| (fragments[candidate.doc_id].trim_left
&& (combined_string.is_empty()
|| !combined_string.ends_with("...")))
{
combined_string.push_str("...")
};
combined_string.push_str(fragments[candidate.doc_id].text);
previous_docid = candidate.doc_id;
if fragments[candidate.doc_id].trim_right {
combined_string.push_str("...")
};
}
} else {
combined_string.push_str(fragments[0].text);
}
}
if highlight.highlight_markup && !no_score_no_highlight {
highlight_terms(
&mut combined_string,
query_terms_ac,
&highlight.pre_tags,
&highlight.post_tags,
);
}
Ok(combined_string)
} else {
Ok("".to_string())
}
}
}
}
pub(crate) fn highlight_terms(
text: &mut String,
query_terms_ac: &AhoCorasick,
pre_tags: &str,
post_tags: &str,
) {
let mut result = String::new();
let mut prev_end = 0;
for mat in query_terms_ac.find_iter(&text) {
result.push_str(&text[prev_end..mat.start()]);
result.push_str(pre_tags);
result.push_str(&text[mat.start()..mat.end()]);
result.push_str(post_tags);
prev_end = mat.end();
}
if prev_end < text.len() {
result.push_str(&text[prev_end..text.len()]);
}
*text = result;
}