use super::Agent;
use crate::agent::dialogue::joining_strategy::JoiningStrategy;
use crate::agent::history::HistoryAwareAgent;
use crate::agent::persona::{Persona, PersonaAgent};
pub struct Chat<A: Agent> {
agent: A,
with_history: bool,
identity: Option<(String, String)>, joining_strategy: Option<JoiningStrategy>,
}
impl<A: Agent> Chat<A> {
pub fn new(agent: A) -> Self {
Self {
agent,
with_history: true,
identity: None,
joining_strategy: None,
}
}
pub fn with_persona(self, persona: Persona) -> Chat<PersonaAgent<A>>
where
A::Output: Send,
{
let identity = Some((persona.name.clone(), persona.role.clone()));
Chat {
agent: PersonaAgent::new(self.agent, persona),
with_history: self.with_history,
identity,
joining_strategy: self.joining_strategy,
}
}
pub fn with_history(mut self, enabled: bool) -> Self {
self.with_history = enabled;
self
}
pub fn with_joining_strategy(mut self, joining_strategy: Option<JoiningStrategy>) -> Self {
self.joining_strategy = joining_strategy;
self
}
pub fn build(self) -> Box<crate::agent::AnyAgent<A::Output>>
where
A: 'static,
A::Output: 'static + Send,
{
if self.with_history {
match self.identity {
Some((name, role)) => crate::agent::AnyAgent::boxed(
HistoryAwareAgent::new_with_identity(self.agent, name, role),
),
None => crate::agent::AnyAgent::boxed(HistoryAwareAgent::new(self.agent)),
}
} else {
crate::agent::AnyAgent::boxed(self.agent)
}
}
}
impl<A: Agent> Chat<PersonaAgent<A>>
where
A::Output: Send,
{
pub fn with_context_config(mut self, config: crate::agent::persona::ContextConfig) -> Self {
self.agent = self.agent.with_context_config(config);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::{AgentError, Payload, PayloadMessage};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
struct TestAgent {
calls: Arc<Mutex<Vec<String>>>,
response: String,
}
impl TestAgent {
fn new(response: &str) -> Self {
Self {
calls: Arc::new(Mutex::new(Vec::new())),
response: response.to_string(),
}
}
async fn get_calls(&self) -> Vec<String> {
self.calls.lock().await.clone()
}
}
#[async_trait]
impl Agent for TestAgent {
type Output = String;
type Expertise = &'static str;
fn expertise(&self) -> &&'static str {
const EXPERTISE: &str = "Test agent for Chat builder";
&EXPERTISE
}
async fn execute(&self, intent: Payload) -> Result<Self::Output, AgentError> {
self.calls.lock().await.push(intent.to_text());
Ok(self.response.clone())
}
}
#[tokio::test]
async fn test_chat_builder_with_history() {
let test_agent = TestAgent::new("response");
let chat = Chat::new(test_agent.clone()).build();
let result1 = chat
.execute(Payload::from_messages(vec![PayloadMessage::user(
"User", "User", "Hello",
)]))
.await
.unwrap();
assert_eq!(result1, "response");
let result2 = chat
.execute(Payload::from_messages(vec![PayloadMessage::user(
"User",
"User",
"How are you?",
)]))
.await
.unwrap();
assert_eq!(result2, "response");
let calls = test_agent.get_calls().await;
assert_eq!(calls.len(), 2);
assert!(calls[1].contains("Previous conversation"));
assert!(calls[1].contains("Hello"));
}
#[tokio::test]
async fn test_chat_builder_without_history() {
let test_agent = TestAgent::new("response");
let chat = Chat::new(test_agent.clone()).with_history(false).build();
let result1 = chat.execute(Payload::text("Hello")).await.unwrap();
assert_eq!(result1, "response");
let result2 = chat.execute(Payload::text("How are you?")).await.unwrap();
assert_eq!(result2, "response");
let calls = test_agent.get_calls().await;
assert_eq!(calls.len(), 2);
assert!(!calls[1].contains("Previous conversation"));
assert_eq!(calls[1], "How are you?");
}
#[tokio::test]
async fn test_chat_builder_with_persona() {
let test_agent = TestAgent::new("response");
let persona = Persona {
name: "TestBot".to_string(),
role: "Test Assistant".to_string(),
background: "A helpful test bot".to_string(),
communication_style: "Direct and clear".to_string(),
visual_identity: None,
capabilities: None,
};
let chat = Chat::new(test_agent.clone())
.with_persona(persona)
.with_history(false) .build();
let result = chat.execute(Payload::text("Hello")).await.unwrap();
assert_eq!(result, "response");
let calls = test_agent.get_calls().await;
assert_eq!(calls.len(), 1);
assert!(calls[0].contains("Persona Profile"));
assert!(calls[0].contains("TestBot"));
assert!(calls[0].contains("Test Assistant"));
}
#[tokio::test]
async fn test_chat_builder_with_persona_and_history() {
let test_agent = TestAgent::new("response");
let persona = Persona {
name: "Alice".to_string(),
role: "Assistant".to_string(),
background: "Helpful AI".to_string(),
communication_style: "Friendly".to_string(),
visual_identity: None,
capabilities: None,
};
let chat = Chat::new(test_agent.clone()).with_persona(persona).build();
let _ = chat
.execute(Payload::from_messages(vec![PayloadMessage::user(
"User", "User", "Hi",
)]))
.await
.unwrap();
let _ = chat
.execute(Payload::from_messages(vec![PayloadMessage::user(
"User", "User", "Bye",
)]))
.await
.unwrap();
let calls = test_agent.get_calls().await;
assert_eq!(calls.len(), 2);
println!("=== Second call ===\n{}\n=== End ===", calls[1]);
assert!(calls[1].contains("Previous conversation"));
assert!(calls[1].contains("Persona Profile"));
assert!(calls[1].contains("Alice"));
}
#[tokio::test]
async fn test_chat_builder_expertise_delegation() {
let test_agent = TestAgent::new("response");
let chat = Chat::new(test_agent).build();
assert_eq!(chat.expertise(), "Test agent for Chat builder");
}
#[tokio::test]
async fn test_chat_builder_expertise_with_persona() {
let test_agent = TestAgent::new("response");
let persona = Persona {
name: "Bob".to_string(),
role: "Expert Coder".to_string(),
background: "Senior developer".to_string(),
communication_style: "Technical".to_string(),
visual_identity: None,
capabilities: None,
};
let chat = Chat::new(test_agent)
.with_persona(persona)
.with_history(false)
.build();
assert_eq!(chat.expertise(), "Expert Coder");
}
#[tokio::test]
async fn test_chat_builder_with_context_config() {
use crate::agent::persona::ContextConfig;
let test_agent = TestAgent::new("response");
let persona = Persona {
name: "Alice".to_string(),
role: "Assistant".to_string(),
background: "Helpful assistant".to_string(),
communication_style: "Friendly".to_string(),
visual_identity: None,
capabilities: None,
};
let config = ContextConfig {
long_conversation_threshold: 100,
recent_messages_count: 5,
participants_after_context: true,
include_trailing_prompt: true,
};
let chat = Chat::new(test_agent)
.with_persona(persona)
.with_context_config(config)
.build();
let result = chat.execute(Payload::text("Test message")).await.unwrap();
assert_eq!(result, "response");
}
}