1use std::cmp::Ordering;
7use std::env;
8use std::ffi::CString;
9use std::sync::OnceLock;
10
11static LOCALE_CONFIG: OnceLock<LocaleConfig> = OnceLock::new();
13
14#[derive(Debug, Clone)]
16pub struct LocaleConfig {
17 pub enabled: bool,
19 pub locale_name: String,
21 pub is_utf8: bool,
23}
24
25impl LocaleConfig {
26 pub fn init() -> Self {
28 let locale = env::var("LC_COLLATE")
30 .or_else(|_| env::var("LC_ALL"))
31 .or_else(|_| env::var("LANG"))
32 .unwrap_or_else(|_| "C".to_string());
33
34 let enabled = !locale.is_empty() && locale != "C" && locale != "POSIX";
36 let is_utf8 = locale.contains("UTF-8") || locale.contains("utf8");
37
38 if enabled {
40 unsafe {
41 let locale_cstr =
42 CString::new(locale.clone()).unwrap_or_else(|_| CString::new("C").unwrap());
43 libc::setlocale(libc::LC_COLLATE, locale_cstr.as_ptr());
44 }
45 }
46
47 Self {
48 enabled,
49 locale_name: locale,
50 is_utf8,
51 }
52 }
53
54 pub fn get() -> &'static LocaleConfig {
56 LOCALE_CONFIG.get_or_init(Self::init)
57 }
58
59 pub fn is_enabled() -> bool {
61 Self::get().enabled
62 }
63}
64
65pub fn strcoll_compare(a: &[u8], b: &[u8]) -> Ordering {
67 if a == b {
69 return Ordering::Equal;
70 }
71
72 let a_str = match std::str::from_utf8(a) {
75 Ok(s) => s,
76 Err(_) => {
77 return a.cmp(b);
79 }
80 };
81
82 let b_str = match std::str::from_utf8(b) {
83 Ok(s) => s,
84 Err(_) => {
85 return a.cmp(b);
87 }
88 };
89
90 let a_cstr = match CString::new(a_str) {
92 Ok(s) => s,
93 Err(_) => {
94 return a.cmp(b);
96 }
97 };
98
99 let b_cstr = match CString::new(b_str) {
100 Ok(s) => s,
101 Err(_) => {
102 return a.cmp(b);
104 }
105 };
106
107 unsafe {
109 let result = libc::strcoll(a_cstr.as_ptr(), b_cstr.as_ptr());
110 match result {
111 x if x < 0 => Ordering::Less,
112 x if x > 0 => Ordering::Greater,
113 _ => Ordering::Equal,
114 }
115 }
116}
117
118pub fn strcasecoll_compare(a: &[u8], b: &[u8]) -> Ordering {
121 if a == b {
123 return Ordering::Equal;
124 }
125
126 let a_str = match std::str::from_utf8(a) {
128 Ok(s) => s,
129 Err(_) => return case_insensitive_byte_compare(a, b),
130 };
131
132 let b_str = match std::str::from_utf8(b) {
133 Ok(s) => s,
134 Err(_) => return case_insensitive_byte_compare(a, b),
135 };
136
137 let a_lower = a_str.to_lowercase();
139 let b_lower = b_str.to_lowercase();
140
141 strcoll_compare(a_lower.as_bytes(), b_lower.as_bytes())
143}
144
145fn case_insensitive_byte_compare(a: &[u8], b: &[u8]) -> Ordering {
147 let len = a.len().min(b.len());
148
149 for i in 0..len {
150 let ca = a[i].to_ascii_lowercase();
151 let cb = b[i].to_ascii_lowercase();
152 match ca.cmp(&cb) {
153 Ordering::Equal => continue,
154 other => return other,
155 }
156 }
157
158 a.len().cmp(&b.len())
159}
160
161pub fn smart_compare(a: &[u8], b: &[u8], ignore_case: bool) -> Ordering {
163 if LocaleConfig::is_enabled() {
164 if ignore_case {
165 strcasecoll_compare(a, b)
166 } else {
167 strcoll_compare(a, b)
168 }
169 } else {
170 if ignore_case {
172 case_insensitive_byte_compare(a, b)
173 } else {
174 a.cmp(b)
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_c_locale() {
185 let original = env::var("LC_COLLATE").ok();
187
188 env::set_var("LC_COLLATE", "C");
189 let config = LocaleConfig::init();
190 assert!(!config.enabled);
191 assert_eq!(config.locale_name, "C");
192
193 if let Some(val) = original {
195 env::set_var("LC_COLLATE", val);
196 } else {
197 env::remove_var("LC_COLLATE");
198 }
199 }
200
201 #[test]
202 fn test_utf8_locale() {
203 let original = env::var("LC_COLLATE").ok();
205
206 env::set_var("LC_COLLATE", "en_US.UTF-8");
207 let config = LocaleConfig::init();
208 assert!(config.enabled);
209 assert!(config.is_utf8);
210 assert_eq!(config.locale_name, "en_US.UTF-8");
211
212 if let Some(val) = original {
214 env::set_var("LC_COLLATE", val);
215 } else {
216 env::remove_var("LC_COLLATE");
217 }
218 }
219
220 #[test]
221 fn test_strcoll_basic() {
222 let a = b"apple";
224 let b = b"banana";
225 assert_eq!(strcoll_compare(a, b), Ordering::Less);
226 assert_eq!(strcoll_compare(b, a), Ordering::Greater);
227 assert_eq!(strcoll_compare(a, a), Ordering::Equal);
228 }
229
230 #[test]
231 fn test_case_insensitive() {
232 let a = b"Apple";
233 let b = b"apple";
234 assert_eq!(strcasecoll_compare(a, b), Ordering::Equal);
235
236 let a = b"ZEBRA";
237 let b = b"aardvark";
238 assert_eq!(strcasecoll_compare(a, b), Ordering::Greater);
239 }
240}