Skip to main content

rs_adk/tools/
example_tool.rs

1//! Example tool — adds few-shot examples to LLM requests.
2//!
3//! Mirrors ADK-Python's `ExampleTool`. Enriches LLM requests by
4//! injecting example conversations into the context.
5
6use crate::llm::LlmRequest;
7
8/// A single few-shot example for an agent.
9#[derive(Debug, Clone)]
10pub struct Example {
11    /// The user input.
12    pub input: String,
13    /// The expected model output.
14    pub output: String,
15}
16
17impl Example {
18    /// Create a new example.
19    pub fn new(input: impl Into<String>, output: impl Into<String>) -> Self {
20        Self {
21            input: input.into(),
22            output: output.into(),
23        }
24    }
25}
26
27/// Trait for providing examples dynamically.
28pub trait ExampleProvider: Send + Sync {
29    /// Get the examples to inject.
30    fn examples(&self) -> Vec<Example>;
31}
32
33/// Tool that adds few-shot examples to the LLM request.
34///
35/// This is not a callable tool — it modifies the LLM request to include
36/// example conversations that guide the model's behavior.
37#[derive(Debug, Clone)]
38pub struct ExampleTool {
39    examples: Vec<Example>,
40}
41
42impl ExampleTool {
43    /// Create a new example tool with static examples.
44    pub fn new(examples: Vec<Example>) -> Self {
45        Self { examples }
46    }
47
48    /// Create from an example provider.
49    pub fn from_provider(provider: &dyn ExampleProvider) -> Self {
50        Self {
51            examples: provider.examples(),
52        }
53    }
54
55    /// Add example instructions to the LLM request.
56    ///
57    /// Appends the examples to the system instruction as formatted
58    /// input/output pairs.
59    pub fn process_llm_request(&self, request: &mut LlmRequest) {
60        if self.examples.is_empty() {
61            return;
62        }
63
64        let mut example_text = String::from("\n\nHere are some examples of expected behavior:\n");
65        for (i, example) in self.examples.iter().enumerate() {
66            example_text.push_str(&format!(
67                "\nExample {}:\nUser: {}\nAssistant: {}\n",
68                i + 1,
69                example.input,
70                example.output
71            ));
72        }
73
74        if let Some(ref mut instruction) = request.system_instruction {
75            instruction.push_str(&example_text);
76        } else {
77            request.system_instruction = Some(example_text);
78        }
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn injects_examples() {
88        let tool = ExampleTool::new(vec![
89            Example::new("What is 2+2?", "4"),
90            Example::new("What color is the sky?", "Blue"),
91        ]);
92
93        let mut request = LlmRequest::default();
94        request.system_instruction = Some("You are helpful.".into());
95
96        tool.process_llm_request(&mut request);
97        let instruction = request.system_instruction.unwrap();
98        assert!(instruction.contains("Example 1:"));
99        assert!(instruction.contains("What is 2+2?"));
100        assert!(instruction.contains("Example 2:"));
101        assert!(instruction.contains("Blue"));
102    }
103
104    #[test]
105    fn empty_examples_noop() {
106        let tool = ExampleTool::new(vec![]);
107        let mut request = LlmRequest::default();
108        request.system_instruction = Some("Original".into());
109
110        tool.process_llm_request(&mut request);
111        assert_eq!(request.system_instruction.unwrap(), "Original");
112    }
113
114    #[test]
115    fn creates_instruction_if_none() {
116        let tool = ExampleTool::new(vec![Example::new("Hi", "Hello!")]);
117        let mut request = LlmRequest::default();
118
119        tool.process_llm_request(&mut request);
120        assert!(request.system_instruction.is_some());
121    }
122
123    struct StaticProvider;
124    impl ExampleProvider for StaticProvider {
125        fn examples(&self) -> Vec<Example> {
126            vec![Example::new("test", "response")]
127        }
128    }
129
130    #[test]
131    fn from_provider() {
132        let tool = ExampleTool::from_provider(&StaticProvider);
133        assert_eq!(tool.examples.len(), 1);
134    }
135}