#![allow(clippy::unnecessary_literal_bound)]
use std::cmp::Ordering;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use tracing::{debug, info};
pub trait CollationFunction: Send + Sync {
fn name(&self) -> &str;
fn compare(&self, left: &[u8], right: &[u8]) -> Ordering;
}
pub struct BinaryCollation;
impl CollationFunction for BinaryCollation {
fn name(&self) -> &str {
"BINARY"
}
fn compare(&self, left: &[u8], right: &[u8]) -> Ordering {
left.cmp(right)
}
}
pub struct NoCaseCollation;
impl CollationFunction for NoCaseCollation {
fn name(&self) -> &str {
"NOCASE"
}
fn compare(&self, left: &[u8], right: &[u8]) -> Ordering {
let l = left.iter().map(u8::to_ascii_uppercase);
let r = right.iter().map(u8::to_ascii_uppercase);
l.cmp(r)
}
}
pub struct RtrimCollation;
impl CollationFunction for RtrimCollation {
fn name(&self) -> &str {
"RTRIM"
}
fn compare(&self, left: &[u8], right: &[u8]) -> Ordering {
let l = strip_trailing_spaces(left);
let r = strip_trailing_spaces(right);
l.cmp(r)
}
}
fn strip_trailing_spaces(s: &[u8]) -> &[u8] {
let mut end = s.len();
while end > 0 && s[end - 1] == b' ' {
end -= 1;
}
&s[..end]
}
fn builtin_collation(name: &str) -> Option<Arc<dyn CollationFunction>> {
type BuiltinCollations = (
Arc<dyn CollationFunction>,
Arc<dyn CollationFunction>,
Arc<dyn CollationFunction>,
);
static BUILTINS: OnceLock<BuiltinCollations> = OnceLock::new();
let (binary, nocase, rtrim) = BUILTINS.get_or_init(|| {
(
Arc::new(BinaryCollation) as Arc<dyn CollationFunction>,
Arc::new(NoCaseCollation) as Arc<dyn CollationFunction>,
Arc::new(RtrimCollation) as Arc<dyn CollationFunction>,
)
});
match name {
"BINARY" => Some(Arc::clone(binary)),
"NOCASE" => Some(Arc::clone(nocase)),
"RTRIM" => Some(Arc::clone(rtrim)),
_ => None,
}
}
#[derive(Clone)]
pub struct CollationRegistry {
custom_collations: HashMap<String, Arc<dyn CollationFunction>>,
}
impl std::fmt::Debug for CollationRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CollationRegistry")
.field("collations", &self.names())
.finish()
}
}
impl Default for CollationRegistry {
fn default() -> Self {
Self::new()
}
}
impl CollationRegistry {
#[must_use]
pub fn new() -> Self {
Self {
custom_collations: HashMap::new(),
}
}
pub fn register<C: CollationFunction + 'static>(
&mut self,
collation: C,
) -> Option<Arc<dyn CollationFunction>> {
let name = collation.name().to_ascii_uppercase();
info!(collation_name = %name, deterministic = true, "custom collation registration");
self.custom_collations
.insert(name.clone(), Arc::new(collation))
.or_else(|| builtin_collation(&name))
}
#[must_use]
pub fn find(&self, name: &str) -> Option<Arc<dyn CollationFunction>> {
let canon = name.to_ascii_uppercase();
let result = self
.custom_collations
.get(&canon)
.cloned()
.or_else(|| builtin_collation(&canon));
debug!(
collation = %canon,
hit = result.is_some(),
"collation registry lookup"
);
result
}
#[must_use]
pub fn contains(&self, name: &str) -> bool {
let canon = name.to_ascii_uppercase();
self.custom_collations.contains_key(&canon) || builtin_collation(&canon).is_some()
}
#[must_use]
pub fn names(&self) -> Vec<String> {
let mut names = vec!["BINARY".to_owned(), "NOCASE".to_owned(), "RTRIM".to_owned()];
let mut custom: Vec<String> = self
.custom_collations
.keys()
.filter(|name| !matches!(name.as_str(), "BINARY" | "NOCASE" | "RTRIM"))
.cloned()
.collect();
custom.sort_unstable_by_key(|name| name.to_ascii_uppercase());
names.extend(custom);
names
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CollationSource {
Explicit,
Schema,
Default,
}
#[derive(Debug, Clone)]
pub struct CollationAnnotation {
pub name: String,
pub source: CollationSource,
}
#[must_use]
pub fn resolve_collation(lhs: &CollationAnnotation, rhs: &CollationAnnotation) -> String {
let result = match (lhs.source, rhs.source) {
(_, CollationSource::Explicit) if lhs.source != CollationSource::Explicit => &rhs.name,
(CollationSource::Default, CollationSource::Schema) => &rhs.name,
_ => &lhs.name,
};
debug!(
collation = %result,
lhs_source = ?lhs.source,
rhs_source = ?rhs.source,
context = "COMPARE",
"collation selection"
);
result.clone()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_collation_binary_memcmp() {
let coll = BinaryCollation;
assert_eq!(coll.compare(b"abc", b"abc"), Ordering::Equal);
assert_eq!(coll.compare(b"abc", b"abd"), Ordering::Less);
assert_eq!(coll.compare(b"abd", b"abc"), Ordering::Greater);
assert_eq!(coll.compare(b"ABC", b"abc"), Ordering::Less);
assert_eq!(
coll.compare("café".as_bytes(), "café".as_bytes()),
Ordering::Equal
);
assert_ne!(coll.compare("über".as_bytes(), b"uber"), Ordering::Equal);
}
#[test]
fn test_collation_binary_basic() {
let coll = BinaryCollation;
assert_eq!(coll.compare(b"ABC", b"abc"), Ordering::Less);
assert_eq!(coll.compare(b"\x00", b"\x01"), Ordering::Less);
assert_eq!(coll.compare(b"\xff", b"\x00"), Ordering::Greater);
}
#[test]
fn test_collation_nocase_ascii() {
let coll = NoCaseCollation;
assert_eq!(coll.compare(b"ABC", b"abc"), Ordering::Equal);
assert_eq!(coll.compare(b"Alice", b"alice"), Ordering::Equal);
assert_eq!(coll.compare(b"[", b"a"), Ordering::Greater);
}
#[test]
fn test_collation_nocase_ascii_only() {
let coll = NoCaseCollation;
assert_ne!(
coll.compare("Ä".as_bytes(), "ä".as_bytes()),
Ordering::Equal,
"NOCASE must NOT fold non-ASCII"
);
assert_eq!(coll.compare(b"Z", b"z"), Ordering::Equal);
assert_eq!(coll.compare(b"[", b"["), Ordering::Equal);
assert_ne!(coll.compare(b"[", b"{"), Ordering::Equal);
}
#[test]
fn test_collation_rtrim() {
let coll = RtrimCollation;
assert_eq!(coll.compare(b"hello ", b"hello"), Ordering::Equal);
assert_eq!(coll.compare(b"hello", b"hello "), Ordering::Equal);
assert_eq!(coll.compare(b"hello ", b"hello "), Ordering::Equal);
assert_ne!(coll.compare(b"hello!", b"hello"), Ordering::Equal);
assert_ne!(coll.compare(b"hello ", b"hello!"), Ordering::Equal);
}
#[test]
fn test_collation_rtrim_tabs_not_stripped() {
let coll = RtrimCollation;
assert_ne!(
coll.compare(b"hello\t", b"hello"),
Ordering::Equal,
"RTRIM must NOT strip tabs"
);
assert_ne!(
coll.compare(b"hello\xc2\xa0", b"hello"),
Ordering::Equal,
"RTRIM must NOT strip non-breaking spaces"
);
}
#[test]
fn test_collation_properties_antisymmetric() {
let collations: Vec<Box<dyn CollationFunction>> = vec![
Box::new(BinaryCollation),
Box::new(NoCaseCollation),
Box::new(RtrimCollation),
];
let pairs: &[(&[u8], &[u8])] = &[
(b"abc", b"def"),
(b"hello", b"world"),
(b"ABC", b"abc"),
(b"hello ", b"hello"),
];
for coll in &collations {
for &(a, b) in pairs {
let forward = coll.compare(a, b);
let reverse = coll.compare(b, a);
assert_eq!(
forward,
reverse.reverse(),
"{}: compare({:?}, {:?}) = {forward:?}, but reverse = {reverse:?}",
coll.name(),
std::str::from_utf8(a).unwrap_or("?"),
std::str::from_utf8(b).unwrap_or("?"),
);
}
}
}
#[test]
fn test_collation_properties_transitive() {
let coll = BinaryCollation;
let a = b"apple";
let b = b"banana";
let c = b"cherry";
assert_eq!(coll.compare(a, b), Ordering::Less);
assert_eq!(coll.compare(b, c), Ordering::Less);
assert_eq!(coll.compare(a, c), Ordering::Less);
}
#[test]
fn test_collation_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<BinaryCollation>();
assert_send_sync::<NoCaseCollation>();
assert_send_sync::<RtrimCollation>();
}
#[test]
fn test_registry_preloaded_builtins() {
let reg = CollationRegistry::new();
assert!(reg.contains("BINARY"));
assert!(reg.contains("NOCASE"));
assert!(reg.contains("RTRIM"));
let binary = reg.find("BINARY").expect("BINARY must be pre-registered");
assert_eq!(binary.compare(b"a", b"b"), Ordering::Less);
let nocase = reg.find("NOCASE").expect("NOCASE must be pre-registered");
assert_eq!(nocase.compare(b"ABC", b"abc"), Ordering::Equal);
let rtrim = reg.find("RTRIM").expect("RTRIM must be pre-registered");
assert_eq!(rtrim.compare(b"x ", b"x"), Ordering::Equal);
}
struct ReverseCollation;
impl CollationFunction for ReverseCollation {
fn name(&self) -> &str {
"REVERSE"
}
fn compare(&self, left: &[u8], right: &[u8]) -> Ordering {
right.cmp(left)
}
}
#[test]
fn test_registry_custom_collation_registration() {
let mut reg = CollationRegistry::new();
let prev = reg.register(ReverseCollation);
assert!(prev.is_none(), "no prior REVERSE collation");
assert!(reg.contains("REVERSE"));
let coll = reg.find("reverse").expect("case-insensitive lookup");
assert_eq!(coll.compare(b"a", b"z"), Ordering::Greater);
}
struct AlwaysEqualCollation;
impl CollationFunction for AlwaysEqualCollation {
fn name(&self) -> &str {
"BINARY"
}
fn compare(&self, _left: &[u8], _right: &[u8]) -> Ordering {
Ordering::Equal
}
}
#[test]
fn test_registry_overwrite_builtin() {
let mut reg = CollationRegistry::new();
let prev = reg.register(AlwaysEqualCollation);
assert!(prev.is_some(), "should return previous BINARY collation");
let coll = reg.find("BINARY").unwrap();
assert_eq!(
coll.compare(b"a", b"z"),
Ordering::Equal,
"custom overwrite must take effect"
);
}
#[test]
fn test_registry_unregistered_returns_none() {
let reg = CollationRegistry::new();
assert!(reg.find("NONEXISTENT").is_none());
assert!(!reg.contains("NONEXISTENT"));
}
#[test]
fn test_registry_name_case_insensitive() {
let reg = CollationRegistry::new();
assert!(reg.find("BINARY").is_some());
assert!(reg.find("binary").is_some());
assert!(reg.find("Binary").is_some());
assert!(reg.find("bInArY").is_some());
assert!(reg.contains("nocase"));
assert!(reg.contains("NOCASE"));
assert!(reg.contains("NoCase"));
}
fn ann(name: &str, source: CollationSource) -> CollationAnnotation {
CollationAnnotation {
name: name.to_owned(),
source,
}
}
#[test]
fn test_collation_selection_explicit_wins() {
let result = resolve_collation(
&ann("NOCASE", CollationSource::Explicit),
&ann("BINARY", CollationSource::Default),
);
assert_eq!(result, "NOCASE");
}
#[test]
fn test_collation_selection_explicit_rhs_wins_over_default() {
let result = resolve_collation(
&ann("BINARY", CollationSource::Default),
&ann("RTRIM", CollationSource::Explicit),
);
assert_eq!(result, "RTRIM");
}
#[test]
fn test_collation_selection_leftmost_explicit_wins() {
let result = resolve_collation(
&ann("NOCASE", CollationSource::Explicit),
&ann("RTRIM", CollationSource::Explicit),
);
assert_eq!(result, "NOCASE");
}
#[test]
fn test_collation_selection_schema_over_default() {
let result = resolve_collation(
&ann("NOCASE", CollationSource::Schema),
&ann("BINARY", CollationSource::Default),
);
assert_eq!(result, "NOCASE");
}
#[test]
fn test_collation_selection_schema_rhs_over_default() {
let result = resolve_collation(
&ann("BINARY", CollationSource::Default),
&ann("NOCASE", CollationSource::Schema),
);
assert_eq!(result, "NOCASE");
}
#[test]
fn test_collation_selection_explicit_over_schema() {
let result = resolve_collation(
&ann("RTRIM", CollationSource::Explicit),
&ann("NOCASE", CollationSource::Schema),
);
assert_eq!(result, "RTRIM");
}
#[test]
fn test_collation_selection_default_binary() {
let result = resolve_collation(
&ann("BINARY", CollationSource::Default),
&ann("BINARY", CollationSource::Default),
);
assert_eq!(result, "BINARY");
}
#[test]
fn test_min_respects_collation() {
let binary = BinaryCollation;
let binary_min = if binary.compare(b"ABC", b"abc") == Ordering::Less {
"ABC"
} else {
"abc"
};
assert_eq!(binary_min, "ABC");
let nocase = NoCaseCollation;
assert_eq!(nocase.compare(b"ABC", b"abc"), Ordering::Equal);
}
#[test]
fn test_max_respects_collation() {
let binary = BinaryCollation;
let binary_max = if binary.compare(b"abc", b"ABC") == Ordering::Greater {
"abc"
} else {
"ABC"
};
assert_eq!(binary_max, "abc");
}
#[test]
fn test_collation_aware_sort() {
let nocase = NoCaseCollation;
let mut data: Vec<&[u8]> = vec![b"Banana", b"apple", b"Cherry", b"date"];
data.sort_by(|a, b| nocase.compare(a, b));
assert_eq!(data[0], b"apple");
assert_eq!(data[1], b"Banana");
assert_eq!(data[2], b"Cherry");
assert_eq!(data[3], b"date");
}
#[test]
fn test_collation_aware_group_by() {
let nocase = NoCaseCollation;
let items: Vec<&[u8]> = vec![b"ABC", b"abc", b"Abc", b"def", b"DEF"];
let mut groups: Vec<Vec<&[u8]>> = Vec::new();
let mut sorted = items;
sorted.sort_by(|a, b| nocase.compare(a, b));
let mut current_group: Vec<&[u8]> = vec![sorted[0]];
for window in sorted.windows(2) {
if nocase.compare(window[0], window[1]) != Ordering::Equal {
groups.push(std::mem::take(&mut current_group));
}
current_group.push(window[1]);
}
groups.push(current_group);
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].len(), 3);
assert_eq!(groups[1].len(), 2);
}
#[test]
fn test_collation_aware_distinct() {
let nocase = NoCaseCollation;
let items: Vec<&[u8]> = vec![b"ABC", b"abc", b"Abc", b"def", b"DEF"];
let mut distinct: Vec<&[u8]> = Vec::new();
for item in &items {
let already = distinct
.iter()
.any(|d| nocase.compare(d, item) == Ordering::Equal);
if !already {
distinct.push(item);
}
}
assert_eq!(distinct.len(), 2);
}
#[test]
fn test_registry_default_impl() {
let reg = CollationRegistry::default();
assert!(reg.contains("BINARY"));
assert!(reg.contains("NOCASE"));
assert!(reg.contains("RTRIM"));
}
#[test]
fn test_collation_annotation_debug() {
let ann = CollationAnnotation {
name: "NOCASE".to_owned(),
source: CollationSource::Explicit,
};
let debug_str = format!("{ann:?}");
assert!(debug_str.contains("NOCASE"));
assert!(debug_str.contains("Explicit"));
}
#[test]
fn test_collation_source_equality() {
assert_eq!(CollationSource::Explicit, CollationSource::Explicit);
assert_ne!(CollationSource::Explicit, CollationSource::Schema);
assert_ne!(CollationSource::Schema, CollationSource::Default);
}
}