use crate::prompts::PromptTemplate;
use std::collections::HashMap;
pub trait ExampleSelector: Send + Sync {
fn select_examples(&self, input: &HashMap<String, String>) -> Vec<&HashMap<String, String>>;
fn examples(&self) -> &[HashMap<String, String>];
fn add_example(&mut self, example: HashMap<String, String>);
}
pub struct LengthBasedExampleSelector {
examples: Vec<HashMap<String, String>>,
max_length: usize,
}
impl LengthBasedExampleSelector {
pub fn new(examples: Vec<HashMap<String, String>>) -> Self {
Self {
examples,
max_length: 2048,
}
}
pub fn with_max_length(mut self, max: usize) -> Self {
self.max_length = max;
self
}
fn format_example_length(&self, example: &HashMap<String, String>, prefix: &str, suffix: &str) -> usize {
let mut formatted = prefix.to_string();
for val in example.values() {
formatted.push_str(val);
}
formatted.push_str(suffix);
formatted.len()
}
pub fn select_examples_by_length(
&self,
input: &HashMap<String, String>,
example_prompt: &PromptTemplate,
prefix: &str,
suffix: &str,
) -> Vec<&HashMap<String, String>> {
let input_text: String = input.values().cloned().collect::<Vec<_>>().join("");
let input_len = prefix.len() + suffix.len() + input_text.len();
let available = if self.max_length > input_len {
self.max_length - input_len
} else {
0
};
let mut selected = Vec::new();
let mut used = 0usize;
for example in &self.examples {
let example_vars: HashMap<&str, &str> = example.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
if let Ok(formatted) = example_prompt.format(&example_vars) {
let ex_len = formatted.len() + 10; if used + ex_len <= available || selected.is_empty() {
selected.push(example);
used += ex_len;
} else {
break;
}
}
}
selected
}
}
impl ExampleSelector for LengthBasedExampleSelector {
fn select_examples(&self, _input: &HashMap<String, String>) -> Vec<&HashMap<String, String>> {
self.examples.iter().collect()
}
fn examples(&self) -> &[HashMap<String, String>] {
&self.examples
}
fn add_example(&mut self, example: HashMap<String, String>) {
self.examples.push(example);
}
}
pub struct FewShotPromptTemplate {
examples: Vec<HashMap<String, String>>,
example_prompt: PromptTemplate,
prefix: String,
suffix: String,
example_separator: String,
input_variables: Vec<String>,
example_selector: Option<Box<dyn ExampleSelector>>,
}
impl FewShotPromptTemplate {
#[allow(clippy::too_many_arguments)]
pub fn new(
examples: Vec<HashMap<String, String>>,
example_prompt: PromptTemplate,
prefix: impl Into<String>,
suffix: impl Into<String>,
input_variables: Vec<String>,
) -> Self {
Self {
examples,
example_prompt,
prefix: prefix.into(),
suffix: suffix.into(),
example_separator: "\n\n".to_string(),
input_variables,
example_selector: None,
}
}
pub fn with_example_separator(mut self, separator: impl Into<String>) -> Self {
self.example_separator = separator.into();
self
}
pub fn with_example_selector(mut self, selector: Box<dyn ExampleSelector>) -> Self {
self.example_selector = Some(selector);
self
}
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.prefix = prefix.into();
self
}
pub fn with_suffix(mut self, suffix: impl Into<String>) -> Self {
self.suffix = suffix.into();
self
}
pub fn format(&self, variables: &HashMap<&str, &str>) -> Result<String, String> {
for var in &self.input_variables {
if !variables.contains_key(var.as_str()) {
return Err(format!("缺少输入变量: {}", var));
}
}
let input_map: HashMap<String, String> = variables.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
let selected_examples: Vec<&HashMap<String, String>> = if let Some(ref selector) = self.example_selector {
let input_ref_map: HashMap<String, String> = input_map.clone();
selector.select_examples(&input_ref_map)
} else {
self.examples.iter().collect()
};
let example_texts: Result<Vec<String>, String> = selected_examples.iter()
.map(|example| {
let example_vars: HashMap<&str, &str> = example.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
self.example_prompt.format(&example_vars)
})
.collect();
let example_texts = example_texts?;
let examples_str = example_texts.join(&self.example_separator);
let suffix_formatted = if self.suffix.is_empty() {
String::new()
} else {
let mut suffix_result = self.suffix.clone();
for (key, value) in variables {
suffix_result = suffix_result.replace(&format!("{{{}}}", key), value);
}
suffix_result
};
let mut parts: Vec<String> = Vec::new();
if !self.prefix.is_empty() {
parts.push(self.prefix.clone());
}
if !examples_str.is_empty() {
parts.push(examples_str);
}
if !suffix_formatted.is_empty() {
parts.push(suffix_formatted);
}
Ok(parts.join("\n\n"))
}
pub fn input_variables(&self) -> &[String] {
&self.input_variables
}
pub fn examples(&self) -> &[HashMap<String, String>] {
&self.examples
}
pub fn add_example(&mut self, example: HashMap<String, String>) {
self.examples.push(example);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_example(input: &str, output: &str) -> HashMap<String, String> {
let mut map = HashMap::new();
map.insert("input".to_string(), input.to_string());
map.insert("output".to_string(), output.to_string());
map
}
#[test]
fn test_few_shot_basic() {
let examples = vec![
make_example("苹果", "水果"),
make_example("玫瑰", "花"),
];
let example_prompt = PromptTemplate::new("输入: {input} -> 输出: {output}");
let few_shot = FewShotPromptTemplate::new(
examples,
example_prompt,
"请分类以下词语:",
"输入: {input} ->",
vec!["input".to_string()],
);
let mut vars = HashMap::new();
vars.insert("input", "太阳");
let result = few_shot.format(&vars).unwrap();
assert!(result.contains("请分类以下词语"));
assert!(result.contains("苹果"));
assert!(result.contains("水果"));
assert!(result.contains("太阳"));
}
#[test]
fn test_few_shot_missing_variable() {
let few_shot = FewShotPromptTemplate::new(
vec![],
PromptTemplate::new("示例: {input} -> {output}"),
"",
"输入: {input}",
vec!["input".to_string(), "extra".to_string()],
);
let mut vars = HashMap::new();
vars.insert("input", "test");
let result = few_shot.format(&vars);
assert!(result.is_err());
assert!(result.unwrap_err().contains("extra"));
}
#[test]
fn test_few_shot_empty_examples() {
let few_shot = FewShotPromptTemplate::new(
vec![],
PromptTemplate::new("{input} -> {output}"),
"Prefix",
"Suffix: {input}",
vec!["input".to_string()],
);
let mut vars = HashMap::new();
vars.insert("input", "hello");
let result = few_shot.format(&vars).unwrap();
assert!(result.contains("Prefix"));
assert!(result.contains("hello"));
assert!(!result.contains("->")); }
#[test]
fn test_few_shot_custom_separator() {
let examples = vec![
make_example("a", "1"),
make_example("b", "2"),
];
let few_shot = FewShotPromptTemplate::new(
examples,
PromptTemplate::new("{input}={output}"),
"",
"",
vec![],
).with_example_separator(" | ");
let vars = HashMap::new();
let result = few_shot.format(&vars).unwrap();
assert_eq!(result, "a=1 | b=2");
}
#[test]
fn test_few_shot_add_example() {
let mut few_shot = FewShotPromptTemplate::new(
vec![make_example("old", "value")],
PromptTemplate::new("{input}={output}"),
"",
"",
vec![],
);
assert_eq!(few_shot.examples().len(), 1);
few_shot.add_example(make_example("new", "value2"));
assert_eq!(few_shot.examples().len(), 2);
}
#[test]
fn test_length_based_selector() {
let examples = vec![
make_example("long text here", "short"),
];
let selector = LengthBasedExampleSelector::new(examples)
.with_max_length(100);
let input_vars = HashMap::new();
let selected = selector.select_examples(&input_vars);
assert!(!selected.is_empty());
}
#[test]
fn test_few_shot_with_selector() {
let examples = vec![
make_example("a", "1"),
make_example("b", "2"),
];
let selector = Box::new(LengthBasedExampleSelector::new(examples.clone()));
let few_shot = FewShotPromptTemplate::new(
examples,
PromptTemplate::new("{input}={output}"),
"Prefix",
"{input}",
vec!["input".to_string()],
).with_example_selector(selector);
let mut vars = HashMap::new();
vars.insert("input", "test");
let result = few_shot.format(&vars).unwrap();
assert!(result.contains("Prefix"));
}
}