use crate::error::{MagiError, ProviderError};
use crate::provider::{CompletionConfig, LlmProvider};
use crate::schema::{AgentName, Mode};
use std::collections::BTreeMap;
use std::path::Path;
use std::sync::Arc;
use crate::prompts;
const ALL_MODES: [Mode; 3] = [Mode::CodeReview, Mode::Design, Mode::Analysis];
pub struct Agent {
name: AgentName,
mode: Mode,
system_prompt: String,
provider: Arc<dyn LlmProvider>,
}
impl Agent {
pub fn new(name: AgentName, mode: Mode, provider: Arc<dyn LlmProvider>) -> Self {
let prompt = match name {
AgentName::Melchior => prompts::melchior::prompt_for_mode(&mode),
AgentName::Balthasar => prompts::balthasar::prompt_for_mode(&mode),
AgentName::Caspar => prompts::caspar::prompt_for_mode(&mode),
};
Self {
name,
mode,
system_prompt: prompt.to_string(),
provider,
}
}
pub fn with_custom_prompt(
name: AgentName,
mode: Mode,
provider: Arc<dyn LlmProvider>,
prompt: String,
) -> Self {
Self {
name,
mode,
system_prompt: prompt,
provider,
}
}
pub fn from_file(
name: AgentName,
mode: Mode,
provider: Arc<dyn LlmProvider>,
path: &Path,
) -> Result<Self, MagiError> {
let prompt = std::fs::read_to_string(path)?;
Ok(Self {
name,
mode,
system_prompt: prompt,
provider,
})
}
pub async fn execute(
&self,
user_prompt: &str,
config: &CompletionConfig,
) -> Result<String, ProviderError> {
self.provider
.complete(&self.system_prompt, user_prompt, config)
.await
}
pub fn name(&self) -> AgentName {
self.name
}
pub fn mode(&self) -> Mode {
self.mode
}
pub fn system_prompt(&self) -> &str {
&self.system_prompt
}
pub fn provider_name(&self) -> &str {
self.provider.name()
}
pub fn provider_model(&self) -> &str {
self.provider.model()
}
pub fn display_name(&self) -> &str {
self.name.display_name()
}
pub fn title(&self) -> &str {
self.name.title()
}
}
pub struct AgentFactory {
default_provider: Arc<dyn LlmProvider>,
agent_providers: BTreeMap<AgentName, Arc<dyn LlmProvider>>,
custom_prompts: BTreeMap<(AgentName, Mode), String>,
}
impl AgentFactory {
pub fn new(default_provider: Arc<dyn LlmProvider>) -> Self {
Self {
default_provider,
agent_providers: BTreeMap::new(),
custom_prompts: BTreeMap::new(),
}
}
pub fn with_provider(mut self, name: AgentName, provider: Arc<dyn LlmProvider>) -> Self {
self.agent_providers.insert(name, provider);
self
}
pub fn with_custom_prompt(mut self, name: AgentName, prompt: String) -> Self {
for mode in ALL_MODES {
self.custom_prompts.insert((name, mode), prompt.clone());
}
self
}
pub fn from_directory(mut self, dir: &Path) -> Result<Self, MagiError> {
std::fs::read_dir(dir)?;
let agents = ["melchior", "balthasar", "caspar"];
let modes = ["code_review", "design", "analysis"];
for agent_str in &agents {
for mode_str in &modes {
let filename = format!("{agent_str}_{mode_str}.md");
let path = dir.join(&filename);
if path.exists() {
let content = std::fs::read_to_string(&path)?;
let agent_name = match *agent_str {
"melchior" => AgentName::Melchior,
"balthasar" => AgentName::Balthasar,
"caspar" => AgentName::Caspar,
_ => unreachable!(),
};
let mode = match *mode_str {
"code_review" => Mode::CodeReview,
"design" => Mode::Design,
"analysis" => Mode::Analysis,
_ => unreachable!(),
};
self.custom_prompts.insert((agent_name, mode), content);
}
}
}
Ok(self)
}
pub fn create_agents(&self, mode: Mode) -> Vec<Agent> {
let names = [AgentName::Melchior, AgentName::Balthasar, AgentName::Caspar];
names
.iter()
.map(|&name| {
let provider = self
.agent_providers
.get(&name)
.cloned()
.unwrap_or_else(|| self.default_provider.clone());
if let Some(prompt) = self.custom_prompts.get(&(name, mode)) {
Agent::with_custom_prompt(name, mode, provider, prompt.clone())
} else {
Agent::new(name, mode, provider)
}
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
struct MockProvider {
name: String,
model: String,
response: String,
call_count: AtomicUsize,
}
impl MockProvider {
fn new(name: &str, model: &str, response: &str) -> Self {
Self {
name: name.to_string(),
model: model.to_string(),
response: response.to_string(),
call_count: AtomicUsize::new(0),
}
}
fn calls(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
}
#[async_trait::async_trait]
impl LlmProvider for MockProvider {
async fn complete(
&self,
_system_prompt: &str,
_user_prompt: &str,
_config: &CompletionConfig,
) -> Result<String, ProviderError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Ok(self.response.clone())
}
fn name(&self) -> &str {
&self.name
}
fn model(&self) -> &str {
&self.model
}
}
#[tokio::test]
async fn test_each_agent_uses_its_own_provider() {
let p1 = Arc::new(MockProvider::new("p1", "m1", "r1"));
let p2 = Arc::new(MockProvider::new("p2", "m2", "r2"));
let p3 = Arc::new(MockProvider::new("p3", "m3", "r3"));
let factory = AgentFactory::new(p1.clone() as Arc<dyn LlmProvider>)
.with_provider(AgentName::Melchior, p1.clone() as Arc<dyn LlmProvider>)
.with_provider(AgentName::Balthasar, p2.clone() as Arc<dyn LlmProvider>)
.with_provider(AgentName::Caspar, p3.clone() as Arc<dyn LlmProvider>);
let agents = factory.create_agents(Mode::CodeReview);
let config = CompletionConfig::default();
for agent in &agents {
let _ = agent.execute("test input", &config).await;
}
assert_eq!(p1.calls(), 1, "p1 should receive exactly 1 call");
assert_eq!(p2.calls(), 1, "p2 should receive exactly 1 call");
assert_eq!(p3.calls(), 1, "p3 should receive exactly 1 call");
}
#[tokio::test]
async fn test_factory_default_and_override_providers() {
let default = Arc::new(MockProvider::new("default", "m1", "r1"));
let caspar_override = Arc::new(MockProvider::new("caspar-special", "m2", "r2"));
let factory = AgentFactory::new(default.clone() as Arc<dyn LlmProvider>).with_provider(
AgentName::Caspar,
caspar_override.clone() as Arc<dyn LlmProvider>,
);
let agents = factory.create_agents(Mode::CodeReview);
let melchior = agents
.iter()
.find(|a| a.name() == AgentName::Melchior)
.unwrap();
let balthasar = agents
.iter()
.find(|a| a.name() == AgentName::Balthasar)
.unwrap();
let caspar = agents
.iter()
.find(|a| a.name() == AgentName::Caspar)
.unwrap();
assert_eq!(melchior.provider_name(), "default");
assert_eq!(balthasar.provider_name(), "default");
assert_eq!(caspar.provider_name(), "caspar-special");
}
#[test]
fn test_different_modes_produce_distinct_prompts() {
let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
let cr = Agent::new(AgentName::Melchior, Mode::CodeReview, provider.clone());
let design = Agent::new(AgentName::Melchior, Mode::Design, provider.clone());
let analysis = Agent::new(AgentName::Melchior, Mode::Analysis, provider.clone());
assert_ne!(cr.system_prompt(), design.system_prompt());
assert_ne!(cr.system_prompt(), analysis.system_prompt());
assert_ne!(design.system_prompt(), analysis.system_prompt());
}
#[test]
fn test_from_directory_returns_io_error_for_nonexistent_path() {
let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
let factory = AgentFactory::new(provider);
let result = factory.from_directory(Path::new("/nonexistent/path"));
assert!(matches!(result, Err(MagiError::Io(_))));
}
#[test]
fn test_agent_new_generates_system_prompt() {
let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
let agent = Agent::new(AgentName::Melchior, Mode::CodeReview, provider);
assert!(!agent.system_prompt().is_empty());
}
#[test]
fn test_agent_with_custom_prompt_uses_provided_prompt() {
let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
let agent = Agent::with_custom_prompt(
AgentName::Melchior,
Mode::CodeReview,
provider,
"Custom prompt".to_string(),
);
assert_eq!(agent.system_prompt(), "Custom prompt");
}
#[tokio::test]
async fn test_agent_execute_delegates_to_provider() {
let provider = Arc::new(MockProvider::new("mock", "m1", "response text"));
let provider_arc = provider.clone() as Arc<dyn LlmProvider>;
let agent = Agent::new(AgentName::Melchior, Mode::CodeReview, provider_arc);
let config = CompletionConfig::default();
let result = agent.execute("user input", &config).await;
assert_eq!(result.unwrap(), "response text");
assert_eq!(provider.calls(), 1);
}
#[test]
fn test_agent_accessors() {
let provider = Arc::new(MockProvider::new("test-provider", "test-model", "r"));
let provider_arc = provider.clone() as Arc<dyn LlmProvider>;
let agent = Agent::new(AgentName::Balthasar, Mode::Design, provider_arc);
assert_eq!(agent.name(), AgentName::Balthasar);
assert_eq!(agent.mode(), Mode::Design);
assert_eq!(agent.provider_name(), "test-provider");
assert_eq!(agent.provider_model(), "test-model");
assert_eq!(agent.display_name(), "Balthasar");
assert_eq!(agent.title(), "Pragmatist");
}
#[test]
fn test_agent_factory_creates_three_agents() {
let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
let factory = AgentFactory::new(provider);
let agents = factory.create_agents(Mode::CodeReview);
assert_eq!(agents.len(), 3);
let names: Vec<AgentName> = agents.iter().map(|a| a.name()).collect();
assert!(names.contains(&AgentName::Melchior));
assert!(names.contains(&AgentName::Balthasar));
assert!(names.contains(&AgentName::Caspar));
}
#[test]
fn test_agent_factory_creates_agents_in_order() {
let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
let factory = AgentFactory::new(provider);
let agents = factory.create_agents(Mode::CodeReview);
assert_eq!(agents[0].name(), AgentName::Melchior);
assert_eq!(agents[1].name(), AgentName::Balthasar);
assert_eq!(agents[2].name(), AgentName::Caspar);
}
#[test]
fn test_agent_factory_with_provider_overrides_specific_agent() {
let default = Arc::new(MockProvider::new("default", "m1", "r1")) as Arc<dyn LlmProvider>;
let override_p =
Arc::new(MockProvider::new("override", "m2", "r2")) as Arc<dyn LlmProvider>;
let factory = AgentFactory::new(default).with_provider(AgentName::Caspar, override_p);
let agents = factory.create_agents(Mode::CodeReview);
let caspar = agents
.iter()
.find(|a| a.name() == AgentName::Caspar)
.unwrap();
assert_eq!(caspar.provider_name(), "override");
let melchior = agents
.iter()
.find(|a| a.name() == AgentName::Melchior)
.unwrap();
assert_eq!(melchior.provider_name(), "default");
}
#[test]
fn test_agent_factory_with_custom_prompt_overrides_prompt() {
let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
let factory = AgentFactory::new(provider)
.with_custom_prompt(AgentName::Melchior, "My custom prompt".to_string());
let agents = factory.create_agents(Mode::CodeReview);
let melchior = agents
.iter()
.find(|a| a.name() == AgentName::Melchior)
.unwrap();
assert_eq!(melchior.system_prompt(), "My custom prompt");
let balthasar = agents
.iter()
.find(|a| a.name() == AgentName::Balthasar)
.unwrap();
assert_ne!(balthasar.system_prompt(), "My custom prompt");
assert!(!balthasar.system_prompt().is_empty());
}
#[test]
fn test_agent_factory_creates_three_agents_for_all_modes() {
let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
let factory = AgentFactory::new(provider);
for mode in [Mode::CodeReview, Mode::Design, Mode::Analysis] {
let agents = factory.create_agents(mode);
assert_eq!(agents.len(), 3, "Expected 3 agents for mode {mode}");
}
}
#[test]
fn test_default_prompts_contain_json_and_english_constraints() {
let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
for name in [AgentName::Melchior, AgentName::Balthasar, AgentName::Caspar] {
for mode in [Mode::CodeReview, Mode::Design, Mode::Analysis] {
let agent = Agent::new(name, mode, provider.clone());
let prompt = agent.system_prompt();
assert!(
prompt.contains("JSON"),
"{name:?}/{mode:?} prompt should mention JSON"
);
assert!(
prompt.contains("English"),
"{name:?}/{mode:?} prompt should mention English"
);
}
}
}
#[test]
fn test_from_file_returns_io_error_for_nonexistent_path() {
let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
let result = Agent::from_file(
AgentName::Melchior,
Mode::CodeReview,
provider,
Path::new("/nonexistent/prompt.md"),
);
assert!(matches!(result, Err(MagiError::Io(_))));
}
}