use crate::llm::LlmRequest;
#[derive(Debug, Clone)]
pub struct Example {
pub input: String,
pub output: String,
}
impl Example {
pub fn new(input: impl Into<String>, output: impl Into<String>) -> Self {
Self {
input: input.into(),
output: output.into(),
}
}
}
pub trait ExampleProvider: Send + Sync {
fn examples(&self) -> Vec<Example>;
}
#[derive(Debug, Clone)]
pub struct ExampleTool {
examples: Vec<Example>,
}
impl ExampleTool {
pub fn new(examples: Vec<Example>) -> Self {
Self { examples }
}
pub fn from_provider(provider: &dyn ExampleProvider) -> Self {
Self {
examples: provider.examples(),
}
}
pub fn process_llm_request(&self, request: &mut LlmRequest) {
if self.examples.is_empty() {
return;
}
let mut example_text = String::from("\n\nHere are some examples of expected behavior:\n");
for (i, example) in self.examples.iter().enumerate() {
example_text.push_str(&format!(
"\nExample {}:\nUser: {}\nAssistant: {}\n",
i + 1,
example.input,
example.output
));
}
if let Some(ref mut instruction) = request.system_instruction {
instruction.push_str(&example_text);
} else {
request.system_instruction = Some(example_text);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn injects_examples() {
let tool = ExampleTool::new(vec![
Example::new("What is 2+2?", "4"),
Example::new("What color is the sky?", "Blue"),
]);
let mut request = LlmRequest::default();
request.system_instruction = Some("You are helpful.".into());
tool.process_llm_request(&mut request);
let instruction = request.system_instruction.unwrap();
assert!(instruction.contains("Example 1:"));
assert!(instruction.contains("What is 2+2?"));
assert!(instruction.contains("Example 2:"));
assert!(instruction.contains("Blue"));
}
#[test]
fn empty_examples_noop() {
let tool = ExampleTool::new(vec![]);
let mut request = LlmRequest::default();
request.system_instruction = Some("Original".into());
tool.process_llm_request(&mut request);
assert_eq!(request.system_instruction.unwrap(), "Original");
}
#[test]
fn creates_instruction_if_none() {
let tool = ExampleTool::new(vec![Example::new("Hi", "Hello!")]);
let mut request = LlmRequest::default();
tool.process_llm_request(&mut request);
assert!(request.system_instruction.is_some());
}
struct StaticProvider;
impl ExampleProvider for StaticProvider {
fn examples(&self) -> Vec<Example> {
vec![Example::new("test", "response")]
}
}
#[test]
fn from_provider() {
let tool = ExampleTool::from_provider(&StaticProvider);
assert_eq!(tool.examples.len(), 1);
}
}