local M = {}
local _AGENT_LLM_CTX = {}
function M._llm_ctx_top()
return _AGENT_LLM_CTX[#_AGENT_LLM_CTX]
end
local function env_true(name)
local v = std.env.get(name)
if not v then return false end
v = string.lower(tostring(v))
return v == "1" or v == "true" or v == "yes" or v == "on"
end
local function normalize_dump_mode(v)
if not v or v == "" then return nil end
v = string.lower(tostring(v))
if v == "off" or v == "none" then return "off" end
if v == "meta" then return "meta" end
if v == "full" then return "full" end
return "off"
end
local function resolve_dump_mode()
local mode = normalize_dump_mode(std.env.get("AGENT_BLOCK_LLM_DUMP"))
if not mode then
local rust_log = string.lower(std.env.get_or("RUST_LOG", ""))
if rust_log:find("trace", 1, true) or rust_log:find("debug", 1, true) then
mode = "meta"
else
mode = "off"
end
end
if mode == "full" then
local env_name = string.lower(std.env.get_or("AGENT_BLOCK_ENV", ""))
local is_prod = env_name == "prod" or env_name == "production"
if is_prod and not env_true("AGENT_BLOCK_LLM_DUMP_ALLOW_PROD") then
log.warn("agent: AGENT_BLOCK_LLM_DUMP=full blocked in production env; downgraded to meta")
mode = "meta"
end
end
return mode
end
local function sanitize_headers_for_dump(headers)
local out = {}
for k, v in pairs(headers or {}) do
local lk = string.lower(tostring(k))
if lk == "x-api-key" or lk == "authorization" then
out[k] = "***REDACTED***"
else
out[k] = v
end
end
return out
end
local function llm_dump(mode, msg)
if mode ~= "off" then
log.info("agent.llm_dump " .. msg)
end
end
local LLM_DUMP_PREFIX = "ab.obs"
local function kv_escape(v)
if v == nil then return "nil" end
if type(v) == "boolean" or type(v) == "number" then
return tostring(v)
end
local s = tostring(v)
if s == "" then return '""' end
if s:find("[%s=]") then
return std.json.encode(s)
end
return s
end
local function format_kv(parts)
local out = {}
for i, pair in ipairs(parts) do
out[i] = tostring(pair[1]) .. "=" .. kv_escape(pair[2])
end
return table.concat(out, " ")
end
local function llm_dump_event(mode, event_name, fields)
if mode == "off" then return end
local pairs = {
{ "prefix", LLM_DUMP_PREFIX },
{ "event", event_name },
{ "component", "llm" },
}
for _, f in ipairs(fields or {}) do
table.insert(pairs, f)
end
llm_dump(mode, format_kv(pairs))
end
local function build_log_meta(opts)
local meta = opts and opts.log_meta or {}
local trace_id = meta.trace_id or std.env.get("AGENT_BLOCK_TRACE_ID")
if not trace_id then
trace_id = meta.task_id or std.env.get("AGENT_BLOCK_TASK_ID")
if trace_id then
log.warn("agent: log_meta.task_id / AGENT_BLOCK_TASK_ID is deprecated; use trace_id / AGENT_BLOCK_TRACE_ID")
end
end
return {
trace_id = trace_id,
agent_id = meta.agent_id or std.env.get("AGENT_BLOCK_AGENT_ID") or std.env.agent_id(),
agent_name = meta.agent_name or std.env.get("AGENT_BLOCK_AGENT_NAME"),
run_id = meta.run_id or std.env.get("AGENT_BLOCK_RUN_ID"),
}
end
local function count_tool_use_blocks(content)
local n = 0
for _, block in ipairs(content or {}) do
if block.type == "tool_use" then
n = n + 1
end
end
return n
end
local function count_text_chars(content)
local n = 0
for _, block in ipairs(content or {}) do
if block.type == "text" and block.text then
n = n + #tostring(block.text)
end
end
return n
end
local DEFAULT_CONTEXT_MANAGEMENT = {
edits = {
{
type = "clear_tool_uses_20250919",
trigger = { type = "input_tokens", value = 80000 },
keep = { type = "tool_uses", value = 3 },
clear_at_least = { type = "input_tokens", value = 10000 },
},
},
}
local function llm_call_anthropic(messages, opts, trace)
local api_key = std.env.get("ANTHROPIC_API_KEY")
if not api_key then
return nil, "ANTHROPIC_API_KEY not set"
end
local model = opts.model or std.env.get_or("ANTHROPIC_MODEL", "claude-haiku-4-5-20251001")
local cache_on = opts.cache_control ~= false
local body = {
model = model,
max_tokens = opts.max_tokens or 4096,
messages = messages,
}
if opts.system and opts.system ~= "" then
if cache_on then
body.system = {
{
type = "text",
text = opts.system,
cache_control = { type = "ephemeral" },
},
}
else
body.system = opts.system
end
end
if opts.tools and #opts.tools > 0 then
if cache_on then
local tools = {}
for i = 1, #opts.tools - 1 do
tools[i] = opts.tools[i]
end
local last = {}
for k, v in pairs(opts.tools[#opts.tools]) do
last[k] = v
end
last.cache_control = { type = "ephemeral" }
tools[#opts.tools] = last
body.tools = tools
else
body.tools = opts.tools
end
end
local headers = {
["x-api-key"] = api_key,
["anthropic-version"] = "2023-06-01",
["content-type"] = "application/json",
}
if opts.context_management ~= nil then
headers["anthropic-beta"] = "context-management-2025-06-27"
body.context_management = opts.context_management
end
local dump_mode = resolve_dump_mode()
local call_index = trace and trace.call_index or "?"
local turn = trace and trace.turn or "?"
local iteration = trace and trace.iteration or "?"
llm_dump_event(dump_mode, "request", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "trace_id", trace and trace.trace_id or nil },
{ "run_id", trace and trace.run_id or nil },
{ "agent_id", trace and trace.agent_id or nil },
{ "agent_name", trace and trace.agent_name or nil },
{ "model", body.model },
{ "messages", #messages },
{ "tools", #(body.tools or {}) },
{ "max_tokens", tonumber(body.max_tokens) or 0 },
{ "timeout", tonumber(opts.timeout or 120) or 120 },
{ "context_mgmt", opts.context_management ~= nil },
})
if dump_mode == "full" then
llm_dump_event(dump_mode, "request_headers", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "payload", std.json.encode(sanitize_headers_for_dump(headers)) },
})
llm_dump_event(dump_mode, "request_body", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "payload", std.json.encode(body) },
})
end
local start_ts = std.time.now()
local resp = http.request("https://api.anthropic.com/v1/messages", {
method = "POST",
headers = headers,
body = std.json.encode(body),
timeout = opts.timeout or 120,
})
local elapsed_ms = math.floor((std.time.now() - start_ts) * 1000)
llm_dump_event(dump_mode, "response", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "trace_id", trace and trace.trace_id or nil },
{ "run_id", trace and trace.run_id or nil },
{ "agent_id", trace and trace.agent_id or nil },
{ "agent_name", trace and trace.agent_name or nil },
{ "status", resp.status },
{ "latency_ms", elapsed_ms },
})
if dump_mode == "full" then
llm_dump_event(dump_mode, "response_headers", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "payload", std.json.encode(resp.headers or {}) },
})
llm_dump_event(dump_mode, "response_body", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "payload", tostring(resp.body or "") },
})
end
if resp.status ~= 200 then
return nil, "API error " .. resp.status
end
local decoded = std.json.decode(resp.body)
if dump_mode ~= "off" then
local usage = decoded.usage or {}
local in_tok = tonumber(usage.input_tokens) or 0
local out_tok = tonumber(usage.output_tokens) or 0
local cache_create = tonumber(usage.cache_creation_input_tokens) or 0
local cache_read = tonumber(usage.cache_read_input_tokens) or 0
local stop_reason = tostring(decoded.stop_reason or "unknown")
local content_blocks = #(decoded.content or {})
local tool_uses = count_tool_use_blocks(decoded.content)
local text_chars = count_text_chars(decoded.content)
local cm_applied = 0
if decoded.context_management and decoded.context_management.applied_edits then
cm_applied = #decoded.context_management.applied_edits
end
llm_dump_event(dump_mode, "summary", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "trace_id", trace and trace.trace_id or nil },
{ "run_id", trace and trace.run_id or nil },
{ "agent_id", trace and trace.agent_id or nil },
{ "agent_name", trace and trace.agent_name or nil },
{ "stop_reason", stop_reason },
{ "blocks", content_blocks },
{ "tool_uses", tool_uses },
{ "text_chars", text_chars },
{ "usage_in", in_tok },
{ "usage_out", out_tok },
{ "usage_total", in_tok + out_tok },
{ "cache_create", cache_create },
{ "cache_read", cache_read },
{ "context_edits", cm_applied },
})
end
return decoded, nil
end
local function map_finish_reason(finish_reason)
if finish_reason == "stop" then
return "end_turn"
elseif finish_reason == "tool_calls" then
return "tool_use"
elseif finish_reason == "length" then
return "max_tokens"
else
return tostring(finish_reason or "end_turn")
end
end
local function normalize_openai_response(raw)
if not raw or not raw.choices or #raw.choices == 0 then
return nil, "invalid OpenAI response: missing choices"
end
local choice = raw.choices[1]
local message = choice and choice.message
if not message then
return nil, "invalid OpenAI response: missing choices[0].message"
end
local content = {}
local text = message.content
if text and text ~= "" then
table.insert(content, { type = "text", text = text })
end
for _, tc in ipairs(message.tool_calls or {}) do
local fn = tc["function"] or {}
local input = {}
local ok, parsed = pcall(std.json.decode, fn.arguments or "{}")
if ok and type(parsed) == "table" then
input = parsed
else
log.warn("agent: OpenAI tool_call arguments JSON parse failed for tool '" .. tostring(fn.name) .. "'; using empty input")
table.insert(content, {
type = "tool_use",
id = tc.id,
name = fn.name or "",
input = {},
is_error_hint = "arguments_parse_failed",
})
goto continue_tc
end
table.insert(content, {
type = "tool_use",
id = tc.id,
name = fn.name or "",
input = input,
})
::continue_tc::
end
local usage_raw = raw.usage or {}
local decoded = {
content = content,
stop_reason = map_finish_reason(choice.finish_reason),
usage = {
input_tokens = tonumber(usage_raw.prompt_tokens) or 0,
output_tokens = tonumber(usage_raw.completion_tokens) or 0,
cache_creation_input_tokens = 0,
cache_read_input_tokens = 0,
},
context_management = nil,
}
return decoded, nil
end
local function convert_messages_to_openai(messages, system)
local out = {}
if system and system ~= "" then
table.insert(out, { role = "system", content = system })
end
for _, msg in ipairs(messages) do
if type(msg.content) == "string" then
table.insert(out, { role = msg.role, content = msg.content })
elseif type(msg.content) == "table" then
if msg.role == "assistant" then
local text_parts = {}
local tool_calls = {}
for _, block in ipairs(msg.content) do
if block.type == "text" then
table.insert(text_parts, block.text or "")
elseif block.type == "tool_use" then
table.insert(tool_calls, {
id = block.id,
type = "function",
["function"] = {
name = block.name,
arguments = std.json.encode(block.input or {}),
},
})
end
end
local text_content = #text_parts > 0 and table.concat(text_parts, "\n") or nil
local oai_msg = { role = "assistant" }
if text_content then oai_msg.content = text_content end
if #tool_calls > 0 then oai_msg.tool_calls = tool_calls end
table.insert(out, oai_msg)
elseif msg.role == "user" then
local has_tool_result = false
for _, block in ipairs(msg.content) do
if block.type == "tool_result" then
has_tool_result = true
break
end
end
if has_tool_result then
for _, block in ipairs(msg.content) do
if block.type == "tool_result" then
table.insert(out, {
role = "tool",
tool_call_id = block.tool_use_id,
content = tostring(block.content or ""),
})
end
end
else
local parts = {}
for _, block in ipairs(msg.content) do
if block.type == "text" then
table.insert(parts, block.text or "")
end
end
table.insert(out, { role = "user", content = table.concat(parts, "\n") })
end
else
table.insert(out, { role = msg.role, content = msg.content })
end
end
end
return out
end
local function llm_call_openai(messages, opts, trace)
if opts.cache_control ~= nil then
log.warn("agent: cache_control is anthropic-only; ignored for provider=openai")
end
if opts.context_management ~= nil then
log.warn("agent: context_management is anthropic-only; ignored for provider=openai")
end
if opts.context_management_config ~= nil then
log.warn("agent: context_management_config is anthropic-only; ignored for provider=openai")
end
local api_key = opts.api_key
if not api_key then
local key_env = opts.api_key_env or "OPENAI_API_KEY"
api_key = std.env.get(key_env)
if not api_key then
return nil, "API key not set: env=" .. key_env
end
end
local model = opts.model or std.env.get_or("OPENAI_MODEL", "gpt-4o-mini")
local base_url = opts.base_url or "https://api.openai.com/v1"
local endpoint = base_url .. "/chat/completions"
local oai_messages = convert_messages_to_openai(messages, opts.system)
local oai_tools = nil
if opts.tools and #opts.tools > 0 then
oai_tools = {}
for _, t in ipairs(opts.tools) do
local fn_def = {
name = t.name,
description = t.description or "",
parameters = t.input_schema or { type = "object", properties = {} },
}
table.insert(oai_tools, { type = "function", ["function"] = fn_def })
end
end
local body = {
model = model,
messages = oai_messages,
max_tokens = opts.max_tokens or 4096,
}
if oai_tools and #oai_tools > 0 then
body.tools = oai_tools
end
if opts.extra_body and type(opts.extra_body) == "table" then
for k, v in pairs(opts.extra_body) do
body[k] = v
end
end
local headers = {
["Authorization"] = "Bearer " .. api_key,
["Content-Type"] = "application/json",
}
local dump_mode = resolve_dump_mode()
local call_index = trace and trace.call_index or "?"
local turn = trace and trace.turn or "?"
local iteration = trace and trace.iteration or "?"
llm_dump_event(dump_mode, "request", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "trace_id", trace and trace.trace_id or nil },
{ "run_id", trace and trace.run_id or nil },
{ "agent_id", trace and trace.agent_id or nil },
{ "agent_name", trace and trace.agent_name or nil },
{ "model", body.model },
{ "messages", #messages },
{ "tools", #(body.tools or {}) },
{ "max_tokens", tonumber(body.max_tokens) or 0 },
{ "timeout", tonumber(opts.timeout or 120) or 120 },
{ "context_mgmt", false },
{ "provider", "openai" },
})
if dump_mode == "full" then
llm_dump_event(dump_mode, "request_headers", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "payload", std.json.encode(sanitize_headers_for_dump(headers)) },
})
llm_dump_event(dump_mode, "request_body", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "payload", std.json.encode(body) },
})
end
local start_ts = std.time.now()
local resp = http.request(endpoint, {
method = "POST",
headers = headers,
body = std.json.encode(body),
timeout = opts.timeout or 120,
})
local elapsed_ms = math.floor((std.time.now() - start_ts) * 1000)
llm_dump_event(dump_mode, "response", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "trace_id", trace and trace.trace_id or nil },
{ "run_id", trace and trace.run_id or nil },
{ "agent_id", trace and trace.agent_id or nil },
{ "agent_name", trace and trace.agent_name or nil },
{ "status", resp.status },
{ "latency_ms", elapsed_ms },
{ "provider", "openai" },
})
if dump_mode == "full" then
llm_dump_event(dump_mode, "response_headers", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "payload", std.json.encode(resp.headers or {}) },
})
llm_dump_event(dump_mode, "response_body", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "payload", tostring(resp.body or "") },
})
end
if resp.status ~= 200 then
return nil, "API error " .. resp.status
end
local ok_parse, raw = pcall(std.json.decode, resp.body)
if not ok_parse then
log.warn("agent: OpenAI response JSON decode failed: " .. tostring(raw))
return nil, "OpenAI response JSON decode failed"
end
local decoded, norm_err = normalize_openai_response(raw)
if not decoded then
log.warn("agent: OpenAI response normalization failed: " .. tostring(norm_err))
return nil, norm_err
end
if dump_mode ~= "off" then
local usage = decoded.usage or {}
local in_tok = tonumber(usage.input_tokens) or 0
local out_tok = tonumber(usage.output_tokens) or 0
local stop_reason = tostring(decoded.stop_reason or "unknown")
local content_blocks = #(decoded.content or {})
local tool_uses = count_tool_use_blocks(decoded.content)
local text_chars = count_text_chars(decoded.content)
llm_dump_event(dump_mode, "summary", {
{ "call", call_index },
{ "turn", turn },
{ "iter", iteration },
{ "trace_id", trace and trace.trace_id or nil },
{ "run_id", trace and trace.run_id or nil },
{ "agent_id", trace and trace.agent_id or nil },
{ "agent_name", trace and trace.agent_name or nil },
{ "stop_reason", stop_reason },
{ "blocks", content_blocks },
{ "tool_uses", tool_uses },
{ "text_chars", text_chars },
{ "usage_in", in_tok },
{ "usage_out", out_tok },
{ "usage_total", in_tok + out_tok },
{ "cache_create", 0 },
{ "cache_read", 0 },
{ "context_edits", 0 },
{ "provider", "openai" },
})
end
return decoded, nil
end
local function llm_call(messages, opts, trace)
if (opts.provider or "anthropic") == "openai" then
return llm_call_openai(messages, opts, trace)
else
return llm_call_anthropic(messages, opts, trace)
end
end
local function new_budget_tracker(max_tokens_budget)
local tracker = {
input_tokens = 0,
output_tokens = 0,
total_tokens = 0,
limit = max_tokens_budget,
}
function tracker:add(usage)
if usage then
self.input_tokens = self.input_tokens + (usage.input_tokens or 0)
self.output_tokens = self.output_tokens + (usage.output_tokens or 0)
self.total_tokens = self.input_tokens + self.output_tokens
end
end
function tracker:exceeded()
if not self.limit then return false end
return self.total_tokens >= self.limit
end
function tracker:summary()
return {
input_tokens = self.input_tokens,
output_tokens = self.output_tokens,
total_tokens = self.total_tokens,
}
end
return tracker
end
local function connect_mcp_servers(servers, opts)
local mcp_tool_map = {}
local connected = {}
opts = opts or {}
for _, srv in ipairs(servers) do
local name = srv.name
local ok, err
if srv.url then
local transport_opts = {}
for k, v in pairs(srv.transport_opts or {}) do transport_opts[k] = v end
if transport_opts.trace_context == nil then
transport_opts.trace_context = not not srv.trace_context
end
ok, err = pcall(mcp.connect_http, name, srv.url, transport_opts)
else
local command = srv.command
local args = srv.args or {}
local connect_opts = { trace_context = not not srv.trace_context }
ok, err = pcall(mcp.connect, name, command, args, connect_opts)
end
if not ok then
return nil, "mcp connect failed for '" .. name .. "': " .. tostring(err), connected
end
table.insert(connected, name)
if opts.sampling then
local sampling_ok, sampling_err = pcall(mcp.set_sampling_handler, name, opts.sampling)
if not sampling_ok then
log.warn("agent: mcp set_sampling_handler failed for '" .. name .. "': " .. tostring(sampling_err))
end
end
local list_result = mcp.list_tools(name)
if not list_result.ok then
return nil, "mcp list_tools failed for '" .. name .. "': " .. tostring(list_result.error), connected
end
local tools = list_result.tools or {}
for _, t in ipairs(tools) do
local ns_name = name .. "__" .. t.name
local input_schema = t.inputSchema or t.input_schema or { type = "object", properties = {} }
mcp_tool_map[ns_name] = {
server = name,
tool = t.name,
def = {
name = ns_name,
description = t.description or "",
input_schema = input_schema,
},
}
end
if opts.on_progress then
local sn = name
local user_cb = opts.on_progress
mcp.on_progress(sn, function(ev)
local ok, cb_err = pcall(user_cb, ev)
if not ok then
log.warn("agent: on_progress callback error: " .. tostring(cb_err))
end
end)
elseif opts.progress_to_log then
local sn = name
mcp.on_progress(sn, function(ev)
local msg = "mcp progress: server=" .. tostring(ev.server)
.. " token=" .. tostring(ev.token)
.. " p=" .. tostring(ev.progress) .. "/" .. tostring(ev.total or "")
if ev.message and ev.message ~= "" then
msg = msg .. " msg=" .. ev.message
end
log.info(msg)
end)
end
if opts.enable_resources or opts.enable_prompts or opts.on_log or opts.log_to_stderr then
local si_result = mcp.server_info(name)
if si_result.ok then
local caps = (si_result.server_info and si_result.server_info.capabilities) or {}
if opts.enable_resources then
if caps.resources ~= nil then
local sn = name
tool.register(sn .. "__mcp_list_resources", {
description = "List available resources on MCP server '" .. sn .. "'",
input_schema = { type = "object", properties = {} },
}, function(_input)
local r = mcp.list_resources(sn)
if not r.ok then return std.json.encode({ error = r.error }) end
return std.json.encode(r.resources)
end)
tool.register(sn .. "__mcp_read_resource", {
description = "Read a resource by URI from MCP server '" .. sn .. "'",
input_schema = {
type = "object",
properties = { uri = { type = "string" } },
required = { "uri" },
},
}, function(input)
local r = mcp.read_resource(sn, input.uri)
if not r.ok then return std.json.encode({ error = r.error }) end
return std.json.encode(r.contents)
end)
else
log.info("agent: server '" .. name .. "' has no resources capability; skipping register")
end
end
if opts.enable_prompts then
if caps.prompts ~= nil then
local sn = name
tool.register(sn .. "__mcp_list_prompts", {
description = "List available prompts on MCP server '" .. sn .. "'",
input_schema = { type = "object", properties = {} },
}, function(_input)
local r = mcp.list_prompts(sn)
if not r.ok then return std.json.encode({ error = r.error }) end
return std.json.encode(r.prompts)
end)
tool.register(sn .. "__mcp_get_prompt", {
description = "Get a prompt by name from MCP server '" .. sn .. "'",
input_schema = {
type = "object",
properties = {
name = { type = "string" },
args = { type = "object" },
},
required = { "name" },
},
}, function(input)
local r = mcp.get_prompt(sn, input.name, input.args or {})
if not r.ok then return std.json.encode({ error = r.error }) end
return std.json.encode(r.messages)
end)
else
log.info("agent: server '" .. name .. "' has no prompts capability; skipping register")
end
end
if opts.on_log or opts.log_to_stderr then
if caps.logging ~= nil then
local sn = name
if opts.on_log then
local user_cb = opts.on_log
mcp.on_log(sn, function(ev)
local ok, cb_err = pcall(user_cb, ev)
if not ok then
log.warn("agent: on_log callback error: " .. tostring(cb_err))
end
end)
else
mcp.on_log(sn, function(ev)
local msg = "mcp log: server=" .. tostring(ev.server)
.. " logger=" .. tostring(ev.logger)
.. " data=" .. tostring(ev.data)
if ev.level == "debug" then
log.debug(msg)
elseif ev.level == "warning" then
log.warn(msg)
elseif ev.level == "error" then
log.error(msg)
else
log.info(msg)
end
end)
end
else
log.info("agent: server '" .. name .. "' has no logging capability; on_log/log_to_stderr skipped")
end
end
else
log.warn("agent: mcp.server_info failed for '" .. name .. "': " .. tostring(si_result.error))
end
end
end
return mcp_tool_map, nil, connected
end
local function disconnect_mcp_servers(server_names)
for _, name in ipairs(server_names) do
local ok, err = pcall(mcp.disconnect, name)
if not ok then
log.warn("agent: mcp disconnect error for '" .. name .. "': " .. tostring(err))
end
end
end
local function build_tools(mcp_tool_map, extra_tools)
local tools = {}
local seen = {}
local function add_unique(t)
if seen[t.name] then
return
end
seen[t.name] = true
table.insert(tools, t)
end
for _, t in ipairs(tool.schema()) do
add_unique(t)
end
for _, entry in pairs(mcp_tool_map) do
add_unique(entry.def)
end
if extra_tools then
for _, t in ipairs(extra_tools) do
if t.schema and t.handler then
add_unique({
name = t.name,
description = t.schema.description,
input_schema = t.schema.input_schema,
})
else
add_unique(t)
end
end
end
return tools
end
local function dispatch_tool(name, input, mcp_tool_map, extra_tools_map)
if mcp_tool_map[name] then
local entry = mcp_tool_map[name]
local call_result = mcp.call(entry.server, entry.tool, input)
if not call_result.ok then
return tostring(call_result.error or "mcp.call failed"), true
end
local is_error = call_result.is_error == true
if is_error then
log.warn(string.format("mcp tool '%s.%s' returned isError=true", entry.server, entry.tool))
end
local content_blocks = call_result.content or {}
if #content_blocks == 1 and content_blocks[1].type == "text" then
return content_blocks[1].text, is_error
elseif #content_blocks == 0 then
return "", is_error
else
return std.json.encode(content_blocks), is_error
end
end
if extra_tools_map and extra_tools_map[name] then
local entry = extra_tools_map[name]
local ok, res = pcall(entry.handler, input)
if not ok then
return "tool error: " .. tostring(res), true
end
if type(res) == "table" then
return std.json.encode(res), false
end
return tostring(res), false
end
local ok, res = pcall(tool.call, name, input)
if not ok then
return "tool error: " .. tostring(res), true
end
if type(res) == "table" then
return std.json.encode(res), false
end
return tostring(res), false
end
local function extract_text(content)
local parts = {}
for _, block in ipairs(content or {}) do
if block.type == "text" and block.text then
table.insert(parts, block.text)
end
end
return table.concat(parts, "\n")
end
function M.run(opts)
opts = opts or {}
if not opts.prompt or opts.prompt == "" then
return { ok = false, error = "prompt is required", usage = { input_tokens = 0, output_tokens = 0, total_tokens = 0 }, num_turns = 0, messages = {} }
end
table.insert(_AGENT_LLM_CTX, {
provider = opts.provider,
base_url = opts.base_url,
api_key = opts.api_key,
api_key_env = opts.api_key_env,
model = opts.model,
})
local budget = new_budget_tracker(opts.max_tokens_budget)
local max_iter = opts.max_iterations or 20
local mcp_tool_map = {}
local connected_servers = {}
if opts.mcp_servers and #opts.mcp_servers > 0 then
local tool_map, err, partial_connected = connect_mcp_servers(opts.mcp_servers, opts)
if err then
disconnect_mcp_servers(partial_connected)
table.remove(_AGENT_LLM_CTX)
return {
ok = false,
error = err,
usage = budget:summary(),
num_turns = 0,
messages = {},
}
end
mcp_tool_map = tool_map
connected_servers = partial_connected
end
local extra_tools_map = {}
if opts.extra_tools then
for _, t in ipairs(opts.extra_tools) do
if t.name and t.handler then
extra_tools_map[t.name] = t
end
end
end
local tools = build_tools(mcp_tool_map, opts.extra_tools)
local cm_final
if opts.context_management == false then
cm_final = nil
else
cm_final = opts.context_management_config or DEFAULT_CONTEXT_MANAGEMENT
end
local call_opts = {
model = opts.model,
max_tokens = opts.max_tokens or 4096,
timeout = opts.timeout or 120,
system = opts.system,
tools = tools,
context_management = cm_final, provider = opts.provider,
base_url = opts.base_url,
api_key = opts.api_key,
api_key_env = opts.api_key_env,
cache_control = opts.cache_control,
context_management_config = opts.context_management_config,
}
local log_meta = build_log_meta(opts)
local messages = {
{ role = "user", content = opts.prompt },
}
local num_turns = 0
local llm_call_index = 0
local final_content = ""
local loop_error = nil
local loop_ok, loop_err = pcall(function()
local iter = 0
while true do
llm_call_index = llm_call_index + 1
local response, api_err = llm_call(messages, call_opts, {
call_index = llm_call_index,
turn = num_turns + 1,
iteration = iter + 1,
trace_id = log_meta.trace_id,
agent_id = log_meta.agent_id,
agent_name = log_meta.agent_name,
run_id = log_meta.run_id,
})
if not response then
loop_error = api_err
return
end
table.insert(messages, {
role = "assistant",
content = response.content,
})
budget:add(response.usage)
num_turns = num_turns + 1
local tool_calls = {}
for _, block in ipairs(response.content or {}) do
if block.type == "tool_use" then
table.insert(tool_calls, block)
end
end
final_content = extract_text(response.content)
if opts.on_turn then
local cb_ok, cb_err = pcall(opts.on_turn, {
turn_number = num_turns,
content = response.content,
tool_calls = tool_calls,
usage = response.usage,
context_management = response.context_management,
})
if not cb_ok then
log.warn("agent: on_turn callback error: " .. tostring(cb_err))
end
end
if #tool_calls == 0 then
break
end
local stop_reason = response.stop_reason
if stop_reason == "end_turn" or stop_reason == "max_tokens" then
break
end
iter = iter + 1
if iter >= max_iter then
log.warn("agent: max iterations (" .. max_iter .. ") reached")
break
end
if budget:exceeded() then
log.warn("agent: token budget exceeded (" .. budget.total_tokens .. "/" .. budget.limit .. ")")
break
end
local tool_results = {}
for _, tc in ipairs(tool_calls) do
local content_str, is_error = dispatch_tool(tc.name, tc.input, mcp_tool_map, extra_tools_map)
table.insert(tool_results, {
type = "tool_result",
tool_use_id = tc.id,
content = content_str,
is_error = is_error or nil,
})
end
table.insert(messages, {
role = "user",
content = tool_results,
})
end
end)
table.remove(_AGENT_LLM_CTX)
disconnect_mcp_servers(connected_servers)
if not loop_ok then
return {
ok = false,
error = tostring(loop_err),
usage = budget:summary(),
num_turns = num_turns,
messages = messages,
}
end
if loop_error then
return {
ok = false,
error = loop_error,
usage = budget:summary(),
num_turns = num_turns,
messages = messages,
}
end
return {
ok = true,
content = final_content,
usage = budget:summary(),
num_turns = num_turns,
messages = messages,
}
end
M._build_tools = build_tools
return M