1use agent_block_types::obs;
4use mlua::prelude::*;
5
6pub fn register(lua: &Lua) -> LuaResult<()> {
11 let script_name: String = lua
12 .globals()
13 .get::<Option<String>>("_SCRIPT_NAME")?
14 .unwrap_or_else(|| "unknown".to_string());
15
16 let registry = lua.create_table()?;
18 lua.globals().set("_TOOL_REGISTRY", registry)?;
19
20 let tool_tbl = lua.create_table()?;
21
22 let script_name_register = script_name.clone();
25 tool_tbl.set(
26 "register",
27 lua.create_function(
28 move |lua,
29 (name, schema, handler, meta): (
30 String,
31 LuaValue,
32 LuaFunction,
33 Option<LuaTable>,
34 )| {
35 let registry: LuaTable = lua.globals().get("_TOOL_REGISTRY")?;
36 let entry = lua.create_table()?;
37 entry.set("name", name.clone())?;
38 entry.set("schema", schema)?;
39 entry.set("handler", handler)?;
40 if let Some(ref m) = meta {
41 if let Ok(group) = m.get::<String>("group") {
42 entry.set("group", group)?;
43 }
44 }
45 registry.set(name.clone(), entry)?;
46 tracing::info!(
47 target: "lua",
48 script = %script_name_register,
49 "{}",
50 obs::obs_line(
51 "tool",
52 "tool_register",
53 &obs::obs_context(None),
54 &[("tool", name.as_str())],
55 ),
56 );
57 Ok(())
58 },
59 )?,
60 )?;
61
62 let script_name_call = script_name.clone();
68 tool_tbl.set(
69 "call",
70 lua.create_async_function(move |lua, (name, input): (String, LuaValue)| {
71 let script_name = script_name_call.clone();
72 async move {
73 tracing::info!(
74 target: "lua",
75 script = %script_name,
76 "{}",
77 obs::obs_line(
78 "tool",
79 "tool_call",
80 &obs::obs_context(None),
81 &[("tool", name.as_str())],
82 ),
83 );
84 let registry: LuaTable = lua.globals().get("_TOOL_REGISTRY")?;
85 let entry: Option<LuaTable> = registry.get(name.clone())?;
86 match entry {
87 None => {
88 tracing::warn!(
89 target: "lua",
90 script = %script_name,
91 "{}",
92 obs::obs_line(
93 "tool",
94 "tool_result",
95 &obs::obs_context(None),
96 &[("tool", name.as_str()), ("ok", "false")],
97 ),
98 );
99 Err(LuaError::external(format!("tool not found: {name}")))
100 }
101 Some(e) => {
102 let handler: LuaFunction = e.get("handler")?;
103 match handler.call_async::<LuaValue>(input).await {
104 Ok(v) => {
105 tracing::info!(
106 target: "lua",
107 script = %script_name,
108 "{}",
109 obs::obs_line(
110 "tool",
111 "tool_result",
112 &obs::obs_context(None),
113 &[("tool", name.as_str()), ("ok", "true")],
114 ),
115 );
116 Ok(v)
117 }
118 Err(e) => {
119 tracing::warn!(
120 target: "lua",
121 script = %script_name,
122 "{}",
123 obs::obs_line(
124 "tool",
125 "tool_result",
126 &obs::obs_context(None),
127 &[("tool", name.as_str()), ("ok", "false")],
128 ),
129 );
130 Err(e)
131 }
132 }
133 }
134 }
135 }
136 })?,
137 )?;
138
139 tool_tbl.set(
141 "list",
142 lua.create_function(|lua, ()| {
143 let registry: LuaTable = lua.globals().get("_TOOL_REGISTRY")?;
144 let names = lua.create_table()?;
145 for (idx, pair) in (1..).zip(registry.pairs::<String, LuaTable>()) {
146 let (name, _) = pair?;
147 names.set(idx, name)?;
148 }
149 Ok(names)
150 })?,
151 )?;
152
153 tool_tbl.set(
156 "schema",
157 lua.create_function(|lua, ()| {
158 let registry: LuaTable = lua.globals().get("_TOOL_REGISTRY")?;
159 let arr = lua.create_table()?;
160 for (idx, pair) in (1..).zip(registry.pairs::<String, LuaTable>()) {
161 let (_, entry) = pair?;
162 let name: String = entry.get("name")?;
163 let schema: LuaTable = entry.get("schema")?;
164 let description: String = schema.get("description")?;
165 let input_schema: LuaValue = schema.get("input_schema")?;
166
167 let tool_def = lua.create_table()?;
168 tool_def.set("name", name)?;
169 tool_def.set("description", description)?;
170 tool_def.set("input_schema", input_schema)?;
171 if let Ok(group) = entry.get::<String>("group") {
172 tool_def.set("group", group)?;
173 }
174 arr.set(idx, tool_def)?;
175 }
176 Ok(arr)
177 })?,
178 )?;
179
180 lua.globals().set("tool", tool_tbl)?;
181 Ok(())
182}