use std::sync::Arc;
use futures_core::Stream;
use crate::llm::{
ChatGptClient, FunctionDefinition, LlmError, Message, ToolDefinition,
};
use crate::pob_parser::{PobParser, PobQuery};
const MAX_TOOL_ROUNDS: usize = 10;
const SYSTEM_PROMPT: &str = "\
You are a Path of Exile 2 build analysis assistant. The user has uploaded \
their Path of Building export.\n\
\n\
You have tools to inspect the build data. Use them to answer the user's \
questions accurately — do NOT guess at numbers.\n\
\n\
Start by calling get_build_stats to get an overview of the build's offense, \
defense, and resources. Then use get_skill_list or get_config if needed \
to answer the user's specific question.\n\
\n\
Use get_item to inspect a specific equipment slot when the user asks about \
their gear, an item's mods, or how a particular slot could be upgraded. \
Do not call get_item unless the question is about specific equipment.\n\
\n\
Use get_passive_tree when the user asks about their passive tree, allocated \
nodes, keystones, notables, ascendancy choices, masteries, or jewel sockets. \
It returns all allocated nodes categorized by type.\n\
\n\
Use get_jewel to inspect a jewel socketed in a passive tree socket. First call \
get_passive_tree to get the jewel_sockets list with node IDs, then call \
get_jewel with the node_id to see the jewel's name, base, rarity, and mods.\n\
\n\
Be specific and reference actual numbers from the build data when relevant. \
If the data doesn't contain enough information to answer, say so.";
#[derive(Debug, Clone)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
pub enum AgentEvent {
ToolCall { name: String },
Token(String),
}
pub struct ToolAgent {
llm: ChatGptClient,
parser: Arc<PobParser>,
}
impl ToolAgent {
pub fn new(llm: ChatGptClient, parser: Arc<PobParser>) -> Self {
Self { llm, parser }
}
pub fn respond(
&self,
build_xml: &[u8],
message: &str,
history: Vec<ChatMessage>,
) -> impl Stream<Item = Result<AgentEvent, LlmError>> + Send {
let llm = self.llm.clone();
let parser = Arc::clone(&self.parser);
let build_xml = build_xml.to_vec();
let message = message.to_owned();
async_stream::try_stream! {
let tools = tool_definitions();
let mut messages = vec![Message::system(SYSTEM_PROMPT)];
for msg in history {
match msg.role.as_str() {
"user" => messages.push(Message::user(&msg.content)),
"assistant" => messages.push(Message::assistant(&msg.content)),
_ => {}
}
}
messages.push(Message::user(message));
let mut tools_were_called = false;
for _ in 0..MAX_TOOL_ROUNDS {
let (assistant_msg, finish_reason) = llm
.chat_with_tools(messages.clone(), Some(&tools))
.await?;
let reason = finish_reason.as_deref().unwrap_or("stop");
if reason == "tool_calls" {
if let Some(ref tool_calls) = assistant_msg.tool_calls {
tools_were_called = true;
for tc in tool_calls {
yield AgentEvent::ToolCall {
name: tc.function.name.clone(),
};
}
messages.push(assistant_msg.clone());
for tc in tool_calls {
let result = execute_tool(&parser, &build_xml, &tc.function.name, &tc.function.arguments).await;
let content = match result {
Ok(val) => val.to_string(),
Err(e) => format!("{{\"error\": \"{e}\"}}"),
};
messages.push(Message::tool_result(&tc.id, content));
}
continue;
}
}
if !tools_were_called {
if let Some(text) = assistant_msg.content {
yield AgentEvent::Token(text);
}
return;
}
break;
}
let stream = llm.chat_stream(messages);
tokio::pin!(stream);
while let Some(token_result) = futures_lite::StreamExt::next(&mut stream).await {
yield AgentEvent::Token(token_result?);
}
}
}
}
fn tool_definitions() -> Vec<ToolDefinition> {
vec![
ToolDefinition {
tool_type: "function".to_owned(),
function: FunctionDefinition {
name: "get_build_stats".to_owned(),
description: "Get extended build statistics including offense, defense, \
resources, speed, and charges. Returns ~40 fields grouped by category."
.to_owned(),
parameters: serde_json::json!({
"type": "object",
"properties": {},
"required": [],
"additionalProperties": false
}),
},
},
ToolDefinition {
tool_type: "function".to_owned(),
function: FunctionDefinition {
name: "get_skill_list".to_owned(),
description: "Get the list of skills with their DPS values, trigger info, \
and gem links (socket groups with gems, levels, and quality)."
.to_owned(),
parameters: serde_json::json!({
"type": "object",
"properties": {},
"required": [],
"additionalProperties": false
}),
},
},
ToolDefinition {
tool_type: "function".to_owned(),
function: FunctionDefinition {
name: "get_config".to_owned(),
description: "Get the build's configuration flags (enemy settings, \
charge generation, conditions, etc.)."
.to_owned(),
parameters: serde_json::json!({
"type": "object",
"properties": {},
"required": [],
"additionalProperties": false
}),
},
},
ToolDefinition {
tool_type: "function".to_owned(),
function: FunctionDefinition {
name: "get_item".to_owned(),
description: "Retrieve the item equipped in a specific gear slot, including \
its name, base type, rarity, and all mod lines (implicit, explicit, \
enchant, rune)."
.to_owned(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"slot": {
"type": "string",
"enum": [
"Weapon 1", "Weapon 2", "Helmet", "Body Armour",
"Gloves", "Boots", "Amulet", "Ring 1", "Ring 2", "Ring 3",
"Belt", "Charm 1", "Charm 2", "Charm 3",
"Flask 1", "Flask 2"
],
"description": "The equipment slot to inspect"
}
},
"required": ["slot"],
"additionalProperties": false
}),
},
},
ToolDefinition {
tool_type: "function".to_owned(),
function: FunctionDefinition {
name: "get_jewel".to_owned(),
description: "Retrieve a jewel socketed in a passive tree socket, including \
its name, base type, rarity, and all mod lines. Use socket node IDs \
from get_passive_tree's jewel_sockets array."
.to_owned(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"node_id": {
"type": "integer",
"description": "The passive tree socket node ID (from get_passive_tree jewel_sockets)"
}
},
"required": ["node_id"],
"additionalProperties": false
}),
},
},
ToolDefinition {
tool_type: "function".to_owned(),
function: FunctionDefinition {
name: "get_passive_tree".to_owned(),
description: "Get the allocated passive tree nodes, grouped by type: \
keystones, notables, ascendancy nodes, masteries, and jewel sockets. \
Also returns class, ascendancy, and total allocated node count."
.to_owned(),
parameters: serde_json::json!({
"type": "object",
"properties": {},
"required": [],
"additionalProperties": false
}),
},
},
]
}
async fn execute_tool(
parser: &PobParser,
build_xml: &[u8],
tool_name: &str,
tool_args: &str,
) -> Result<serde_json::Value, String> {
let query = match tool_name {
"get_build_stats" => PobQuery::BuildStats,
"get_skill_list" => PobQuery::SkillList,
"get_config" => PobQuery::Config,
"get_item" => {
let args: serde_json::Value =
serde_json::from_str(tool_args).map_err(|e| format!("invalid arguments: {e}"))?;
let slot = args["slot"]
.as_str()
.ok_or("missing required parameter: slot")?
.to_owned();
PobQuery::Item(slot)
}
"get_jewel" => {
let args: serde_json::Value =
serde_json::from_str(tool_args).map_err(|e| format!("invalid arguments: {e}"))?;
let node_id = args["node_id"]
.as_i64()
.ok_or("missing required parameter: node_id")?;
PobQuery::Jewel(node_id)
}
"get_passive_tree" => PobQuery::PassiveTree,
other => return Err(format!("unknown tool: {other}")),
};
parser
.query(build_xml, query)
.await
.map_err(|e| e.to_string())
}