use std::borrow::Cow;
use std::path::Path;
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum FormatType {
Json,
Yaml,
Markdown,
Xml,
#[default]
Plain,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum PromptTone {
Inclusive,
#[default]
Balanced,
Restrictive,
}
impl PromptTone {
#[inline]
pub fn modifiers(&self) -> ToneModifiers {
match self {
Self::Inclusive => ToneModifiers {
requirement_prefix: "Consider including",
possibility_prefix: "You may also",
constraint_prefix: "Ideally",
format_intro: "A suggested format is",
output_verb: "could be",
uncertainty_guidance: "When uncertain, include rather than exclude.",
},
Self::Balanced => ToneModifiers {
requirement_prefix: "Include",
possibility_prefix: "You can",
constraint_prefix: "Please",
format_intro: "Use the following format",
output_verb: "should be",
uncertainty_guidance: "Use your best judgment when uncertain.",
},
Self::Restrictive => ToneModifiers {
requirement_prefix: "You MUST include",
possibility_prefix: "Only include",
constraint_prefix: "Required:",
format_intro: "Output EXACTLY in this format",
output_verb: "must be",
uncertainty_guidance: "When uncertain, exclude rather than include.",
},
}
}
#[inline]
pub fn default_threshold(&self) -> f64 {
match self {
Self::Inclusive => 0.6,
Self::Balanced => 0.8,
Self::Restrictive => 0.9,
}
}
#[inline]
pub fn favors_recall(&self) -> bool {
matches!(self, Self::Inclusive)
}
#[inline]
pub fn favors_precision(&self) -> bool {
matches!(self, Self::Restrictive)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ToneModifiers {
pub requirement_prefix: &'static str,
pub possibility_prefix: &'static str,
pub constraint_prefix: &'static str,
pub format_intro: &'static str,
pub output_verb: &'static str,
pub uncertainty_guidance: &'static str,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct JsonSchema {
#[serde(rename = "type", default)]
pub schema_type: String,
#[serde(default)]
pub required: Vec<String>,
#[serde(default)]
pub properties: serde_json::Map<String, serde_json::Value>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct FormatSpec {
#[serde(rename = "type", default)]
pub format_type: FormatType,
#[serde(default)]
pub schema: Option<JsonSchema>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemplateOptions {
#[serde(default)]
pub strict: bool,
#[serde(default = "default_true")]
pub include_in_prompt: bool,
#[serde(default)]
pub tone: PromptTone,
}
fn default_true() -> bool {
true
}
impl Default for TemplateOptions {
fn default() -> Self {
Self {
strict: false,
include_in_prompt: true,
tone: PromptTone::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemplateExample<'a> {
#[serde(borrow)]
pub input: Cow<'a, str>,
#[serde(borrow)]
pub output: Cow<'a, str>,
}
impl<'a> TemplateExample<'a> {
pub fn new(input: impl Into<Cow<'a, str>>, output: impl Into<Cow<'a, str>>) -> Self {
Self {
input: input.into(),
output: output.into(),
}
}
pub fn into_owned(self) -> TemplateExample<'static> {
TemplateExample {
input: Cow::Owned(self.input.into_owned()),
output: Cow::Owned(self.output.into_owned()),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
struct TemplateFrontmatter {
#[serde(default)]
name: String,
#[serde(default)]
version: String,
#[serde(default)]
signature: String,
#[serde(default)]
format: FormatSpec,
#[serde(default)]
options: TemplateOptions,
}
#[derive(Debug, Clone)]
pub struct Template<'a> {
pub name: Cow<'a, str>,
pub version: Cow<'a, str>,
pub signature: Cow<'a, str>,
pub format: FormatSpec,
pub system_prompt: Cow<'a, str>,
pub format_instructions: Cow<'a, str>,
pub examples: Vec<TemplateExample<'a>>,
pub options: TemplateOptions,
}
impl Default for Template<'_> {
fn default() -> Self {
Self {
name: Cow::Borrowed("default"),
version: Cow::Borrowed("1.0"),
signature: Cow::Borrowed("input -> output"),
format: FormatSpec::default(),
system_prompt: Cow::Borrowed(""),
format_instructions: Cow::Borrowed(""),
examples: Vec::new(),
options: TemplateOptions::default(),
}
}
}
impl<'a> Template<'a> {
pub fn new(name: impl Into<Cow<'a, str>>) -> Self {
Self {
name: name.into(),
..Default::default()
}
}
pub fn simple(prompt: impl Into<Cow<'a, str>>) -> Self {
Self {
name: Cow::Borrowed("simple"),
system_prompt: prompt.into(),
..Default::default()
}
}
pub fn render(&self, input: &str) -> String {
let mut output = String::with_capacity(self.system_prompt.len() + input.len() + 100);
if !self.system_prompt.is_empty() {
output.push_str(&self.system_prompt);
output.push_str("\n\n");
}
output.push_str("Input: ");
output.push_str(input);
output
}
pub fn from_str(content: &'a str) -> Result<Self> {
let content = content.trim();
if !content.starts_with("---") {
return Ok(Self {
name: Cow::Borrowed("inline"),
system_prompt: Cow::Borrowed(content),
..Default::default()
});
}
let after_first = &content[3..];
let frontmatter_end = after_first
.find("\n---")
.ok_or_else(|| Error::parse("Missing closing --- for frontmatter"))?;
let frontmatter_str = &after_first[..frontmatter_end].trim();
let body_start = frontmatter_end + 4; let body = if body_start < after_first.len() {
after_first[body_start..].trim()
} else {
""
};
let frontmatter: TemplateFrontmatter = serde_yaml::from_str(frontmatter_str)
.map_err(|e| Error::parse(format!("Invalid YAML frontmatter: {}", e)))?;
let (system_prompt, examples_section) = if let Some(idx) = body.find("---examples---") {
let prompt = body[..idx].trim();
let examples_str = body[idx + 14..].trim(); (prompt, Some(examples_str))
} else {
(body, None)
};
let format_instructions = extract_format_instructions(system_prompt);
let examples = if let Some(examples_str) = examples_section {
parse_examples(examples_str)?
} else {
Vec::new()
};
Ok(Self {
name: Cow::Borrowed(if frontmatter.name.is_empty() {
"unnamed"
} else {
return Ok(Self {
name: Cow::Owned(frontmatter.name),
version: Cow::Owned(frontmatter.version),
signature: Cow::Owned(frontmatter.signature),
format: frontmatter.format,
system_prompt: Cow::Borrowed(system_prompt),
format_instructions: Cow::Owned(format_instructions),
examples,
options: frontmatter.options,
});
}),
version: Cow::Owned(frontmatter.version),
signature: Cow::Owned(frontmatter.signature),
format: frontmatter.format,
system_prompt: Cow::Borrowed(system_prompt),
format_instructions: Cow::Owned(format_instructions),
examples,
options: frontmatter.options,
})
}
pub fn from_file(path: impl AsRef<Path>) -> Result<Template<'static>> {
let content = std::fs::read_to_string(path.as_ref())
.map_err(|e| Error::io(format!("Failed to read template file: {}", e)))?;
Self::parse_owned(&content)
}
fn parse_owned(content: &str) -> Result<Template<'static>> {
let content = content.trim();
if !content.starts_with("---") {
return Ok(Template {
name: Cow::Owned("inline".to_string()),
system_prompt: Cow::Owned(content.to_string()),
..Default::default()
});
}
let after_first = &content[3..];
let frontmatter_end = after_first
.find("\n---")
.ok_or_else(|| Error::parse("Missing closing --- for frontmatter"))?;
let frontmatter_str = &after_first[..frontmatter_end].trim();
let body_start = frontmatter_end + 4; let body = if body_start < after_first.len() {
after_first[body_start..].trim()
} else {
""
};
let frontmatter: TemplateFrontmatter = serde_yaml::from_str(frontmatter_str)
.map_err(|e| Error::parse(format!("Invalid YAML frontmatter: {}", e)))?;
let (system_prompt, examples_section) = if let Some(idx) = body.find("---examples---") {
let prompt = body[..idx].trim();
let examples_str = body[idx + 14..].trim(); (prompt, Some(examples_str))
} else {
(body, None)
};
let format_instructions = extract_format_instructions(system_prompt);
let examples = if let Some(examples_str) = examples_section {
parse_examples(examples_str)?
} else {
Vec::new()
};
Ok(Template {
name: Cow::Owned(if frontmatter.name.is_empty() {
"unnamed".to_string()
} else {
frontmatter.name
}),
version: Cow::Owned(frontmatter.version),
signature: Cow::Owned(frontmatter.signature),
format: frontmatter.format,
system_prompt: Cow::Owned(system_prompt.to_string()),
format_instructions: Cow::Owned(format_instructions),
examples,
options: frontmatter.options,
})
}
pub fn into_owned(self) -> Template<'static> {
Template {
name: Cow::Owned(self.name.into_owned()),
version: Cow::Owned(self.version.into_owned()),
signature: Cow::Owned(self.signature.into_owned()),
format: self.format,
system_prompt: Cow::Owned(self.system_prompt.into_owned()),
format_instructions: Cow::Owned(self.format_instructions.into_owned()),
examples: self.examples.into_iter().map(|e| e.into_owned()).collect(),
options: self.options,
}
}
pub fn with_system_prompt(mut self, prompt: impl Into<Cow<'a, str>>) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn with_format(mut self, format_type: FormatType) -> Self {
self.format.format_type = format_type;
self
}
pub fn with_example(mut self, example: TemplateExample<'a>) -> Self {
self.examples.push(example);
self
}
pub fn strict(mut self, strict: bool) -> Self {
self.options.strict = strict;
self
}
pub fn with_tone(mut self, tone: PromptTone) -> Self {
self.options.tone = tone;
self
}
pub fn assemble_prompt(
&self,
question: &str,
iteration: u32,
feedback: Option<&str>,
) -> String {
let mut prompt = String::with_capacity(4096);
let tone = self.options.tone.modifiers();
if !self.system_prompt.is_empty() {
prompt.push_str(&self.system_prompt);
prompt.push_str("\n\n");
}
if self.options.include_in_prompt && !self.format_instructions.is_empty() {
prompt.push_str("## Output Format\n\n");
prompt.push_str(&self.format_instructions);
prompt.push_str("\n\n");
} else if self.options.include_in_prompt {
match self.format.format_type {
FormatType::Json => {
prompt.push_str("## Output Format\n\n");
prompt.push_str(tone.format_intro);
prompt.push_str(": valid JSON.\n");
if let Some(ref schema) = self.format.schema {
if !schema.required.is_empty() {
prompt.push_str(tone.requirement_prefix);
prompt.push_str(" these fields: ");
prompt.push_str(&schema.required.join(", "));
prompt.push('\n');
}
}
prompt.push_str(tone.uncertainty_guidance);
prompt.push_str("\n\n");
}
FormatType::Yaml => {
prompt.push_str("## Output Format\n\n");
prompt.push_str(tone.format_intro);
prompt.push_str(": valid YAML.\n");
if let Some(ref schema) = self.format.schema {
if !schema.required.is_empty() {
prompt.push_str(tone.requirement_prefix);
prompt.push_str(" these fields: ");
prompt.push_str(&schema.required.join(", "));
prompt.push('\n');
}
}
prompt.push_str(tone.uncertainty_guidance);
prompt.push_str("\n\n");
}
FormatType::Xml => {
prompt.push_str("## Output Format\n\n");
prompt.push_str(tone.format_intro);
prompt.push_str(": valid XML.\n");
prompt.push_str(tone.uncertainty_guidance);
prompt.push_str("\n\n");
}
FormatType::Markdown => {
prompt.push_str("## Output Format\n\n");
prompt.push_str(tone.format_intro);
prompt.push_str(": Markdown.\n");
prompt.push_str(tone.uncertainty_guidance);
prompt.push_str("\n\n");
}
FormatType::Plain => {}
}
}
if !self.examples.is_empty() {
prompt.push_str("## Examples\n\n");
for (i, example) in self.examples.iter().enumerate() {
prompt.push_str(&format!("### Example {}\n\n", i + 1));
prompt.push_str("**Input:** ");
prompt.push_str(&example.input);
prompt.push_str("\n\n**Output:**\n");
prompt.push_str(&example.output);
prompt.push_str("\n\n");
}
}
prompt.push_str("## Your Task\n\n");
prompt.push_str("**Input:** ");
prompt.push_str(question);
prompt.push('\n');
if iteration > 0 {
if let Some(fb) = feedback {
prompt.push_str("\n**Feedback from previous attempt:**\n");
prompt.push_str(fb);
prompt.push('\n');
}
}
prompt
}
pub fn validate_output(&self, output: &str) -> Result<()> {
match self.format.format_type {
FormatType::Json => {
let json_str = extract_json_from_output(output);
let value: serde_json::Value = serde_json::from_str(json_str)
.map_err(|e| Error::validation(format!("Invalid JSON: {}", e)))?;
if let Some(ref schema) = self.format.schema {
if schema.schema_type == "object" {
if let serde_json::Value::Object(obj) = &value {
for field in &schema.required {
if !obj.contains_key(field) {
return Err(Error::validation(format!(
"Missing required field: {}",
field
)));
}
}
} else {
return Err(Error::validation("Expected JSON object"));
}
}
}
Ok(())
}
FormatType::Yaml => {
let yaml_str = extract_yaml_from_output(output);
let value: serde_yaml::Value = serde_yaml::from_str(yaml_str)
.map_err(|e| Error::validation(format!("Invalid YAML: {}", e)))?;
if let Some(ref schema) = self.format.schema {
if schema.schema_type == "object" {
if let serde_yaml::Value::Mapping(map) = &value {
for field in &schema.required {
let key = serde_yaml::Value::String(field.clone());
if !map.contains_key(&key) {
return Err(Error::validation(format!(
"Missing required field: {}",
field
)));
}
}
} else {
return Err(Error::validation("Expected YAML mapping"));
}
}
}
Ok(())
}
FormatType::Xml => {
if !output.trim().starts_with('<') || !output.trim().ends_with('>') {
return Err(Error::validation(
"Invalid XML: must start with < and end with >",
));
}
Ok(())
}
FormatType::Markdown | FormatType::Plain => {
Ok(())
}
}
}
pub fn parse_output<T: serde::de::DeserializeOwned>(&self, output: &str) -> Result<T> {
match self.format.format_type {
FormatType::Json => {
let json_str = extract_json_from_output(output);
serde_json::from_str(json_str)
.map_err(|e| Error::parse(format!("Failed to parse JSON output: {}", e)))
}
FormatType::Yaml => {
let yaml_str = extract_yaml_from_output(output);
serde_yaml::from_str(yaml_str)
.map_err(|e| Error::parse(format!("Failed to parse YAML output: {}", e)))
}
_ => Err(Error::parse(
"parse_output only supports JSON and YAML formats",
)),
}
}
pub fn get_format_instructions(&self) -> &str {
&self.format_instructions
}
}
fn extract_format_instructions(markdown: &str) -> String {
let lower = markdown.to_lowercase();
for marker in &[
"## output format",
"### output format",
"## format",
"### format",
] {
if let Some(start) = lower.find(marker) {
let content_start = start + marker.len();
let rest = &markdown[content_start..];
let end = rest
.find("\n## ")
.or_else(|| rest.find("\n### "))
.unwrap_or(rest.len());
return rest[..end].trim().to_string();
}
}
String::new()
}
fn parse_examples(content: &str) -> Result<Vec<TemplateExample<'static>>> {
let mut examples = Vec::new();
let parts: Vec<&str> = content
.split(|c| c == '#')
.filter(|s| !s.trim().is_empty())
.collect();
for part in parts {
let part = part.trim();
if part.is_empty() {
continue;
}
let content_start = part.find('\n').unwrap_or(0);
let example_content = &part[content_start..].trim();
let input = extract_field(example_content, "input")?;
let output = extract_field(example_content, "output")?;
if !input.is_empty() && !output.is_empty() {
examples.push(TemplateExample {
input: Cow::Owned(input),
output: Cow::Owned(output),
});
}
}
Ok(examples)
}
fn extract_field(content: &str, field_name: &str) -> Result<String> {
let lower = content.to_lowercase();
let markers = [
format!("**{}:**", field_name),
format!("**{}**:", field_name),
format!("{}:", field_name),
];
for marker in &markers {
let marker_lower = marker.to_lowercase();
if let Some(start) = lower.find(&marker_lower) {
let value_start = start + marker.len();
let rest = &content[value_start..];
let end = rest
.to_lowercase()
.find("**output")
.or_else(|| rest.to_lowercase().find("**input"))
.or_else(|| rest.find("\n## "))
.or_else(|| rest.find("\n### "))
.unwrap_or(rest.len());
return Ok(rest[..end].trim().to_string());
}
}
Ok(String::new())
}
fn extract_json_from_output(output: &str) -> &str {
let trimmed = output.trim();
if let Some(start) = trimmed.find("```json") {
let json_start = start + 7;
if let Some(end) = trimmed[json_start..].find("```") {
return trimmed[json_start..json_start + end].trim();
}
}
if let Some(start) = trimmed.find("```") {
let json_start = start + 3;
let content_start = trimmed[json_start..]
.find('\n')
.map(|i| json_start + i + 1)
.unwrap_or(json_start);
if let Some(end) = trimmed[content_start..].find("```") {
return trimmed[content_start..content_start + end].trim();
}
}
trimmed
}
fn extract_yaml_from_output(output: &str) -> &str {
let trimmed = output.trim();
if let Some(start) = trimmed.find("```yaml") {
let yaml_start = start + 7;
if let Some(end) = trimmed[yaml_start..].find("```") {
return trimmed[yaml_start..yaml_start + end].trim();
}
}
if let Some(start) = trimmed.find("```yml") {
let yaml_start = start + 6;
if let Some(end) = trimmed[yaml_start..].find("```") {
return trimmed[yaml_start..yaml_start + end].trim();
}
}
if let Some(start) = trimmed.find("```") {
let yaml_start = start + 3;
let content_start = trimmed[yaml_start..]
.find('\n')
.map(|i| yaml_start + i + 1)
.unwrap_or(yaml_start);
if let Some(end) = trimmed[content_start..].find("```") {
return trimmed[content_start..content_start + end].trim();
}
}
trimmed
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_simple_template() {
let content = r#"---
name: test_template
version: "1.0"
signature: "question -> answer"
format:
type: json
options:
strict: true
---
You are a helpful assistant.
Answer questions concisely.
"#;
let template = Template::from_str(content).unwrap();
assert_eq!(template.name, "test_template");
assert_eq!(template.version, "1.0");
assert_eq!(template.signature, "question -> answer");
assert_eq!(template.format.format_type, FormatType::Json);
assert!(template.options.strict);
assert!(template.system_prompt.contains("helpful assistant"));
}
#[test]
fn test_parse_template_with_examples() {
let content = r#"---
name: qa
format:
type: json
---
Answer questions.
---examples---
## Example 1
**Input:** What is 2+2?
**Output:**
```json
{"answer": "4"}
```
## Example 2
**Input:** What color is the sky?
**Output:**
```json
{"answer": "blue"}
```
"#;
let template = Template::from_str(content).unwrap();
assert_eq!(template.examples.len(), 2);
assert!(template.examples[0].input.contains("2+2"));
assert!(template.examples[0].output.contains("4"));
}
#[test]
fn test_parse_template_with_schema() {
let content = r#"---
name: code_gen
format:
type: json
schema:
type: object
required:
- code
- explanation
properties:
code:
type: string
explanation:
type: string
---
Generate code.
"#;
let template = Template::from_str(content).unwrap();
let schema = template.format.schema.as_ref().unwrap();
assert_eq!(schema.required, vec!["code", "explanation"]);
}
#[test]
fn test_validate_json_output() {
let template = Template::new("test").with_format(FormatType::Json);
assert!(template.validate_output(r#"{"key": "value"}"#).is_ok());
assert!(template.validate_output("not json").is_err());
}
#[test]
fn test_validate_json_with_schema() {
let content = r#"---
format:
type: json
schema:
type: object
required:
- answer
---
Test
"#;
let template = Template::from_str(content).unwrap();
assert!(template.validate_output(r#"{"answer": "yes"}"#).is_ok());
assert!(template.validate_output(r#"{"other": "no"}"#).is_err());
}
#[test]
fn test_assemble_prompt() {
let content = r#"---
name: qa
format:
type: json
---
You are a helpful assistant.
## Output Format
Return JSON with an "answer" field.
---examples---
## Example 1
**Input:** What is 1+1?
**Output:**
```json
{"answer": "2"}
```
"#;
let template = Template::from_str(content).unwrap();
let prompt = template.assemble_prompt("What is 2+2?", 0, None);
assert!(prompt.contains("helpful assistant"));
assert!(prompt.contains("What is 2+2?"));
assert!(prompt.contains("Example 1"));
}
#[test]
fn test_assemble_prompt_with_feedback() {
let template = Template::new("test").with_system_prompt("Answer questions.");
let prompt = template.assemble_prompt("What is 2+2?", 1, Some("Previous answer was wrong"));
assert!(prompt.contains("Feedback from previous attempt"));
assert!(prompt.contains("Previous answer was wrong"));
}
#[test]
fn test_extract_json_from_code_block() {
let output = r#"Here is the answer:
```json
{"answer": "42"}
```
"#;
let json = extract_json_from_output(output);
assert_eq!(json, r#"{"answer": "42"}"#);
}
#[test]
fn test_extract_yaml_from_code_block() {
let output = r#"Here is the configuration:
```yaml
name: test
version: 1.0
enabled: true
```
"#;
let yaml = extract_yaml_from_output(output);
assert_eq!(yaml, "name: test\nversion: 1.0\nenabled: true");
}
#[test]
fn test_extract_yaml_from_yml_code_block() {
let output = r#"Config file:
```yml
database:
host: localhost
port: 5432
```
"#;
let yaml = extract_yaml_from_output(output);
assert_eq!(yaml, "database:\n host: localhost\n port: 5432");
}
#[test]
fn test_extract_yaml_plain() {
let output = "name: test\nvalue: 42";
let yaml = extract_yaml_from_output(output);
assert_eq!(yaml, "name: test\nvalue: 42");
}
#[test]
fn test_validate_yaml_output() {
let template = Template::new("test").with_format(FormatType::Yaml);
assert!(template.validate_output("name: test\nvalue: 42").is_ok());
let with_block = "```yaml\nname: test\n```";
assert!(template.validate_output(with_block).is_ok());
assert!(template.validate_output("name: [unclosed").is_err());
}
#[test]
fn test_validate_yaml_with_schema() {
let content = r#"---
name: yaml_test
format:
type: yaml
schema:
type: object
required:
- name
- value
---
Generate YAML output.
"#;
let template = Template::from_str(content).unwrap();
let valid = "name: test\nvalue: 42";
assert!(template.validate_output(valid).is_ok());
let missing = "name: test";
assert!(template.validate_output(missing).is_err());
}
#[test]
fn test_parse_yaml_output() {
#[derive(Debug, serde::Deserialize, PartialEq)]
struct Config {
name: String,
port: u16,
}
let template = Template::new("test").with_format(FormatType::Yaml);
let output = "name: myapp\nport: 8080";
let config: Config = template.parse_output(output).unwrap();
assert_eq!(config.name, "myapp");
assert_eq!(config.port, 8080);
let with_block = "```yaml\nname: other\nport: 3000\n```";
let config2: Config = template.parse_output(with_block).unwrap();
assert_eq!(config2.name, "other");
assert_eq!(config2.port, 3000);
}
#[test]
fn test_format_type_default() {
let format = FormatSpec::default();
assert_eq!(format.format_type, FormatType::Plain);
}
#[test]
fn test_template_into_owned() {
let content = r#"---
name: test
---
Prompt
"#;
let template = Template::from_str(content).unwrap();
let owned: Template<'static> = template.into_owned();
assert_eq!(owned.name, "test");
}
#[test]
fn test_extract_format_instructions() {
let markdown = r#"
# Main Prompt
Do something.
## Output Format
Return JSON with the following fields:
- answer: string
- confidence: number
## Other Section
More content.
"#;
let instructions = extract_format_instructions(markdown);
assert!(instructions.contains("Return JSON"));
assert!(instructions.contains("answer: string"));
}
#[test]
fn test_parse_output_json() {
#[derive(Debug, serde::Deserialize, PartialEq)]
struct Answer {
answer: String,
}
let template = Template::new("test").with_format(FormatType::Json);
let result: Answer = template.parse_output(r#"{"answer": "42"}"#).unwrap();
assert_eq!(result.answer, "42");
}
#[test]
fn test_template_example_into_owned() {
let example = TemplateExample::new("input", "output");
let owned = example.into_owned();
assert_eq!(owned.input, "input");
assert_eq!(owned.output, "output");
}
#[test]
fn test_prompt_tone_modifiers() {
let inclusive = PromptTone::Inclusive.modifiers();
assert!(inclusive.requirement_prefix.contains("Consider"));
assert!(inclusive
.uncertainty_guidance
.contains("include rather than exclude"));
let balanced = PromptTone::Balanced.modifiers();
assert!(balanced.requirement_prefix.contains("Include"));
assert!(!balanced.requirement_prefix.contains("MUST"));
let restrictive = PromptTone::Restrictive.modifiers();
assert!(restrictive.requirement_prefix.contains("MUST"));
assert!(restrictive
.uncertainty_guidance
.contains("exclude rather than include"));
}
#[test]
fn test_prompt_tone_default_thresholds() {
assert!((PromptTone::Inclusive.default_threshold() - 0.6).abs() < 0.001);
assert!((PromptTone::Balanced.default_threshold() - 0.8).abs() < 0.001);
assert!((PromptTone::Restrictive.default_threshold() - 0.9).abs() < 0.001);
}
#[test]
fn test_prompt_tone_favors() {
assert!(PromptTone::Inclusive.favors_recall());
assert!(!PromptTone::Inclusive.favors_precision());
assert!(!PromptTone::Balanced.favors_recall());
assert!(!PromptTone::Balanced.favors_precision());
assert!(!PromptTone::Restrictive.favors_recall());
assert!(PromptTone::Restrictive.favors_precision());
}
#[test]
fn test_template_with_tone() {
let template = Template::new("test")
.with_format(FormatType::Json)
.with_tone(PromptTone::Restrictive);
assert_eq!(template.options.tone, PromptTone::Restrictive);
let prompt = template.assemble_prompt("What is 2+2?", 0, None);
assert!(prompt.contains("EXACTLY"));
assert!(prompt.contains("exclude rather than include"));
}
#[test]
fn test_template_inclusive_tone_prompt() {
let template = Template::new("test")
.with_format(FormatType::Json)
.with_tone(PromptTone::Inclusive);
let prompt = template.assemble_prompt("What is 2+2?", 0, None);
assert!(prompt.contains("suggested format"));
assert!(prompt.contains("include rather than exclude"));
}
#[test]
fn test_prompt_tone_default() {
assert_eq!(PromptTone::default(), PromptTone::Balanced);
}
#[test]
fn test_template_options_default_tone() {
let options = TemplateOptions::default();
assert_eq!(options.tone, PromptTone::Balanced);
}
}