cognis_core/prompts/
few_shot.rs1use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use serde::Serialize;
7
8use crate::prompts::template::{render, scan_variables};
9use crate::runnable::{Runnable, RunnableConfig};
10use crate::{CognisError, Result};
11
12#[derive(Debug, Clone)]
20pub struct FewShotTemplate<I = serde_json::Value, E = serde_json::Value> {
21 prefix: String,
22 example_template: String,
23 examples: Vec<E>,
24 suffix: String,
25 separator: String,
26 _input: PhantomData<fn() -> I>,
27}
28
29impl<I, E> FewShotTemplate<I, E>
30where
31 I: Serialize + Send + Sync + 'static,
32 E: Serialize + Send + Sync + Clone + 'static,
33{
34 pub fn new(
36 prefix: impl Into<String>,
37 example_template: impl Into<String>,
38 examples: Vec<E>,
39 suffix: impl Into<String>,
40 ) -> Self {
41 Self {
42 prefix: prefix.into(),
43 example_template: example_template.into(),
44 examples,
45 suffix: suffix.into(),
46 separator: "\n\n".into(),
47 _input: PhantomData,
48 }
49 }
50
51 pub fn with_separator(mut self, sep: impl Into<String>) -> Self {
53 self.separator = sep.into();
54 self
55 }
56
57 pub fn render(&self, input: &I) -> Result<String> {
59 let input_ctx =
60 serde_json::to_value(input).map_err(|e| CognisError::Serialization(e.to_string()))?;
61 let mut rendered_examples = Vec::with_capacity(self.examples.len());
62 for ex in &self.examples {
63 let ex_ctx =
64 serde_json::to_value(ex).map_err(|e| CognisError::Serialization(e.to_string()))?;
65 rendered_examples.push(render(&self.example_template, &ex_ctx)?);
66 }
67 let prefix = render(&self.prefix, &input_ctx)?;
68 let suffix = render(&self.suffix, &input_ctx)?;
69 let body = rendered_examples.join(&self.separator);
70 Ok(format!(
71 "{prefix}{sep}{body}{sep}{suffix}",
72 sep = self.separator
73 ))
74 }
75
76 pub fn input_variables(&self) -> Vec<String> {
78 let mut out = scan_variables(&self.prefix);
79 for v in scan_variables(&self.suffix) {
80 if !out.contains(&v) {
81 out.push(v);
82 }
83 }
84 out
85 }
86}
87
88#[async_trait]
89impl<I, E> Runnable<I, String> for FewShotTemplate<I, E>
90where
91 I: Serialize + Send + Sync + 'static,
92 E: Serialize + Send + Sync + Clone + 'static,
93{
94 async fn invoke(&self, input: I, _: RunnableConfig) -> Result<String> {
95 self.render(&input)
96 }
97 fn name(&self) -> &str {
98 "FewShotTemplate"
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use serde_json::{json, Value};
106
107 #[test]
108 fn renders_prefix_examples_suffix() {
109 let examples = vec![json!({"q": "2+2", "a": "4"}), json!({"q": "3+3", "a": "6"})];
110 let p: FewShotTemplate<Value, Value> = FewShotTemplate::new(
111 "Math problems for {topic}:",
112 "Q: {q}\nA: {a}",
113 examples,
114 "Q: {question}\nA:",
115 );
116 let out = p
117 .render(&json!({"topic": "addition", "question": "5+5"}))
118 .unwrap();
119 assert!(out.starts_with("Math problems for addition:"));
120 assert!(out.contains("Q: 2+2\nA: 4"));
121 assert!(out.contains("Q: 3+3\nA: 6"));
122 assert!(out.ends_with("Q: 5+5\nA:"));
123 }
124
125 #[test]
126 fn separator_override() {
127 let p: FewShotTemplate<Value, Value> =
128 FewShotTemplate::new("P", "{x}", vec![json!({"x": "a"}), json!({"x": "b"})], "S")
129 .with_separator(" | ");
130 let out = p.render(&json!({})).unwrap();
131 assert_eq!(out, "P | a | b | S");
132 }
133}