use lellm_agent::schemars::JsonSchema;
use lellm_agent::serde::Deserialize;
use lellm_agent::{AgentBuilder, ResolvedModel, ToolUseLoop};
use lellm_core::{
ChatRequest, ChatResponse, ContentBlock, LlmError, Message, TokenUsage, ToolCall,
};
use lellm_graph::{GraphBuilder, GraphExecutor, NodeKind, TaskNode};
use lellm_macros::Tool;
use schemars;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Deserialize, JsonSchema, Tool)]
#[tool(name = "add", description = "Add two numbers")]
struct AddArgs {
a: f64,
b: f64,
}
#[derive(Deserialize, JsonSchema, Tool)]
#[tool(name = "multiply", description = "Multiply two numbers")]
struct MultiplyArgs {
a: f64,
b: f64,
}
#[derive(Deserialize, JsonSchema, Tool)]
#[tool(name = "divide", description = "Divide two numbers")]
struct DivideArgs {
a: f64,
b: f64,
}
struct CalculatorMockProvider {
round_responses: Vec<ChatResponse>,
current_round: Mutex<usize>,
}
impl CalculatorMockProvider {
fn new(responses: Vec<ChatResponse>) -> Self {
Self {
round_responses: responses,
current_round: Mutex::new(0),
}
}
}
#[::async_trait::async_trait]
impl lellm_provider::LlmProvider for CalculatorMockProvider {
async fn call(&self, _request: &ChatRequest) -> Result<ChatResponse, LlmError> {
let round = {
let mut r = self.current_round.lock().unwrap();
let current = *r;
*r += 1;
current
};
Ok(self.round_responses.get(round).cloned().unwrap_or_else(|| {
ChatResponse::new(
vec![ContentBlock::text("计算完成。".to_string())],
TokenUsage::default(),
serde_json::json!(null),
)
}))
}
async fn stream(
&self,
_request: &ChatRequest,
) -> Result<lellm_provider::ProviderStream, LlmError> {
unimplemented!("stream not needed for this example")
}
fn provider_id(&self) -> &str {
"calculator-mock"
}
}
fn create_calculator_agent() -> ToolUseLoop {
let add_call = ChatResponse::new(
vec![ContentBlock::ToolCall(ToolCall {
id: "call_add_001".to_string(),
name: "add".to_string(),
arguments: serde_json::json!({ "a": 3.0, "b": 4.0 }),
})],
TokenUsage::default(),
serde_json::json!(null),
);
let multiply_call = ChatResponse::new(
vec![ContentBlock::ToolCall(ToolCall {
id: "call_mul_002".to_string(),
name: "multiply".to_string(),
arguments: serde_json::json!({ "a": 7.0, "b": 2.0 }),
})],
TokenUsage::default(),
serde_json::json!(null),
);
let final_answer = ChatResponse::new(
vec![ContentBlock::text(
"3 + 4 = 7,然后 7 × 2 = 14。最终答案是 14。".to_string(),
)],
TokenUsage {
prompt_tokens: 300,
completion_tokens: 40,
total_tokens: 340,
},
serde_json::json!(null),
);
let provider = Arc::new(CalculatorMockProvider::new(vec![
add_call,
multiply_call,
final_answer,
]));
let model = ResolvedModel {
context_window: None,
provider,
model: "claude-sonnet-4-5".to_string(),
};
let tools = vec![
AddArgs::safe(|args| async move {
let result = args.a + args.b;
Ok(serde_json::json!(result))
}),
MultiplyArgs::safe(|args| async move {
let result = args.a * args.b;
Ok(serde_json::json!(result))
}),
DivideArgs::safe(|args| async move {
if args.b == 0.0 {
Err(lellm_agent::ToolError::invalid_input("Division by zero"))
} else {
let result = args.a / args.b;
Ok(serde_json::json!(result))
}
}),
];
AgentBuilder::new(model)
.system_prompt("你是一个数学助手,负责对数字执行算术运算。".to_string())
.tools(tools)
.max_iterations(10)
.build()
}
#[tokio::main]
async fn main() {
let agent = create_calculator_agent();
let mut g = GraphBuilder::new("calculator");
let _ = g.start("init");
let _ = g.node(
"init",
NodeKind::Task(TaskNode::new("init", |state| {
state.insert(
"calc.messages".into(),
serde_json::json!(vec![Message::User {
content: lellm_core::text_block("3加4等于多少,然后再乘以2。".to_string(),),
}]),
);
Ok(())
})),
);
let _ = g.node(
"agent",
NodeKind::Agent(Box::new(
lellm_graph::AgentNode::new("agent", agent).with_prefix("calc"), )),
);
let _ = g.node(
"summary",
NodeKind::Task(TaskNode::new("summary", |state| {
println!("\n=== Graph 执行结果 ===");
let output = state
.get("calc.output")
.and_then(|v| v.as_str())
.unwrap_or("(no output)");
println!("最终输出: {}", output);
let iterations = state
.get("calc.iterations")
.and_then(|v| v.as_u64())
.unwrap_or(0);
println!("LLM 调用轮次: {}", iterations);
let tool_calls = state
.get("calc.tool_calls")
.and_then(|v| v.as_u64())
.unwrap_or(0);
println!("工具调用次数: {}", tool_calls);
let stop_reason = state
.get("calc.stop_reason")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
println!("停止原因: {}", stop_reason);
if let Some(msgs) = state.get("calc.messages") {
let count = if let Some(arr) = msgs.as_array() {
arr.len()
} else {
0
};
println!("对话消息数: {}", count);
}
Ok(())
})),
);
let _ = g.edge("init", "agent");
let _ = g.edge("agent", "summary");
let _ = g.end("summary");
let graph = g.build().expect("Graph 构建失败");
println!("=== LeLLM Calculator Graph ===\n");
println!("Graph 节点: {:?}", graph.node_names());
println!("起始节点: {}", graph.start_node());
println!();
let result = GraphExecutor::default()
.execute(std::sync::Arc::new(graph), HashMap::new())
.await
.expect("Graph 执行失败");
println!("\n=== 执行日志 ===");
for (i, entry) in result.execution_log.iter().enumerate() {
let status = if entry.success { "✅" } else { "❌" };
println!(
" [{}] {} {} {}ms",
i + 1,
entry.node_name,
status,
entry.elapsed().as_millis(),
);
}
println!("总耗时: {}ms", result.duration.as_millis());
println!("\n=== 最终状态 ===");
for (key, value) in &result.state {
println!(" {}: {}", key, value);
}
}