1#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct ImageRef {
11 pub source: String,
13 pub kind: ImageSourceKind,
15 pub start: usize,
17 pub end: usize,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum ImageSourceKind {
23 LocalFile,
25 DataUri,
27 RemoteUrl,
29}
30
31pub fn parse_image_markers(text: &str) -> Vec<ImageRef> {
33 let mut refs = Vec::new();
34 let marker_prefix = "[IMAGE:";
35 let mut search_from = 0;
36
37 while search_from < text.len() {
38 let start = match text[search_from..].find(marker_prefix) {
39 Some(pos) => search_from + pos,
40 None => break,
41 };
42
43 let content_start = start + marker_prefix.len();
44 let end = match text[content_start..].find(']') {
45 Some(pos) => content_start + pos + 1,
46 None => break, };
48
49 let source = text[content_start..end - 1].trim().to_string();
50 if !source.is_empty() {
51 let kind = classify_source(&source);
52 refs.push(ImageRef {
53 source,
54 kind,
55 start,
56 end,
57 });
58 }
59
60 search_from = end;
61 }
62
63 refs
64}
65
66pub fn strip_image_markers(text: &str) -> String {
68 let refs = parse_image_markers(text);
69 if refs.is_empty() {
70 return text.to_string();
71 }
72
73 let mut result = String::with_capacity(text.len());
74 let mut last_end = 0;
75
76 for r in &refs {
77 result.push_str(&text[last_end..r.start]);
78 last_end = r.end;
79 }
80 result.push_str(&text[last_end..]);
81
82 let cleaned: Vec<&str> = result.split_whitespace().collect();
84 cleaned.join(" ")
85}
86
87fn classify_source(source: &str) -> ImageSourceKind {
88 if source.starts_with("data:") {
89 ImageSourceKind::DataUri
90 } else if source.starts_with("http://") || source.starts_with("https://") {
91 ImageSourceKind::RemoteUrl
92 } else {
93 ImageSourceKind::LocalFile
94 }
95}
96
97pub fn validate_image_refs(
99 refs: &[ImageRef],
100 max_images: usize,
101 allow_remote_fetch: bool,
102) -> Result<(), String> {
103 if refs.len() > max_images {
104 return Err(format!(
105 "Too many images: {} (max {})",
106 refs.len(),
107 max_images
108 ));
109 }
110
111 for r in refs {
112 if r.kind == ImageSourceKind::RemoteUrl && !allow_remote_fetch {
113 return Err(format!("Remote image fetch is disabled: {}", r.source));
114 }
115 }
116
117 Ok(())
118}
119
120pub fn check_vision_support(
126 image_refs: &[ImageRef],
127 vision_support: Option<bool>,
128) -> Result<(), String> {
129 if image_refs.is_empty() {
130 return Ok(());
131 }
132
133 match vision_support {
134 Some(false) => Err(format!(
135 "Provider does not support vision, but message contains {} image(s). \
136 Set model_support_vision = true in config or remove images.",
137 image_refs.len()
138 )),
139 _ => Ok(()),
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn parse_local_file() {
149 let refs = parse_image_markers("Look at this [IMAGE:/tmp/screenshot.png] image");
150 assert_eq!(refs.len(), 1);
151 assert_eq!(refs[0].source, "/tmp/screenshot.png");
152 assert_eq!(refs[0].kind, ImageSourceKind::LocalFile);
153 }
154
155 #[test]
156 fn parse_data_uri() {
157 let refs = parse_image_markers("[IMAGE:data:image/png;base64,iVBOR...]");
158 assert_eq!(refs.len(), 1);
159 assert_eq!(refs[0].kind, ImageSourceKind::DataUri);
160 }
161
162 #[test]
163 fn parse_remote_url() {
164 let refs = parse_image_markers("[IMAGE:https://example.com/photo.jpg]");
165 assert_eq!(refs.len(), 1);
166 assert_eq!(refs[0].source, "https://example.com/photo.jpg");
167 assert_eq!(refs[0].kind, ImageSourceKind::RemoteUrl);
168 }
169
170 #[test]
171 fn parse_multiple_markers() {
172 let text = "Compare [IMAGE:/a.png] with [IMAGE:/b.png] and [IMAGE:https://c.jpg]";
173 let refs = parse_image_markers(text);
174 assert_eq!(refs.len(), 3);
175 assert_eq!(refs[0].kind, ImageSourceKind::LocalFile);
176 assert_eq!(refs[1].kind, ImageSourceKind::LocalFile);
177 assert_eq!(refs[2].kind, ImageSourceKind::RemoteUrl);
178 }
179
180 #[test]
181 fn parse_no_markers() {
182 let refs = parse_image_markers("Just a normal message");
183 assert!(refs.is_empty());
184 }
185
186 #[test]
187 fn parse_unclosed_marker_skipped() {
188 let refs = parse_image_markers("[IMAGE:/broken");
189 assert!(refs.is_empty());
190 }
191
192 #[test]
193 fn parse_empty_marker_skipped() {
194 let refs = parse_image_markers("[IMAGE:]");
195 assert!(refs.is_empty());
196 }
197
198 #[test]
199 fn strip_markers_removes_images() {
200 let text = "Look at [IMAGE:/tmp/a.png] this image";
201 let stripped = strip_image_markers(text);
202 assert_eq!(stripped, "Look at this image");
203 }
204
205 #[test]
206 fn strip_no_markers_unchanged() {
207 let text = "hello world";
208 assert_eq!(strip_image_markers(text), "hello world");
209 }
210
211 #[test]
212 fn validate_too_many_images() {
213 let refs = vec![
214 ImageRef {
215 source: "/a.png".into(),
216 kind: ImageSourceKind::LocalFile,
217 start: 0,
218 end: 10,
219 },
220 ImageRef {
221 source: "/b.png".into(),
222 kind: ImageSourceKind::LocalFile,
223 start: 20,
224 end: 30,
225 },
226 ];
227 assert!(validate_image_refs(&refs, 1, true).is_err());
228 assert!(validate_image_refs(&refs, 2, true).is_ok());
229 }
230
231 #[test]
232 fn validate_remote_fetch_disabled() {
233 let refs = vec![ImageRef {
234 source: "https://example.com/a.png".into(),
235 kind: ImageSourceKind::RemoteUrl,
236 start: 0,
237 end: 30,
238 }];
239 assert!(validate_image_refs(&refs, 4, false).is_err());
240 assert!(validate_image_refs(&refs, 4, true).is_ok());
241 }
242
243 #[test]
244 fn validate_local_files_always_ok() {
245 let refs = vec![ImageRef {
246 source: "/tmp/a.png".into(),
247 kind: ImageSourceKind::LocalFile,
248 start: 0,
249 end: 10,
250 }];
251 assert!(validate_image_refs(&refs, 4, false).is_ok());
252 }
253
254 fn sample_image_ref() -> ImageRef {
255 ImageRef {
256 source: "/tmp/a.png".into(),
257 kind: ImageSourceKind::LocalFile,
258 start: 0,
259 end: 10,
260 }
261 }
262
263 #[test]
264 fn vision_check_no_images_always_ok() {
265 assert!(check_vision_support(&[], Some(false)).is_ok());
266 assert!(check_vision_support(&[], Some(true)).is_ok());
267 assert!(check_vision_support(&[], None).is_ok());
268 }
269
270 #[test]
271 fn vision_check_explicit_false_rejects_images() {
272 let refs = vec![sample_image_ref()];
273 let err = check_vision_support(&refs, Some(false)).unwrap_err();
274 assert!(err.contains("does not support vision"));
275 assert!(err.contains("1 image(s)"));
276 }
277
278 #[test]
279 fn vision_check_explicit_true_allows_images() {
280 let refs = vec![sample_image_ref()];
281 assert!(check_vision_support(&refs, Some(true)).is_ok());
282 }
283
284 #[test]
285 fn vision_check_none_allows_images() {
286 let refs = vec![sample_image_ref()];
287 assert!(check_vision_support(&refs, None).is_ok());
288 }
289}