1mod mutation;
7mod pob;
8mod trade;
9
10use std::collections::HashMap;
11
12use async_trait::async_trait;
13
14use crate::llm::ToolDefinition;
15use crate::pob_parser::PobParser;
16use crate::trade::TradeClient;
17
18pub struct ToolContext<'a> {
20 pub parser: &'a PobParser,
21 pub build_xml: &'a [u8],
22 pub trade: Option<&'a TradeClient>,
23}
24
25pub struct BuildMutation {
27 pub xml: String,
29 pub label: String,
31}
32
33pub struct ToolResult {
35 pub response: serde_json::Value,
37 pub mutation: Option<BuildMutation>,
39}
40
41#[async_trait]
43pub trait Tool: Send + Sync {
44 fn definition(&self) -> ToolDefinition;
46
47 async fn execute(&self, ctx: &ToolContext<'_>, args: &str) -> Result<ToolResult, String>;
49}
50
51pub struct ToolRegistry {
53 tools: Vec<Box<dyn Tool>>,
54 index: HashMap<String, usize>,
55}
56
57impl ToolRegistry {
58 pub fn new(has_trade: bool) -> Self {
62 let mut tools: Vec<Box<dyn Tool>> = Vec::new();
63
64 pob::register(&mut tools);
66
67 mutation::register(&mut tools);
69
70 if has_trade {
72 trade::register(&mut tools);
73 }
74
75 let index = tools
76 .iter()
77 .enumerate()
78 .map(|(i, t)| (t.definition().name.clone(), i))
79 .collect();
80
81 Self { tools, index }
82 }
83
84 pub fn definitions(&self) -> Vec<ToolDefinition> {
86 self.tools.iter().map(|t| t.definition()).collect()
87 }
88
89 pub async fn execute(
91 &self,
92 ctx: &ToolContext<'_>,
93 tool_name: &str,
94 args: &str,
95 ) -> Result<ToolResult, String> {
96 let idx = self
97 .index
98 .get(tool_name)
99 .ok_or_else(|| format!("unknown tool: {tool_name}"))?;
100 self.tools[*idx].execute(ctx, args).await
101 }
102}
103
104fn parse_args(args: &str) -> Result<serde_json::Value, String> {
108 serde_json::from_str(args).map_err(|e| format!("invalid arguments: {e}"))
109}
110
111async fn pob_query(
113 ctx: &ToolContext<'_>,
114 query: crate::pob_parser::PobQuery,
115) -> Result<ToolResult, String> {
116 ctx.parser
117 .query(ctx.build_xml, query)
118 .await
119 .map(|v| ToolResult {
120 response: v,
121 mutation: None,
122 })
123 .map_err(|e| e.to_string())
124}