use std::path::PathBuf;
use rust_agent::{run_agent, OpenAIChatModel, McpClient, SimpleMcpClient, McpTool, McpAgent, SimpleMemory, BaseMemory, CompositeMemory};
use std::sync::Arc;
use std::collections::HashMap;
use chrono;
use serde_json::{Value, json};
use anyhow::Error;
use log::LevelFilter;
use env_logger;
use log::{info, error};
#[tokio::main]
async fn main() {
env_logger::Builder::new()
.filter_level(LevelFilter::Info)
.init();
info!("=== Rust Agent 使用示例 ===");
let memory_type = std::env::var("MEMORY_TYPE").unwrap_or_else(|_| "composite".to_string());
let summary_threshold = std::env::var("SUMMARY_THRESHOLD")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(200);
let recent_messages_count = std::env::var("RECENT_MESSAGES_COUNT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(10);
info!("使用记忆类型: {}", memory_type);
info!("摘要阈值: {}", summary_threshold);
info!("保留最近消息数: {}", recent_messages_count);
let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "OPENAI_API_KEY".to_string());
let base_url = std::env::var("OPENAI_API_URL").ok();
let mcp_url = std::env::var("MCP_URL").unwrap_or("http://localhost:8000/mcp".to_string());
let model = OpenAIChatModel::new(api_key.clone(), base_url)
.with_model(std::env::var("OPENAI_API_MODEL").unwrap_or_else(|_| "gpt-3.5-turbo".to_string()))
.with_temperature(0.7)
.with_max_tokens(8*1024);
let mut mcp_client = SimpleMcpClient::new(mcp_url.clone());
mcp_client.clear_tools();
mcp_client.add_tools(vec![
McpTool {
name: "get_weather".to_string(),
description: format!(
"Get weather information for a specified city. For example: 'What's the weather like in Beijing?'.
The parameter request body you should extract is: '\"parameters\": {{ \"city\": \"{}\" }}'",
"city".to_string()),
},
McpTool {
name: "simple_calculate".to_string(),
description: format!(
"Execute simple mathematical calculations. For example: 'What is 9.11 plus 9.8?'.
The parameter request body you should extract is: '\"parameters\": {{ \"expression\": \"{}\" }}'",
"expression".to_string()),
},
]);
mcp_client.register_tool_handler("get_weather".to_string(), |params: HashMap<String, Value>| async move {
let default_city = Value::String("Shanghai".to_string());
let city_value = params.get("city").unwrap_or(&default_city);
let city = city_value.as_str().unwrap_or("Shanghai");
Ok(json!({
"city": city,
"temperature": "25°C",
"weather": "Sunny",
"humidity": "40%",
"updated_at": chrono::Utc::now().to_rfc3339()
}))
});
mcp_client.register_tool_handler("simple_calculate".to_string(), |params: HashMap<String, Value>| async move {
let expression_value = params.get("expression").ok_or_else(|| Error::msg("Missing calculation expression"))?;
let expression = expression_value.as_str().ok_or_else(|| Error::msg("Expression format error"))?;
let result = parse_and_calculate(expression)?;
Ok(json!({
"expression": expression,
"result": result,
"calculated_at": chrono::Utc::now().to_rfc3339()
}))
});
info!("Using local tools only, not connecting to MCP server...");
info!("Using model: {}", model.model_name().map_or("Model not specified", |v| v));
info!("Using API URL: {}", model.base_url());
info!("----------------------------------------");
let client_arc: Arc<dyn McpClient> = Arc::new(mcp_client);
let memory: Box<dyn BaseMemory> = match memory_type.as_str() {
"simple" => {
info!("使用SimpleMemory (仅内存记忆)");
Box::new(SimpleMemory::new())
},
"composite" => {
info!("使用CompositeMemory (组合记忆 - 支持中长期记忆和摘要记忆)");
let memory = CompositeMemory::with_basic_params(
PathBuf::from("./data/memory"),
summary_threshold,
recent_messages_count,
).await.expect("Failed to create composite memory");
Box::new(memory)
},
_ => {
error!("未知的记忆类型: {}, 使用默认的SimpleMemory", memory_type);
Box::new(SimpleMemory::new())
}
};
let user_system_prompt = "You are an AI assistant that can use tools to answer user questions. Please decide whether to use tools based on the user's needs.".to_string();
let mut agent = McpAgent::with_openai_model_and_memory(
client_arc.clone(),
user_system_prompt,
model.clone(),
memory
);
if let Err(e) = agent.auto_add_tools().await {
error!("Failed to auto add tools from MCP server: {}", e);
}
println!("基于MCP的AI Agent聊天机器人已启动!");
println!("记忆类型: {}", memory_type);
if memory_type == "composite" {
println!("摘要功能: 已启用 (阈值: {} 条消息)", summary_threshold);
println!("中长期记忆: 已启用");
}
println!("输入'退出'结束对话");
println!("----------------------------------------");
println!("Using tools example:");
let tools = client_arc.get_tools().await.unwrap_or_else(|e| {
error!("Failed to get tools from MCP server: {}", e);
vec![
McpTool {
name: "get_weather".to_string(),
description: "Get the weather information for a specified city. For example: 'What's the weather like in Beijing?'".to_string(),
},
McpTool {
name: "simple_calculate".to_string(),
description: "Perform simple mathematical calculations. For example: 'What is 9.11 plus 9.8?'".to_string(),
},
]
});
let mut index = 0;
for tool in &tools {
index += 1;
println!("{index}. {}: {}", tool.name, tool.description);
}
println!("----------------------------------------");
loop {
let mut user_input = String::new();
println!("你: ");
std::io::stdin().read_line(&mut user_input).expect("读取输入失败");
println!("");
let user_input = user_input.trim();
if user_input.to_lowercase() == "退出" || user_input.to_lowercase() == "exit" {
println!("再见!");
break;
}
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), serde_json::Value::String(user_input.to_string()));
match run_agent(&agent, user_input.to_string()).await {
Ok(response) => {
if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response) {
if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
println!("助手: {}", content);
} else {
println!("助手: {}", response);
}
} else {
println!("助手: {}", response);
}
},
Err(e) => {
println!("助手: 抱歉,处理您的请求时出现错误: {}", e);
},
}
info!("----------------------------------------");
}
info!("对话历史:");
if let Some(memory) = agent.get_memory() {
match memory.load_memory_variables(&HashMap::new()).await {
Ok(memories) => {
if let Some(chat_history) = memories.get("chat_history") {
if let serde_json::Value::Array(messages) = chat_history {
info!("总消息数: {}", messages.len());
for (i, message) in messages.iter().enumerate() {
if let serde_json::Value::Object(msg) = message {
let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("unknown");
let content = msg.get("content").and_then(|v| v.as_str()).unwrap_or("");
let display_content = if content.len() > 100 {
format!("{}...", &content[..100])
} else {
content.to_string()
};
info!("{}. {}: {}", i + 1, role, display_content);
}
}
}
}
if let Some(summary) = memories.get("summary") {
if let serde_json::Value::String(summary_text) = summary {
info!("对话摘要: {}", summary_text);
}
}
},
Err(e) => {
info!("Failed to load memory variables: {}", e);
}
}
} else {
info!("No memory available");
}
if let Err(e) = client_arc.disconnect().await {
error!("Failed to disconnect MCP client: {}", e);
}
}
fn parse_and_calculate(expression: &str) -> Result<f64, Error> {
let expression = expression.replace(" ", "");
for op_char in ["+", "-", "*", "/"].iter() {
if let Some(pos) = expression.find(op_char) {
let left_str = &expression[0..pos];
let right_str = &expression[pos + 1..];
let left = left_str.parse::<f64>().map_err(|e|
Error::msg(format!("Left operand format error: {}", e)))?;
let right = right_str.parse::<f64>().map_err(|e|
Error::msg(format!("Right operand format error: {}", e)))?;
let result = match *op_char {
"+" => left + right,
"-" => left - right,
"*" => left * right,
"/" => {
if right == 0.0 {
return Err(Error::msg("除数不能为零"));
}
left / right
},
_ => unreachable!()
};
return Ok(result);
}
}
if let Ok(number) = expression.parse::<f64>() {
return Ok(number);
}
Err(Error::msg(format!("Failed to parse expression: {}", expression)))
}