gnu_sort/
locale.rs

1//! Locale-aware string comparison support for LC_COLLATE
2//!
3//! This module provides locale-aware string comparison using the system's
4//! strcoll function, respecting the LC_COLLATE environment variable.
5
6use std::cmp::Ordering;
7use std::env;
8use std::ffi::CString;
9use std::sync::OnceLock;
10
11/// Global locale configuration
12static LOCALE_CONFIG: OnceLock<LocaleConfig> = OnceLock::new();
13
14/// Locale configuration for string comparison
15#[derive(Debug, Clone)]
16pub struct LocaleConfig {
17    /// Whether locale-aware comparison is enabled
18    pub enabled: bool,
19    /// The current locale name
20    pub locale_name: String,
21    /// Whether the locale is UTF-8
22    pub is_utf8: bool,
23}
24
25impl LocaleConfig {
26    /// Initialize locale configuration from environment
27    pub fn init() -> Self {
28        // Get LC_COLLATE or LC_ALL or LANG
29        let locale = env::var("LC_COLLATE")
30            .or_else(|_| env::var("LC_ALL"))
31            .or_else(|_| env::var("LANG"))
32            .unwrap_or_else(|_| "C".to_string());
33
34        // Check if locale is C or POSIX (byte comparison)
35        let enabled = !locale.is_empty() && locale != "C" && locale != "POSIX";
36        let is_utf8 = locale.contains("UTF-8") || locale.contains("utf8");
37
38        // Set locale for strcoll
39        if enabled {
40            unsafe {
41                let locale_cstr =
42                    CString::new(locale.clone()).unwrap_or_else(|_| CString::new("C").unwrap());
43                libc::setlocale(libc::LC_COLLATE, locale_cstr.as_ptr());
44            }
45        }
46
47        Self {
48            enabled,
49            locale_name: locale,
50            is_utf8,
51        }
52    }
53
54    /// Get the global locale configuration
55    pub fn get() -> &'static LocaleConfig {
56        LOCALE_CONFIG.get_or_init(Self::init)
57    }
58
59    /// Check if locale-aware comparison is enabled
60    pub fn is_enabled() -> bool {
61        Self::get().enabled
62    }
63}
64
65/// Locale-aware string comparison using strcoll
66pub fn strcoll_compare(a: &[u8], b: &[u8]) -> Ordering {
67    // Fast path for identical strings
68    if a == b {
69        return Ordering::Equal;
70    }
71
72    // Convert to null-terminated C strings
73    // For non-UTF8 locales, we need to handle invalid sequences
74    let a_str = match std::str::from_utf8(a) {
75        Ok(s) => s,
76        Err(_) => {
77            // Fallback to byte comparison for invalid UTF-8
78            return a.cmp(b);
79        }
80    };
81
82    let b_str = match std::str::from_utf8(b) {
83        Ok(s) => s,
84        Err(_) => {
85            // Fallback to byte comparison for invalid UTF-8
86            return a.cmp(b);
87        }
88    };
89
90    // Create C strings
91    let a_cstr = match CString::new(a_str) {
92        Ok(s) => s,
93        Err(_) => {
94            // String contains null bytes, fallback to byte comparison
95            return a.cmp(b);
96        }
97    };
98
99    let b_cstr = match CString::new(b_str) {
100        Ok(s) => s,
101        Err(_) => {
102            // String contains null bytes, fallback to byte comparison
103            return a.cmp(b);
104        }
105    };
106
107    // Call strcoll for locale-aware comparison
108    unsafe {
109        let result = libc::strcoll(a_cstr.as_ptr(), b_cstr.as_ptr());
110        match result {
111            x if x < 0 => Ordering::Less,
112            x if x > 0 => Ordering::Greater,
113            _ => Ordering::Equal,
114        }
115    }
116}
117
118/// Case-insensitive locale-aware comparison using strcasecoll (if available)
119/// Falls back to lowercasing + strcoll if strcasecoll is not available
120pub fn strcasecoll_compare(a: &[u8], b: &[u8]) -> Ordering {
121    // Fast path for identical strings
122    if a == b {
123        return Ordering::Equal;
124    }
125
126    // Convert to strings
127    let a_str = match std::str::from_utf8(a) {
128        Ok(s) => s,
129        Err(_) => return case_insensitive_byte_compare(a, b),
130    };
131
132    let b_str = match std::str::from_utf8(b) {
133        Ok(s) => s,
134        Err(_) => return case_insensitive_byte_compare(a, b),
135    };
136
137    // Convert to lowercase for case-insensitive comparison
138    let a_lower = a_str.to_lowercase();
139    let b_lower = b_str.to_lowercase();
140
141    // Use strcoll on lowercased strings
142    strcoll_compare(a_lower.as_bytes(), b_lower.as_bytes())
143}
144
145/// Fallback case-insensitive byte comparison
146fn case_insensitive_byte_compare(a: &[u8], b: &[u8]) -> Ordering {
147    let len = a.len().min(b.len());
148
149    for i in 0..len {
150        let ca = a[i].to_ascii_lowercase();
151        let cb = b[i].to_ascii_lowercase();
152        match ca.cmp(&cb) {
153            Ordering::Equal => continue,
154            other => return other,
155        }
156    }
157
158    a.len().cmp(&b.len())
159}
160
161/// Smart comparison that chooses between locale-aware and byte comparison
162pub fn smart_compare(a: &[u8], b: &[u8], ignore_case: bool) -> Ordering {
163    if LocaleConfig::is_enabled() {
164        if ignore_case {
165            strcasecoll_compare(a, b)
166        } else {
167            strcoll_compare(a, b)
168        }
169    } else {
170        // Fast path: byte comparison for C/POSIX locale
171        if ignore_case {
172            case_insensitive_byte_compare(a, b)
173        } else {
174            a.cmp(b)
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_c_locale() {
185        // Save original env var
186        let original = env::var("LC_COLLATE").ok();
187
188        env::set_var("LC_COLLATE", "C");
189        let config = LocaleConfig::init();
190        assert!(!config.enabled);
191        assert_eq!(config.locale_name, "C");
192
193        // Restore original
194        if let Some(val) = original {
195            env::set_var("LC_COLLATE", val);
196        } else {
197            env::remove_var("LC_COLLATE");
198        }
199    }
200
201    #[test]
202    fn test_utf8_locale() {
203        // Save original env var
204        let original = env::var("LC_COLLATE").ok();
205
206        env::set_var("LC_COLLATE", "en_US.UTF-8");
207        let config = LocaleConfig::init();
208        assert!(config.enabled);
209        assert!(config.is_utf8);
210        assert_eq!(config.locale_name, "en_US.UTF-8");
211
212        // Restore original
213        if let Some(val) = original {
214            env::set_var("LC_COLLATE", val);
215        } else {
216            env::remove_var("LC_COLLATE");
217        }
218    }
219
220    #[test]
221    fn test_strcoll_basic() {
222        // Test basic ASCII comparison
223        let a = b"apple";
224        let b = b"banana";
225        assert_eq!(strcoll_compare(a, b), Ordering::Less);
226        assert_eq!(strcoll_compare(b, a), Ordering::Greater);
227        assert_eq!(strcoll_compare(a, a), Ordering::Equal);
228    }
229
230    #[test]
231    fn test_case_insensitive() {
232        let a = b"Apple";
233        let b = b"apple";
234        assert_eq!(strcasecoll_compare(a, b), Ordering::Equal);
235
236        let a = b"ZEBRA";
237        let b = b"aardvark";
238        assert_eq!(strcasecoll_compare(a, b), Ordering::Greater);
239    }
240}