use log::info;
use zai_rs::model::{chat::data::ChatCompletion, chat_base_response::ChatCompletionResponse, *};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
let model = GLM4_5_flash {};
let weather_func = Function::new(
"get_weather",
"Get current weather for a city",
serde_json::json!({
"type": "object",
"properties": {
"city": {"type": "string"}
},
"required": ["city"],
"additionalProperties": false
}),
);
let tools = Tools::Function {
function: weather_func,
};
let key = get_key()?;
let user_text = "你是谁,帮为查找深圳今天的天气";
let mut client = ChatCompletion::new(model, TextMessage::user(user_text), key)
.with_thinking(ThinkingType::disabled())
.with_temperature(0.7)
.with_top_p(0.9)
.with_max_tokens(512)
.add_tool(tools);
let body: ChatCompletionResponse = client.send().await?;
let v = serde_json::to_value(&body).expect("Failed to serialize response to JSON");
info!(
"{}",
serde_json::to_string_pretty(&v).expect("Failed to format JSON")
);
if let Some((id, name, arguments)) = parse_first_tool_call(&v) {
info!("提取到的 tool_call -> name: {name}, arguments: {arguments}");
let result = handle_tool_call(&name, &arguments)
.unwrap_or_else(|| serde_json::json!({"ok": false, "error": "no_result"}));
info!(
"模拟函数返回结果: {}",
serde_json::to_string_pretty(&result).expect("Failed to format JSON")
);
let tool_msg = TextMessage::tool_with_id(
serde_json::to_string(&result).expect("Failed to serialize tool result"),
id,
);
client = client.add_messages(tool_msg).with_max_tokens(512);
let body2: ChatCompletionResponse = client.send().await?;
let v2 = serde_json::to_value(&body2).expect("Failed to serialize response to JSON");
info!(
"继续对话返回: {}",
serde_json::to_string_pretty(&v2).expect("Failed to format JSON")
);
} else {
info!("未发现 tool_calls");
}
Ok(())
}
fn get_key() -> Result<String, Box<dyn std::error::Error>> {
if let Ok(key) = std::env::var("ZHIPU_API_KEY") {
Ok(key)
} else {
let mut key = String::new();
println!("请输入 ZHIPU_API_KEY: ");
std::io::stdin().read_line(&mut key)?;
Ok(key.trim().to_string())
}
}
fn parse_first_tool_call(v: &serde_json::Value) -> Option<(String, String, String)> {
let tool_calls = v.pointer("/choices/0/message/tool_calls")?.as_array()?;
let tc0 = tool_calls.first()?;
let id = tc0.get("id")?.as_str()?.to_string();
let func = tc0.get("function")?;
let name = func.get("name")?.as_str()?.to_string();
let arguments = func.get("arguments")?.as_str()?.to_string();
Some((id, name, arguments))
}
fn handle_tool_call(name: &str, arguments: &str) -> Option<serde_json::Value> {
match name {
"get_weather" => {
let parsed: serde_json::Value = match serde_json::from_str(arguments) {
Ok(v) => v,
Err(err) => {
log::warn!("解析 arguments 失败: {} | 原始: {}", err, arguments);
return Some(serde_json::json!({
"ok": false,
"error": "invalid_arguments",
"raw": arguments,
}));
},
};
let city = parsed
.get("city")
.and_then(|v| v.as_str())
.unwrap_or("未知城市");
Some(serde_json::json!({
"ok": true,
"name": name,
"request": { "city": city },
"result": {
"city": city,
"condition": "晴",
"temperature_c": 28,
"humidity": 0.65,
"tips": format!("{} 现在户外紫外线较强,注意防晒。", city),
},
"source": "mock",
}))
},
_ => {
Some(serde_json::json!({
"ok": false,
"error": "unknown_tool",
"name": name,
"raw_arguments": arguments,
}))
},
}
}