llmsdk_provider/middleware/builtin/
add_tool_input_examples.rs1use async_trait::async_trait;
8
9use crate::error::Result;
10use crate::language_model::{CallOptions, FunctionTool, LanguageModel, Tool};
11use crate::middleware::language_model::{CallKind, LanguageModelMiddleware};
12
13pub struct AddToolInputExamplesMiddleware {
22 prefix: String,
23 formatter: ExampleFormatter,
24 remove: bool,
25}
26
27type ExampleFormatter =
32 Box<dyn Fn(&crate::language_model::ToolInputExample, usize) -> String + Send + Sync>;
33
34impl std::fmt::Debug for AddToolInputExamplesMiddleware {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("AddToolInputExamplesMiddleware")
39 .field("prefix", &self.prefix)
40 .field("remove", &self.remove)
41 .finish_non_exhaustive()
42 }
43}
44
45impl Default for AddToolInputExamplesMiddleware {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl AddToolInputExamplesMiddleware {
52 #[must_use]
59 pub fn new() -> Self {
60 Self {
61 prefix: "Input Examples:".to_owned(),
62 formatter: Box::new(default_formatter),
63 remove: true,
64 }
65 }
66
67 #[must_use]
70 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
71 self.prefix = prefix.into();
72 self
73 }
74
75 #[must_use]
80 pub fn with_formatter<F>(mut self, formatter: F) -> Self
81 where
82 F: Fn(&crate::language_model::ToolInputExample, usize) -> String + Send + Sync + 'static,
83 {
84 self.formatter = Box::new(formatter);
85 self
86 }
87
88 #[must_use]
96 pub fn with_remove(mut self, remove: bool) -> Self {
97 self.remove = remove;
98 self
99 }
100}
101
102fn default_formatter(example: &crate::language_model::ToolInputExample, _index: usize) -> String {
103 serde_json::to_string(&example.input).unwrap_or_else(|_| "<unserializable>".to_owned())
107}
108
109#[async_trait]
110impl LanguageModelMiddleware for AddToolInputExamplesMiddleware {
111 async fn transform_params(
112 &self,
113 _kind: CallKind,
114 mut params: CallOptions,
115 _inner: &dyn LanguageModel,
116 ) -> Result<CallOptions> {
117 let Some(tools) = params.tools.as_mut() else {
118 return Ok(params);
119 };
120 for tool in tools.iter_mut() {
121 if let Tool::Function(FunctionTool {
122 description,
123 input_examples,
124 ..
125 }) = tool
126 {
127 let Some(examples) = input_examples.as_ref() else {
128 continue;
129 };
130 if examples.is_empty() {
131 continue;
132 }
133 let formatted = examples
138 .iter()
139 .enumerate()
140 .map(|(i, ex)| (self.formatter)(ex, i))
141 .collect::<Vec<_>>()
142 .join("\n");
143 let examples_section = format!("{}\n{formatted}", self.prefix);
144 *description = Some(match description.take() {
145 Some(existing) if !existing.is_empty() => {
146 format!("{existing}\n\n{examples_section}")
147 }
148 _ => examples_section,
149 });
150 if self.remove {
156 *input_examples = None;
157 }
158 }
159 }
160 Ok(params)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use std::sync::Arc;
167
168 use super::*;
169 use crate::language_model::{GenerateResult, Prompt, StreamResult, ToolInputExample};
170 use crate::middleware::wrap_language_model;
171 use async_trait::async_trait;
172
173 #[derive(Debug, Default)]
174 struct LastParams(std::sync::Mutex<Option<CallOptions>>);
175
176 #[derive(Debug)]
177 struct Recorder(Arc<LastParams>);
178
179 #[async_trait]
180 impl LanguageModel for Recorder {
181 fn provider(&self) -> &'static str {
182 "rec"
183 }
184 fn model_id(&self) -> &'static str {
185 "rec"
186 }
187 async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult> {
188 *self.0.0.lock().expect("mutex") = Some(options);
189 Ok(GenerateResult {
190 content: vec![],
191 finish_reason: crate::language_model::FinishReason::new(
192 crate::language_model::FinishReasonKind::Stop,
193 ),
194 usage: crate::language_model::Usage::default(),
195 provider_metadata: None,
196 request: None,
197 response: None,
198 warnings: vec![],
199 })
200 }
201 async fn do_stream(&self, _options: CallOptions) -> Result<StreamResult> {
202 unimplemented!()
203 }
204 }
205
206 #[tokio::test]
207 async fn appends_examples_to_description() {
208 let last = Arc::new(LastParams::default());
209 let inner: Arc<dyn LanguageModel> = Arc::new(Recorder(Arc::clone(&last)));
210 let wrapped = wrap_language_model(
211 inner,
212 [Arc::new(AddToolInputExamplesMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
213 );
214
215 wrapped
216 .do_generate(CallOptions {
217 prompt: Prompt::default(),
218 tools: Some(vec![Tool::Function(FunctionTool {
219 name: "get_weather".into(),
220 description: Some("Get weather".into()),
221 input_schema: serde_json::from_value(serde_json::json!({"type": "object"}))
222 .unwrap(),
223 input_examples: Some(vec![ToolInputExample {
224 input: serde_json::json!({"city": "Tokyo"})
225 .as_object()
226 .cloned()
227 .unwrap(),
228 }]),
229 strict: None,
230 provider_options: None,
231 })]),
232 ..Default::default()
233 })
234 .await
235 .expect("generate");
236
237 let captured = last.0.lock().expect("mutex").clone().expect("params");
238 let tools = captured.tools.unwrap();
239 let Tool::Function(f) = &tools[0] else {
240 panic!("expected function tool");
241 };
242 let desc = f.description.as_ref().unwrap();
243 assert!(desc.contains("Get weather"), "preserves original desc");
244 assert!(desc.contains("Examples:"), "appends examples header");
245 assert!(desc.contains("Tokyo"), "renders example body");
246 assert!(
249 f.input_examples.is_none(),
250 "default remove=true strips input_examples",
251 );
252 }
253
254 #[tokio::test]
255 async fn with_remove_false_keeps_input_examples() {
256 let last = Arc::new(LastParams::default());
257 let inner: Arc<dyn LanguageModel> = Arc::new(Recorder(Arc::clone(&last)));
258 let wrapped = wrap_language_model(
259 inner,
260 [
261 Arc::new(AddToolInputExamplesMiddleware::new().with_remove(false))
262 as Arc<dyn LanguageModelMiddleware>,
263 ],
264 );
265
266 wrapped
267 .do_generate(CallOptions {
268 prompt: Prompt::default(),
269 tools: Some(vec![Tool::Function(FunctionTool {
270 name: "get_weather".into(),
271 description: Some("Get weather".into()),
272 input_schema: serde_json::from_value(serde_json::json!({"type": "object"}))
273 .unwrap(),
274 input_examples: Some(vec![ToolInputExample {
275 input: serde_json::json!({"city": "Paris"})
276 .as_object()
277 .cloned()
278 .unwrap(),
279 }]),
280 strict: None,
281 provider_options: None,
282 })]),
283 ..Default::default()
284 })
285 .await
286 .expect("generate");
287
288 let captured = last.0.lock().expect("mutex").clone().expect("params");
289 let tools = captured.tools.unwrap();
290 let Tool::Function(f) = &tools[0] else {
291 panic!("expected function tool");
292 };
293 assert!(
294 f.input_examples.as_ref().is_some_and(|v| v.len() == 1),
295 "with_remove(false) preserves input_examples",
296 );
297 }
298}