use crate::context::{Context, ContextKey, Fact};
use std::collections::HashSet;
use std::fmt::Write;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PromptFormat {
Plain,
#[default]
Edn,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AgentRole {
Proposer,
Validator,
Synthesizer,
Analyzer,
}
impl AgentRole {
fn to_keyword(self) -> &'static str {
match self {
Self::Proposer => ":proposer",
Self::Validator => ":validator",
Self::Synthesizer => ":synthesizer",
Self::Analyzer => ":analyzer",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum Constraint {
NoInvent,
NoContradict,
NoHallucinate,
CiteSources,
}
impl Constraint {
fn to_keyword(self) -> &'static str {
match self {
Self::NoInvent => ":no-invent",
Self::NoContradict => ":no-contradict",
Self::NoHallucinate => ":no-hallucinate",
Self::CiteSources => ":cite-sources",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OutputContract {
pub emit: String,
pub key: ContextKey,
pub format: Option<String>,
}
impl OutputContract {
#[must_use]
pub fn new(emit: impl Into<String>, key: ContextKey) -> Self {
Self {
emit: emit.into(),
key,
format: None,
}
}
#[must_use]
pub fn with_format(mut self, format: impl Into<String>) -> Self {
self.format = Some(format.into());
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AgentPrompt {
pub role: AgentRole,
pub objective: String,
pub context: PromptContext,
pub constraints: HashSet<Constraint>,
pub output_contract: OutputContract,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct PromptContext {
pub facts: Vec<(ContextKey, Vec<Fact>)>,
}
impl PromptContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add_facts(&mut self, key: ContextKey, facts: Vec<Fact>) {
if !facts.is_empty() {
self.facts.push((key, facts));
}
}
#[must_use]
pub fn from_context(ctx: &Context, dependencies: &[ContextKey]) -> Self {
let mut prompt_ctx = Self::new();
for &key in dependencies {
let facts = ctx.get(key).to_vec();
prompt_ctx.add_facts(key, facts);
}
prompt_ctx
}
}
fn context_key_to_keyword(key: ContextKey) -> &'static str {
match key {
ContextKey::Seeds => ":seeds",
ContextKey::Hypotheses => ":hypotheses",
ContextKey::Strategies => ":strategies",
ContextKey::Constraints => ":constraints",
ContextKey::Signals => ":signals",
ContextKey::Competitors => ":competitors",
ContextKey::Evaluations => ":evaluations",
ContextKey::Proposals => ":proposals",
ContextKey::Diagnostic => ":diagnostic",
}
}
impl AgentPrompt {
#[must_use]
pub fn new(
role: AgentRole,
objective: impl Into<String>,
context: PromptContext,
output_contract: OutputContract,
) -> Self {
Self {
role,
objective: objective.into(),
context,
constraints: HashSet::new(),
output_contract,
}
}
#[must_use]
pub fn with_constraint(mut self, constraint: Constraint) -> Self {
self.constraints.insert(constraint);
self
}
#[must_use]
pub fn with_constraints(mut self, constraints: impl IntoIterator<Item = Constraint>) -> Self {
self.constraints.extend(constraints);
self
}
#[must_use]
pub fn to_edn(&self) -> String {
let mut s = String::new();
s.push_str("{:r ");
s.push_str(self.role.to_keyword());
s.push_str("\n :o :");
s.push_str(&self.objective.replace(' ', "-"));
s.push_str("\n :c {");
let mut first_key = true;
for (key, facts) in &self.context.facts {
if !first_key {
s.push(' ');
}
first_key = false;
s.push_str(context_key_to_keyword(*key));
s.push_str(" [{");
for (i, fact) in facts.iter().enumerate() {
if i > 0 {
s.push_str("} {");
}
s.push_str(":id \"");
s.push_str(&escape_string(&fact.id));
s.push_str("\" :c \"");
s.push_str(&escape_string(&fact.content));
s.push('"');
}
s.push_str("}]");
}
s.push_str("}\n :k #{");
let mut constraints: Vec<_> = self.constraints.iter().collect();
constraints.sort(); for (i, constraint) in constraints.iter().enumerate() {
if i > 0 {
s.push(' ');
}
s.push_str(constraint.to_keyword());
}
s.push_str("}\n :out {:emit :");
s.push_str(&self.output_contract.emit);
s.push_str(" :key ");
s.push_str(context_key_to_keyword(self.output_contract.key));
if let Some(ref format) = self.output_contract.format {
s.push_str(" :format :");
s.push_str(format);
}
s.push_str("}}");
s
}
#[must_use]
pub fn to_plain(&self) -> String {
let mut s = String::new();
writeln!(s, "Role: {:?}", self.role).unwrap();
writeln!(s, "Objective: {}", self.objective).unwrap();
writeln!(s, "\nContext:").unwrap();
for (key, facts) in &self.context.facts {
writeln!(s, "\n## {key:?}").unwrap();
for fact in facts {
writeln!(s, "- {}: {}", fact.id, fact.content).unwrap();
}
}
if !self.constraints.is_empty() {
writeln!(s, "\nConstraints:").unwrap();
for constraint in &self.constraints {
writeln!(s, "- {constraint:?}").unwrap();
}
}
writeln!(
s,
"\nOutput: {:?} -> {:?}",
self.output_contract.emit, self.output_contract.key
)
.unwrap();
s
}
#[must_use]
pub fn serialize(&self, format: PromptFormat) -> String {
match format {
PromptFormat::Edn => self.to_edn(),
PromptFormat::Plain => self.to_plain(),
}
}
}
fn escape_string(s: &str) -> String {
s.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\t', "\\t")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::{Context, Fact};
#[test]
fn test_edn_serialization() {
let mut ctx = PromptContext::new();
ctx.add_facts(
ContextKey::Signals,
vec![
Fact {
key: ContextKey::Signals,
id: "s1".to_string(),
content: "Revenue +15% Q3".to_string(),
},
Fact {
key: ContextKey::Signals,
id: "s2".to_string(),
content: "Market $2.3B".to_string(),
},
],
);
let prompt = AgentPrompt::new(
AgentRole::Proposer,
"extract-competitors",
ctx,
OutputContract::new("proposed-fact", ContextKey::Competitors),
)
.with_constraint(Constraint::NoInvent)
.with_constraint(Constraint::NoContradict);
let edn = prompt.to_edn();
assert!(edn.contains(":r :proposer"));
assert!(edn.contains(":o :extract-competitors"));
assert!(edn.contains(":signals"));
assert!(edn.contains(":no-invent"));
assert!(edn.contains(":no-contradict"));
assert!(edn.contains(":competitors"));
}
#[test]
fn test_context_building() {
let mut context = Context::new();
context
.add_fact(Fact {
key: ContextKey::Seeds,
id: "seed1".to_string(),
content: "Test seed".to_string(),
})
.unwrap();
let prompt_ctx = PromptContext::from_context(&context, &[ContextKey::Seeds]);
assert_eq!(prompt_ctx.facts.len(), 1);
assert_eq!(prompt_ctx.facts[0].0, ContextKey::Seeds);
assert_eq!(prompt_ctx.facts[0].1.len(), 1);
}
#[test]
fn test_escape_string() {
assert_eq!(escape_string("hello"), "hello");
assert_eq!(escape_string("hello\"world"), "hello\\\"world");
assert_eq!(escape_string("hello\nworld"), "hello\\nworld");
}
#[test]
fn test_token_efficiency() {
let mut ctx = PromptContext::new();
ctx.add_facts(
ContextKey::Signals,
vec![Fact {
key: ContextKey::Signals,
id: "s1".to_string(),
content: "Revenue +15% Q3".to_string(),
}],
);
let prompt = AgentPrompt::new(
AgentRole::Proposer,
"analyze",
ctx,
OutputContract::new("proposed-fact", ContextKey::Strategies),
);
let edn = prompt.to_edn();
let plain = prompt.to_plain();
println!("EDN length: {}", edn.len());
println!("Plain length: {}", plain.len());
println!("EDN:\n{edn}");
println!("Plain:\n{plain}");
assert!(!edn.is_empty());
assert!(!plain.is_empty());
}
}