Skip to main content

atomcode_core/
vision_preprocessor.rs

1//! VL-model image preprocessor.
2//!
3//! When the active main provider does not accept images and the user submits
4//! an image, this module routes the image (plus the current-turn caption only)
5//! through a configurable vision-language provider, returning a textual
6//! description that callers splice into the user message before forwarding to
7//! the main provider as plain text.
8//!
9//! Key invariant: the VL call NEVER sees the main conversation history. The
10//! `Vec<Message>` passed to the VL provider is constructed locally from
11//! `caption + images` and contains exactly one user turn.
12
13use crate::config::Config;
14use crate::conversation::message::{ImagePart, Message, MessageContent, Role};
15use crate::provider::{create_provider, model_name_suggests_vision, LlmProvider};
16use futures::StreamExt;
17
18/// Outcome of a preprocessing attempt.
19#[derive(Debug, Clone)]
20pub enum PreprocessOutcome {
21    /// Preprocessing did not run — feature disabled, main provider already
22    /// accepts images, or no images attached. Caller must use the original
23    /// `(caption, images)` tuple unchanged.
24    Skipped,
25    /// VL call succeeded. `text` is the raw VL output (no wrapping);
26    /// `vl_key` is the provider key used (so the caller can show "by
27    /// {model}" in the splice wrapper). Caller is responsible for
28    /// splicing both into the user message — recommended shape:
29    /// `format!("{caption}\n\n[图片内容(由 {vl_key} 识别)]\n{text}")`
30    /// — and clearing the images vec.
31    Replaced { text: String, vl_key: String },
32    /// VL call failed (provider missing, network error, timeout, empty
33    /// response). `reason` is intended for `AgentEvent::Warning`. Caller
34    /// should append `"\n\n[图片识别失败]"` to the user message and clear
35    /// images so the turn proceeds with a useful placeholder.
36    Failed { reason: String },
37}
38
39/// Decide whether and how to preprocess images before a main-provider turn.
40///
41/// Short-circuit order (each → `Skipped`, except the last):
42/// 1. `images` is empty.
43/// 2. The active provider's model name passes the `model_name_suggests_vision`
44///    heuristic (it can handle the image natively).
45/// 3. `config.vision_preprocessor_provider` is `None` or `Some("")`.
46/// 4. The configured key is missing from `config.providers` → `Failed` (this
47///    is a configuration mistake worth surfacing, not a silent skip).
48pub async fn maybe_preprocess(
49    config: &Config,
50    active_provider: &dyn LlmProvider,
51    caption: &str,
52    images: &[ImagePart],
53) -> PreprocessOutcome {
54    if images.is_empty() {
55        return PreprocessOutcome::Skipped;
56    }
57    if model_name_suggests_vision(active_provider.model_name()) {
58        return PreprocessOutcome::Skipped;
59    }
60    let vl_key = match config.vision_preprocessor_provider.as_deref() {
61        Some(k) if !k.is_empty() => k,
62        _ => return PreprocessOutcome::Skipped,
63    };
64    let vl_cfg = match config.providers.get(vl_key) {
65        Some(c) => c.clone(),
66        None => {
67            return PreprocessOutcome::Failed {
68                reason: format!("VL provider '{vl_key}' not found in config.providers"),
69            };
70        }
71    };
72
73    // Build a one-off VL provider. `create_provider` handles auth-token
74    // loading (api_key=None) for the AtomGit gateway case.
75    let vl_provider = match create_provider(&vl_cfg) {
76        Ok(p) => p,
77        Err(e) => {
78            return PreprocessOutcome::Failed {
79                reason: format!("VL provider build failed: {e:#}"),
80            };
81        }
82    };
83
84    let prompt = if caption.trim().is_empty() {
85        "请详细描述这张图片的内容。如果是代码、报错截图或终端输出,请逐字转录文本。"
86            .to_string()
87    } else {
88        format!(
89            "用户的当前请求:{caption}\n\n请详细描述这张图片的内容。如果是代码、\
90             报错截图或终端输出,请逐字转录文本。",
91        )
92    };
93
94    // Local one-shot conversation — explicitly NOT linked to the main
95    // `agent.conversation.messages`. This is the structural guarantee that
96    // VL only sees the current image + caption, never history.
97    let messages = vec![Message {
98        role: Role::User,
99        content: MessageContent::MultiPart {
100            text: Some(prompt),
101            images: images.to_vec(),
102        },
103    }];
104
105    // Idle (no-progress) timeout, NOT wall-clock. A VL call can take any
106    // total duration as long as the stream keeps producing chunks — we only
107    // abort when no event has arrived for `IDLE_TIMEOUT`. The previous 30s
108    // wall-clock killed perfectly healthy slow gateways: a Qwen3-VL cold
109    // start can spend 10-15s on TTFT, then another 10-20s OCR-ing a dense
110    // screenshot, easily clearing 30s end-to-end while streaming the whole
111    // way through. Idle-timeout still catches genuinely stuck sockets
112    // (gateway accepted the request, holds the connection, never produces
113    // tokens) — that's the failure mode worth aborting on.
114    const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
115
116    let mut stream = match vl_provider.chat_stream(&messages, None) {
117        Ok(s) => s,
118        Err(e) => {
119            return PreprocessOutcome::Failed {
120                reason: format!("provider '{vl_key}' stream init failed: {e:#}"),
121            };
122        }
123    };
124
125    let mut buf = String::new();
126    loop {
127        let next = match tokio::time::timeout(IDLE_TIMEOUT, stream.next()).await {
128            Ok(n) => n,
129            Err(_) => {
130                return PreprocessOutcome::Failed {
131                    reason: format!(
132                        "provider '{vl_key}' no progress for {}s",
133                        IDLE_TIMEOUT.as_secs(),
134                    ),
135                };
136            }
137        };
138        let event = match next {
139            None => break,
140            Some(Ok(ev)) => ev,
141            Some(Err(e)) => {
142                return PreprocessOutcome::Failed {
143                    reason: format!("provider '{vl_key}' call error: {e:#}"),
144                };
145            }
146        };
147        match event {
148            crate::stream::StreamEvent::Delta(s) => buf.push_str(&s),
149            crate::stream::StreamEvent::Reasoning(_) => {}
150            crate::stream::StreamEvent::Done { .. } => break,
151            crate::stream::StreamEvent::Error(e) => {
152                return PreprocessOutcome::Failed {
153                    reason: format!("provider '{vl_key}' call error: {e}"),
154                };
155            }
156            // VL is a one-shot OCR call — Warnings (e.g., proxy truncation
157            // heuristics) and Usage stats are not actionable for the user
158            // here; tool-call variants don't apply because we pass `None`
159            // for tools. Drop them.
160            crate::stream::StreamEvent::Warning(_)
161            | crate::stream::StreamEvent::Usage(_)
162            | crate::stream::StreamEvent::ThinkingBlock { .. }
163            | crate::stream::StreamEvent::ToolCallStart { .. }
164            | crate::stream::StreamEvent::ToolCallDelta(_)
165            | crate::stream::StreamEvent::ToolCallDone(_) => {}
166        }
167    }
168
169    let trimmed = buf.trim();
170    if trimmed.is_empty() {
171        PreprocessOutcome::Failed {
172            reason: format!("provider '{vl_key}' returned empty response"),
173        }
174    } else {
175        PreprocessOutcome::Replaced {
176            text: trimmed.to_string(),
177            vl_key: vl_key.to_string(),
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::config::provider::ProviderConfig;
186    use std::collections::HashMap;
187
188    fn blank_config() -> Config {
189        // Mirrors `coding_plan::setup::tests::blank_config` but kept local
190        // so this test module does not reach into another module's private test
191        // helpers. If new mandatory fields are added to Config, update both.
192        Config {
193            default_provider: String::new(),
194            default_workdir: None,
195            providers: HashMap::new(),
196            datalog: Default::default(),
197            auto_update: true,
198            notifications: Default::default(),
199            telemetry: Default::default(),
200            lsp: Default::default(),
201            auto_commit: false,
202            subagent: Default::default(),
203            vision_preprocessor_provider: None,
204            language: None,
205            ui: Default::default(),
206            plugin: Default::default(),
207        }
208    }
209
210    fn sample_image() -> ImagePart {
211        ImagePart {
212            media_type: "image/png".into(),
213            data: "iVBORw0KGgoAAAANSUhEUg==".into(),
214        }
215    }
216
217    /// Stub `LlmProvider` that only carries a model name — chat_stream is
218    /// never called in short-circuit tests, but the trait requires the impl.
219    struct StubProvider {
220        model: &'static str,
221    }
222    use crate::stream::StreamEvent;
223    use crate::tool::ToolDef;
224    use anyhow::Result;
225    use async_trait::async_trait;
226    use futures::Stream;
227    use std::pin::Pin;
228    #[async_trait]
229    impl LlmProvider for StubProvider {
230        fn chat_stream(
231            &self,
232            _messages: &[crate::conversation::message::Message],
233            _tools: Option<&[ToolDef]>,
234        ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
235            anyhow::bail!("stub never streams");
236        }
237        fn model_name(&self) -> &str {
238            self.model
239        }
240    }
241
242    #[tokio::test]
243    async fn skipped_when_no_images() {
244        let cfg = blank_config();
245        let provider = StubProvider { model: "deepseek-v4-flash" };
246        let result = maybe_preprocess(&cfg, &provider, "any caption", &[]).await;
247        assert!(matches!(result, PreprocessOutcome::Skipped));
248    }
249
250    #[tokio::test]
251    async fn skipped_when_main_provider_accepts_images() {
252        let cfg = blank_config();
253        let provider = StubProvider { model: "claude-sonnet-4-5" };
254        let result =
255            maybe_preprocess(&cfg, &provider, "describe", &[sample_image()]).await;
256        assert!(matches!(result, PreprocessOutcome::Skipped));
257    }
258
259    #[tokio::test]
260    async fn skipped_when_config_field_unset() {
261        let cfg = blank_config();
262        let provider = StubProvider { model: "deepseek-v4-flash" };
263        let result =
264            maybe_preprocess(&cfg, &provider, "describe", &[sample_image()]).await;
265        assert!(matches!(result, PreprocessOutcome::Skipped));
266    }
267
268    #[tokio::test]
269    async fn skipped_when_config_field_empty_string() {
270        let mut cfg = blank_config();
271        cfg.vision_preprocessor_provider = Some(String::new());
272        let provider = StubProvider { model: "deepseek-v4-flash" };
273        let result =
274            maybe_preprocess(&cfg, &provider, "describe", &[sample_image()]).await;
275        assert!(matches!(result, PreprocessOutcome::Skipped));
276    }
277
278    #[tokio::test]
279    async fn failed_when_configured_key_missing_from_providers() {
280        let mut cfg = blank_config();
281        cfg.vision_preprocessor_provider = Some("AtomGit-NoSuchModel".into());
282        let provider = StubProvider { model: "deepseek-v4-flash" };
283        let result =
284            maybe_preprocess(&cfg, &provider, "describe", &[sample_image()]).await;
285        match result {
286            PreprocessOutcome::Failed { reason } => {
287                assert!(
288                    reason.contains("AtomGit-NoSuchModel") && reason.contains("not found"),
289                    "expected 'not found' for missing key, got: {reason}",
290                );
291            }
292            other => panic!("expected Failed, got {other:?}"),
293        }
294    }
295
296    use wiremock::matchers::{method, path};
297    use wiremock::{Mock, MockServer, ResponseTemplate};
298
299    /// Minimal SSE chunk fixture for an OpenAI-compatible /chat/completions
300    /// endpoint that returns one `delta.content` token then a stop chunk
301    /// then `[DONE]`. Mirrors the wire shape `OpenAiProvider` consumes.
302    fn sse_one_token(text: &str) -> String {
303        let chunk = serde_json::json!({
304            "choices": [{
305                "delta": { "content": text },
306                "finish_reason": null,
307            }],
308        });
309        let done = serde_json::json!({
310            "choices": [{
311                "delta": {},
312                "finish_reason": "stop",
313            }],
314        });
315        format!("data: {}\n\ndata: {}\n\ndata: [DONE]\n\n", chunk, done)
316    }
317
318    fn vl_provider_cfg(base_url: &str) -> ProviderConfig {
319        ProviderConfig {
320            provider_type: "openai".into(),
321            api_key: Some("sk-test".into()),
322            model: "Qwen/Qwen3-VL-32B-Instruct".into(),
323            base_url: Some(base_url.to_string()),
324            system_prompt: None,
325            user_agent: None,
326            context_window: 8000,
327            max_tokens: None,
328            thinking_type: None,
329            thinking_keep: None,
330            reasoning_history: None,
331            thinking_enabled: None,
332            thinking_budget: None,
333            skip_tls_verify: false,
334            ephemeral: false,
335        }
336    }
337
338    #[tokio::test]
339    async fn replaced_when_vl_returns_text() {
340        let server = MockServer::start().await;
341        Mock::given(method("POST"))
342            .and(path("/chat/completions"))
343            .respond_with(
344                ResponseTemplate::new(200)
345                    .insert_header("content-type", "text/event-stream")
346                    .set_body_string(sse_one_token(
347                        "Python stack trace showing ZeroDivisionError on line 42",
348                    )),
349            )
350            .expect(1)
351            .mount(&server)
352            .await;
353
354        let mut cfg = blank_config();
355        cfg.providers.insert(
356            "vl".into(),
357            vl_provider_cfg(&server.uri()),
358        );
359        cfg.vision_preprocessor_provider = Some("vl".into());
360
361        let provider = StubProvider { model: "deepseek-v4-flash" };
362        let result =
363            maybe_preprocess(&cfg, &provider, "explain this", &[sample_image()]).await;
364
365        match result {
366            PreprocessOutcome::Replaced { text, vl_key } => {
367                assert_eq!(
368                    text,
369                    "Python stack trace showing ZeroDivisionError on line 42"
370                );
371                assert_eq!(vl_key, "vl", "Replaced must carry the configured key");
372            }
373            other => panic!("expected Replaced, got {other:?}"),
374        }
375    }
376
377    #[tokio::test]
378    async fn failed_when_vl_returns_500() {
379        let server = MockServer::start().await;
380        Mock::given(method("POST"))
381            .and(path("/chat/completions"))
382            .respond_with(ResponseTemplate::new(500).set_body_string("upstream error"))
383            // Existing OpenAI provider may retry per its retry::RetryPolicy.
384            // Don't pin .expect(N); just assert the eventual outcome.
385            .mount(&server)
386            .await;
387
388        let mut cfg = blank_config();
389        cfg.providers.insert(
390            "vl".into(),
391            vl_provider_cfg(&format!("{}/", server.uri())),
392        );
393        cfg.vision_preprocessor_provider = Some("vl".into());
394
395        let provider = StubProvider { model: "deepseek-v4-flash" };
396        let result =
397            maybe_preprocess(&cfg, &provider, "x", &[sample_image()]).await;
398
399        match result {
400            PreprocessOutcome::Failed { reason } => {
401                assert!(
402                    reason.contains("VL call error") || reason.contains("500"),
403                    "expected error reason mentioning failure, got: {reason}",
404                );
405            }
406            other => panic!("expected Failed, got {other:?}"),
407        }
408    }
409
410    #[tokio::test]
411    async fn failed_when_vl_returns_empty_string() {
412        let server = MockServer::start().await;
413        Mock::given(method("POST"))
414            .and(path("/chat/completions"))
415            .respond_with(
416                ResponseTemplate::new(200)
417                    .insert_header("content-type", "text/event-stream")
418                    .set_body_string(sse_one_token("")), // empty token then [DONE]
419            )
420            .mount(&server)
421            .await;
422
423        let mut cfg = blank_config();
424        cfg.providers.insert(
425            "vl".into(),
426            vl_provider_cfg(&format!("{}/", server.uri())),
427        );
428        cfg.vision_preprocessor_provider = Some("vl".into());
429
430        let provider = StubProvider { model: "deepseek-v4-flash" };
431        let result =
432            maybe_preprocess(&cfg, &provider, "x", &[sample_image()]).await;
433
434        match result {
435            PreprocessOutcome::Failed { reason } => {
436                assert!(
437                    reason.contains("empty"),
438                    "expected 'empty' in reason, got: {reason}",
439                );
440            }
441            other => panic!("expected Failed for empty response, got {other:?}"),
442        }
443    }
444
445    /// Custom matcher for request body containing a substring.
446    use wiremock::Match;
447    struct BodyContains(String);
448    impl Match for BodyContains {
449        fn matches(&self, req: &wiremock::Request) -> bool {
450            String::from_utf8_lossy(&req.body).contains(&self.0)
451        }
452    }
453
454    /// Inverse of `BodyContains` — matches when the request body does NOT
455    /// include the substring. Pairs with `BodyContains` to assert that one
456    /// prompt template was selected and the other was not.
457    struct BodyNotContains(String);
458    impl Match for BodyNotContains {
459        fn matches(&self, request: &wiremock::Request) -> bool {
460            !String::from_utf8_lossy(&request.body).contains(&self.0)
461        }
462    }
463
464    #[tokio::test]
465    async fn caption_is_included_in_vl_prompt() {
466        let server = MockServer::start().await;
467        Mock::given(method("POST"))
468            .and(path("/chat/completions"))
469            .and(BodyContains("用户的当前请求:解释这段代码".into()))
470            .respond_with(
471                ResponseTemplate::new(200)
472                    .insert_header("content-type", "text/event-stream")
473                    .set_body_string(sse_one_token("ok")),
474            )
475            .expect(1)
476            .mount(&server)
477            .await;
478
479        let mut cfg = blank_config();
480        cfg.providers.insert(
481            "vl".into(),
482            vl_provider_cfg(&format!("{}/", server.uri())),
483        );
484        cfg.vision_preprocessor_provider = Some("vl".into());
485
486        let provider = StubProvider { model: "deepseek-v4-flash" };
487        let result = maybe_preprocess(
488            &cfg,
489            &provider,
490            "解释这段代码",
491            &[sample_image()],
492        )
493        .await;
494
495        // Replaced confirms the body matched the caption pattern (otherwise
496        // wiremock would reject the request and the call would fail).
497        assert!(matches!(result, PreprocessOutcome::Replaced { .. }));
498    }
499
500    #[tokio::test]
501    async fn empty_caption_uses_pure_describe_prompt() {
502        let server = MockServer::start().await;
503        // Pure describe prompt — must NOT contain the "用户的当前请求:" prefix.
504        Mock::given(method("POST"))
505            .and(path("/chat/completions"))
506            .and(BodyContains("请详细描述这张图片的内容".into()))
507            .and(BodyNotContains("用户的当前请求:".into()))
508            .respond_with(
509                ResponseTemplate::new(200)
510                    .insert_header("content-type", "text/event-stream")
511                    .set_body_string(sse_one_token("ok")),
512            )
513            .expect(1)
514            .mount(&server)
515            .await;
516
517        let mut cfg = blank_config();
518        cfg.providers.insert(
519            "vl".into(),
520            vl_provider_cfg(&format!("{}/", server.uri())),
521        );
522        cfg.vision_preprocessor_provider = Some("vl".into());
523
524        let provider = StubProvider { model: "deepseek-v4-flash" };
525        let result = maybe_preprocess(&cfg, &provider, "  ", &[sample_image()]).await;
526
527        assert!(matches!(result, PreprocessOutcome::Replaced { .. }));
528    }
529}