rlx_locateanything/
parse.rs1use std::collections::HashSet;
19
20#[derive(Debug, Clone, PartialEq)]
22pub struct ParsedBox {
23 pub x1: f32,
24 pub y1: f32,
25 pub x2: f32,
26 pub y2: f32,
27}
28
29#[derive(Debug, Clone, PartialEq)]
31pub struct ParsedPoint {
32 pub x: f32,
33 pub y: f32,
34}
35
36const COORD_SCALE: f32 = 1000.0;
38
39pub fn parse_boxes(answer: &str, image_width: u32, image_height: u32) -> Vec<ParsedBox> {
41 let w = image_width as f32;
42 let h = image_height as f32;
43 let mut out = Vec::new();
44 let mut rest = answer;
45 while let Some(start) = rest.find("<box><") {
46 let after = &rest[start + 6..];
47 let Some(end) = after.find("></box>") else {
48 break;
49 };
50 let inner = &after[..end];
51 let nums: Vec<u32> = inner
52 .split("><")
53 .filter_map(|s| s.trim_matches(|c| c == '<' || c == '>').parse().ok())
54 .collect();
55 if nums.len() == 4 {
56 let (x1, y1, x2, y2) = (nums[0], nums[1], nums[2], nums[3]);
57 out.push(ParsedBox {
58 x1: x1 as f32 / COORD_SCALE * w,
59 y1: y1 as f32 / COORD_SCALE * h,
60 x2: x2 as f32 / COORD_SCALE * w,
61 y2: y2 as f32 / COORD_SCALE * h,
62 });
63 }
64 rest = &after[end + 7..];
65 }
66 out
67}
68
69pub fn parse_points(answer: &str, image_width: u32, image_height: u32) -> Vec<ParsedPoint> {
71 let w = image_width as f32;
72 let h = image_height as f32;
73 let mut out = Vec::new();
74 let mut rest = answer;
75 while let Some(start) = rest.find("<box><") {
76 let after = &rest[start + 6..];
77 let Some(end) = after.find("></box>") else {
78 break;
79 };
80 let inner = &after[..end];
81 let nums: Vec<u32> = inner
82 .split("><")
83 .filter_map(|s| s.trim_matches(|c| c == '<' || c == '>').parse().ok())
84 .collect();
85 if nums.len() == 2 {
86 out.push(ParsedPoint {
87 x: nums[0] as f32 / COORD_SCALE * w,
88 y: nums[1] as f32 / COORD_SCALE * h,
89 });
90 }
91 rest = &after[end + 7..];
92 }
93 out
94}
95
96pub fn parse_refs(answer: &str) -> Vec<String> {
98 let mut out = Vec::new();
99 let mut rest = answer;
100 while let Some(start) = rest.find("<ref>") {
101 let after = &rest[start + 5..];
102 let Some(end) = after.find("</ref>") else {
103 break;
104 };
105 out.push(after[..end].to_string());
106 rest = &after[end + 6..];
107 }
108 out
109}
110
111fn bracket_coord_nums(s: &str) -> Vec<u32> {
113 let mut nums = Vec::new();
114 let mut rest = s;
115 while let Some(start) = rest.find('<') {
116 let after = &rest[start + 1..];
117 let Some(end) = after.find('>') else {
118 break;
119 };
120 if let Ok(n) = after[..end].trim().parse::<u32>() {
121 if n <= 1000 {
122 nums.push(n);
123 }
124 }
125 rest = &after[end + 1..];
126 }
127 nums
128}
129
130pub fn parse_ref_boxes(answer: &str, image_width: u32, image_height: u32) -> Vec<ParsedBox> {
132 let w = image_width as f32;
133 let h = image_height as f32;
134 let mut out = Vec::new();
135 let mut rest = answer;
136 while let Some(start) = rest.find("<box>") {
137 let after = &rest[start + 5..];
138 let end_tag = after.find("</box>").or_else(|| after.find("></box>"));
139 let Some(end) = end_tag else {
140 break;
141 };
142 let inner = &after[..end];
143 let nums = bracket_coord_nums(inner);
144 for chunk in nums.chunks(4).filter(|c| c.len() == 4) {
145 let (x1, y1, x2, y2) = (chunk[0], chunk[1], chunk[2], chunk[3]);
146 out.push(ParsedBox {
147 x1: x1 as f32 / COORD_SCALE * w,
148 y1: y1 as f32 / COORD_SCALE * h,
149 x2: x2 as f32 / COORD_SCALE * w,
150 y2: y2 as f32 / COORD_SCALE * h,
151 });
152 }
153 rest = &after[end..];
154 }
155 out
156}
157
158#[derive(Debug, Clone, Default)]
160pub struct GroundingParse {
161 pub text: String,
162 pub raw: String,
163 pub refs: Vec<String>,
164 pub boxes: Vec<ParsedBox>,
165 pub points: Vec<ParsedPoint>,
166 pub prompt_len: usize,
167 pub new_tokens: usize,
168}
169
170pub fn parse_grounding(answer: &str, image_width: u32, image_height: u32) -> GroundingParse {
172 let refs = parse_refs(answer);
173 let mut boxes: Vec<ParsedBox> = parse_boxes(answer, image_width, image_height);
174 boxes.extend(parse_ref_boxes(answer, image_width, image_height));
175 dedupe_boxes(&mut boxes);
176 let points = parse_points(answer, image_width, image_height);
177 GroundingParse {
178 text: answer.to_string(),
179 raw: String::new(),
180 refs,
181 boxes,
182 points,
183 ..Default::default()
184 }
185}
186
187fn dedupe_boxes(boxes: &mut Vec<ParsedBox>) {
188 let mut seen = HashSet::new();
189 boxes.retain(|b| {
190 let key = (
191 (b.x1 * 10.0).round() as i32,
192 (b.y1 * 10.0).round() as i32,
193 (b.x2 * 10.0).round() as i32,
194 (b.y2 * 10.0).round() as i32,
195 );
196 seen.insert(key)
197 });
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn parse_box_and_point() {
206 let answer = "<box><100><200><300><400></box> <box><500><600></box>";
207 let boxes = parse_boxes(answer, 1000, 800);
208 assert_eq!(boxes.len(), 1);
209 assert!((boxes[0].x1 - 100.0).abs() < 1e-3);
210 let points = parse_points(answer, 1000, 800);
211 assert_eq!(points.len(), 1);
212 assert!((points[0].x - 500.0).abs() < 1e-3);
213 }
214
215 #[test]
216 fn parse_ref_and_ref_boxes() {
217 let answer = "<ref>bus</ref><box><100><200><300><400></box>";
218 assert_eq!(parse_refs(answer), vec!["bus"]);
219 let g = parse_grounding(answer, 1000, 800);
220 assert_eq!(g.refs, vec!["bus"]);
221 assert_eq!(g.boxes.len(), 1);
222 }
223}