1use std::ops::Range;
2
3use rassa_core::RassaResult;
4use rassa_unibreak::{BreakAnalysis, LineBreakOpportunity, WordBreakOpportunity, analyze_breaks};
5use unicode_bidi::{BidiClass, BidiInfo};
6
7#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
8pub enum BidiDirection {
9 #[default]
10 Neutral,
11 LeftToRight,
12 RightToLeft,
13 WeakLeftToRight,
14 WeakRightToLeft,
15}
16
17#[derive(Clone, Debug, Default, PartialEq, Eq)]
18pub struct BidiAnalysis {
19 pub direction: BidiDirection,
20 pub visual_text: String,
21 pub logical_to_visual: Vec<usize>,
22 pub visual_to_logical: Vec<usize>,
23 pub embedding_levels: Vec<u8>,
24}
25
26#[derive(Clone, Debug, Default, PartialEq, Eq)]
27pub struct TextSegment {
28 pub text: String,
29 pub byte_range: Range<usize>,
30 pub char_range: Range<usize>,
31 pub line_breaks: Vec<LineBreakOpportunity>,
32 pub word_breaks: Vec<WordBreakOpportunity>,
33}
34
35#[derive(Clone, Debug, Default, PartialEq, Eq)]
36pub struct UnicodeAnalysis {
37 pub text: String,
38 pub break_analysis: BreakAnalysis,
39 pub bidi_analysis: BidiAnalysis,
40 pub segments: Vec<TextSegment>,
41}
42
43#[derive(Default)]
44pub struct UnicodePipeline;
45
46impl UnicodePipeline {
47 pub fn analyze_text(&self, text: &str, language: Option<&str>) -> RassaResult<UnicodeAnalysis> {
48 let break_analysis = analyze_breaks(text, language)?;
49 let bidi_analysis = analyze_bidi(text)?;
50 let segments = segment_by_mandatory_breaks(text, &break_analysis);
51
52 Ok(UnicodeAnalysis {
53 text: text.to_string(),
54 break_analysis,
55 bidi_analysis,
56 segments,
57 })
58 }
59
60 pub fn segment_text(
61 &self,
62 text: &str,
63 language: Option<&str>,
64 ) -> RassaResult<Vec<TextSegment>> {
65 Ok(self.analyze_text(text, language)?.segments)
66 }
67}
68
69pub fn analyze_bidi(text: &str) -> RassaResult<BidiAnalysis> {
70 if text.is_empty() {
71 return Ok(BidiAnalysis::default());
72 }
73
74 Ok(analyze_bidi_with_unicode_bidi(text))
75}
76
77fn analyze_bidi_with_unicode_bidi(text: &str) -> BidiAnalysis {
78 let bidi_info = BidiInfo::new(text, None);
79 let Some(paragraph) = bidi_info.paragraphs.first() else {
80 return BidiAnalysis::default();
81 };
82
83 let levels = bidi_info.reordered_levels_per_char(paragraph, paragraph.range.clone());
84 let visual_to_logical = BidiInfo::reorder_visual(&levels);
85 let mut logical_to_visual = vec![0; visual_to_logical.len()];
86 for (visual_index, logical_index) in visual_to_logical.iter().copied().enumerate() {
87 if let Some(slot) = logical_to_visual.get_mut(logical_index) {
88 *slot = visual_index;
89 }
90 }
91
92 BidiAnalysis {
93 direction: first_strong_direction(&bidi_info),
94 visual_text: bidi_info
95 .reorder_line(paragraph, paragraph.range.clone())
96 .into_owned(),
97 logical_to_visual,
98 visual_to_logical,
99 embedding_levels: levels.iter().map(|level| level.number()).collect(),
100 }
101}
102
103fn first_strong_direction(bidi_info: &BidiInfo<'_>) -> BidiDirection {
104 bidi_info
105 .original_classes
106 .iter()
107 .find_map(|class| match class {
108 BidiClass::L => Some(BidiDirection::LeftToRight),
109 BidiClass::R | BidiClass::AL => Some(BidiDirection::RightToLeft),
110 _ => None,
111 })
112 .unwrap_or(BidiDirection::Neutral)
113}
114
115fn segment_by_mandatory_breaks(text: &str, analysis: &BreakAnalysis) -> Vec<TextSegment> {
116 let mut segments = Vec::new();
117 let mut byte_start = 0;
118 let mut char_start = 0;
119 let chars = text.char_indices().collect::<Vec<_>>();
120
121 for (index, (byte_index, character)) in chars.iter().copied().enumerate() {
122 let should_break = matches!(
123 analysis.line_breaks.get(index),
124 Some(LineBreakOpportunity::Mandatory)
125 );
126 if should_break {
127 let end_byte = byte_index + character.len_utf8();
128 segments.push(build_segment(
129 text,
130 analysis,
131 byte_start,
132 end_byte,
133 char_start,
134 index + 1,
135 ));
136 byte_start = end_byte;
137 char_start = index + 1;
138 }
139 }
140
141 if char_start < chars.len() || text.is_empty() {
142 segments.push(build_segment(
143 text,
144 analysis,
145 byte_start,
146 text.len(),
147 char_start,
148 chars.len(),
149 ));
150 }
151
152 segments
153 .into_iter()
154 .filter(|segment| !segment.text.is_empty() || text.is_empty())
155 .collect()
156}
157
158fn build_segment(
159 text: &str,
160 analysis: &BreakAnalysis,
161 byte_start: usize,
162 byte_end: usize,
163 char_start: usize,
164 char_end: usize,
165) -> TextSegment {
166 TextSegment {
167 text: text[byte_start..byte_end].to_string(),
168 byte_range: byte_start..byte_end,
169 char_range: char_start..char_end,
170 line_breaks: analysis.line_breaks[char_start..char_end].to_vec(),
171 word_breaks: analysis.word_breaks[char_start..char_end].to_vec(),
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn splits_text_on_mandatory_breaks() {
181 let pipeline = UnicodePipeline;
182 let segments = pipeline
183 .segment_text("alpha\nbeta", Some("en"))
184 .expect("unicode segmentation should succeed");
185
186 assert_eq!(segments.len(), 2);
187 assert_eq!(segments[0].text, "alpha\n");
188 assert_eq!(segments[1].text, "beta");
189 }
190
191 #[test]
192 fn bidi_analysis_returns_shape_metadata() {
193 let analysis = analyze_bidi("abc").expect("bidi analysis should succeed");
194
195 assert_eq!(analysis.visual_text.chars().count(), 3);
196 assert_eq!(analysis.logical_to_visual.len(), 3);
197 assert_eq!(analysis.visual_to_logical.len(), 3);
198 assert_eq!(analysis.embedding_levels.len(), 3);
199 }
200
201 #[test]
202 fn bidi_fallback_reorders_rtl_runs() {
203 let analysis = analyze_bidi_with_unicode_bidi("abc אבג");
204
205 assert_eq!(analysis.direction, BidiDirection::LeftToRight);
206 assert_eq!(analysis.visual_text, "abc גבא");
207 assert_ne!(analysis.logical_to_visual, vec![0, 1, 2, 3, 4, 5, 6]);
208 assert!(analysis.embedding_levels.iter().any(|level| *level > 0));
209 }
210
211 #[test]
212 fn bidi_fallback_detects_rtl_paragraph_direction() {
213 let analysis = analyze_bidi_with_unicode_bidi("אבג abc");
214
215 assert_eq!(analysis.direction, BidiDirection::RightToLeft);
216 assert_ne!(analysis.visual_text, "אבג abc");
217 assert!(
218 analysis
219 .embedding_levels
220 .iter()
221 .any(|level| *level % 2 == 1)
222 );
223 }
224}