1use std::fs;
53use std::path::PathBuf;
54
55use claude_agent_sdk_rs::{
56 ClaudeAgentOptions, ClaudeClient, ContentBlock, Message, PermissionMode, SystemPrompt,
57 SystemPromptPreset, ToolResultContent, Tools, query,
58};
59use futures::StreamExt;
60use tracing::{debug, info, instrument, trace};
61
62use gba_pm::PromptManager;
63
64use crate::config::{EngineConfig, TaskConfig};
65use crate::error::{EngineError, Result};
66use crate::event::EventHandler;
67use crate::session::{Session, SessionBuilder};
68use crate::task::{Task, TaskKind, TaskResult, TaskStats};
69
70pub struct Engine<'a> {
105 workdir: PathBuf,
107
108 prompts: PromptManager<'a>,
110
111 base_options: Option<ClaudeAgentOptions>,
113}
114
115impl std::fmt::Debug for Engine<'_> {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 f.debug_struct("Engine")
118 .field("workdir", &self.workdir)
119 .field("prompts", &self.prompts)
120 .field("base_options", &"<ClaudeAgentOptions>")
121 .finish()
122 }
123}
124
125impl<'a> Engine<'a> {
126 pub fn new(config: EngineConfig<'a>) -> Result<Self> {
136 debug!(workdir = %config.workdir.display(), "creating engine");
137
138 Ok(Self {
139 workdir: config.workdir,
140 prompts: config.prompts,
141 base_options: config.agent_options,
142 })
143 }
144
145 #[instrument(skip(self, task), fields(task_kind = %task.kind))]
164 pub async fn run(&self, task: Task) -> Result<TaskResult> {
165 info!(task_kind = %task.kind, "running task");
166
167 let task_config = self.load_task_config(&task.kind)?;
169 debug!(?task_config, "loaded task configuration");
170
171 let system_prompt = self.render_system_prompt(&task, &task_config)?;
173 let user_prompt = self.render_user_prompt(&task)?;
174 debug!("rendered prompts");
175
176 let options = self.build_agent_options(&task_config, system_prompt);
178
179 info!("executing Claude agent query");
181 let messages = query(&user_prompt, Some(options)).await?;
182
183 let result = self.process_messages(messages)?;
185 info!(
186 success = result.success,
187 turns = result.stats.turns,
188 "task completed"
189 );
190
191 Ok(result)
192 }
193
194 #[instrument(skip(self, task, handler), fields(task_kind = %task.kind))]
237 pub async fn run_stream(
238 &self,
239 task: Task,
240 handler: &mut impl EventHandler,
241 ) -> Result<TaskResult> {
242 info!(task_kind = %task.kind, "running task with streaming");
243
244 let task_config = self.load_task_config(&task.kind)?;
246 debug!(?task_config, "loaded task configuration");
247
248 let system_prompt = self.render_system_prompt(&task, &task_config)?;
250 let user_prompt = self.render_user_prompt(&task)?;
251 debug!("rendered prompts");
252
253 let options = self.build_agent_options(&task_config, system_prompt);
255
256 let mut client = ClaudeClient::new(options);
258 client.connect().await?;
259
260 info!("sending query to Claude");
262 client.query(&user_prompt).await?;
263
264 let mut output = String::new();
266 let mut stats = TaskStats::default();
267 let mut success = true;
268
269 let mut stream = client.receive_response();
270 while let Some(result) = stream.next().await {
271 match result {
272 Ok(msg) => {
273 self.process_streaming_message(
274 &msg,
275 &mut output,
276 &mut stats,
277 &mut success,
278 handler,
279 )?;
280 }
281 Err(e) => {
282 let error_msg = e.to_string();
283 handler.on_error(&error_msg);
284 drop(stream);
285 client.disconnect().await?;
286 return Err(e.into());
287 }
288 }
289 }
290 drop(stream);
291
292 handler.on_complete();
293
294 client.disconnect().await?;
296
297 let result = TaskResult {
298 success,
299 output,
300 artifacts: Vec::new(),
301 stats,
302 };
303
304 info!(
305 success = result.success,
306 turns = result.stats.turns,
307 "streaming task completed"
308 );
309
310 Ok(result)
311 }
312
313 pub fn session(&self, session_id: Option<String>) -> Result<Session> {
356 debug!(session_id = ?session_id, "creating session from engine");
357
358 let mut builder = SessionBuilder::new(self.workdir.clone());
359
360 if let Some(ref base) = self.base_options {
361 builder = builder.with_base_options(base.clone());
362 }
363
364 if let Some(id) = session_id {
365 builder = builder.with_session_id(id);
366 }
367
368 builder.build()
369 }
370
371 pub fn session_with_task(
387 &self,
388 task_kind: &TaskKind,
389 context: &serde_json::Value,
390 session_id: Option<String>,
391 ) -> Result<Session> {
392 debug!(task_kind = %task_kind, session_id = ?session_id, "creating session with task config");
393
394 let task_config = self.load_task_config(task_kind)?;
395
396 let temp_task = Task::new(task_kind.clone(), context.clone());
398 let system_prompt = self.render_system_prompt(&temp_task, &task_config)?;
399
400 let mut builder = SessionBuilder::new(self.workdir.clone()).with_task_config(task_config);
401
402 if let Some(ref base) = self.base_options {
403 builder = builder.with_base_options(base.clone());
404 }
405
406 if let Some(prompt) = system_prompt {
407 builder = builder.with_system_prompt(prompt);
408 }
409
410 if let Some(id) = session_id {
411 builder = builder.with_session_id(id);
412 }
413
414 builder.build()
415 }
416
417 fn load_task_config(&self, kind: &TaskKind) -> Result<TaskConfig> {
419 let config_path = self
420 .workdir
421 .join("tasks")
422 .join(kind.dir_name())
423 .join("config.yml");
424
425 if !config_path.exists() {
426 return Err(EngineError::TaskConfigNotFound(kind.to_string()));
427 }
428
429 let content =
430 fs::read_to_string(&config_path).map_err(|e| EngineError::io_error(&config_path, e))?;
431
432 let config: TaskConfig =
433 serde_yaml::from_str(&content).map_err(|e| EngineError::yaml_error(&config_path, e))?;
434
435 Ok(config)
436 }
437
438 fn render_system_prompt(
440 &self,
441 task: &Task,
442 config: &TaskConfig,
443 ) -> Result<Option<SystemPrompt>> {
444 if let Some(ref override_prompt) = task.system_prompt {
446 return Ok(Some(SystemPrompt::Text(override_prompt.clone())));
447 }
448
449 let template_name = format!("{}/system", task.kind.dir_name());
451 let rendered = match self.prompts.render(&template_name, &task.context) {
452 Ok(content) => content,
453 Err(gba_pm::PromptError::TemplateNotFound(_)) => {
454 if config.preset {
456 return Ok(Some(SystemPrompt::Preset(SystemPromptPreset::new(
457 "claude_code",
458 ))));
459 }
460 return Ok(None);
461 }
462 Err(e) => return Err(e.into()),
463 };
464
465 if config.preset {
467 Ok(Some(SystemPrompt::Preset(SystemPromptPreset::with_append(
468 "claude_code",
469 rendered,
470 ))))
471 } else {
472 Ok(Some(SystemPrompt::Text(rendered)))
473 }
474 }
475
476 fn render_user_prompt(&self, task: &Task) -> Result<String> {
478 let template_name = format!("{}/user", task.kind.dir_name());
479 let rendered = self.prompts.render(&template_name, &task.context)?;
480 Ok(rendered)
481 }
482
483 fn build_agent_options(
485 &self,
486 config: &TaskConfig,
487 system_prompt: Option<SystemPrompt>,
488 ) -> ClaudeAgentOptions {
489 let mut options = ClaudeAgentOptions::default();
491
492 if let Some(ref base) = self.base_options {
494 if base.model.is_some() {
495 options.model = base.model.clone();
496 }
497 if base.permission_mode.is_some() {
498 options.permission_mode = base.permission_mode;
499 }
500 if base.max_turns.is_some() {
501 options.max_turns = base.max_turns;
502 }
503 if base.cwd.is_some() {
504 options.cwd = base.cwd.clone();
505 }
506 }
507
508 if options.cwd.is_none() {
510 options.cwd = Some(self.workdir.clone());
511 }
512
513 if system_prompt.is_some() {
515 options.system_prompt = system_prompt;
516 }
517
518 if !config.tools.is_empty() {
520 options.tools = Some(Tools::from(config.tools.clone()));
521 }
522
523 if !config.disallowed_tools.is_empty() {
524 options.disallowed_tools = config.disallowed_tools.clone();
525 }
526
527 if let Some(mode) = config.permission_mode {
529 options.permission_mode = Some(mode.into());
530 }
531
532 if options.permission_mode.is_none() {
534 options.permission_mode = Some(PermissionMode::BypassPermissions);
535 }
536
537 options.skip_version_check = true;
539
540 options
541 }
542
543 fn process_messages(&self, messages: Vec<Message>) -> Result<TaskResult> {
545 let mut output = String::new();
546 let mut stats = TaskStats::default();
547 let mut success = true;
548
549 for message in messages {
550 match message {
551 Message::Assistant(msg) => {
552 for block in &msg.message.content {
553 if let ContentBlock::Text(text) = block {
554 if !output.is_empty() {
555 output.push('\n');
556 }
557 output.push_str(&text.text);
558 }
559 }
560 }
561 Message::Result(result) => {
562 stats.turns = result.num_turns;
563 stats.cost_usd = result.total_cost_usd.unwrap_or(0.0);
564
565 if let Some(usage) = result.usage {
567 if let Some(input) = usage.get("input_tokens").and_then(|v| v.as_u64()) {
568 stats.input_tokens = input;
569 }
570 if let Some(output_tokens) =
571 usage.get("output_tokens").and_then(|v| v.as_u64())
572 {
573 stats.output_tokens = output_tokens;
574 }
575 }
576
577 success = !result.is_error;
578 }
579 _ => {}
580 }
581 }
582
583 Ok(TaskResult {
584 success,
585 output,
586 artifacts: Vec::new(), stats,
588 })
589 }
590
591 fn process_streaming_message(
593 &self,
594 msg: &Message,
595 output: &mut String,
596 stats: &mut TaskStats,
597 success: &mut bool,
598 handler: &mut impl EventHandler,
599 ) -> Result<()> {
600 match msg {
601 Message::Assistant(assistant_msg) => {
602 for block in &assistant_msg.message.content {
603 match block {
604 ContentBlock::Text(text) => {
605 output.push_str(&text.text);
606 handler.on_text(&text.text);
607 }
608 ContentBlock::ToolUse(tool_use) => {
609 handler.on_tool_use(&tool_use.name, &tool_use.input);
610 }
611 _ => {}
612 }
613 }
614 }
615 Message::User(user_msg) => {
616 if let Some(ref content) = user_msg.content {
618 for block in content {
619 if let ContentBlock::ToolResult(tool_result) = block {
620 let result_str = match &tool_result.content {
621 Some(ToolResultContent::Text(s)) => s.as_str(),
622 Some(ToolResultContent::Blocks(_)) => "[structured content]",
623 None => "",
624 };
625 handler.on_tool_result(result_str);
626 }
627 }
628 }
629 }
630 Message::Result(result_msg) => {
631 stats.turns = result_msg.num_turns;
632 stats.cost_usd = result_msg.total_cost_usd.unwrap_or(0.0);
633
634 if let Some(ref usage) = result_msg.usage {
635 if let Some(input) = usage.get("input_tokens").and_then(|v| v.as_u64()) {
636 stats.input_tokens = input;
637 }
638 if let Some(output_tokens) = usage.get("output_tokens").and_then(|v| v.as_u64())
639 {
640 stats.output_tokens = output_tokens;
641 }
642 }
643
644 *success = !result_msg.is_error;
645
646 if result_msg.is_error {
647 handler.on_error("Claude reported an error");
648 }
649
650 trace!(
651 turns = result_msg.num_turns,
652 cost = result_msg.total_cost_usd,
653 "result message processed"
654 );
655 }
656 Message::System(_) | Message::StreamEvent(_) | Message::ControlCancelRequest(_) => {
657 }
659 }
660
661 Ok(())
662 }
663}
664
665#[cfg(test)]
666mod tests {
667 use super::*;
668 use serde_json::json;
669 use tempfile::TempDir;
670
671 fn create_test_task_dir(temp_dir: &TempDir) -> PathBuf {
672 let tasks_dir = temp_dir.path().join("tasks").join("init");
673 fs::create_dir_all(&tasks_dir).unwrap();
674
675 fs::write(
677 tasks_dir.join("config.yml"),
678 r#"
679preset: true
680tools: []
681disallowedTools: []
682"#,
683 )
684 .unwrap();
685
686 fs::write(
688 tasks_dir.join("system.j2"),
689 "You are GBA. Working directory: {{ repo_path }}",
690 )
691 .unwrap();
692
693 fs::write(tasks_dir.join("user.j2"), "Initialize the repository.").unwrap();
695
696 temp_dir.path().to_path_buf()
697 }
698
699 #[test]
700 fn test_should_load_task_config() {
701 let temp_dir = TempDir::new().unwrap();
702 let workdir = create_test_task_dir(&temp_dir);
703
704 let mut prompts = PromptManager::new();
705 prompts.load_dir(workdir.join("tasks")).unwrap();
706
707 let config = EngineConfig::builder()
708 .workdir(&workdir)
709 .prompts(prompts)
710 .build();
711
712 let engine = Engine::new(config).unwrap();
713 let task_config = engine.load_task_config(&TaskKind::Init).unwrap();
714
715 assert!(task_config.preset);
716 assert!(task_config.tools.is_empty());
717 assert!(task_config.disallowed_tools.is_empty());
718 }
719
720 #[test]
721 fn test_should_render_user_prompt() {
722 let temp_dir = TempDir::new().unwrap();
723 let workdir = create_test_task_dir(&temp_dir);
724
725 let mut prompts = PromptManager::new();
726 prompts.load_dir(workdir.join("tasks")).unwrap();
727
728 let config = EngineConfig::builder()
729 .workdir(&workdir)
730 .prompts(prompts)
731 .build();
732
733 let engine = Engine::new(config).unwrap();
734 let task = Task::new(TaskKind::Init, json!({"repo_path": "/test"}));
735
736 let user_prompt = engine.render_user_prompt(&task).unwrap();
737 assert_eq!(user_prompt, "Initialize the repository.");
738 }
739
740 #[test]
741 fn test_should_render_system_prompt_with_preset() {
742 let temp_dir = TempDir::new().unwrap();
743 let workdir = create_test_task_dir(&temp_dir);
744
745 let mut prompts = PromptManager::new();
746 prompts.load_dir(workdir.join("tasks")).unwrap();
747
748 let config = EngineConfig::builder()
749 .workdir(&workdir)
750 .prompts(prompts)
751 .build();
752
753 let engine = Engine::new(config).unwrap();
754 let task = Task::new(TaskKind::Init, json!({"repo_path": "/test"}));
755 let task_config = engine.load_task_config(&task.kind).unwrap();
756
757 let system_prompt = engine.render_system_prompt(&task, &task_config).unwrap();
758
759 match system_prompt {
760 Some(SystemPrompt::Preset(preset)) => {
761 assert_eq!(preset.preset, "claude_code");
762 assert!(preset.append.is_some());
763 assert!(preset.append.unwrap().contains("/test"));
764 }
765 _ => panic!("Expected preset system prompt"),
766 }
767 }
768
769 #[test]
770 fn test_should_use_custom_system_prompt_override() {
771 let temp_dir = TempDir::new().unwrap();
772 let workdir = create_test_task_dir(&temp_dir);
773
774 let mut prompts = PromptManager::new();
775 prompts.load_dir(workdir.join("tasks")).unwrap();
776
777 let config = EngineConfig::builder()
778 .workdir(&workdir)
779 .prompts(prompts)
780 .build();
781
782 let engine = Engine::new(config).unwrap();
783 let task = Task::new(TaskKind::Init, json!({})).with_system_prompt("Custom override");
784 let task_config = engine.load_task_config(&task.kind).unwrap();
785
786 let system_prompt = engine.render_system_prompt(&task, &task_config).unwrap();
787
788 match system_prompt {
789 Some(SystemPrompt::Text(text)) => {
790 assert_eq!(text, "Custom override");
791 }
792 _ => panic!("Expected text system prompt"),
793 }
794 }
795
796 #[test]
797 fn test_should_return_error_for_missing_task_config() {
798 let temp_dir = TempDir::new().unwrap();
799 fs::create_dir_all(temp_dir.path().join("tasks")).unwrap();
800
801 let prompts = PromptManager::new();
802 let config = EngineConfig::builder()
803 .workdir(temp_dir.path())
804 .prompts(prompts)
805 .build();
806
807 let engine = Engine::new(config).unwrap();
808 let result = engine.load_task_config(&TaskKind::Custom("nonexistent".to_string()));
809
810 assert!(matches!(result, Err(EngineError::TaskConfigNotFound(_))));
811 }
812
813 #[test]
814 fn test_should_build_agent_options_with_disallowed_tools() {
815 let temp_dir = TempDir::new().unwrap();
816 let tasks_dir = temp_dir.path().join("tasks").join("review");
817 fs::create_dir_all(&tasks_dir).unwrap();
818
819 fs::write(
820 tasks_dir.join("config.yml"),
821 r#"
822preset: true
823tools: []
824disallowedTools:
825 - Write
826 - Edit
827"#,
828 )
829 .unwrap();
830
831 fs::write(tasks_dir.join("system.j2"), "Review mode.").unwrap();
832 fs::write(tasks_dir.join("user.j2"), "Review the code.").unwrap();
833
834 let mut prompts = PromptManager::new();
835 prompts.load_dir(temp_dir.path().join("tasks")).unwrap();
836
837 let config = EngineConfig::builder()
838 .workdir(temp_dir.path())
839 .prompts(prompts)
840 .build();
841
842 let engine = Engine::new(config).unwrap();
843 let task_config = engine.load_task_config(&TaskKind::Review).unwrap();
844
845 let options = engine.build_agent_options(&task_config, None);
846
847 assert_eq!(options.disallowed_tools, vec!["Write", "Edit"]);
848 }
849}