1use unicode_segmentation::UnicodeSegmentation;
2
3pub const PUNCTUATION_CHARS: &[char] = &[
5 '.', ',', ';', ':', '!', '?', '(', ')', '[', ']', '{', '}', '<', '>', '\'', '"', '+', '-', '*',
6 '/', '\\', '|', '&', '%', '$', '#', '@', '~', '`', '^', '=',
7];
8
9#[derive(Default)]
11#[allow(clippy::type_complexity)]
12pub struct WordNavigationOptions<'a> {
13 pub segment: Option<&'a dyn Fn(&str) -> Vec<WordSegment>>,
16 pub is_atomic_segment: Option<&'a dyn Fn(&str) -> bool>,
19}
20
21#[derive(Debug, Clone)]
23pub struct WordSegment {
24 pub text: String,
25 pub is_word: bool,
26}
27
28impl WordSegment {
29 pub fn len(&self) -> usize {
30 self.text.len()
31 }
32
33 pub fn is_empty(&self) -> bool {
34 self.text.is_empty()
35 }
36}
37
38fn default_segment(text: &str) -> Vec<WordSegment> {
41 let mut segments: Vec<WordSegment> = Vec::new();
42
43 for grapheme in text.graphemes(true) {
44 let is_word_char = is_word_char(grapheme);
45
46 if let Some(last) = segments.last_mut()
47 && last.is_word == is_word_char
48 && !is_single_punctuation(grapheme)
49 {
50 last.text.push_str(grapheme);
51 continue;
52 }
53
54 segments.push(WordSegment {
55 text: grapheme.to_string(),
56 is_word: is_word_char,
57 });
58 }
59
60 let mut merged: Vec<WordSegment> = Vec::new();
62 for seg in segments {
63 if let Some(last) = merged.last_mut()
64 && last.is_word == seg.is_word
65 {
66 last.text.push_str(&seg.text);
67 continue;
68 }
69 merged.push(seg);
70 }
71
72 merged
73}
74
75fn get_segments<'a>(text: &'a str, options: &WordNavigationOptions<'a>) -> Vec<WordSegment> {
76 if let Some(segment_fn) = options.segment {
77 segment_fn(text)
78 } else {
79 default_segment(text)
80 }
81}
82
83fn is_atomic(segment: &str, options: &WordNavigationOptions) -> bool {
84 options
85 .is_atomic_segment
86 .is_some_and(|is_atomic| is_atomic(segment))
87}
88
89pub fn find_word_backward(text: &str, cursor: usize) -> usize {
96 find_word_backward_with(text, cursor, &WordNavigationOptions::default())
97}
98
99pub fn find_word_backward_with(
102 text: &str,
103 cursor: usize,
104 options: &WordNavigationOptions,
105) -> usize {
106 if cursor == 0 {
107 return 0;
108 }
109
110 let cursor = cursor.min(text.len());
111 let segments = get_segments(&text[..cursor], options);
112
113 if segments.is_empty() {
114 return 0;
115 }
116
117 let mut pos = cursor;
118
119 let mut i = segments.len();
121 while i > 0 {
122 i -= 1;
123 let seg = &segments[i];
124 if !is_atomic(&seg.text, options) && is_whitespace_segment(seg) {
125 pos -= seg.len();
126 } else {
127 break;
128 }
129 }
130
131 if i == 0 && !segments.is_empty() && is_whitespace_segment(&segments[0]) {
132 return pos;
133 }
134
135 if i >= segments.len() {
136 return pos;
137 }
138
139 let last = &segments[i];
140
141 if is_atomic(&last.text, options) {
142 pos -= last.text.len();
144 } else if last.is_word {
145 if let Some(punct_pos) = last.text.rfind(is_ascii_punctuation) {
147 let after_punct: String = last.text[punct_pos..].graphemes(true).take(1).collect();
148 pos -= last.text.len() - (punct_pos + after_punct.len());
149 } else {
150 pos -= last.text.len();
151 }
152 } else {
153 pos -= last.text.len();
155 while i > 0 {
156 i -= 1;
157 let seg = &segments[i];
158 if is_atomic(&seg.text, options) || seg.is_word || is_whitespace_segment(seg) {
159 break;
160 }
161 pos -= seg.text.len();
162 }
163 }
164
165 pos
166}
167
168pub fn find_word_forward(text: &str, cursor: usize) -> usize {
175 find_word_forward_with(text, cursor, &WordNavigationOptions::default())
176}
177
178pub fn find_word_forward_with(text: &str, cursor: usize, options: &WordNavigationOptions) -> usize {
180 if cursor >= text.len() {
181 return text.len();
182 }
183
184 let segments = get_segments(&text[cursor..], options);
185
186 let mut pos = cursor;
187 let mut i = 0;
188
189 while i < segments.len()
191 && !is_atomic(&segments[i].text, options)
192 && is_whitespace_segment(&segments[i])
193 {
194 pos += segments[i].text.len();
195 i += 1;
196 }
197
198 if i >= segments.len() {
199 return pos;
200 }
201
202 let first = &segments[i];
203
204 if is_atomic(&first.text, options) {
205 pos += first.text.len();
207 } else if first.is_word {
208 if let Some(punct_pos) = first.text.find(is_ascii_punctuation) {
210 let up_to_punct: String = first.text[..=punct_pos].graphemes(true).collect();
211 pos += up_to_punct.len();
212 } else {
213 pos += first.text.len();
214 }
215 } else {
216 while i < segments.len()
218 && !is_atomic(&segments[i].text, options)
219 && !segments[i].is_word
220 && !is_whitespace_segment(&segments[i])
221 {
222 pos += segments[i].text.len();
223 i += 1;
224 }
225 }
226
227 pos
228}
229
230fn is_whitespace_segment(seg: &WordSegment) -> bool {
231 !seg.is_word && seg.text.trim().is_empty()
232}
233
234fn is_word_char(grapheme: &str) -> bool {
235 grapheme.chars().any(|c| c.is_alphanumeric() || is_cjk(c))
236}
237
238fn is_cjk(c: char) -> bool {
239 let block = c as u32;
240 (0x4E00..=0x9FFF).contains(&block)
241 || (0x3040..=0x309F).contains(&block)
242 || (0x30A0..=0x30FF).contains(&block)
243 || (0xAC00..=0xD7AF).contains(&block)
244}
245
246fn is_ascii_punctuation(c: char) -> bool {
247 PUNCTUATION_CHARS.contains(&c)
248}
249
250fn is_single_punctuation(grapheme: &str) -> bool {
251 grapheme.len() == 1 && grapheme.chars().next().is_some_and(is_ascii_punctuation)
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_find_word_backward_basic() {
260 let text = "hello world";
261 assert_eq!(find_word_backward(text, 11), 6);
262 assert_eq!(find_word_backward(text, 6), 0);
263 }
264
265 #[test]
266 fn test_find_word_backward_dotted() {
267 let text = "foo.bar";
268 assert_eq!(find_word_backward(text, 7), 4);
269 assert_eq!(find_word_backward(text, 4), 3);
270 assert_eq!(find_word_backward(text, 3), 0);
271 }
272
273 #[test]
274 fn test_find_word_backward_cursor_at_zero() {
275 assert_eq!(find_word_backward("hello", 0), 0);
276 }
277
278 #[test]
279 fn test_find_word_backward_punctuation_run() {
280 let text = "foo...bar";
281 assert_eq!(find_word_backward(text, 9), 6);
282 assert_eq!(find_word_backward(text, 6), 3);
283 assert_eq!(find_word_backward(text, 3), 0);
284 }
285
286 #[test]
287 fn test_find_word_forward_basic() {
288 let text = "hello world";
289 assert_eq!(find_word_forward(text, 0), 5);
290 assert_eq!(find_word_forward(text, 5), 11);
291 }
292
293 #[test]
294 fn test_find_word_forward_dotted() {
295 let text = "foo.bar";
296 assert_eq!(find_word_forward(text, 0), 3);
297 assert_eq!(find_word_forward(text, 3), 4);
298 assert_eq!(find_word_forward(text, 4), 7);
299 }
300
301 #[test]
302 fn test_find_word_forward_cursor_at_end() {
303 assert_eq!(find_word_forward("hello", 5), 5);
304 }
305
306 #[test]
307 fn test_find_word_forward_punctuation_run() {
308 let text = "foo...bar";
309 assert_eq!(find_word_forward(text, 0), 3);
310 assert_eq!(find_word_forward(text, 3), 6);
311 assert_eq!(find_word_forward(text, 6), 9);
312 }
313
314 #[test]
315 fn test_find_word_backward_with_atomic_segment() {
316 let options = WordNavigationOptions {
317 segment: None,
318 is_atomic_segment: Some(&|s: &str| s.starts_with("[paste")),
319 };
320 let text = "hello [paste #1] world";
321 let cursor = text.len();
322 let result = find_word_backward_with(text, cursor, &options);
324 assert!(result < cursor, "Should have moved backward");
326 }
327
328 #[test]
329 fn test_find_word_forward_with_atomic_segment() {
330 let options = WordNavigationOptions {
331 segment: None,
332 is_atomic_segment: Some(&|s: &str| s.starts_with("[paste")),
333 };
334 let text = "hello [paste #1] world";
335 let cursor = 6;
337 let result = find_word_forward_with(text, cursor, &options);
338 assert!(result > cursor, "Should have moved forward past marker");
340 }
341
342 #[test]
343 fn test_punctuation_regex_matches() {
344 assert!(matches!('.', c if is_ascii_punctuation(c)));
345 assert!(matches!(',', c if is_ascii_punctuation(c)));
346 assert!(matches!(';', c if is_ascii_punctuation(c)));
347 assert!(matches!(':', c if is_ascii_punctuation(c)));
348 assert!(matches!('!', c if is_ascii_punctuation(c)));
349 assert!(matches!('?', c if is_ascii_punctuation(c)));
350 assert!(!matches!('a', c if is_ascii_punctuation(c)));
351 assert!(!matches!(' ', c if is_ascii_punctuation(c)));
352 }
353
354 #[test]
355 fn test_word_segment_empty() {
356 let ws = WordSegment {
357 text: "".to_string(),
358 is_word: false,
359 };
360 assert!(ws.is_empty());
361 assert_eq!(ws.len(), 0);
362 }
363}