Skip to main content

llmsdk_provider/middleware/builtin/
add_tool_input_examples.rs

1//! Append `input_examples` to each tool's description so non-tool-using
2//! models can still see the examples.
3//!
4//! Mirrors `@ai-sdk/ai/src/middleware/add-tool-input-examples-middleware.ts`.
5// Rust guideline compliant 2026-02-21
6
7use async_trait::async_trait;
8
9use crate::error::Result;
10use crate::language_model::{CallOptions, FunctionTool, LanguageModel, Tool};
11use crate::middleware::language_model::{CallKind, LanguageModelMiddleware};
12
13/// Middleware that serializes `tool.input_examples` (if any) and appends them
14/// to the tool's `description` field.
15///
16/// Default layout mirrors `@ai-sdk/ai/src/middleware/add-tool-input-examples-middleware.ts`:
17/// `"{description}\n\n{prefix}\n{example_1}\n{example_2}..."` where `prefix`
18/// defaults to `"Input Examples:"` and each example is `JSON.stringify(example.input)`
19/// (no enumeration prefix). Override with [`Self::with_prefix`] to customise
20/// the header line or [`Self::with_formatter`] to take full control.
21pub struct AddToolInputExamplesMiddleware {
22    prefix: String,
23    formatter: ExampleFormatter,
24    remove: bool,
25}
26
27/// Boxed formatter invoked once per [`crate::language_model::ToolInputExample`],
28/// receiving the example and its zero-based index. Mirrors upstream
29/// `(example, index) => string` signature
30/// (`@ai-sdk/ai/src/middleware/add-tool-input-examples-middleware.ts:46`).
31type ExampleFormatter =
32    Box<dyn Fn(&crate::language_model::ToolInputExample, usize) -> String + Send + Sync>;
33
34impl std::fmt::Debug for AddToolInputExamplesMiddleware {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        // `formatter` is a boxed closure with no useful Debug representation;
37        // mark non-exhaustive instead of dumping a function pointer address.
38        f.debug_struct("AddToolInputExamplesMiddleware")
39            .field("prefix", &self.prefix)
40            .field("remove", &self.remove)
41            .finish_non_exhaustive()
42    }
43}
44
45impl Default for AddToolInputExamplesMiddleware {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl AddToolInputExamplesMiddleware {
52    /// Build with the upstream-aligned defaults.
53    ///
54    /// The default `prefix` is `"Input Examples:"` and the default `remove` is
55    /// `true`, matching upstream `add-tool-input-examples-middleware.ts` so the
56    /// rewritten tool no longer carries the now-redundant `input_examples`
57    /// field on the wire.
58    #[must_use]
59    pub fn new() -> Self {
60        Self {
61            prefix: "Input Examples:".to_owned(),
62            formatter: Box::new(default_formatter),
63            remove: true,
64        }
65    }
66
67    /// Override the header line prepended before the serialized examples.
68    /// Mirrors upstream `prefix` option (default `"Input Examples:"`).
69    #[must_use]
70    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
71        self.prefix = prefix.into();
72        self
73    }
74
75    /// Override how each example is rendered. The formatter receives the
76    /// example and its zero-based index, mirroring upstream
77    /// `(example, index) => string`
78    /// (`@ai-sdk/ai/src/middleware/add-tool-input-examples-middleware.ts:46`).
79    #[must_use]
80    pub fn with_formatter<F>(mut self, formatter: F) -> Self
81    where
82        F: Fn(&crate::language_model::ToolInputExample, usize) -> String + Send + Sync + 'static,
83    {
84        self.formatter = Box::new(formatter);
85        self
86    }
87
88    /// Toggle whether `input_examples` is cleared after being appended.
89    ///
90    /// Mirrors upstream `remove?: boolean` option (default `true`). When
91    /// `true`, the rewritten function tool drops its `input_examples` so the
92    /// downstream provider does not re-serialize them on the wire after they
93    /// have already been folded into `description`. Set to `false` to keep
94    /// the structured field alongside the textual description.
95    #[must_use]
96    pub fn with_remove(mut self, remove: bool) -> Self {
97        self.remove = remove;
98        self
99    }
100}
101
102fn default_formatter(example: &crate::language_model::ToolInputExample, _index: usize) -> String {
103    // Mirrors upstream `defaultFormatExample = (example) => JSON.stringify(example.input)`
104    // (`add-tool-input-examples-middleware.ts:1-3`). Index is unused for the
105    // default but exposed so custom formatters can prepend ordinals.
106    serde_json::to_string(&example.input).unwrap_or_else(|_| "<unserializable>".to_owned())
107}
108
109#[async_trait]
110impl LanguageModelMiddleware for AddToolInputExamplesMiddleware {
111    async fn transform_params(
112        &self,
113        _kind: CallKind,
114        mut params: CallOptions,
115        _inner: &dyn LanguageModel,
116    ) -> Result<CallOptions> {
117        let Some(tools) = params.tools.as_mut() else {
118            return Ok(params);
119        };
120        for tool in tools.iter_mut() {
121            if let Tool::Function(FunctionTool {
122                description,
123                input_examples,
124                ..
125            }) = tool
126            {
127                let Some(examples) = input_examples.as_ref() else {
128                    continue;
129                };
130                if examples.is_empty() {
131                    continue;
132                }
133                // Mirrors upstream `add-tool-input-examples-middleware.ts:67-72`:
134                //   formattedExamples = examples.map((ex, i) => format(ex, i)).join('\n')
135                //   examplesSection   = `${prefix}\n${formattedExamples}`
136                //   description       = description ? `${description}\n\n${examplesSection}` : examplesSection
137                let formatted = examples
138                    .iter()
139                    .enumerate()
140                    .map(|(i, ex)| (self.formatter)(ex, i))
141                    .collect::<Vec<_>>()
142                    .join("\n");
143                let examples_section = format!("{}\n{formatted}", self.prefix);
144                *description = Some(match description.take() {
145                    Some(existing) if !existing.is_empty() => {
146                        format!("{existing}\n\n{examples_section}")
147                    }
148                    _ => examples_section,
149                });
150                // Mirrors upstream `add-tool-input-examples-middleware.ts:80`:
151                //   `inputExamples: remove ? undefined : tool.inputExamples`.
152                // Default `remove = true` strips the structured field so the
153                // downstream provider does not re-serialize examples that are
154                // already embedded in the textual description.
155                if self.remove {
156                    *input_examples = None;
157                }
158            }
159        }
160        Ok(params)
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use std::sync::Arc;
167
168    use super::*;
169    use crate::language_model::{GenerateResult, Prompt, StreamResult, ToolInputExample};
170    use crate::middleware::wrap_language_model;
171    use async_trait::async_trait;
172
173    #[derive(Debug, Default)]
174    struct LastParams(std::sync::Mutex<Option<CallOptions>>);
175
176    #[derive(Debug)]
177    struct Recorder(Arc<LastParams>);
178
179    #[async_trait]
180    impl LanguageModel for Recorder {
181        fn provider(&self) -> &'static str {
182            "rec"
183        }
184        fn model_id(&self) -> &'static str {
185            "rec"
186        }
187        async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult> {
188            *self.0.0.lock().expect("mutex") = Some(options);
189            Ok(GenerateResult {
190                content: vec![],
191                finish_reason: crate::language_model::FinishReason::new(
192                    crate::language_model::FinishReasonKind::Stop,
193                ),
194                usage: crate::language_model::Usage::default(),
195                provider_metadata: None,
196                request: None,
197                response: None,
198                warnings: vec![],
199            })
200        }
201        async fn do_stream(&self, _options: CallOptions) -> Result<StreamResult> {
202            unimplemented!()
203        }
204    }
205
206    #[tokio::test]
207    async fn appends_examples_to_description() {
208        let last = Arc::new(LastParams::default());
209        let inner: Arc<dyn LanguageModel> = Arc::new(Recorder(Arc::clone(&last)));
210        let wrapped = wrap_language_model(
211            inner,
212            [Arc::new(AddToolInputExamplesMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
213        );
214
215        wrapped
216            .do_generate(CallOptions {
217                prompt: Prompt::default(),
218                tools: Some(vec![Tool::Function(FunctionTool {
219                    name: "get_weather".into(),
220                    description: Some("Get weather".into()),
221                    input_schema: serde_json::from_value(serde_json::json!({"type": "object"}))
222                        .unwrap(),
223                    input_examples: Some(vec![ToolInputExample {
224                        input: serde_json::json!({"city": "Tokyo"})
225                            .as_object()
226                            .cloned()
227                            .unwrap(),
228                    }]),
229                    strict: None,
230                    provider_options: None,
231                })]),
232                ..Default::default()
233            })
234            .await
235            .expect("generate");
236
237        let captured = last.0.lock().expect("mutex").clone().expect("params");
238        let tools = captured.tools.unwrap();
239        let Tool::Function(f) = &tools[0] else {
240            panic!("expected function tool");
241        };
242        let desc = f.description.as_ref().unwrap();
243        assert!(desc.contains("Get weather"), "preserves original desc");
244        assert!(desc.contains("Examples:"), "appends examples header");
245        assert!(desc.contains("Tokyo"), "renders example body");
246        // Mirrors upstream default `remove = true` — the structured field is
247        // dropped after being folded into `description`.
248        assert!(
249            f.input_examples.is_none(),
250            "default remove=true strips input_examples",
251        );
252    }
253
254    #[tokio::test]
255    async fn with_remove_false_keeps_input_examples() {
256        let last = Arc::new(LastParams::default());
257        let inner: Arc<dyn LanguageModel> = Arc::new(Recorder(Arc::clone(&last)));
258        let wrapped = wrap_language_model(
259            inner,
260            [
261                Arc::new(AddToolInputExamplesMiddleware::new().with_remove(false))
262                    as Arc<dyn LanguageModelMiddleware>,
263            ],
264        );
265
266        wrapped
267            .do_generate(CallOptions {
268                prompt: Prompt::default(),
269                tools: Some(vec![Tool::Function(FunctionTool {
270                    name: "get_weather".into(),
271                    description: Some("Get weather".into()),
272                    input_schema: serde_json::from_value(serde_json::json!({"type": "object"}))
273                        .unwrap(),
274                    input_examples: Some(vec![ToolInputExample {
275                        input: serde_json::json!({"city": "Paris"})
276                            .as_object()
277                            .cloned()
278                            .unwrap(),
279                    }]),
280                    strict: None,
281                    provider_options: None,
282                })]),
283                ..Default::default()
284            })
285            .await
286            .expect("generate");
287
288        let captured = last.0.lock().expect("mutex").clone().expect("params");
289        let tools = captured.tools.unwrap();
290        let Tool::Function(f) = &tools[0] else {
291            panic!("expected function tool");
292        };
293        assert!(
294            f.input_examples.as_ref().is_some_and(|v| v.len() == 1),
295            "with_remove(false) preserves input_examples",
296        );
297    }
298}