use super::dialogue::ExecutionModel;
use super::{Agent, AgentError, Payload, RelatedParticipant, participant_relation};
use crate::ToPrompt;
use crate::agent::payload_message::format_messages_with_relation;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(ToPrompt, Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[prompt(template = "")] pub struct VisualIdentity {
pub icon: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tagline: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub color: Option<String>,
}
impl VisualIdentity {
pub fn new(icon: impl Into<String>) -> Self {
Self {
icon: icon.into(),
tagline: None,
color: None,
}
}
pub fn with_tagline(mut self, tagline: impl Into<String>) -> Self {
self.tagline = Some(tagline.into());
self
}
pub fn with_color(mut self, color: impl Into<String>) -> Self {
self.color = Some(color.into());
self
}
}
#[derive(ToPrompt, Serialize, Deserialize, Clone, Debug)]
#[prompt(
template = "{% if visual_identity %}{{ visual_identity.icon }} {% endif %}# Persona Profile
**Name**: {{ name }}
**Role**: {{ role }}
{% if visual_identity and visual_identity.tagline %}**Tagline**: {{ visual_identity.tagline }}
{% endif %}
## Background
{{ background }}
## Communication Style
{{ communication_style }}
{% if capabilities %}
## Capabilities
{% for cap in capabilities %}
- {{ cap }}
{% endfor %}
{% endif %}"
)]
pub struct Persona {
pub name: String,
pub role: String,
pub background: String,
pub communication_style: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub visual_identity: Option<VisualIdentity>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub capabilities: Option<Vec<super::Capability>>,
}
impl Persona {
pub fn new(name: impl Into<String>, role: impl Into<String>) -> Self {
Self {
name: name.into(),
role: role.into(),
background: String::new(),
communication_style: String::new(),
visual_identity: None,
capabilities: None,
}
}
pub fn with_background(mut self, background: impl Into<String>) -> Self {
self.background = background.into();
self
}
pub fn with_communication_style(mut self, style: impl Into<String>) -> Self {
self.communication_style = style.into();
self
}
pub fn with_visual_identity(mut self, identity: VisualIdentity) -> Self {
self.visual_identity = Some(identity);
self
}
pub fn with_icon(mut self, icon: impl Into<String>) -> Self {
self.visual_identity = Some(VisualIdentity::new(icon));
self
}
pub fn icon(&self) -> Option<&str> {
self.visual_identity.as_ref().map(|v| v.icon.as_str())
}
pub fn tagline(&self) -> Option<&str> {
self.visual_identity
.as_ref()
.and_then(|v| v.tagline.as_deref())
}
pub fn with_capabilities(mut self, capabilities: Vec<super::Capability>) -> Self {
self.capabilities = Some(capabilities);
self
}
pub fn display_name(&self) -> String {
match self.icon() {
Some(icon) => format!("{} {}", icon, self.name),
None => self.name.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonaTeam {
pub team_name: String,
pub context: String,
pub personas: Vec<Persona>,
#[serde(skip_serializing_if = "Option::is_none")]
pub execution_strategy: Option<ExecutionModel>,
}
#[derive(Serialize, ToPrompt)]
#[prompt(template = r##"
# Persona Team Generation Task
You are an expert in team composition and organizational dynamics. Your task is to generate a well-balanced team of personas for a specific scenario.
## Scenario Context
{{ context }}
{% if role_descriptions %}
## Required Roles
{{ role_descriptions }}
{% endif %}
{% if team_graph %}
### Team Structure
```mermaid
{{ team_graph }}
```
{% endif %}
---
## Your Task
Generate a PersonaTeam as a JSON object with the following structure:
```json
{
"team_name": "A descriptive name for this team (e.g., 'HR SaaS Development Team')",
"context": "Brief description of the scenario/context",
"execution_strategy": "sequential" or "broadcast" (choose based on the scenario),
"personas": [
{
"name": "A realistic name for this person",
"role": "Job title/function (e.g., 'Product Owner', 'UX Designer')",
"background": "2-3 sentences describing their relevant experience, expertise, and perspective. Be specific to make them feel real.",
"communication_style": "1-2 sentences describing how they communicate, make decisions, and collaborate with others."
}
]
}
```
**Guidelines:**
1. **Team Composition**: Analyze the scenario and create personas that cover all necessary perspectives and expertise areas
2. **Balanced Team**: Ensure diversity in perspectives, experience levels, and communication styles
3. **Realistic Personas**: Give each persona:
- A realistic name (vary cultural backgrounds naturally)
- Specific expertise and experience (not generic)
- Distinct communication style (some data-driven, some user-focused, etc.)
- Believable background (mention years of experience, past roles, specializations)
4. **Execution Strategy Selection**:
- Use "broadcast" for brainstorming, reviews, or when all voices should be heard simultaneously
- Use "sequential" for process-driven workflows (e.g., requirements → design → implementation → QA)
5. **Team Size**: Generally 3-6 personas is ideal. Too few lacks perspective; too many becomes unwieldy.
6. **Role Coverage**: For typical product development scenarios, consider including:
- Decision maker (Product Owner, Tech Lead)
- User advocate (UX Designer, Customer Success)
- Technical experts (Engineers, Architects)
- Quality/Risk (QA, Security, DevOps if relevant)
**Important:** Return ONLY the JSON object, no additional explanation or commentary.
"##)]
pub struct PersonaTeamGenerationRequest {
pub context: String,
#[serde(skip_serializing_if = "String::is_empty")]
pub role_descriptions: String,
#[serde(skip_serializing_if = "String::is_empty")]
pub team_graph: String,
}
impl PersonaTeamGenerationRequest {
pub fn new(context: String) -> Self {
Self {
context,
role_descriptions: String::new(),
team_graph: String::new(),
}
}
pub fn with_role_descriptions(mut self, descriptions: String) -> Self {
self.role_descriptions = descriptions;
self
}
pub fn with_team_graph(mut self, graph: String) -> Self {
self.team_graph = graph;
self
}
}
impl PersonaTeam {
pub fn new(team_name: String, context: String) -> Self {
Self {
team_name,
context,
personas: Vec::new(),
execution_strategy: None,
}
}
pub fn add_persona(&mut self, persona: Persona) -> &mut Self {
self.personas.push(persona);
self
}
pub fn with_execution_strategy(mut self, strategy: ExecutionModel) -> Self {
self.execution_strategy = Some(strategy);
self
}
pub fn load(path: impl AsRef<std::path::Path>) -> Result<Self, std::io::Error> {
let content = std::fs::read_to_string(path)?;
serde_json::from_str(&content)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), std::io::Error> {
let content = serde_json::to_string_pretty(self)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
std::fs::write(path, content)
}
}
fn relate_participants<'a>(
participants: impl IntoIterator<Item = &'a super::dialogue::ParticipantInfo>,
self_name: &str,
) -> Vec<RelatedParticipant> {
participants
.into_iter()
.cloned()
.map(|participant| {
let relation = participant_relation(&participant, self_name);
RelatedParticipant::new(participant, relation)
})
.collect()
}
fn format_participants_with_relation(
participants: &[super::dialogue::ParticipantInfo],
self_name: &str,
) -> String {
relate_participants(participants.iter(), self_name)
.into_iter()
.map(|participant| participant.format_line())
.collect::<Vec<_>>()
.join("\n")
}
#[derive(ToPrompt, Serialize)]
#[prompt(template = r#"YOU ARE A PERSONA-DRIVEN AI AGENT.
{{persona}}
{% if participants_before %}
# Participants
{{participants_before}}
{% endif %}{% if context %}
# Conversation Context (History)
{{context}}{% endif %}
{% if participants_after %}
# Participants
{{participants_after}}
{% endif %}{% if current_content %}
# Current Messages
{{current_content}}
{% endif %}{% if trailing_prompt %}
{{trailing_prompt}}
{% endif %}"#)]
struct PersonaAgentPrompt {
persona: Persona,
#[serde(skip_serializing_if = "String::is_empty")]
participants_before: String,
#[serde(skip_serializing_if = "String::is_empty")]
context: String,
#[serde(skip_serializing_if = "String::is_empty")]
participants_after: String,
#[serde(skip_serializing_if = "String::is_empty")]
current_content: String,
#[serde(skip_serializing_if = "String::is_empty")]
trailing_prompt: String,
}
#[derive(Debug, Clone)]
pub struct ContextConfig {
pub long_conversation_threshold: usize,
pub recent_messages_count: usize,
pub participants_after_context: bool,
pub include_trailing_prompt: bool,
}
impl Default for ContextConfig {
fn default() -> Self {
Self {
long_conversation_threshold: 5000,
recent_messages_count: 10, participants_after_context: false,
include_trailing_prompt: false,
}
}
}
pub struct PersonaAgent<T: Agent> {
inner_agent: T,
persona: Persona,
context_config: ContextConfig,
}
impl<T: Agent> PersonaAgent<T> {
pub fn new(inner_agent: T, persona: Persona) -> Self {
Self {
inner_agent,
persona,
context_config: ContextConfig::default(),
}
}
pub fn with_context_config(mut self, config: ContextConfig) -> Self {
self.context_config = config;
self
}
}
#[async_trait]
impl<T> Agent for PersonaAgent<T>
where
T: Agent + Send + Sync,
T::Output: Send,
{
type Output = T::Output;
type Expertise = String;
fn expertise(&self) -> &String {
&self.persona.role
}
fn capabilities(&self) -> Option<Vec<super::Capability>> {
self.persona.capabilities.clone()
}
#[crate::tracing::instrument(
name = "persona_agent.execute",
skip(self, intent),
fields(
agent.name = %self.persona.name,
agent.role = %self.persona.role,
has_participants = intent.participants().is_some(),
message_count = intent.to_messages().len(),
)
)]
async fn execute(&self, intent: Payload) -> Result<Self::Output, AgentError> {
let participants_text = intent
.participants()
.map(|participants| format_participants_with_relation(participants, &self.persona.name))
.unwrap_or_default();
let contexts = intent.contexts();
let context_string = if !contexts.is_empty() {
Some(contexts.join("\n\n"))
} else {
None
};
let mut context_text = String::new();
let text_content = intent.to_text();
if !text_content.is_empty() {
context_text.push_str(&text_content);
}
let system_messages: Vec<String> = intent
.to_messages()
.into_iter()
.filter_map(|msg| {
if matches!(msg.speaker, super::dialogue::Speaker::System) {
Some(msg.content)
} else {
None
}
})
.collect();
if !system_messages.is_empty() {
if !context_text.is_empty() {
context_text.push_str("\n\n");
}
context_text.push_str(&system_messages.join("\n\n"));
}
let messages = intent.to_messages();
let total_content_count = intent.total_content_count();
let (context_with_basic, current_messages_text) = if let Some(ctx_str) = context_string {
let total_message_length: usize = messages.iter().map(|m| m.content.len()).sum();
let is_long_conversation =
total_message_length >= self.context_config.long_conversation_threshold;
if is_long_conversation {
let split_point = messages
.len()
.saturating_sub(self.context_config.recent_messages_count);
let (old_messages, recent_messages) = messages.split_at(split_point);
let old_messages_text = if !old_messages.is_empty() {
format_messages_with_relation(
old_messages,
&self.persona.name,
total_content_count,
)
} else {
String::new()
};
let basic_context_section =
format!("\n\n---\n\n# Basic Context\n\n{}\n\n---\n\n", ctx_str);
let recent_messages_text = if !recent_messages.is_empty() {
format_messages_with_relation(
recent_messages,
&self.persona.name,
total_content_count,
)
} else {
String::new()
};
let mut combined_context = context_text.clone();
if !old_messages_text.is_empty() {
if !combined_context.is_empty() {
combined_context.push_str("\n\n");
}
combined_context.push_str(&old_messages_text);
}
combined_context.push_str(&basic_context_section);
(combined_context, recent_messages_text)
} else {
let basic_context_section =
format!("\n\n---\n\n# Basic Context\n\n{}\n\n---\n\n", ctx_str);
let mut combined_context = context_text.clone();
combined_context.push_str(&basic_context_section);
let all_messages_text = format_messages_with_relation(
&messages,
&self.persona.name,
total_content_count,
);
(combined_context, all_messages_text)
}
} else {
let current_messages_text =
format_messages_with_relation(&messages, &self.persona.name, total_content_count);
(context_text, current_messages_text)
};
let (participants_before, participants_after) =
if self.context_config.participants_after_context {
(String::new(), participants_text)
} else {
(participants_text, String::new())
};
let trailing_prompt = if self.context_config.include_trailing_prompt {
format!("YOU ({}):", self.persona.name)
} else {
String::new()
};
let prompt_struct = PersonaAgentPrompt {
persona: self.persona.clone(),
participants_before,
context: context_with_basic,
participants_after,
current_content: current_messages_text,
trailing_prompt,
};
let prompt_text = prompt_struct.to_prompt();
crate::tracing::debug!(
target: "llm_toolkit::agent::persona",
persona_name = %self.persona.name,
prompt_length = prompt_text.len(),
"Generated persona prompt"
);
crate::tracing::trace!(
target: "llm_toolkit::agent::persona",
"\n========== PERSONA PROMPT ==========\n{}\n====================================",
prompt_text
);
let final_payload = intent.clone().set_text(prompt_text.clone());
#[cfg(test)]
eprintln!(
"[PersonaAgent] final_payload text: '{:?}'\n prompt_text: '{}'",
final_payload, prompt_text,
);
self.inner_agent.execute(final_payload).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::dialogue::Speaker;
use crate::agent::{Agent, AgentError, Payload, PayloadMessage};
use crate::attachment::Attachment;
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
struct RecordingAgent<T: Clone + Serialize + DeserializeOwned + Send + Sync + 'static> {
calls: Arc<Mutex<Vec<Payload>>>,
response: T,
}
impl<T: Clone + Serialize + DeserializeOwned + Send + Sync + 'static> RecordingAgent<T> {
fn new(response: T) -> Self {
Self {
calls: Arc::new(Mutex::new(Vec::new())),
response,
}
}
async fn last_call(&self) -> Option<Payload> {
self.calls.lock().await.last().cloned()
}
}
#[async_trait]
impl<T> Agent for RecordingAgent<T>
where
T: Clone + Serialize + DeserializeOwned + Send + Sync + 'static,
{
type Output = T;
type Expertise = &'static str;
fn expertise(&self) -> &&'static str {
const EXPERTISE: &str = "Test agent";
&EXPERTISE
}
async fn execute(&self, intent: Payload) -> Result<Self::Output, AgentError> {
self.calls.lock().await.push(intent);
Ok(self.response.clone())
}
}
#[tokio::test]
async fn persona_agent_preserves_attachments() {
let persona = Persona {
name: "Tester".to_string(),
role: "Attachment Checker".to_string(),
background: "Validates payload handling.".to_string(),
communication_style: "Direct and concise.".to_string(),
visual_identity: None,
capabilities: None,
};
let base_agent = RecordingAgent::new(String::from("ok"));
let persona_agent = PersonaAgent::new(base_agent.clone(), persona);
let attachment = Attachment::in_memory(vec![1, 2, 3]);
let payload = Payload::text("Please inspect the data").with_attachment(attachment.clone());
let _ = persona_agent.execute(payload).await.unwrap();
let recorded_payload = base_agent.last_call().await.expect("call recorded");
assert!(
recorded_payload.has_attachments(),
"attachments should be preserved"
);
let attachments = recorded_payload.attachments();
assert_eq!(attachments.len(), 1);
assert_eq!(attachments[0], &attachment);
}
#[tokio::test]
async fn persona_agent_works() {
let persona = Persona {
name: "TestBot".to_string(),
role: "Test Assistant".to_string(),
background: "A helpful test bot for unit testing".to_string(),
communication_style: "Direct and clear".to_string(),
visual_identity: None,
capabilities: None,
};
let base_agent = RecordingAgent::new(String::from("response"));
let persona_agent = PersonaAgent::new(base_agent.clone(), persona);
let result = persona_agent
.execute(Payload::text("Initial conversation").with_message(
Speaker::User {
name: "User1".to_string(),
role: "User".to_string(),
},
"additional context here".to_string(),
))
.await
.unwrap();
assert_eq!(result, "response");
let call = base_agent.last_call().await.expect("call recorded");
let call_text = call.to_text();
assert!(call_text.contains("Persona Profile"));
assert!(call_text.contains("TestBot"));
assert!(call_text.contains("Test Assistant"));
}
#[test]
fn persona_to_prompt_template_expansion() {
use crate::ToPrompt;
let persona = Persona {
name: "Alice".to_string(),
role: "Engineer".to_string(),
background: "Senior software engineer with 10 years of experience".to_string(),
communication_style: "Direct and clear".to_string(),
visual_identity: None,
capabilities: None,
};
let prompt = persona.to_prompt();
assert!(
!prompt.contains("{{ name }}"),
"Template variables should be expanded, not left as placeholders"
);
assert!(
!prompt.contains("{{ role }}"),
"Template variables should be expanded, not left as placeholders"
);
assert!(
!prompt.contains("{{ background }}"),
"Template variables should be expanded, not left as placeholders"
);
assert!(
!prompt.contains("{{ communication_style }}"),
"Template variables should be expanded, not left as placeholders"
);
assert!(prompt.contains("Alice"), "Name should be in prompt");
assert!(prompt.contains("Engineer"), "Role should be in prompt");
assert!(
prompt.contains("Senior software engineer"),
"Background should be in prompt"
);
assert!(
prompt.contains("Direct and clear"),
"Communication style should be in prompt"
);
assert!(prompt.contains("# Persona Profile"), "Should have header");
assert!(prompt.contains("**Name**:"), "Should have Name label");
assert!(prompt.contains("**Role**:"), "Should have Role label");
assert!(
prompt.contains("## Background"),
"Should have Background section"
);
assert!(
prompt.contains("## Communication Style"),
"Should have Communication Style section"
);
println!("Generated prompt:\n{}", prompt);
}
#[test]
fn persona_agent_prompt_nested_template_expansion() {
use crate::ToPrompt;
let persona = Persona {
name: "Alice".to_string(),
role: "Engineer".to_string(),
background: "Senior software engineer".to_string(),
communication_style: "Direct and clear".to_string(),
visual_identity: None,
capabilities: None,
};
let prompt_struct = PersonaAgentPrompt {
persona: persona.clone(),
participants_before: "- Bob (Developer)\n- Charlie (Designer)".to_string(),
context: "additional context here".to_string(),
participants_after: String::new(),
current_content: "Please review the code".to_string(),
trailing_prompt: String::new(),
};
let prompt = prompt_struct.to_prompt();
eprintln!(
"=== Generated PersonaAgentPrompt ===\n{}\n=== End ===",
prompt
);
assert_eq!(
prompt,
r#"YOU ARE A PERSONA-DRIVEN AI AGENT.
# Persona Profile
**Name**: Alice
**Role**: Engineer
## Background
Senior software engineer
## Communication Style
Direct and clear
# Participants
- Bob (Developer)
- Charlie (Designer)
# Conversation Context (History)
additional context here
# Current Messages
Please review the code
"#
);
let is_json_serialized = prompt.contains(r#""name""#) || prompt.contains(r#""role""#);
if is_json_serialized {
println!(
"ISSUE CONFIRMED: Persona is being JSON serialized instead of using its ToPrompt template"
);
println!("Expected: Persona's formatted template with markdown");
println!("Actual: JSON representation of Persona struct");
}
assert!(
prompt.contains("# Persona Profile"),
"Should contain Persona's template header (not JSON)"
);
assert!(
prompt.contains("**Name**: Alice"),
"Should use Persona's template format (not JSON)"
);
assert!(
!is_json_serialized,
"Persona should use its ToPrompt template, not JSON serialization"
);
}
#[test]
fn persona_team_serialization() {
let mut team = PersonaTeam::new("Test Team".to_string(), "Testing scenario".to_string());
team.add_persona(Persona {
name: "Alice".to_string(),
role: "Developer".to_string(),
background: "Senior engineer".to_string(),
communication_style: "Technical".to_string(),
visual_identity: None,
capabilities: None,
});
team.add_persona(Persona {
name: "Bob".to_string(),
role: "Designer".to_string(),
background: "UX specialist".to_string(),
communication_style: "User-focused".to_string(),
visual_identity: None,
capabilities: None,
});
let json = serde_json::to_string_pretty(&team).unwrap();
assert!(json.contains("Test Team"));
assert!(json.contains("Alice"));
assert!(json.contains("Bob"));
let deserialized: PersonaTeam = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.team_name, "Test Team");
assert_eq!(deserialized.personas.len(), 2);
assert_eq!(deserialized.personas[0].name, "Alice");
assert_eq!(deserialized.personas[1].name, "Bob");
}
#[test]
fn persona_team_load_save() {
use tempfile::NamedTempFile;
let mut team = PersonaTeam::new("Dev Team".to_string(), "Software development".to_string());
team.add_persona(Persona {
name: "Charlie".to_string(),
role: "Tech Lead".to_string(),
background: "10 years experience".to_string(),
communication_style: "Strategic".to_string(),
visual_identity: None,
capabilities: None,
});
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path();
team.save(path).unwrap();
let loaded = PersonaTeam::load(path).unwrap();
assert_eq!(loaded.team_name, "Dev Team");
assert_eq!(loaded.personas.len(), 1);
assert_eq!(loaded.personas[0].name, "Charlie");
}
#[test]
fn persona_team_generation_request_prompt() {
use crate::ToPrompt;
let request = PersonaTeamGenerationRequest::new(
"Product development meeting for HR SaaS".to_string(),
)
.with_role_descriptions("PO, Designer, Engineer".to_string());
let prompt = request.to_prompt();
assert!(prompt.contains("Persona Team Generation Task"));
assert!(prompt.contains("Product development meeting"));
assert!(prompt.contains("PO, Designer, Engineer"));
}
#[tokio::test]
async fn persona_agent_formats_participants_with_self_as_you() {
use crate::agent::dialogue::ParticipantInfo;
let persona = Persona {
name: "Alice".to_string(),
role: "PM".to_string(),
background: "Product manager".to_string(),
communication_style: "Strategic".to_string(),
visual_identity: None,
capabilities: None,
};
let participants = vec![
ParticipantInfo::new(
"Alice".to_string(),
"PM".to_string(),
"Product manager".to_string(),
),
ParticipantInfo::new(
"Bob".to_string(),
"Engineer".to_string(),
"Backend developer".to_string(),
),
];
let base_agent = RecordingAgent::new("response".to_string());
let persona_agent = PersonaAgent::new(base_agent.clone(), persona);
let payload = Payload::text("Task").with_participants(participants);
let _ = persona_agent.execute(payload).await.unwrap();
let call = base_agent.last_call().await.unwrap();
let call_text = call.to_text();
assert!(call_text.contains("Alice (YOU)"));
assert!(call_text.contains("Bob"));
assert!(!call_text.contains("Bob (YOU)"));
}
#[tokio::test]
async fn persona_agent_formats_messages() {
let persona = Persona {
name: "Agent".to_string(),
role: "Assistant".to_string(),
background: "Helper".to_string(),
communication_style: "Friendly".to_string(),
visual_identity: None,
capabilities: None,
};
let base_agent = RecordingAgent::new("response".to_string());
let persona_agent = PersonaAgent::new(base_agent.clone(), persona);
let payload = Payload::from_messages(vec![
PayloadMessage::system("System instruction"),
PayloadMessage::user("Alice", "PM", "User message"),
]);
let _ = persona_agent.execute(payload).await.unwrap();
let call = base_agent.last_call().await.unwrap();
let call_text = call.to_text();
assert!(call_text.contains("[System]: System instruction"));
assert!(call_text.contains("[Alice]: User message"));
}
#[tokio::test]
async fn persona_agent_preserves_messages_structure() {
use crate::agent::dialogue::Speaker;
let persona = Persona {
name: "Agent".to_string(),
role: "Assistant".to_string(),
background: "Helper".to_string(),
communication_style: "Friendly".to_string(),
visual_identity: None,
capabilities: None,
};
let base_agent = RecordingAgent::new("response".to_string());
let persona_agent = PersonaAgent::new(base_agent.clone(), persona);
let original_messages = vec![
PayloadMessage::system("System msg"),
PayloadMessage::user("Alice", "PM", "User msg"),
];
let payload = Payload::from_messages(original_messages.clone());
let _ = persona_agent.execute(payload).await.unwrap();
let call = base_agent.last_call().await.unwrap();
let received_messages = call.to_messages();
assert_eq!(received_messages.len(), original_messages.len());
assert_eq!(received_messages[0].speaker, Speaker::System);
assert_eq!(received_messages[0].content, "System msg");
assert_eq!(received_messages[1].speaker, Speaker::user("Alice", "PM"));
assert_eq!(received_messages[1].content, "User msg");
}
#[tokio::test]
async fn persona_agent_full_integration() {
use crate::agent::dialogue::{ParticipantInfo, Speaker};
let persona = Persona {
name: "Alice".to_string(),
role: "PM".to_string(),
background: "Product manager with 5 years experience".to_string(),
communication_style: "Strategic and data-driven".to_string(),
visual_identity: None,
capabilities: None,
};
let participants = vec![
ParticipantInfo::new(
"Alice".to_string(),
"PM".to_string(),
"Product manager".to_string(),
),
ParticipantInfo::new(
"Bob".to_string(),
"Engineer".to_string(),
"Backend developer".to_string(),
),
];
let messages = vec![
PayloadMessage::system("Discuss feature priorities"),
PayloadMessage::user("Bob", "Engineer", "I suggest we focus on performance"),
];
let base_agent = RecordingAgent::new("Good idea".to_string());
let persona_agent = PersonaAgent::new(base_agent.clone(), persona);
let payload = Payload::from_messages(messages.clone()).with_participants(participants);
let result = persona_agent.execute(payload).await.unwrap();
assert_eq!(result, "Good idea");
let call = base_agent.last_call().await.unwrap();
let call_text = call.to_text();
println!("=== Actual call_text ===\n{}\n=== End ===", call_text);
assert!(call_text.contains("# Persona Profile"));
assert!(call_text.contains("**Name**: Alice"));
assert!(call_text.contains("# Participants"));
assert!(call_text.contains("**Alice (YOU)**"));
assert!(call_text.contains("**Bob (ALLY)**"));
assert!(call_text.contains("# Current Messages"));
assert!(call_text.contains("[System]: Discuss feature priorities"));
assert!(call_text.contains("[Bob]: I suggest we focus on performance"));
let received_messages = call.to_messages();
assert_eq!(received_messages.len(), 2);
assert_eq!(received_messages[0].speaker, Speaker::System);
assert_eq!(received_messages[0].content, "Discuss feature priorities");
assert_eq!(
received_messages[1].speaker,
Speaker::user("Bob", "Engineer")
);
assert_eq!(
received_messages[1].content,
"I suggest we focus on performance"
);
}
#[test]
fn persona_capabilities_formatting() {
use crate::ToPrompt;
use crate::agent::Capability;
let persona = Persona {
name: "DevBot".to_string(),
role: "Development Assistant".to_string(),
background: "Helps with software development tasks".to_string(),
communication_style: "Technical and precise".to_string(),
visual_identity: None,
capabilities: Some(vec![
Capability::new("file:read"),
Capability::new("file:write").with_description("Write content to a file"),
Capability::new("api:call").with_description("Make HTTP API calls"),
]),
};
let prompt = persona.to_prompt();
assert!(
prompt.contains("## Capabilities"),
"Capabilities section should be present"
);
assert!(
prompt.contains("- file:read"),
"First capability should be properly formatted"
);
assert!(
prompt.contains("- file:write: Write content to a file"),
"Second capability with description should be properly formatted"
);
assert!(
prompt.contains("- api:call: Make HTTP API calls"),
"Third capability with description should be properly formatted"
);
let lines: Vec<&str> = prompt.lines().collect();
let capability_section_start = lines
.iter()
.position(|line| line.contains("## Capabilities"))
.expect("Should have Capabilities section");
let cap_lines: Vec<&str> = lines
.iter()
.skip(capability_section_start + 1)
.filter(|line| !line.trim().is_empty())
.take(3)
.copied()
.collect();
assert_eq!(cap_lines.len(), 3, "Should have 3 capability lines");
for cap_line in &cap_lines {
assert!(
cap_line.len() > 5,
"Capability line should be a full line, not single characters. Got: '{}'",
cap_line
);
assert!(
cap_line.starts_with('-'),
"Capability line should start with '-'. Got: '{}'",
cap_line
);
}
}
#[derive(Clone)]
struct LocalMockAgent {
responses: Vec<String>,
call_count: std::sync::Arc<std::sync::Mutex<usize>>,
}
impl LocalMockAgent {
fn new(responses: Vec<&str>) -> Self {
Self {
responses: responses.iter().map(|s| s.to_string()).collect(),
call_count: std::sync::Arc::new(std::sync::Mutex::new(0)),
}
}
}
#[async_trait]
impl Agent for LocalMockAgent {
type Output = String;
type Expertise = &'static str;
fn expertise(&self) -> &&'static str {
const EXPERTISE: &str = "Mock agent for testing";
&EXPERTISE
}
async fn execute(&self, _payload: Payload) -> Result<Self::Output, AgentError> {
let mut count = self.call_count.lock().unwrap();
let response = self
.responses
.get(*count)
.unwrap_or(&self.responses[0])
.clone();
*count += 1;
Ok(response)
}
}
#[tokio::test]
async fn test_persona_agent_context_placement_short_conversation() {
let inner_agent = LocalMockAgent::new(vec!["Response"]);
let persona = Persona::new("TestBot", "Test persona");
let agent = PersonaAgent::new(inner_agent, persona);
let payload = Payload::from_messages(vec![
PayloadMessage::user("User1", "Role1", "Hello"),
PayloadMessage::system("System response"),
])
.with_context("Important context");
let result = agent.execute(payload).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_persona_agent_context_placement_long_conversation() {
let inner_agent = LocalMockAgent::new(vec!["Response"]);
let persona = Persona::new("TestBot", "Test persona");
let config = ContextConfig {
long_conversation_threshold: 50, recent_messages_count: 2,
participants_after_context: false,
include_trailing_prompt: false,
};
let agent = PersonaAgent::new(inner_agent, persona).with_context_config(config);
let mut messages = vec![];
for i in 0..10 {
messages.push(PayloadMessage::user(
"User1",
"Role1",
format!("Message {}", i),
));
messages.push(PayloadMessage::system(format!("Response {}", i)));
}
let payload = Payload::from_messages(messages).with_context("Strategic context");
let result = agent.execute(payload).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_persona_agent_multiple_contexts() {
let inner_agent = LocalMockAgent::new(vec!["Response"]);
let persona = Persona::new("TestBot", "Test persona");
let agent = PersonaAgent::new(inner_agent, persona);
let payload = Payload::text("Question")
.with_context("Context 1")
.with_context("Context 2")
.with_context("Context 3");
let result = agent.execute(payload).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_persona_agent_no_context() {
let inner_agent = LocalMockAgent::new(vec!["Response"]);
let persona = Persona::new("TestBot", "Test persona");
let agent = PersonaAgent::new(inner_agent, persona);
let payload = Payload::text("Question without context");
let result = agent.execute(payload).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Response");
}
#[tokio::test]
async fn test_context_config_customization() {
let config = ContextConfig {
long_conversation_threshold: 10000,
recent_messages_count: 20,
participants_after_context: false,
include_trailing_prompt: false,
};
let inner_agent = LocalMockAgent::new(vec!["Response"]);
let persona = Persona::new("TestBot", "Test persona");
let agent = PersonaAgent::new(inner_agent, persona).with_context_config(config.clone());
assert_eq!(agent.context_config.long_conversation_threshold, 10000);
assert_eq!(agent.context_config.recent_messages_count, 20);
}
#[tokio::test]
async fn test_participants_after_context_strategy() {
let inner_agent = LocalMockAgent::new(vec!["Response"]);
let persona = Persona::new("TestBot", "Test persona");
let config = ContextConfig {
long_conversation_threshold: 5000,
recent_messages_count: 10,
participants_after_context: true,
include_trailing_prompt: false,
};
let agent = PersonaAgent::new(inner_agent, persona).with_context_config(config);
let payload = Payload::text("Test message").with_context("Important context");
let result = agent.execute(payload).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_trailing_prompt_strategy() {
let inner_agent = LocalMockAgent::new(vec!["Response"]);
let persona = Persona::new("Alice", "Engineer");
let config = ContextConfig {
long_conversation_threshold: 5000,
recent_messages_count: 10,
participants_after_context: false,
include_trailing_prompt: true,
};
let agent = PersonaAgent::new(inner_agent, persona).with_context_config(config);
let payload = Payload::text("What should we do?");
let result = agent.execute(payload).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_combined_strategies() {
let inner_agent = LocalMockAgent::new(vec!["Response"]);
let persona = Persona::new("Bob", "Designer");
let config = ContextConfig {
long_conversation_threshold: 100,
recent_messages_count: 2,
participants_after_context: true,
include_trailing_prompt: true,
};
let agent = PersonaAgent::new(inner_agent, persona).with_context_config(config);
let mut messages = vec![];
for i in 0..10 {
messages.push(PayloadMessage::user(
"User1",
"Role1",
format!("Message {}", i),
));
}
let payload = Payload::from_messages(messages).with_context("Design context");
let result = agent.execute(payload).await;
assert!(result.is_ok());
}
}