use crate::ascii::{
AsciiResult, compare_ascii_primary_non_ignorable, fill_codepoints_and_compare_ascii,
};
use crate::cea::{LazyPrimaryResult, compare_primary_streaming, compare_primary_streaming_utf8};
use crate::consts::{ARABIC_INTERLEAVED, ARABIC_SCRIPT, CLDR_ROOT, DUCET, LOW_CLDR, LOW_DUCET};
use crate::first_weight::try_initial;
use crate::normalize::make_nfd;
use crate::prefix::{find_byte_prefix, find_prefix_shifted};
use crate::sort_key::compare_incremental;
use crate::tables::CollationTable;
use crate::{Locale, Tailoring};
use bstr::{B, ByteSlice};
use std::cmp::Ordering;
const LAZY_UTF8_PRIMARY_MIN_COMBINED_BYTES: usize = 64;
#[cfg(feature = "pipeline-stats")]
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PipelineStats {
pub comparisons: u64,
pub equal_early: u64,
pub byte_prefix_trimmed: u64,
pub byte_prefix_bytes_trimmed: u64,
pub ascii_primary_resolved: u64,
pub lazy_utf8_primary_attempts: u64,
pub lazy_utf8_primary_resolved: u64,
pub lazy_utf8_full_fallback: u64,
pub fill_ascii_resolved: u64,
pub nfd_normalizations: u64,
pub codepoints_decoded: u64,
pub codepoint_prefix_trimmed: u64,
pub codepoint_prefix_codepoints_trimmed: u64,
pub initial_primary_resolved: u64,
pub streaming_primary_resolved: u64,
pub codepoints_consumed_primary: u64,
pub later_levels_reached: u64,
pub later_levels_resolved: u64,
pub tiebreak_resolved: u64,
}
#[cfg(feature = "pipeline-stats")]
impl PipelineStats {
const ZEROED: Self = Self {
comparisons: 0,
equal_early: 0,
byte_prefix_trimmed: 0,
byte_prefix_bytes_trimmed: 0,
ascii_primary_resolved: 0,
lazy_utf8_primary_attempts: 0,
lazy_utf8_primary_resolved: 0,
lazy_utf8_full_fallback: 0,
fill_ascii_resolved: 0,
nfd_normalizations: 0,
codepoints_decoded: 0,
codepoint_prefix_trimmed: 0,
codepoint_prefix_codepoints_trimmed: 0,
initial_primary_resolved: 0,
streaming_primary_resolved: 0,
codepoints_consumed_primary: 0,
later_levels_reached: 0,
later_levels_resolved: 0,
tiebreak_resolved: 0,
};
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
pub struct Collator {
pub tailoring: Tailoring,
pub shifting: bool,
pub tiebreak: bool,
a_chars: Vec<u32>,
b_chars: Vec<u32>,
a_cea: Vec<u32>,
b_cea: Vec<u32>,
#[cfg(feature = "pipeline-stats")]
stats: PipelineStats,
}
impl Default for Collator {
fn default() -> Self {
Self::new(Tailoring::default(), true, true)
}
}
impl Collator {
#[must_use]
pub fn new(tailoring: Tailoring, shifting: bool, tiebreak: bool) -> Self {
Self {
tailoring,
shifting,
tiebreak,
a_chars: Vec::new(),
b_chars: Vec::new(),
a_cea: vec![0; 64],
b_cea: vec![0; 64],
#[cfg(feature = "pipeline-stats")]
stats: PipelineStats::ZEROED,
}
}
#[cfg(feature = "pipeline-stats")]
#[must_use]
pub const fn stats(&self) -> &PipelineStats {
&self.stats
}
#[cfg(feature = "pipeline-stats")]
pub const fn clear_stats(&mut self) {
self.stats = PipelineStats::ZEROED;
}
#[allow(clippy::too_many_lines)]
pub fn collate<T: AsRef<[u8]> + Ord + ?Sized>(&mut self, a: &T, b: &T) -> Ordering {
#[cfg(feature = "pipeline-stats")]
{
self.stats.comparisons += 1;
}
if a == b {
#[cfg(feature = "pipeline-stats")]
{
self.stats.equal_early += 1;
}
return Ordering::Equal;
}
let a_bytes = a.as_ref();
let b_bytes = b.as_ref();
let mut ctx = None;
let byte_offset = if has_byte_prefix(a_bytes, b_bytes) {
let current_ctx =
ctx.get_or_insert_with(|| CollationContext::new(self.shifting, self.tailoring));
find_byte_prefix(a_bytes, b_bytes, current_ctx)
} else {
0
};
#[cfg(feature = "pipeline-stats")]
if byte_offset > 0 {
self.stats.byte_prefix_trimmed += 1;
self.stats.byte_prefix_bytes_trimmed += u64::try_from(byte_offset).unwrap_or(u64::MAX);
}
let a_bytes = &a_bytes[byte_offset..];
let b_bytes = &b_bytes[byte_offset..];
if !self.shifting {
let current_ctx =
ctx.get_or_insert_with(|| CollationContext::new(self.shifting, self.tailoring));
if let Some(comparison) =
compare_ascii_primary_non_ignorable(a_bytes, b_bytes, current_ctx.low)
{
#[cfg(feature = "pipeline-stats")]
{
self.stats.ascii_primary_resolved += 1;
}
return comparison;
}
}
if a_bytes.len() + b_bytes.len() >= LAZY_UTF8_PRIMARY_MIN_COMBINED_BYTES {
let current_ctx =
ctx.get_or_insert_with(|| CollationContext::new(self.shifting, self.tailoring));
#[cfg(feature = "pipeline-stats")]
{
self.stats.lazy_utf8_primary_attempts += 1;
}
match compare_primary_streaming_utf8(
&mut self.a_cea,
&mut self.b_cea,
a_bytes,
b_bytes,
current_ctx,
) {
LazyPrimaryResult::Decided(comparison) => {
#[cfg(feature = "pipeline-stats")]
{
self.stats.lazy_utf8_primary_resolved += 1;
}
return comparison;
}
LazyPrimaryResult::NeedsFullFallback => {
#[cfg(feature = "pipeline-stats")]
{
self.stats.lazy_utf8_full_fallback += 1;
}
}
}
}
let mut a_iter = B(a_bytes).chars().map(|c| c as u32);
let mut b_iter = B(b_bytes).chars().map(|c| c as u32);
self.a_chars.clear();
self.b_chars.clear();
let ascii_result = fill_codepoints_and_compare_ascii(
&mut a_iter,
&mut b_iter,
&mut self.a_chars,
&mut self.b_chars,
);
#[cfg(feature = "pipeline-stats")]
{
self.stats.codepoints_decoded +=
u64::try_from(self.a_chars.len() + self.b_chars.len()).unwrap_or(u64::MAX);
}
let (a_needs_nfd, b_needs_nfd) = match ascii_result {
AsciiResult::Done(o) => {
#[cfg(feature = "pipeline-stats")]
{
self.stats.fill_ascii_resolved += 1;
}
return o;
}
AsciiResult::Continue {
a_needs_nfd,
b_needs_nfd,
} => (a_needs_nfd, b_needs_nfd),
};
if a_needs_nfd {
#[cfg(feature = "pipeline-stats")]
{
self.stats.nfd_normalizations += 1;
}
make_nfd(&mut self.a_chars);
}
if b_needs_nfd {
#[cfg(feature = "pipeline-stats")]
{
self.stats.nfd_normalizations += 1;
}
make_nfd(&mut self.b_chars);
}
let ctx = ctx.get_or_insert_with(|| CollationContext::new(self.shifting, self.tailoring));
let offset = if self.shifting {
find_prefix_shifted(&self.a_chars, &self.b_chars, ctx)
} else {
0
};
#[cfg(feature = "pipeline-stats")]
if offset > 0 {
self.stats.codepoint_prefix_trimmed += 1;
self.stats.codepoint_prefix_codepoints_trimmed +=
u64::try_from(offset).unwrap_or(u64::MAX);
}
if self.a_chars[offset..].is_empty() || self.b_chars[offset..].is_empty() {
return self.a_chars.len().cmp(&self.b_chars.len());
}
if let Some(o) = try_initial(ctx, &self.a_chars[offset..], &self.b_chars[offset..]) {
#[cfg(feature = "pipeline-stats")]
{
self.stats.initial_primary_resolved += 1;
}
return o;
}
if let Some(comparison) = compare_primary_streaming(
&mut self.a_cea,
&mut self.b_cea,
&mut self.a_chars,
&mut self.b_chars,
ctx,
offset,
#[cfg(feature = "pipeline-stats")]
&mut self.stats,
) {
#[cfg(feature = "pipeline-stats")]
{
self.stats.streaming_primary_resolved += 1;
}
return comparison;
}
#[cfg(feature = "pipeline-stats")]
{
self.stats.later_levels_reached += 1;
}
let comparison = compare_incremental(&self.a_cea, &self.b_cea, ctx.shifting);
if comparison == Ordering::Equal && self.tiebreak {
#[cfg(feature = "pipeline-stats")]
{
self.stats.tiebreak_resolved += 1;
}
return a.cmp(b);
}
#[cfg(feature = "pipeline-stats")]
if comparison != Ordering::Equal {
self.stats.later_levels_resolved += 1;
}
comparison
}
}
fn has_byte_prefix(a: &[u8], b: &[u8]) -> bool {
a.first().zip(b.first()).is_some_and(|(x, y)| x == y)
}
pub struct CollationContext {
pub shifting: bool,
pub cldr: bool,
pub table: &'static CollationTable,
pub low: &'static [u32],
}
impl CollationContext {
fn new(shifting: bool, tailoring: Tailoring) -> Self {
let cldr = tailoring != Tailoring::Ducet;
Self {
shifting,
cldr,
table: get_collation_table(tailoring),
low: if cldr { &LOW_CLDR } else { &LOW_DUCET },
}
}
}
fn get_collation_table(tailoring: Tailoring) -> &'static CollationTable {
match tailoring {
Tailoring::Cldr(Locale::ArabicScript) => &ARABIC_SCRIPT,
Tailoring::Cldr(Locale::ArabicInterleaved) => &ARABIC_INTERLEAVED,
Tailoring::Cldr(Locale::Root) => &CLDR_ROOT,
Tailoring::Ducet => &DUCET,
}
}