Skip to main content

cognis_core/prompts/
few_shot.rs

1//! Few-shot string prompt: prefix + per-example template + suffix.
2
3use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use serde::Serialize;
7
8use crate::prompts::template::{render, scan_variables};
9use crate::runnable::{Runnable, RunnableConfig};
10use crate::{CognisError, Result};
11
12/// Few-shot string prompt with a static list of examples.
13///
14/// Renders as:
15/// `<prefix><sep><ex1><sep><ex2><sep>...<sep><suffix>`
16///
17/// The prefix and suffix are rendered against the call input. Each example
18/// is rendered against itself (separate `Serialize` type).
19#[derive(Debug, Clone)]
20pub struct FewShotTemplate<I = serde_json::Value, E = serde_json::Value> {
21    prefix: String,
22    example_template: String,
23    examples: Vec<E>,
24    suffix: String,
25    separator: String,
26    _input: PhantomData<fn() -> I>,
27}
28
29impl<I, E> FewShotTemplate<I, E>
30where
31    I: Serialize + Send + Sync + 'static,
32    E: Serialize + Send + Sync + Clone + 'static,
33{
34    /// Build a few-shot template.
35    pub fn new(
36        prefix: impl Into<String>,
37        example_template: impl Into<String>,
38        examples: Vec<E>,
39        suffix: impl Into<String>,
40    ) -> Self {
41        Self {
42            prefix: prefix.into(),
43            example_template: example_template.into(),
44            examples,
45            suffix: suffix.into(),
46            separator: "\n\n".into(),
47            _input: PhantomData,
48        }
49    }
50
51    /// Override the separator used between rendered examples (default `\n\n`).
52    pub fn with_separator(mut self, sep: impl Into<String>) -> Self {
53        self.separator = sep.into();
54        self
55    }
56
57    /// Render the full prompt against the input.
58    pub fn render(&self, input: &I) -> Result<String> {
59        let input_ctx =
60            serde_json::to_value(input).map_err(|e| CognisError::Serialization(e.to_string()))?;
61        let mut rendered_examples = Vec::with_capacity(self.examples.len());
62        for ex in &self.examples {
63            let ex_ctx =
64                serde_json::to_value(ex).map_err(|e| CognisError::Serialization(e.to_string()))?;
65            rendered_examples.push(render(&self.example_template, &ex_ctx)?);
66        }
67        let prefix = render(&self.prefix, &input_ctx)?;
68        let suffix = render(&self.suffix, &input_ctx)?;
69        let body = rendered_examples.join(&self.separator);
70        Ok(format!(
71            "{prefix}{sep}{body}{sep}{suffix}",
72            sep = self.separator
73        ))
74    }
75
76    /// Variables referenced by the prefix and suffix.
77    pub fn input_variables(&self) -> Vec<String> {
78        let mut out = scan_variables(&self.prefix);
79        for v in scan_variables(&self.suffix) {
80            if !out.contains(&v) {
81                out.push(v);
82            }
83        }
84        out
85    }
86}
87
88#[async_trait]
89impl<I, E> Runnable<I, String> for FewShotTemplate<I, E>
90where
91    I: Serialize + Send + Sync + 'static,
92    E: Serialize + Send + Sync + Clone + 'static,
93{
94    async fn invoke(&self, input: I, _: RunnableConfig) -> Result<String> {
95        self.render(&input)
96    }
97    fn name(&self) -> &str {
98        "FewShotTemplate"
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use serde_json::{json, Value};
106
107    #[test]
108    fn renders_prefix_examples_suffix() {
109        let examples = vec![json!({"q": "2+2", "a": "4"}), json!({"q": "3+3", "a": "6"})];
110        let p: FewShotTemplate<Value, Value> = FewShotTemplate::new(
111            "Math problems for {topic}:",
112            "Q: {q}\nA: {a}",
113            examples,
114            "Q: {question}\nA:",
115        );
116        let out = p
117            .render(&json!({"topic": "addition", "question": "5+5"}))
118            .unwrap();
119        assert!(out.starts_with("Math problems for addition:"));
120        assert!(out.contains("Q: 2+2\nA: 4"));
121        assert!(out.contains("Q: 3+3\nA: 6"));
122        assert!(out.ends_with("Q: 5+5\nA:"));
123    }
124
125    #[test]
126    fn separator_override() {
127        let p: FewShotTemplate<Value, Value> =
128            FewShotTemplate::new("P", "{x}", vec![json!({"x": "a"}), json!({"x": "b"})], "S")
129                .with_separator(" | ");
130        let out = p.render(&json!({})).unwrap();
131        assert_eq!(out, "P | a | b | S");
132    }
133}