use lellm_agent::schemars::JsonSchema;
use lellm_agent::serde::Deserialize;
use lellm_agent::{AgentBuilder, AgentFlowNode, ResolvedModel};
use lellm_core::{Message, text_block};
use lellm_derive::Tool;
use lellm_graph::{GraphBuilder, GraphExecutor, NodeKind, StateDelta, TaskNode};
use lellm_provider::providers::base::CodecProvider;
use lellm_provider::providers::openai_compat::OpenAICompatCodec;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Deserialize, JsonSchema, Tool)]
#[tool(name = "add", description = "将两个数字相加")]
struct AddArgs {
a: f64,
b: f64,
}
#[derive(Deserialize, JsonSchema, Tool)]
#[tool(name = "multiply", description = "将两个数字相乘")]
struct MultiplyArgs {
a: f64,
b: f64,
}
fn build_graph() -> lellm_graph::Graph {
let provider = CodecProvider::load(OpenAICompatCodec::llama())
.expect("请设置 OPENAI_API_KEY 环境变量");
let model = ResolvedModel {
provider: Arc::new(provider),
model: "llama3.2".into(),
context_window: Some(8192),
};
let agent = AgentBuilder::new(model)
.system_prompt("你是一个数学助手。当用户问数学问题时,使用工具计算。".into())
.tools([
AddArgs::safe(|args| async move { Ok(serde_json::json!(args.a + args.b)) }),
MultiplyArgs::safe(|args| async move { Ok(serde_json::json!(args.a * args.b)) }),
])
.max_iterations(10)
.build();
let mut g = GraphBuilder::new("calculator");
g.start("init");
g.node(
"init",
NodeKind::Task(TaskNode::new("init", |_| {
Ok(vec![StateDelta::put(
"messages",
serde_json::json!([Message::User {
content: text_block("3加4等于多少,然后再乘以2。".into()),
}]),
)])
})),
);
g.node(
"agent",
NodeKind::External(Arc::new(AgentFlowNode::new("agent", agent))),
);
g.node(
"summary",
NodeKind::Task(TaskNode::new("summary", |state| {
println!("\n=== 结果 ===");
if let Some(msgs) = state.get("messages") {
let count = msgs.as_array().map_or(0, |a| a.len());
println!("消息数: {count}");
}
let reason = state
.get("agent_stop_reason")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let iters = state
.get("agent_iterations")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let calls = state
.get("agent_tool_calls")
.and_then(|v| v.as_u64())
.unwrap_or(0);
println!("停止原因: {reason}");
println!("迭代次数: {iters}");
println!("工具调用: {calls}");
Ok(vec![])
})),
);
g.edge("init", "agent");
g.edge("agent", "summary");
g.end("summary");
g.build().expect("Graph 构建失败")
}
#[tokio::main]
async fn main() {
let _ = tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "lellm_agent=trace,lellm_provider=trace,info".into()),
)
.try_init();
let graph = build_graph();
println!("=== Calculator Graph (LLaMA) ===\n");
println!("节点: {:?}", graph.node_names());
let result = GraphExecutor::default()
.execute(Arc::new(graph), HashMap::new())
.await
.expect("执行失败");
println!("\n=== 执行日志 ===");
for (i, e) in result.execution_log.iter().enumerate() {
let icon = if e.success { "✅" } else { "❌" };
println!(" [{}] {} {icon} {}ms", i + 1, e.node_name, e.elapsed().as_millis());
}
println!("总耗时: {}ms", result.duration.as_millis());
println!("\n=== 最终状态 ===");
for (k, v) in &result.state {
println!(" {k}: {v}");
}
}