Skip to main content

perspt_agent/
agent.rs

1//! Agent Trait and Implementations
2//!
3//! Defines the interface for all agent implementations and provides
4//! LLM-integrated implementations for Architect, Actuator, and Verifier roles.
5
6use crate::types::{AgentContext, AgentMessage, ModelTier, SRBNNode};
7use anyhow::Result;
8use async_trait::async_trait;
9use perspt_core::llm_provider::GenAIProvider;
10use std::fs;
11use std::path::Path;
12use std::sync::Arc;
13
14/// The Agent trait defines the interface for SRBN agents.
15///
16/// Each agent role (Architect, Actuator, Verifier, Speculator) implements
17/// this trait to provide specialized behavior.
18#[async_trait]
19pub trait Agent: Send + Sync {
20    /// Process a task and return a message
21    async fn process(&self, node: &SRBNNode, ctx: &AgentContext) -> Result<AgentMessage>;
22
23    /// Get the agent's display name
24    fn name(&self) -> &str;
25
26    /// Check if this agent can handle the given node
27    fn can_handle(&self, node: &SRBNNode) -> bool;
28
29    /// Get the model name used by this agent (for logging)
30    fn model(&self) -> &str;
31
32    /// Build the prompt for this agent (for logging)
33    fn build_prompt(&self, node: &SRBNNode, ctx: &AgentContext) -> String;
34}
35
36/// Architect agent - handles planning and DAG construction
37pub struct ArchitectAgent {
38    model: String,
39    provider: Arc<GenAIProvider>,
40}
41
42impl ArchitectAgent {
43    pub fn new(provider: Arc<GenAIProvider>, model: Option<String>) -> Self {
44        Self {
45            model: model.unwrap_or_else(|| ModelTier::Architect.default_model().to_string()),
46            provider,
47        }
48    }
49
50    pub fn build_planning_prompt(&self, node: &SRBNNode, ctx: &AgentContext) -> String {
51        let project_context = format!(
52            "Context Files: {:?}\nOutput Targets: {:?}",
53            node.context_files, node.output_targets
54        );
55        crate::prompts::render_architect(
56            crate::prompts::ARCHITECT_EXISTING,
57            &node.goal,
58            &ctx.working_dir,
59            &project_context,
60            "",
61            "",
62            &ctx.active_plugins,
63        )
64    }
65}
66
67#[async_trait]
68impl Agent for ArchitectAgent {
69    async fn process(&self, node: &SRBNNode, ctx: &AgentContext) -> Result<AgentMessage> {
70        log::info!(
71            "[Architect] Processing node: {} with model {}",
72            node.node_id,
73            self.model
74        );
75
76        let prompt = self.build_planning_prompt(node, ctx);
77
78        let response = self
79            .provider
80            .generate_response_simple(&self.model, &prompt)
81            .await?
82            .text;
83
84        Ok(AgentMessage::new(ModelTier::Architect, response))
85    }
86
87    fn name(&self) -> &str {
88        "Architect"
89    }
90
91    fn can_handle(&self, node: &SRBNNode) -> bool {
92        matches!(node.tier, ModelTier::Architect)
93    }
94
95    fn model(&self) -> &str {
96        &self.model
97    }
98
99    fn build_prompt(&self, node: &SRBNNode, ctx: &AgentContext) -> String {
100        self.build_planning_prompt(node, ctx)
101    }
102}
103
104/// Actuator agent - handles code generation
105pub struct ActuatorAgent {
106    model: String,
107    provider: Arc<GenAIProvider>,
108}
109
110impl ActuatorAgent {
111    pub fn new(provider: Arc<GenAIProvider>, model: Option<String>) -> Self {
112        Self {
113            model: model.unwrap_or_else(|| ModelTier::Actuator.default_model().to_string()),
114            provider,
115        }
116    }
117
118    pub fn build_coding_prompt(&self, node: &SRBNNode, ctx: &AgentContext) -> String {
119        let contract = &node.contract;
120        let allowed_output_paths: Vec<String> = node
121            .output_targets
122            .iter()
123            .map(|path| path.to_string_lossy().to_string())
124            .collect();
125        let workspace_import_hints = Self::workspace_import_hints(&ctx.working_dir);
126
127        // Determine target file from output_targets or generate default
128        let target_file = node
129            .output_targets
130            .first()
131            .map(|p| p.to_string_lossy().to_string())
132            .unwrap_or_else(|| "main.py".to_string());
133
134        // PSP-5: Determine output format based on execution mode and plugin
135        let is_project_mode = ctx.execution_mode == perspt_core::types::ExecutionMode::Project;
136        let has_multiple_outputs = node.output_targets.len() > 1;
137
138        crate::prompts::render_actuator(
139            &node.goal,
140            &contract.interface_signature,
141            &format!("{:?}", contract.invariants),
142            &format!("{:?}", contract.forbidden_patterns),
143            &format!("{:?}", ctx.working_dir),
144            &format!("{:?}", node.context_files),
145            &target_file,
146            &format!("{:?}", allowed_output_paths),
147            &format!("{:?}", workspace_import_hints),
148            is_project_mode || has_multiple_outputs,
149        )
150    }
151
152    fn workspace_import_hints(working_dir: &Path) -> Vec<String> {
153        let mut hints = Vec::new();
154
155        // Rust: detect workspace members OR single-crate name
156        let rust_hints = Self::detect_rust_workspace_crates(working_dir);
157        if !rust_hints.is_empty() {
158            hints.extend(rust_hints);
159        }
160
161        if let Some(package_name) = Self::detect_python_package_name(working_dir) {
162            hints.push(format!(
163                "Python package import root: {}. Tests and entry points must import `{}` and never `src.{}`.",
164                package_name, package_name, package_name
165            ));
166        }
167
168        hints
169    }
170
171    /// Detect Rust crate names for import hints.
172    ///
173    /// Handles both:
174    /// - Single-crate projects: `[package]` with a `name`
175    /// - Workspace projects: `[workspace]` with `members`, enumerating each member's crate name
176    fn detect_rust_workspace_crates(working_dir: &Path) -> Vec<String> {
177        let cargo_toml = match fs::read_to_string(working_dir.join("Cargo.toml")) {
178            Ok(content) => content,
179            Err(_) => return Vec::new(),
180        };
181
182        // Check if this is a workspace manifest
183        let mut in_workspace = false;
184        let mut in_package = false;
185        let mut members: Vec<String> = Vec::new();
186        let mut single_crate_name: Option<String> = None;
187        let mut is_workspace = false;
188
189        for raw_line in cargo_toml.lines() {
190            let line = raw_line.trim();
191            if line.starts_with('[') {
192                in_workspace = line == "[workspace]";
193                in_package = line == "[package]";
194                if in_workspace {
195                    is_workspace = true;
196                }
197                continue;
198            }
199
200            // Parse [package] name for single-crate projects
201            if in_package && line.starts_with("name") {
202                if let Some((_, value)) = line.split_once('=') {
203                    single_crate_name = Some(value.trim().trim_matches('"').to_string());
204                }
205            }
206
207            // Parse [workspace] members
208            if in_workspace && line.starts_with("members") {
209                if let Some((_, value)) = line.split_once('=') {
210                    let raw = value.trim();
211                    // Parse inline array: members = ["crates/foo", "crates/bar"]
212                    if raw.starts_with('[') {
213                        let inner = raw.trim_start_matches('[').trim_end_matches(']');
214                        for item in inner.split(',') {
215                            let member = item.trim().trim_matches('"').trim_matches('\'');
216                            if !member.is_empty() {
217                                members.push(member.to_string());
218                            }
219                        }
220                    }
221                }
222            }
223        }
224
225        if is_workspace && !members.is_empty() {
226            // Enumerate each member crate's name
227            let mut hints = Vec::new();
228            let mut crate_names = Vec::new();
229
230            for member in &members {
231                let member_cargo = working_dir.join(member).join("Cargo.toml");
232                if let Ok(content) = fs::read_to_string(&member_cargo) {
233                    let mut in_pkg = false;
234                    for raw_line in content.lines() {
235                        let line = raw_line.trim();
236                        if line.starts_with('[') {
237                            in_pkg = line == "[package]";
238                            continue;
239                        }
240                        if in_pkg && line.starts_with("name") {
241                            if let Some((_, value)) = line.split_once('=') {
242                                let name = value.trim().trim_matches('"').to_string();
243                                crate_names.push(name);
244                            }
245                            break;
246                        }
247                    }
248                }
249            }
250
251            if !crate_names.is_empty() {
252                hints.push(format!(
253                    "Rust workspace with {} crate(s): {}. \
254                     Cross-crate imports use `use <crate_name>::...;`. \
255                     Add dependencies between workspace crates via `<name>.workspace = true` \
256                     or `<name> = {{ path = \"../other\" }}`.",
257                    crate_names.len(),
258                    crate_names.join(", ")
259                ));
260            }
261
262            hints
263        } else if let Some(name) = single_crate_name {
264            vec![format!(
265                "Rust crate name: {}. Integration tests and external modules must import via `{}`.",
266                name, name
267            )]
268        } else {
269            Vec::new()
270        }
271    }
272
273    fn detect_python_package_name(working_dir: &Path) -> Option<String> {
274        let src_dir = working_dir.join("src");
275        if let Ok(entries) = fs::read_dir(&src_dir) {
276            for entry in entries.flatten() {
277                if entry.file_type().ok()?.is_dir() {
278                    let name = entry.file_name().to_string_lossy().to_string();
279                    if !name.starts_with('.') {
280                        return Some(name);
281                    }
282                }
283            }
284        }
285
286        let pyproject = fs::read_to_string(working_dir.join("pyproject.toml")).ok()?;
287        let mut in_project = false;
288        for raw_line in pyproject.lines() {
289            let line = raw_line.trim();
290            if line.starts_with('[') {
291                in_project = line == "[project]";
292                continue;
293            }
294
295            if in_project && line.starts_with("name") {
296                let (_, value) = line.split_once('=')?;
297                return Some(value.trim().trim_matches('"').replace('-', "_"));
298            }
299        }
300
301        None
302    }
303}
304
305#[async_trait]
306impl Agent for ActuatorAgent {
307    async fn process(&self, node: &SRBNNode, ctx: &AgentContext) -> Result<AgentMessage> {
308        log::info!(
309            "[Actuator] Processing node: {} with model {}",
310            node.node_id,
311            self.model
312        );
313
314        let prompt = self.build_coding_prompt(node, ctx);
315
316        let response = self
317            .provider
318            .generate_response_simple(&self.model, &prompt)
319            .await?
320            .text;
321
322        Ok(AgentMessage::new(ModelTier::Actuator, response))
323    }
324
325    fn name(&self) -> &str {
326        "Actuator"
327    }
328
329    fn can_handle(&self, node: &SRBNNode) -> bool {
330        matches!(node.tier, ModelTier::Actuator)
331    }
332
333    fn model(&self) -> &str {
334        &self.model
335    }
336
337    fn build_prompt(&self, node: &SRBNNode, ctx: &AgentContext) -> String {
338        self.build_coding_prompt(node, ctx)
339    }
340}
341
342/// Verifier agent - handles stability verification and contract checking
343pub struct VerifierAgent {
344    model: String,
345    provider: Arc<GenAIProvider>,
346}
347
348impl VerifierAgent {
349    pub fn new(provider: Arc<GenAIProvider>, model: Option<String>) -> Self {
350        Self {
351            model: model.unwrap_or_else(|| ModelTier::Verifier.default_model().to_string()),
352            provider,
353        }
354    }
355
356    pub fn build_verification_prompt(&self, node: &SRBNNode, implementation: &str) -> String {
357        let contract = &node.contract;
358        crate::prompts::render_verifier(
359            &contract.interface_signature,
360            &format!("{:?}", contract.invariants),
361            &format!("{:?}", contract.forbidden_patterns),
362            &format!("{:?}", contract.weighted_tests),
363            implementation,
364        )
365    }
366}
367
368#[async_trait]
369impl Agent for VerifierAgent {
370    async fn process(&self, node: &SRBNNode, ctx: &AgentContext) -> Result<AgentMessage> {
371        log::info!(
372            "[Verifier] Processing node: {} with model {}",
373            node.node_id,
374            self.model
375        );
376
377        // In a real implementation, we would get the actual implementation from the context
378        let implementation = ctx
379            .history
380            .last()
381            .map(|m| m.content.as_str())
382            .unwrap_or("No implementation provided");
383
384        let prompt = self.build_verification_prompt(node, implementation);
385
386        let response = self
387            .provider
388            .generate_response_simple(&self.model, &prompt)
389            .await?
390            .text;
391
392        Ok(AgentMessage::new(ModelTier::Verifier, response))
393    }
394
395    fn name(&self) -> &str {
396        "Verifier"
397    }
398
399    fn can_handle(&self, node: &SRBNNode) -> bool {
400        matches!(node.tier, ModelTier::Verifier)
401    }
402
403    fn model(&self) -> &str {
404        &self.model
405    }
406
407    fn build_prompt(&self, node: &SRBNNode, _ctx: &AgentContext) -> String {
408        // Verifier needs implementation context, use a placeholder
409        self.build_verification_prompt(node, "<implementation>")
410    }
411}
412
413/// Speculator agent - handles fast lookahead for exploration
414pub struct SpeculatorAgent {
415    model: String,
416    provider: Arc<GenAIProvider>,
417}
418
419impl SpeculatorAgent {
420    pub fn new(provider: Arc<GenAIProvider>, model: Option<String>) -> Self {
421        Self {
422            model: model.unwrap_or_else(|| ModelTier::Speculator.default_model().to_string()),
423            provider,
424        }
425    }
426}
427
428#[async_trait]
429impl Agent for SpeculatorAgent {
430    async fn process(&self, node: &SRBNNode, ctx: &AgentContext) -> Result<AgentMessage> {
431        log::info!(
432            "[Speculator] Processing node: {} with model {}",
433            node.node_id,
434            self.model
435        );
436
437        let prompt = self.build_prompt(node, ctx);
438
439        let response = self
440            .provider
441            .generate_response_simple(&self.model, &prompt)
442            .await?
443            .text;
444
445        Ok(AgentMessage::new(ModelTier::Speculator, response))
446    }
447
448    fn name(&self) -> &str {
449        "Speculator"
450    }
451
452    fn can_handle(&self, node: &SRBNNode) -> bool {
453        matches!(node.tier, ModelTier::Speculator)
454    }
455
456    fn model(&self) -> &str {
457        &self.model
458    }
459
460    fn build_prompt(&self, node: &SRBNNode, _ctx: &AgentContext) -> String {
461        crate::prompts::SPECULATOR_BASIC.replace("{goal}", &node.goal)
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use tempfile::tempdir;
469
470    #[test]
471    fn build_coding_prompt_includes_rust_crate_hint() {
472        let dir = tempdir().unwrap();
473        fs::write(
474            dir.path().join("Cargo.toml"),
475            "[package]\nname = \"validator_lib\"\nversion = \"0.1.0\"\n",
476        )
477        .unwrap();
478
479        let provider = Arc::new(GenAIProvider::new().unwrap());
480        let agent = ActuatorAgent::new(provider, Some("test-model".into()));
481        let mut node = SRBNNode::new("n1".into(), "goal".into(), ModelTier::Actuator);
482        node.output_targets.push("tests/integration.rs".into());
483        let ctx = AgentContext {
484            working_dir: dir.path().to_path_buf(),
485            ..Default::default()
486        };
487
488        let prompt = agent.build_coding_prompt(&node, &ctx);
489        assert!(
490            prompt.contains("Rust crate name: validator_lib"),
491            "{prompt}"
492        );
493    }
494
495    #[test]
496    fn build_coding_prompt_includes_python_package_hint() {
497        let dir = tempdir().unwrap();
498        fs::create_dir_all(dir.path().join("src/psp5_python_verify")).unwrap();
499        fs::write(
500            dir.path().join("pyproject.toml"),
501            "[project]\nname = \"psp5-python-verify\"\nversion = \"0.1.0\"\n",
502        )
503        .unwrap();
504
505        let provider = Arc::new(GenAIProvider::new().unwrap());
506        let agent = ActuatorAgent::new(provider, Some("test-model".into()));
507        let mut node = SRBNNode::new("n1".into(), "goal".into(), ModelTier::Actuator);
508        node.output_targets.push("tests/test_main.py".into());
509        let ctx = AgentContext {
510            working_dir: dir.path().to_path_buf(),
511            ..Default::default()
512        };
513
514        let prompt = agent.build_coding_prompt(&node, &ctx);
515        assert!(
516            prompt.contains("Python package import root: psp5_python_verify"),
517            "{prompt}"
518        );
519    }
520}