Skip to main content

bamboo_server/
message_hooks.rs

1//! Message preflight hooks.
2//!
3//! These hooks run before we forward requests upstream (proxy endpoints) and before we
4//! enter the agent loop. They operate on internal `bamboo_agent_core::Message` so the
5//! same behavior applies across OpenAI-compatible, Anthropic, Gemini, and agent routes.
6
7use crate::app_state::AppState;
8use bamboo_agent_core::Message;
9use bamboo_infrastructure::Config;
10
11#[derive(Debug, thiserror::Error)]
12pub enum HookError {
13    #[error("Invalid hook configuration: {0}")]
14    InvalidConfig(String),
15    #[error("Request not supported: {0}")]
16    Unsupported(String),
17}
18
19/// Apply all configured preflight hooks.
20pub async fn apply_message_preflight_hooks(
21    state: Option<&AppState>,
22    config: &Config,
23    _model: &str,
24    messages: &mut [Message],
25) -> Result<(), HookError> {
26    apply_image_fallback_hook(state, config, messages).await
27}
28
29async fn apply_image_fallback_hook(
30    state: Option<&AppState>,
31    config: &Config,
32    messages: &mut [Message],
33) -> Result<(), HookError> {
34    let hook_cfg = &config.hooks.image_fallback;
35    if !hook_cfg.enabled {
36        return Ok(());
37    }
38
39    let mode = hook_cfg.mode.trim().to_ascii_lowercase();
40    let fallback_mode = match mode.as_str() {
41        "placeholder" => bamboo_engine::ImageFallbackMode::Placeholder,
42        "error" => bamboo_engine::ImageFallbackMode::Error,
43        "ocr" => bamboo_engine::ImageFallbackMode::Ocr,
44        _ => {
45            return Err(HookError::InvalidConfig(format!(
46                "hooks.image_fallback.mode must be 'placeholder', 'error', or 'ocr' (got '{mode}')"
47            )));
48        }
49    };
50
51    let fallback = bamboo_engine::ImageFallbackConfig {
52        mode: fallback_mode,
53        vision_model: None,
54    };
55
56    let attachment_reader: Option<&dyn bamboo_agent_core::storage::AttachmentReader> = state
57        .map(|s| s.session_store.as_ref() as &dyn bamboo_agent_core::storage::AttachmentReader);
58
59    bamboo_engine::runtime::runner::image_fallback::apply_image_fallback_to_llm_messages(
60        messages,
61        fallback,
62        attachment_reader,
63        None,
64    )
65    .await
66    .map_err(|e| HookError::Unsupported(e.to_string()))
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use bamboo_infrastructure::models::{ContentPart, ImageUrl};
73    use tempfile::TempDir;
74
75    fn base_config(mode: &str) -> Config {
76        let dir = TempDir::new().expect("tempdir");
77        let mut cfg = Config::from_data_dir(Some(dir.path().to_path_buf()));
78        cfg.hooks.image_fallback.enabled = true;
79        cfg.hooks.image_fallback.mode = mode.to_string();
80        cfg
81    }
82
83    #[tokio::test]
84    async fn image_fallback_placeholder_rewrites_images_to_text_without_leaking_data() {
85        let cfg = base_config("placeholder");
86
87        let mut messages = vec![Message::user_with_parts(
88            "What is in this image?",
89            vec![
90                ContentPart::Text {
91                    text: "What is in this image?".to_string(),
92                },
93                ContentPart::ImageUrl {
94                    image_url: ImageUrl {
95                        url: "data:image/png;base64,AAAABBBBCCCC".to_string(),
96                        detail: None,
97                    },
98                },
99            ]
100            .into_iter()
101            .map(Into::into)
102            .collect(),
103        )];
104
105        apply_message_preflight_hooks(None, &cfg, "m", &mut messages)
106            .await
107            .expect("hook ok");
108
109        assert!(messages[0].content.contains("Image omitted: image/png"));
110        assert!(!messages[0].content.contains("AAAABBBBCCCC"));
111        assert!(messages[0].content_parts.is_none());
112    }
113
114    #[tokio::test]
115    async fn image_fallback_error_rejects_requests_with_images() {
116        let cfg = base_config("error");
117
118        let mut messages = vec![Message::user_with_parts(
119            "",
120            vec![ContentPart::ImageUrl {
121                image_url: ImageUrl {
122                    url: "https://example.com/image.png".to_string(),
123                    detail: None,
124                },
125            }]
126            .into_iter()
127            .map(Into::into)
128            .collect(),
129        )];
130
131        let err = apply_message_preflight_hooks(None, &cfg, "m", &mut messages)
132            .await
133            .expect_err("should err");
134        assert!(err
135            .to_string()
136            .contains("does not currently support image inputs"));
137    }
138
139    #[tokio::test]
140    async fn image_fallback_invalid_mode_errors() {
141        let cfg = base_config("wat");
142        let mut messages = Vec::new();
143        let err = apply_message_preflight_hooks(None, &cfg, "m", &mut messages)
144            .await
145            .expect_err("should err");
146        assert!(matches!(err, HookError::InvalidConfig(_)));
147    }
148
149    #[cfg(not(windows))]
150    #[tokio::test]
151    async fn image_fallback_ocr_non_windows_leaves_images_intact() {
152        let cfg = base_config("ocr");
153
154        let mut messages = vec![Message::user_with_parts(
155            "hi",
156            vec![
157                ContentPart::Text {
158                    text: "hi".to_string(),
159                },
160                ContentPart::ImageUrl {
161                    image_url: ImageUrl {
162                        url: "data:image/png;base64,AAAABBBBCCCC".to_string(),
163                        detail: None,
164                    },
165                },
166            ]
167            .into_iter()
168            .map(Into::into)
169            .collect(),
170        )];
171
172        apply_message_preflight_hooks(None, &cfg, "m", &mut messages)
173            .await
174            .expect("hook ok");
175
176        assert!(messages[0].content_parts.is_some());
177        assert!(messages[0].content.contains("hi"));
178    }
179}