use async_trait::async_trait;
use crate::error::Result;
use crate::language_model::{CallOptions, FunctionTool, LanguageModel, Tool};
use crate::middleware::language_model::{CallKind, LanguageModelMiddleware};
pub struct AddToolInputExamplesMiddleware {
prefix: String,
formatter: ExampleFormatter,
remove: bool,
}
type ExampleFormatter =
Box<dyn Fn(&crate::language_model::ToolInputExample, usize) -> String + Send + Sync>;
impl std::fmt::Debug for AddToolInputExamplesMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AddToolInputExamplesMiddleware")
.field("prefix", &self.prefix)
.field("remove", &self.remove)
.finish_non_exhaustive()
}
}
impl Default for AddToolInputExamplesMiddleware {
fn default() -> Self {
Self::new()
}
}
impl AddToolInputExamplesMiddleware {
#[must_use]
pub fn new() -> Self {
Self {
prefix: "Input Examples:".to_owned(),
formatter: Box::new(default_formatter),
remove: true,
}
}
#[must_use]
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.prefix = prefix.into();
self
}
#[must_use]
pub fn with_formatter<F>(mut self, formatter: F) -> Self
where
F: Fn(&crate::language_model::ToolInputExample, usize) -> String + Send + Sync + 'static,
{
self.formatter = Box::new(formatter);
self
}
#[must_use]
pub fn with_remove(mut self, remove: bool) -> Self {
self.remove = remove;
self
}
}
fn default_formatter(example: &crate::language_model::ToolInputExample, _index: usize) -> String {
serde_json::to_string(&example.input).unwrap_or_else(|_| "<unserializable>".to_owned())
}
#[async_trait]
impl LanguageModelMiddleware for AddToolInputExamplesMiddleware {
async fn transform_params(
&self,
_kind: CallKind,
mut params: CallOptions,
_inner: &dyn LanguageModel,
) -> Result<CallOptions> {
let Some(tools) = params.tools.as_mut() else {
return Ok(params);
};
for tool in tools.iter_mut() {
if let Tool::Function(FunctionTool {
description,
input_examples,
..
}) = tool
{
let Some(examples) = input_examples.as_ref() else {
continue;
};
if examples.is_empty() {
continue;
}
let formatted = examples
.iter()
.enumerate()
.map(|(i, ex)| (self.formatter)(ex, i))
.collect::<Vec<_>>()
.join("\n");
let examples_section = format!("{}\n{formatted}", self.prefix);
*description = Some(match description.take() {
Some(existing) if !existing.is_empty() => {
format!("{existing}\n\n{examples_section}")
}
_ => examples_section,
});
if self.remove {
*input_examples = None;
}
}
}
Ok(params)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::language_model::{GenerateResult, Prompt, StreamResult, ToolInputExample};
use crate::middleware::wrap_language_model;
use async_trait::async_trait;
#[derive(Debug, Default)]
struct LastParams(std::sync::Mutex<Option<CallOptions>>);
#[derive(Debug)]
struct Recorder(Arc<LastParams>);
#[async_trait]
impl LanguageModel for Recorder {
fn provider(&self) -> &'static str {
"rec"
}
fn model_id(&self) -> &'static str {
"rec"
}
async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult> {
*self.0.0.lock().expect("mutex") = Some(options);
Ok(GenerateResult {
content: vec![],
finish_reason: crate::language_model::FinishReason::new(
crate::language_model::FinishReasonKind::Stop,
),
usage: crate::language_model::Usage::default(),
provider_metadata: None,
request: None,
response: None,
warnings: vec![],
})
}
async fn do_stream(&self, _options: CallOptions) -> Result<StreamResult> {
unimplemented!()
}
}
#[tokio::test]
async fn appends_examples_to_description() {
let last = Arc::new(LastParams::default());
let inner: Arc<dyn LanguageModel> = Arc::new(Recorder(Arc::clone(&last)));
let wrapped = wrap_language_model(
inner,
[Arc::new(AddToolInputExamplesMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
);
wrapped
.do_generate(CallOptions {
prompt: Prompt::default(),
tools: Some(vec![Tool::Function(FunctionTool {
name: "get_weather".into(),
description: Some("Get weather".into()),
input_schema: serde_json::from_value(serde_json::json!({"type": "object"}))
.unwrap(),
input_examples: Some(vec![ToolInputExample {
input: serde_json::json!({"city": "Tokyo"})
.as_object()
.cloned()
.unwrap(),
}]),
strict: None,
provider_options: None,
})]),
..Default::default()
})
.await
.expect("generate");
let captured = last.0.lock().expect("mutex").clone().expect("params");
let tools = captured.tools.unwrap();
let Tool::Function(f) = &tools[0] else {
panic!("expected function tool");
};
let desc = f.description.as_ref().unwrap();
assert!(desc.contains("Get weather"), "preserves original desc");
assert!(desc.contains("Examples:"), "appends examples header");
assert!(desc.contains("Tokyo"), "renders example body");
assert!(
f.input_examples.is_none(),
"default remove=true strips input_examples",
);
}
#[tokio::test]
async fn with_remove_false_keeps_input_examples() {
let last = Arc::new(LastParams::default());
let inner: Arc<dyn LanguageModel> = Arc::new(Recorder(Arc::clone(&last)));
let wrapped = wrap_language_model(
inner,
[
Arc::new(AddToolInputExamplesMiddleware::new().with_remove(false))
as Arc<dyn LanguageModelMiddleware>,
],
);
wrapped
.do_generate(CallOptions {
prompt: Prompt::default(),
tools: Some(vec![Tool::Function(FunctionTool {
name: "get_weather".into(),
description: Some("Get weather".into()),
input_schema: serde_json::from_value(serde_json::json!({"type": "object"}))
.unwrap(),
input_examples: Some(vec![ToolInputExample {
input: serde_json::json!({"city": "Paris"})
.as_object()
.cloned()
.unwrap(),
}]),
strict: None,
provider_options: None,
})]),
..Default::default()
})
.await
.expect("generate");
let captured = last.0.lock().expect("mutex").clone().expect("params");
let tools = captured.tools.unwrap();
let Tool::Function(f) = &tools[0] else {
panic!("expected function tool");
};
assert!(
f.input_examples.as_ref().is_some_and(|v| v.len() == 1),
"with_remove(false) preserves input_examples",
);
}
}