Skip to main content

reddb_server/runtime/ai/
citation_parser.rs

1//! `CitationParser` — pure text-to-citations extractor.
2//!
3//! Issue #393 (PRD #391): scan an LLM-produced answer for inline
4//! `[^N]` markers and emit a structured `Vec<Citation>` plus
5//! `Vec<CitationWarning>` for anomalies. The module is pure — no I/O,
6//! no allocations beyond the result vectors, no panics on adversarial
7//! input — so it can be unit-tested in isolation and reused by every
8//! transport.
9//!
10//! ## Grammar
11//!
12//! ```text
13//! marker     = "[^" digits "]"
14//! digits     = '1'..='9' ('0'..='9')*     # N ≥ 1, no leading zero
15//! escape     = "\\[^"                       # literal `\[^…]` is NOT a marker
16//! code-fence = "```"                        # inside fences, markers are ignored
17//! ```
18//!
19//! Only ASCII digits count. `N` is parsed as `u32`; values that
20//! overflow `u32::MAX` produce a `WarningKind::Malformed` and are
21//! dropped (we don't truncate silently — a runaway value is almost
22//! certainly an LLM hallucination).
23//!
24//! `source_index` is `N - 1` (markers are 1-indexed for humans, the
25//! sources array is 0-indexed). Out-of-range indices still produce a
26//! `Citation` entry — callers decide whether to surface them — and
27//! also produce a `WarningKind::OutOfRange` for the validator path.
28//!
29//! ## Code fences
30//!
31//! Toggled on a line whose first non-whitespace bytes are ```` ``` ````.
32//! Inside a fence we skip every byte until the closing fence. Inline
33//! single-backtick spans are NOT honoured because the LLM occasionally
34//! cites things like `` `result_field` [^1] `` and we still want the
35//! citation parsed.
36//!
37//! ## Escape
38//!
39//! A backslash directly before `[` suppresses parsing: `\[^1]` is
40//! treated as literal text. We do NOT consume the backslash from the
41//! span — the parser only emits citation spans, not rewritten text.
42
43use std::ops::Range;
44
45/// A parsed `[^N]` citation marker.
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct Citation {
48    /// The number `N` as it appeared in the marker (1-indexed).
49    pub marker: u32,
50    /// Byte span of the marker inside the original text, including
51    /// both brackets.
52    pub span: Range<usize>,
53    /// `marker - 1`, intended to index into the flat sources array.
54    /// Note: this can equal or exceed the actual source count; check
55    /// `warnings` for `OutOfRange` entries before dereferencing.
56    pub source_index: u32,
57}
58
59/// A non-fatal problem encountered while scanning.
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub struct CitationWarning {
62    pub kind: CitationWarningKind,
63    pub span: Range<usize>,
64    pub detail: String,
65}
66
67#[derive(Debug, Clone, PartialEq, Eq)]
68pub enum CitationWarningKind {
69    /// Saw `[^` but the body wasn't a positive decimal terminated by `]`.
70    Malformed,
71    /// `N - 1 >= sources_count`. Always emitted in addition to the
72    /// `Citation` entry so callers can choose to suppress.
73    OutOfRange,
74}
75
76/// Parse `[^N]` citation markers out of `text`.
77///
78/// `sources_count` is used only to flag `OutOfRange` warnings; the
79/// citations themselves are returned regardless of bounds.
80pub fn parse_citations(text: &str, sources_count: usize) -> CitationParseResult {
81    let bytes = text.as_bytes();
82    let mut citations: Vec<Citation> = Vec::new();
83    let mut warnings: Vec<CitationWarning> = Vec::new();
84
85    let mut i = 0usize;
86    let mut in_fence = false;
87
88    while i < bytes.len() {
89        // Code-fence toggle: a `` ``` `` at the start of a line (after
90        // optional whitespace) flips the fence state.
91        if is_line_start(bytes, i) {
92            let line_first = first_non_ws_on_line(bytes, i);
93            if line_first + 2 < bytes.len()
94                && bytes[line_first] == b'`'
95                && bytes[line_first + 1] == b'`'
96                && bytes[line_first + 2] == b'`'
97            {
98                in_fence = !in_fence;
99                // skip past the fence marker; don't try to parse the
100                // info-string. Advance to end of line.
101                i = advance_to_newline(bytes, line_first + 3);
102                continue;
103            }
104        }
105
106        if in_fence {
107            i += 1;
108            continue;
109        }
110
111        if bytes[i] == b'[' {
112            // Escape check: preceding char is an unescaped backslash.
113            if i > 0 && bytes[i - 1] == b'\\' {
114                // Must not be `\\[` (i.e. an escaped backslash before
115                // the bracket); count backslashes.
116                let backslashes = count_preceding_backslashes(bytes, i);
117                if backslashes % 2 == 1 {
118                    i += 1;
119                    continue;
120                }
121            }
122
123            if i + 1 < bytes.len() && bytes[i + 1] == b'^' {
124                // Attempt to consume `[^digits]`.
125                match read_marker(bytes, i) {
126                    MarkerScan::Ok { marker, end } => {
127                        let span = i..end;
128                        let source_index = marker.saturating_sub(1);
129                        if (source_index as usize) >= sources_count {
130                            warnings.push(CitationWarning {
131                                kind: CitationWarningKind::OutOfRange,
132                                span: span.clone(),
133                                detail: format!(
134                                    "marker [^{marker}] references source #{} but only {} sources available",
135                                    source_index + 1,
136                                    sources_count
137                                ),
138                            });
139                        }
140                        citations.push(Citation {
141                            marker,
142                            span,
143                            source_index,
144                        });
145                        i = end;
146                        continue;
147                    }
148                    MarkerScan::Malformed { end, reason } => {
149                        warnings.push(CitationWarning {
150                            kind: CitationWarningKind::Malformed,
151                            span: i..end,
152                            detail: reason,
153                        });
154                        i = end;
155                        continue;
156                    }
157                    MarkerScan::NotAMarker => {
158                        // `[^` followed by something that can't start
159                        // a marker (e.g. `[^abc]`, `[^]`). Advance 1 so
160                        // we re-scan from the next byte.
161                        i += 1;
162                        continue;
163                    }
164                }
165            }
166        }
167
168        i += 1;
169    }
170
171    CitationParseResult {
172        citations,
173        warnings,
174    }
175}
176
177/// Outcome of `parse_citations`.
178#[derive(Debug, Clone, PartialEq, Eq, Default)]
179pub struct CitationParseResult {
180    pub citations: Vec<Citation>,
181    pub warnings: Vec<CitationWarning>,
182}
183
184enum MarkerScan {
185    Ok { marker: u32, end: usize },
186    Malformed { end: usize, reason: String },
187    NotAMarker,
188}
189
190fn read_marker(bytes: &[u8], start: usize) -> MarkerScan {
191    // Caller guarantees bytes[start] == b'[' and bytes[start+1] == b'^'.
192    let body_start = start + 2;
193    if body_start >= bytes.len() {
194        return MarkerScan::NotAMarker;
195    }
196
197    // Find the closing `]`. We accept the marker only if every byte
198    // between `[^` and `]` is an ASCII digit and the number is ≥ 1.
199    let mut j = body_start;
200    while j < bytes.len() && bytes[j] != b']' {
201        if !bytes[j].is_ascii_digit() {
202            // Recognise the `[^anything-non-digit…]` shape so we can
203            // emit a precise warning. Cap the scan at 16 bytes so a
204            // malicious input can't make us scan to EOF.
205            let mut k = body_start;
206            let mut all_inside = true;
207            while k < bytes.len() && k - body_start < 16 {
208                if bytes[k] == b']' {
209                    break;
210                }
211                k += 1;
212                if k < bytes.len() && bytes[k] == b'\n' {
213                    all_inside = false;
214                    break;
215                }
216            }
217            if all_inside && k < bytes.len() && bytes[k] == b']' {
218                return MarkerScan::Malformed {
219                    end: k + 1,
220                    reason: format!(
221                        "expected digits inside [^…], got `{}`",
222                        String::from_utf8_lossy(&bytes[body_start..k])
223                    ),
224                };
225            }
226            return MarkerScan::NotAMarker;
227        }
228        j += 1;
229    }
230    if j >= bytes.len() {
231        return MarkerScan::NotAMarker;
232    }
233    // Empty body `[^]`.
234    if j == body_start {
235        return MarkerScan::Malformed {
236            end: j + 1,
237            reason: "empty marker body".to_string(),
238        };
239    }
240    // Leading zero (e.g. `[^01]`) is not the canonical form. We accept
241    // single `0` as malformed (N ≥ 1) and reject any multi-digit value
242    // with a leading zero.
243    if bytes[body_start] == b'0' {
244        return MarkerScan::Malformed {
245            end: j + 1,
246            reason: format!(
247                "marker must be a positive integer with no leading zero, got `{}`",
248                String::from_utf8_lossy(&bytes[body_start..j])
249            ),
250        };
251    }
252
253    // Parse the digits as u32. A value that overflows u32 is treated
254    // as malformed — an LLM emitting `[^99999999999]` is almost
255    // certainly hallucinating.
256    let digits = &bytes[body_start..j];
257    let mut acc: u64 = 0;
258    for &d in digits {
259        acc = acc * 10 + (d - b'0') as u64;
260        if acc > u32::MAX as u64 {
261            return MarkerScan::Malformed {
262                end: j + 1,
263                reason: format!(
264                    "marker value `{}` exceeds u32::MAX",
265                    String::from_utf8_lossy(digits)
266                ),
267            };
268        }
269    }
270    let marker = acc as u32;
271    if marker == 0 {
272        // Defensive — should have been caught by the leading-zero check.
273        return MarkerScan::Malformed {
274            end: j + 1,
275            reason: "marker must be ≥ 1".to_string(),
276        };
277    }
278
279    MarkerScan::Ok { marker, end: j + 1 }
280}
281
282fn is_line_start(bytes: &[u8], i: usize) -> bool {
283    i == 0 || bytes[i - 1] == b'\n'
284}
285
286fn first_non_ws_on_line(bytes: &[u8], i: usize) -> usize {
287    let mut k = i;
288    while k < bytes.len() && (bytes[k] == b' ' || bytes[k] == b'\t') {
289        k += 1;
290    }
291    k
292}
293
294fn advance_to_newline(bytes: &[u8], i: usize) -> usize {
295    let mut k = i;
296    while k < bytes.len() && bytes[k] != b'\n' {
297        k += 1;
298    }
299    // Step past the newline if we're sitting on one.
300    if k < bytes.len() {
301        k + 1
302    } else {
303        k
304    }
305}
306
307fn count_preceding_backslashes(bytes: &[u8], i: usize) -> usize {
308    let mut k = i;
309    let mut count = 0;
310    while k > 0 && bytes[k - 1] == b'\\' {
311        count += 1;
312        k -= 1;
313    }
314    count
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    fn parse(text: &str, n_sources: usize) -> CitationParseResult {
322        parse_citations(text, n_sources)
323    }
324
325    #[test]
326    fn well_formed_single_marker() {
327        let r = parse("Churn was driven by pricing[^1].", 1);
328        assert_eq!(r.citations.len(), 1);
329        assert!(r.warnings.is_empty());
330        assert_eq!(r.citations[0].marker, 1);
331        assert_eq!(r.citations[0].source_index, 0);
332        // span covers `[^1]`
333        let c = &r.citations[0];
334        assert_eq!(&"Churn was driven by pricing[^1]."[c.span.clone()], "[^1]");
335    }
336
337    #[test]
338    fn well_formed_multi_digit_marker() {
339        let r = parse("see [^42] and [^1234]", 1300);
340        assert_eq!(
341            r.citations.iter().map(|c| c.marker).collect::<Vec<_>>(),
342            vec![42, 1234]
343        );
344        assert!(r.warnings.is_empty());
345    }
346
347    #[test]
348    fn repeated_markers_are_each_emitted() {
349        let r = parse("a[^1] b[^1] c[^2]", 2);
350        assert_eq!(r.citations.len(), 3);
351        assert_eq!(r.citations[0].marker, 1);
352        assert_eq!(r.citations[1].marker, 1);
353        assert_eq!(r.citations[2].marker, 2);
354        assert!(r.warnings.is_empty());
355    }
356
357    #[test]
358    fn empty_marker_body_is_malformed() {
359        let r = parse("a[^] b", 0);
360        assert!(r.citations.is_empty());
361        assert_eq!(r.warnings.len(), 1);
362        assert!(matches!(r.warnings[0].kind, CitationWarningKind::Malformed));
363    }
364
365    #[test]
366    fn non_digit_marker_is_malformed() {
367        let r = parse("see [^abc] for context", 0);
368        assert!(r.citations.is_empty());
369        assert_eq!(r.warnings.len(), 1);
370        assert!(matches!(r.warnings[0].kind, CitationWarningKind::Malformed));
371    }
372
373    #[test]
374    fn negative_looking_marker_is_malformed() {
375        let r = parse("nope[^-1]nope", 0);
376        // `-` is not a digit → malformed.
377        assert!(r.citations.is_empty());
378        assert_eq!(r.warnings.len(), 1);
379        assert!(matches!(r.warnings[0].kind, CitationWarningKind::Malformed));
380    }
381
382    #[test]
383    fn leading_zero_marker_is_malformed() {
384        let r = parse("nope[^01]nope", 5);
385        assert!(r.citations.is_empty());
386        assert_eq!(r.warnings.len(), 1);
387        assert!(matches!(r.warnings[0].kind, CitationWarningKind::Malformed));
388    }
389
390    #[test]
391    fn lone_zero_marker_is_malformed() {
392        let r = parse("nope[^0]nope", 5);
393        assert!(r.citations.is_empty());
394        assert_eq!(r.warnings.len(), 1);
395    }
396
397    #[test]
398    fn very_large_marker_within_u32() {
399        let r = parse("see [^4294967295]", 1);
400        assert_eq!(r.citations.len(), 1);
401        assert_eq!(r.citations[0].marker, u32::MAX);
402        // Out of range vs 1 source.
403        assert_eq!(r.warnings.len(), 1);
404        assert!(matches!(
405            r.warnings[0].kind,
406            CitationWarningKind::OutOfRange
407        ));
408    }
409
410    #[test]
411    fn marker_over_u32_is_malformed() {
412        let r = parse("see [^9999999999999]", 0);
413        assert!(r.citations.is_empty());
414        assert_eq!(r.warnings.len(), 1);
415        assert!(matches!(r.warnings[0].kind, CitationWarningKind::Malformed));
416    }
417
418    #[test]
419    fn escaped_marker_is_not_parsed() {
420        let r = parse(r"literal \[^1\] in text", 1);
421        assert!(r.citations.is_empty());
422        assert!(r.warnings.is_empty());
423    }
424
425    #[test]
426    fn double_backslash_does_not_escape() {
427        // `\\[^1]` — the backslash before `[` is itself escaped, so
428        // the marker should parse.
429        let r = parse(r"path\\[^1] continues", 1);
430        assert_eq!(r.citations.len(), 1);
431    }
432
433    #[test]
434    fn marker_inside_code_fence_is_ignored() {
435        let text = "before[^1]\n```\nthe code uses [^2] internally\n```\nafter[^3]";
436        let r = parse(text, 3);
437        let markers: Vec<u32> = r.citations.iter().map(|c| c.marker).collect();
438        assert_eq!(markers, vec![1, 3]);
439        assert!(r.warnings.is_empty());
440    }
441
442    #[test]
443    fn fenced_with_info_string_still_ignored() {
444        let text = "head[^1]\n```rust\nlet x = [^99];\n```\ntail[^2]";
445        let r = parse(text, 2);
446        let markers: Vec<u32> = r.citations.iter().map(|c| c.marker).collect();
447        assert_eq!(markers, vec![1, 2]);
448    }
449
450    #[test]
451    fn unicode_neighbors_are_safe() {
452        let text = "感谢[^1]谢谢";
453        let r = parse(text, 1);
454        assert_eq!(r.citations.len(), 1);
455        let span = r.citations[0].span.clone();
456        assert_eq!(&text[span], "[^1]");
457    }
458
459    #[test]
460    fn out_of_range_emits_citation_and_warning() {
461        let r = parse("see [^5] and [^1]", 2);
462        assert_eq!(r.citations.len(), 2);
463        assert_eq!(r.warnings.len(), 1);
464        assert_eq!(r.warnings[0].kind, CitationWarningKind::OutOfRange);
465        // Out-of-range citation still present so the caller can render
466        // it as a soft error.
467        assert_eq!(r.citations[0].marker, 5);
468        assert_eq!(r.citations[0].source_index, 4);
469    }
470
471    #[test]
472    fn empty_text_yields_empty_result() {
473        let r = parse("", 0);
474        assert!(r.citations.is_empty());
475        assert!(r.warnings.is_empty());
476    }
477
478    #[test]
479    fn no_panics_on_truncated_markers() {
480        // Adversarial inputs that look like the start of a marker but
481        // never close. None of these should panic or allocate
482        // unbounded.
483        for bad in ["[", "[^", "[^1", "[^123", "[^abc", "[^\n1]", "[^99"] {
484            let _ = parse(bad, 0);
485        }
486    }
487
488    #[test]
489    fn malformed_with_newline_inside_body() {
490        let r = parse("see [^12\n] here", 0);
491        // Newline aborts the scan; nothing emitted.
492        assert!(r.citations.is_empty());
493        assert!(r.warnings.is_empty());
494    }
495
496    #[test]
497    fn back_to_back_markers() {
498        let r = parse("[^1][^2][^3]", 3);
499        assert_eq!(
500            r.citations.iter().map(|c| c.marker).collect::<Vec<_>>(),
501            vec![1, 2, 3]
502        );
503        assert!(r.warnings.is_empty());
504    }
505}