Skip to main content

rosetta_date/
nlp.rs

1//! Natural language date parsing: tokenizer, relative & absolute time parsing.
2
3use crate::datetime::RosettaDateTime;
4use crate::error::{Result, RosettaError};
5use crate::i18n::{self, Direction, LanguageData, SpecialWord, TimeUnit};
6
7// ── Public API ────────────────────────────────────────────────────────
8
9/// Attempt to parse a natural-language date/time expression.
10///
11/// `base` is the reference time (usually "now") against which relative
12/// expressions like "2 hours ago" are resolved.
13pub fn parse_nlp(
14    input: &str,
15    base: &RosettaDateTime,
16    languages: Option<&[&LanguageData]>,
17) -> Result<RosettaDateTime> {
18    let input_clean = input.trim();
19    let input_clean_lower = input_clean.to_lowercase();
20
21    // Determine which languages to try
22    let available = i18n::available_languages();
23    let langs: Vec<&LanguageData> = match languages {
24        Some(l) => l.to_vec(),
25        None => {
26            // Auto-detect or try all
27            let detected = i18n::detect_language(&input_clean_lower);
28            if !detected.is_empty() {
29                detected
30            } else {
31                available.to_vec()
32            }
33        }
34    };
35
36    for lang in &langs {
37        // Try special standalone words first (yesterday, today, tomorrow, etc.)
38        if let Some(dt) = try_special_words(&input_clean_lower, base, lang) {
39            return Ok(dt);
40        }
41        // Try relative time expressions (2 hours ago, 3天前, in 5 days)
42        if let Some(dt) = try_relative_parse(&input_clean_lower, base, lang) {
43            return Ok(dt);
44        }
45        // Try modifier + unit (last week, next month, 上周, 下个月)
46        if let Some(dt) = try_modifier_unit(&input_clean_lower, base, lang) {
47            return Ok(dt);
48        }
49        // Try absolute month/day parsing (January 1st 2023, 1月15日)
50        if let Some(dt) = try_absolute_nlp(&input_clean_lower, base, lang) {
51            return Ok(dt);
52        }
53    }
54
55    Err(RosettaError::ParseError(format!(
56        "Could not parse natural language date: '{}'",
57        input
58    )))
59}
60
61// ── Special Words ─────────────────────────────────────────────────────
62
63fn try_special_words(
64    input: &str,
65    base: &RosettaDateTime,
66    lang: &LanguageData,
67) -> Option<RosettaDateTime> {
68    let input = input.trim();
69
70    for &(word_type, words) in lang.special_words {
71        for &w in words {
72            if input == w {
73                let dt = match word_type {
74                    SpecialWord::Now => base.clone(),
75                    SpecialWord::Yesterday => base.clone().add_days(-1),
76                    SpecialWord::Today => base.clone(),
77                    SpecialWord::Tomorrow => base.clone().add_days(1),
78                    SpecialWord::DayBeforeYesterday => base.clone().add_days(-2),
79                    SpecialWord::DayAfterTomorrow => base.clone().add_days(2),
80                };
81                return Some(dt);
82            }
83        }
84    }
85    None
86}
87
88// ── Relative Time ─────────────────────────────────────────────────────
89
90pub fn try_relative_parse(
91    input: &str,
92    base: &RosettaDateTime,
93    lang: &LanguageData,
94) -> Option<RosettaDateTime> {
95    // Pattern 1: "<number> <unit> ago/前"  (e.g. "2 hours ago", "3天前")
96    // Pattern 2: "in <number> <unit>"      (e.g. "in 5 days")
97    // Pattern 3: "<number> <unit> later/后" (e.g. "5 days later", "3天后")
98
99    // First check for direction keywords to know if past/future
100    let mut direction: Option<Direction> = None;
101
102    // Check for "ago" / "前" style suffixes → Past
103    for &ago in lang.ago_words {
104        if input.ends_with(&ago) || input.contains(ago) {
105            direction = Some(Direction::Past);
106            break;
107        }
108    }
109
110    // Check for "from now" / "later" / "后" style suffixes → Future
111    if direction.is_none() {
112        for &fw in lang.future_words {
113            if input.ends_with(&fw) || input.contains(fw) {
114                direction = Some(Direction::Future);
115                break;
116            }
117        }
118    }
119
120    // Check for "in" prefix → Future
121    if direction.is_none() {
122        for &prefix in lang.future_prefix {
123            let p = prefix;
124            if input.starts_with(p) && input.len() > p.len() {
125                let next_char = input[p.len()..].chars().next();
126                if next_char == Some(' ') || next_char.is_some_and(|c| !c.is_ascii_alphabetic()) {
127                    direction = Some(Direction::Future);
128                    break;
129                }
130            }
131        }
132    }
133
134    // Check for "hace"/"il y a"/"vor" prefix → Past
135    if direction.is_none() {
136        for &prefix in lang.past_prefix {
137            let p = prefix;
138            if input.starts_with(p) && input.len() > p.len() {
139                let next_char = input[p.len()..].chars().next();
140                if next_char == Some(' ') || next_char.is_some_and(|c| !c.is_ascii_alphabetic()) {
141                    direction = Some(Direction::Past);
142                    break;
143                }
144            }
145        }
146    }
147
148    let direction = direction?;
149
150    // Now find a number and a time unit in the string
151    let (number, unit) = extract_number_and_unit(input, lang)?;
152
153    let seconds = unit_to_seconds(unit, number);
154    let multiplier = match direction {
155        Direction::Past => -1i64,
156        Direction::Future => 1i64,
157    };
158
159    Some(base.clone().add_seconds(seconds * multiplier))
160}
161
162// ── Modifier + Unit ───────────────────────────────────────────────────
163
164pub fn try_modifier_unit(
165    input: &str,
166    base: &RosettaDateTime,
167    lang: &LanguageData,
168) -> Option<RosettaDateTime> {
169    // "last week" / "上周" → base - 1 week
170    // "next month" / "下个月" → base + 1 month
171    // "this year" / "今年" → base (no change)
172
173    let mut modifier: Option<i64> = None;
174
175    for &w in lang.last_words {
176        if input.contains(w) {
177            modifier = Some(-1);
178            break;
179        }
180    }
181    if modifier.is_none() {
182        for &w in lang.next_words {
183            if input.contains(w) {
184                modifier = Some(1);
185                break;
186            }
187        }
188    }
189    if modifier.is_none() {
190        for &w in lang.this_words {
191            if input.contains(w) {
192                modifier = Some(0);
193                break;
194            }
195        }
196    }
197
198    let modifier = modifier?;
199
200    // Find the unit
201    for &(unit, keywords) in lang.time_units {
202        for &kw in keywords {
203            if input.contains(kw) {
204                let secs = unit_to_seconds(unit, modifier.unsigned_abs() as i64);
205                return Some(base.clone().add_seconds(secs * modifier.signum()));
206            }
207        }
208    }
209
210    None
211}
212
213// ── Absolute NLP ──────────────────────────────────────────────────────
214
215pub fn try_absolute_nlp(
216    input: &str,
217    base: &RosettaDateTime,
218    lang: &LanguageData,
219) -> Option<RosettaDateTime> {
220    // Try to find a month name in the string
221    let mut found_month: Option<u8> = None;
222    let mut month_pos: usize = 0;
223    let mut month_len: usize = 0;
224
225    // Check long month names first (more specific)
226    for (i, &name) in lang.months_long.iter().enumerate() {
227        let name_lower = name;
228        if let Some(pos) = find_word(input, name_lower)
229            && !is_direction_word(name_lower, lang)
230        {
231            found_month = Some((i + 1) as u8);
232            month_pos = pos;
233            month_len = name.len();
234            break;
235        }
236    }
237
238    // Fall back to short month names
239    if found_month.is_none() {
240        for (i, &name) in lang.months_short.iter().enumerate() {
241            let name_lower = name;
242            if !name_lower.is_empty()
243                && let Some(pos) = find_word(input, name_lower)
244                && !is_direction_word(name_lower, lang)
245            {
246                found_month = Some((i + 1) as u8);
247                month_pos = pos;
248                month_len = name.len();
249                break;
250            }
251        }
252    }
253
254    let month = found_month?;
255
256    // Extract time pattern (HH:MM or HH:MM:SS) from the full string
257    let (time_hour, time_minute, time_second, time_range) = extract_time_pattern(input, lang);
258
259    // Extract numbers from the string, excluding the time pattern range
260    let before_month = &input[..month_pos];
261    let after_month = &input[month_pos + month_len..];
262
263    let numbers_before = extract_all_numbers(before_month);
264    let numbers_after =
265        extract_all_numbers_excluding(after_month, &time_range, month_pos + month_len);
266
267    let all_numbers: Vec<i64> = numbers_before
268        .iter()
269        .chain(numbers_after.iter())
270        .copied()
271        .collect();
272
273    // Determine day and year from context
274    let (day, year) = match all_numbers.len() {
275        0 => (1u8, base.year()),
276        1 => {
277            let n = all_numbers[0];
278            if n > 31 {
279                (1u8, n as i32) // it's a year
280            } else {
281                (n as u8, base.year()) // it's a day
282            }
283        }
284        _ => {
285            // Two or more numbers: figure out which is day and which is year
286            let (mut day, mut year) = (1u8, base.year());
287            for &n in &all_numbers {
288                if n > 31 {
289                    year = n as i32;
290                } else if day == 1 {
291                    day = n as u8;
292                }
293            }
294            (day, year)
295        }
296    };
297
298    let hour = time_hour.unwrap_or(0);
299    let minute = time_minute.unwrap_or(0);
300    let second = time_second.unwrap_or(0);
301
302    RosettaDateTime::from_components(year, month, day, hour, minute, second, base.offset()).ok()
303}
304
305/// Extract a time pattern (HH:MM or HH:MM:SS) from a string.
306/// Also handles AM/PM. Returns (hour, minute, second, byte_range_in_original).
307fn extract_time_pattern(
308    input: &str,
309    lang: &LanguageData,
310) -> (
311    Option<u8>,
312    Option<u8>,
313    Option<u8>,
314    Option<std::ops::Range<usize>>,
315) {
316    // Scan for digit:digit pattern
317    let bytes = input.as_bytes();
318    let len = bytes.len();
319    let mut i = 0;
320
321    while i < len {
322        // Find a ':' that has digits on both sides
323        if bytes[i] == b':' && i > 0 && i + 1 < len {
324            // Find hour digits before ':'
325            let mut h_start = i;
326            while h_start > 0 && bytes[h_start - 1].is_ascii_digit() {
327                h_start -= 1;
328            }
329            if h_start == i {
330                i += 1;
331                continue;
332            }
333
334            // Find minute digits after ':'
335            let mut m_end = i + 1;
336            while m_end < len && bytes[m_end].is_ascii_digit() {
337                m_end += 1;
338            }
339            if m_end == i + 1 {
340                i += 1;
341                continue;
342            }
343
344            let h_str = &input[h_start..i];
345            let m_str = &input[i + 1..m_end];
346
347            if let (Ok(h), Ok(m)) = (h_str.parse::<u8>(), m_str.parse::<u8>())
348                && h <= 23
349                && m <= 59
350            {
351                // Check for seconds
352                let (s, range_end) = if m_end < len && bytes[m_end] == b':' {
353                    let mut s_end = m_end + 1;
354                    while s_end < len && bytes[s_end].is_ascii_digit() {
355                        s_end += 1;
356                    }
357                    let s_str = &input[m_end + 1..s_end];
358                    if let Ok(s) = s_str.parse::<u8>() {
359                        (Some(s), s_end)
360                    } else {
361                        (None, m_end)
362                    }
363                } else {
364                    (None, m_end)
365                };
366
367                // Check for AM/PM after the time
368                let after_time = &input[range_end..].trim_start();
369                let mut hour = h;
370
371                for &pm in lang.pm_indicators {
372                    if after_time.starts_with(pm) {
373                        if hour < 12 {
374                            hour += 12;
375                        }
376                        break;
377                    }
378                }
379                for &am in lang.am_indicators {
380                    if after_time.starts_with(am) {
381                        if hour == 12 {
382                            hour = 0;
383                        }
384                        break;
385                    }
386                }
387
388                return (Some(hour), Some(m), s, Some(h_start..range_end));
389            }
390        }
391        i += 1;
392    }
393
394    (None, None, None, None)
395}
396
397/// Extract all numbers from a string, excluding those within a given byte range.
398fn extract_all_numbers_excluding(
399    s: &str,
400    exclude_range: &Option<std::ops::Range<usize>>,
401    offset: usize, // byte offset of `s` within the original string
402) -> Vec<i64> {
403    let mut result = Vec::new();
404    let mut current = String::new();
405    let mut byte_pos = 0;
406
407    for c in s.chars() {
408        let char_len = c.len_utf8();
409        let abs_pos = offset + byte_pos;
410
411        if c.is_ascii_digit() {
412            // Check if this position falls within the excluded range
413            let excluded = exclude_range
414                .as_ref()
415                .is_some_and(|r| abs_pos >= r.start && abs_pos < r.end);
416            if excluded {
417                if !current.is_empty() {
418                    if let Ok(n) = current.parse::<i64>() {
419                        result.push(n);
420                    }
421                    current.clear();
422                }
423            } else {
424                current.push(c);
425            }
426        } else {
427            if !current.is_empty() {
428                if let Ok(n) = current.parse::<i64>() {
429                    result.push(n);
430                }
431                current.clear();
432            }
433        }
434        byte_pos += char_len;
435    }
436    if !current.is_empty()
437        && let Ok(n) = current.parse::<i64>()
438    {
439        result.push(n);
440    }
441    result
442}
443
444// ── Helpers ───────────────────────────────────────────────────────────
445
446/// Extract a number and a time unit from the input string.
447fn extract_number_and_unit(input: &str, lang: &LanguageData) -> Option<(i64, TimeUnit)> {
448    // Try to find a time unit keyword
449    for &(unit, keywords) in lang.time_units {
450        // Sort keywords by length desc so we match longest first
451        let mut kw_sorted: Vec<&str> = keywords.to_vec();
452        kw_sorted.sort_by_key(|b| std::cmp::Reverse(b.len()));
453
454        for kw in kw_sorted {
455            let kw_lower = kw;
456            if let Some(pos) = input.find(kw_lower) {
457                // Look for a number before the unit
458                let before = &input[..pos];
459                if let Some(n) = extract_number_from_end(before.trim_end(), lang) {
460                    return Some((n, unit));
461                }
462                // Look for a number after the unit
463                let after = &input[pos + kw.len()..];
464                if let Some(n) = extract_number_from_start(after.trim_start(), lang) {
465                    return Some((n, unit));
466                }
467            }
468        }
469    }
470    None
471}
472
473/// Extract a number from the end of a string.
474fn extract_number_from_end(s: &str, lang: &LanguageData) -> Option<i64> {
475    if s.is_empty() {
476        return None;
477    }
478
479    // Try Arabic digits at the end
480    let digit_str: String = s
481        .chars()
482        .rev()
483        .take_while(|c| c.is_ascii_digit())
484        .collect::<String>()
485        .chars()
486        .rev()
487        .collect();
488    if !digit_str.is_empty() {
489        return digit_str.parse().ok();
490    }
491
492    // Try number words (check from longest match)
493    try_number_word(s, lang)
494}
495
496/// Extract a number from the start of a string.
497fn extract_number_from_start(s: &str, lang: &LanguageData) -> Option<i64> {
498    if s.is_empty() {
499        return None;
500    }
501
502    // Try Arabic digits at the start
503    let digit_str: String = s.chars().take_while(|c| c.is_ascii_digit()).collect();
504    if !digit_str.is_empty() {
505        return digit_str.parse().ok();
506    }
507
508    // Try number words
509    try_number_word(s, lang)
510}
511
512/// Try matching number words from the language data.
513fn try_number_word(s: &str, lang: &LanguageData) -> Option<i64> {
514    let trimmed = s.trim();
515
516    // Sort by length desc for longest-match-first
517    let mut words: Vec<(&str, i64)> = lang.number_words.to_vec();
518    words.sort_by_key(|a| std::cmp::Reverse(a.0.len()));
519
520    for (word, val) in words {
521        let w = word;
522        if trimmed == w {
523            return Some(val);
524        }
525        if let Some(stripped) = trimmed.strip_prefix(w) {
526            if let Some(next_char) = stripped.chars().next() {
527                if !next_char.is_ascii_alphabetic() {
528                    return Some(val);
529                }
530            } else {
531                return Some(val);
532            }
533        }
534        if trimmed.ends_with(&w) {
535            let prev_str = &trimmed[..trimmed.len() - w.len()];
536            if let Some(prev_char) = prev_str.chars().last() {
537                if !prev_char.is_ascii_alphabetic() {
538                    return Some(val);
539                }
540            } else {
541                return Some(val);
542            }
543        }
544    }
545
546    // Handle Chinese composite numbers like "十二" (12), "二十" (20), "二十三" (23)
547    if let Some(n) = try_chinese_number(trimmed, lang) {
548        return Some(n);
549    }
550
551    None
552}
553
554/// Parse simple Chinese composite numbers (e.g. 十二=12, 二十三=23).
555fn try_chinese_number(s: &str, lang: &LanguageData) -> Option<i64> {
556    let chars: Vec<char> = s.chars().collect();
557    if chars.is_empty() {
558        return None;
559    }
560
561    // Build a lookup from character to value
562    let lookup = |c: char| -> Option<i64> {
563        for &(word, val) in lang.number_words {
564            let wc: Vec<char> = word.chars().collect();
565            if wc.len() == 1 && wc[0] == c {
566                return Some(val);
567            }
568        }
569        None
570    };
571
572    if chars.len() == 1 {
573        return lookup(chars[0]);
574    }
575
576    // Two-char patterns:
577    // "十X" = 10 + X,  "X十" = X * 10
578    if chars.len() == 2 {
579        let a = lookup(chars[0]);
580        let b = lookup(chars[1]);
581        match (a, b) {
582            (Some(10), Some(bv)) if bv < 10 => return Some(10 + bv), // 十二 = 12
583            (Some(av), Some(10)) if av < 10 => return Some(av * 10), // 二十 = 20
584            _ => {}
585        }
586    }
587
588    // Three-char: "X十Y" = X * 10 + Y
589    if chars.len() == 3 {
590        let a = lookup(chars[0]);
591        let b = lookup(chars[1]);
592        let c = lookup(chars[2]);
593        if let (Some(av), Some(10), Some(cv)) = (a, b, c)
594            && av < 10
595            && cv < 10
596        {
597            return Some(av * 10 + cv); // 二十三 = 23
598        }
599    }
600
601    None
602}
603
604/// Extract all plain integer numbers from a string.
605fn extract_all_numbers(s: &str) -> Vec<i64> {
606    let mut result = Vec::new();
607    let mut current = String::new();
608
609    for c in s.chars() {
610        if c.is_ascii_digit() {
611            current.push(c);
612        } else {
613            if !current.is_empty() {
614                if let Ok(n) = current.parse::<i64>() {
615                    result.push(n);
616                }
617                current.clear();
618            }
619        }
620    }
621    if !current.is_empty()
622        && let Ok(n) = current.parse::<i64>()
623    {
624        result.push(n);
625    }
626    result
627}
628
629/// Convert a time unit + count to total seconds.
630fn unit_to_seconds(unit: TimeUnit, count: i64) -> i64 {
631    let count = if count == -1 { 1 } else { count }; // -1 is "half", approximate as 1
632    match unit {
633        TimeUnit::Second => count,
634        TimeUnit::Minute => count * 60,
635        TimeUnit::Hour => count * 3600,
636        TimeUnit::Day => count * 86400,
637        TimeUnit::Week => count * 7 * 86400,
638        TimeUnit::Month => count * 30 * 86400,
639        TimeUnit::Year => count * 365 * 86400,
640    }
641}
642
643/// Returns true if `word` is also used as a past/future direction indicator
644/// in this language (e.g. Portuguese "ago" = August AND past indicator).
645/// Such words must not be treated as month names in absolute NLP parsing.
646fn is_direction_word(word: &str, lang: &LanguageData) -> bool {
647    let w = word;
648    for &ago in lang.ago_words {
649        if w == ago {
650            return true;
651        }
652    }
653    for &fw in lang.future_words {
654        if w == fw {
655            return true;
656        }
657    }
658    for &pp in lang.past_prefix {
659        if w == pp {
660            return true;
661        }
662    }
663    for &fp in lang.future_prefix {
664        if w == fp {
665            return true;
666        }
667    }
668    false
669}
670
671/// Find `needle` in `haystack` only at word boundaries (surrounded by
672/// non-alphabetic chars or at start/end of string).  Returns the byte
673/// position of the first match, or `None`.
674fn find_word(haystack: &str, needle: &str) -> Option<usize> {
675    if needle.is_empty() {
676        return None;
677    }
678    let bytes = haystack.as_bytes();
679    let nbytes = needle.as_bytes();
680    let nlen = nbytes.len();
681
682    let mut i = 0;
683    while i + nlen <= bytes.len() {
684        if &bytes[i..i + nlen] == nbytes {
685            let before_ok = if i == 0 {
686                true
687            } else {
688                !bytes[i - 1].is_ascii_alphabetic()
689            };
690            let after_ok = if i + nlen >= bytes.len() {
691                true
692            } else {
693                !bytes[i + nlen].is_ascii_alphabetic()
694            };
695            if before_ok && after_ok {
696                return Some(i);
697            }
698        }
699        i += 1;
700    }
701    None
702}
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707
708    fn base_time() -> RosettaDateTime {
709        // 2023-10-15 12:00:00 UTC
710        RosettaDateTime::from_components(2023, 10, 15, 12, 0, 0, crate::timezone::TzOffset::UTC)
711            .unwrap()
712    }
713
714    #[cfg(feature = "lang-en")]
715    #[test]
716    fn test_english_yesterday() {
717        let base = base_time();
718        let result = parse_nlp("yesterday", &base, None).unwrap();
719        assert_eq!(result.day(), 14);
720        assert_eq!(result.month(), 10);
721    }
722
723    #[cfg(feature = "lang-en")]
724    #[test]
725    fn test_english_tomorrow() {
726        let base = base_time();
727        let result = parse_nlp("tomorrow", &base, None).unwrap();
728        assert_eq!(result.day(), 16);
729    }
730
731    #[cfg(feature = "lang-en")]
732    #[test]
733    fn test_english_hours_ago() {
734        let base = base_time();
735        let result = parse_nlp("2 hours ago", &base, None).unwrap();
736        assert_eq!(result.hour(), 10);
737    }
738
739    #[cfg(feature = "lang-en")]
740    #[test]
741    fn test_english_in_5_days() {
742        let base = base_time();
743        let result = parse_nlp("in 5 days", &base, None).unwrap();
744        assert_eq!(result.day(), 20);
745    }
746
747    #[cfg(feature = "lang-en")]
748    #[test]
749    fn test_english_3_weeks_ago() {
750        let base = base_time();
751        let result = parse_nlp("3 weeks ago", &base, None).unwrap();
752        // 15 - 21 = Sept 24
753        assert_eq!(result.month(), 9);
754        assert_eq!(result.day(), 24);
755    }
756
757    #[cfg(feature = "lang-en")]
758    #[test]
759    fn test_english_month_day() {
760        let base = base_time();
761        let result = parse_nlp("January 15", &base, None).unwrap();
762        assert_eq!(result.month(), 1);
763        assert_eq!(result.day(), 15);
764    }
765
766    #[cfg(feature = "lang-en")]
767    #[test]
768    fn test_english_month_day_year() {
769        let base = base_time();
770        let result = parse_nlp("March 5 2025", &base, None).unwrap();
771        assert_eq!(result.month(), 3);
772        assert_eq!(result.day(), 5);
773        assert_eq!(result.year(), 2025);
774    }
775
776    #[cfg(feature = "lang-zh")]
777    #[test]
778    fn test_chinese_yesterday() {
779        let base = base_time();
780        let result = parse_nlp("昨天", &base, None).unwrap();
781        assert_eq!(result.day(), 14);
782    }
783
784    #[cfg(feature = "lang-zh")]
785    #[test]
786    fn test_chinese_3_hours_ago() {
787        let base = base_time();
788        let result = parse_nlp("3小时前", &base, None).unwrap();
789        assert_eq!(result.hour(), 9);
790    }
791
792    #[cfg(feature = "lang-zh")]
793    #[test]
794    fn test_chinese_day_before_yesterday() {
795        let base = base_time();
796        let result = parse_nlp("前天", &base, None).unwrap();
797        assert_eq!(result.day(), 13);
798    }
799}