use std::io::{self, Write};
use std::sync::Arc;
use agent_base::{
AgentBuilder, AgentError, AgentEvent, AgentResult, ApprovalDecision, ApprovalHandler,
ApprovalRequest, OpenAiClient, RiskLevel, Tool, ToolContext, ToolControlFlow, ToolOutput,
ToolPolicy,
};
use async_trait::async_trait;
use dotenvy::dotenv;
use serde_json::{json, Value};
struct AddTool;
#[async_trait]
impl Tool for AddTool {
fn name(&self) -> &'static str {
"add"
}
fn definition(&self) -> Value {
json!({
"type": "function",
"function": {
"name": "add",
"description": "Calculate the sum of two integers",
"parameters": {
"type": "object",
"properties": {
"a": { "type": "integer", "description": "First addend" },
"b": { "type": "integer", "description": "Second addend" }
},
"required": ["a", "b"]
}
}
})
}
async fn call(&self, args: &Value, _ctx: &ToolContext) -> AgentResult<ToolOutput> {
let a = args["a"].as_i64().unwrap_or(0);
let b = args["b"].as_i64().unwrap_or(0);
let result = a + b;
Ok(ToolOutput {
summary: format!("{} + {} = {}", a, b, result),
raw: Some(json!({ "result": result })),
control_flow: ToolControlFlow::Break,
truncated: false,
})
}
}
struct SubtractTool;
#[async_trait]
impl Tool for SubtractTool {
fn name(&self) -> &'static str {
"subtract"
}
fn definition(&self) -> Value {
json!({
"type": "function",
"function": {
"name": "subtract",
"description": "Calculate the difference of two integers(a - b)",
"parameters": {
"type": "object",
"properties": {
"a": { "type": "integer", "description": "Minuend" },
"b": { "type": "integer", "description": "Subtrahend" }
},
"required": ["a", "b"]
}
}
})
}
async fn call(&self, args: &Value, _ctx: &ToolContext) -> AgentResult<ToolOutput> {
let a = args["a"].as_i64().unwrap_or(0);
let b = args["b"].as_i64().unwrap_or(0);
let result = a - b;
Ok(ToolOutput {
summary: format!("{} - {} = {}", a, b, result),
raw: Some(json!({ "result": result })),
control_flow: ToolControlFlow::Break,
truncated: false,
})
}
}
struct MultiplyTool;
#[async_trait]
impl Tool for MultiplyTool {
fn name(&self) -> &'static str {
"multiply"
}
fn definition(&self) -> Value {
json!({
"type": "function",
"function": {
"name": "multiply",
"description": "Calculate the product of two integers",
"parameters": {
"type": "object",
"properties": {
"a": { "type": "integer", "description": "Multiplier" },
"b": { "type": "integer", "description": "Multiplier" }
},
"required": ["a", "b"]
}
}
})
}
async fn call(&self, args: &Value, _ctx: &ToolContext) -> AgentResult<ToolOutput> {
let a = args["a"].as_i64().unwrap_or(0);
let b = args["b"].as_i64().unwrap_or(0);
let result = a * b;
Ok(ToolOutput {
summary: format!("{} × {} = {}", a, b, result),
raw: Some(json!({ "result": result })),
control_flow: ToolControlFlow::Break,
truncated: false,
})
}
}
struct DivideTool;
#[async_trait]
impl Tool for DivideTool {
fn name(&self) -> &'static str {
"divide"
}
fn definition(&self) -> Value {
json!({
"type": "function",
"function": {
"name": "divide",
"description": "Calculate the quotient of two integers(a ÷ b),Returns quotient and remainder",
"parameters": {
"type": "object",
"properties": {
"a": { "type": "integer", "description": "Dividend" },
"b": { "type": "integer", "description": "Divisor" }
},
"required": ["a", "b"]
}
}
})
}
async fn call(&self, args: &Value, _ctx: &ToolContext) -> AgentResult<ToolOutput> {
let a = args["a"].as_i64().unwrap_or(0);
let b = args["b"].as_i64().unwrap_or(0);
if b == 0 {
return Ok(ToolOutput {
summary: "Error:Divisor cannot be zero".to_string(),
raw: Some(json!({ "error": "division by zero" })),
control_flow: ToolControlFlow::Break,
truncated: false,
});
}
let quotient = a / b;
let remainder = a % b;
Ok(ToolOutput {
summary: format!("{} ÷ {} = {}(remainder {})", a, b, quotient, remainder),
raw: Some(json!({ "quotient": quotient, "remainder": remainder })),
control_flow: ToolControlFlow::Break,
truncated: false,
})
}
}
#[derive(Clone, Debug, Default)]
struct CliApprovalHandler;
#[async_trait]
impl ApprovalHandler for CliApprovalHandler {
async fn approve(&self, request: ApprovalRequest) -> AgentResult<ApprovalDecision> {
println!();
println!("[Approval request] {}", request.title);
println!(" Risk level: {:?}", request.risk_level);
println!(" Content: {}", request.message);
loop {
print!(" Choice [y=Allow / a=AlwaysAllow / n=Deny]: ");
io::stdout().flush().unwrap();
let mut input = String::new();
io::stdin()
.read_line(&mut input)
.map_err(|e| AgentError::internal(format!("Failed to read input: {e}")))?;
match input.trim().to_ascii_lowercase().as_str() {
"y" | "yes" => return Ok(ApprovalDecision::AllowOnce),
"a" | "always" => return Ok(ApprovalDecision::AllowAlways),
"n" | "no" => return Ok(ApprovalDecision::Deny),
_ => println!(" Invalid input,Please enter y / a / n"),
}
}
}
}
struct EventPrinter;
impl EventPrinter {
fn handle(event: AgentEvent) -> AgentResult<()> {
match event {
AgentEvent::TextDelta { text, .. } => {
print!("{}", text);
io::stdout().flush().unwrap();
}
AgentEvent::ThoughtDelta { text, .. } => {
print!("[Thinking]:\x1b[90m{} \x1b[0m", text);
println!();
io::stdout().flush().unwrap();
}
AgentEvent::ToolCallStarted {
tool_name, args_json, ..
} => {
println!();
println!("[Tool call] {} (with args: {})", tool_name, args_json);
}
AgentEvent::ToolCallFinished {
tool_name, summary, ..
} => {
println!("[Tool finished] {} -> {}", tool_name, summary);
}
AgentEvent::AwaitingApproval { request, .. } => {
println!(
"[Waiting for approval] {} (Risk: {:?})",
request.title, request.risk_level
);
}
AgentEvent::RunFinished { .. } => {
println!();
println!("[Run finished]");
}
AgentEvent::Custom { payload, .. } => {
println!("[Custom event] {}", payload);
}
AgentEvent::Checkpoint { .. } => {}
}
Ok(())
}
}
struct ArithmeticToolPolicy;
impl ToolPolicy for ArithmeticToolPolicy {
fn evaluate_approval(
&self,
tool_name: &str,
_args: &Value,
_args_json: &str,
) -> Option<ApprovalRequest> {
if tool_name == "divide" {
return Some(ApprovalRequest {
title: "Division operation".to_string(),
message: "Allow division execution?".to_string(),
action_key: Some("divide".to_string()),
risk_level: RiskLevel::Safe,
raw: None,
});
}
None
}
fn on_pre_call(&self, _tool_name: &str, _args: &Value, _ctx: &ToolContext) {}
fn on_post_call(
&self,
_tool_name: &str,
_args: &Value,
_result: &ToolOutput,
_ctx: &ToolContext,
) {
}
}
const SYSTEM_PROMPT: &str = r#"You are an arithmetic assistant,You can help users perform various arithmetic operations。
The tools available to you include:
- add: Calculate the sum of two integers
- subtract: Calculate the difference of two integers
- multiply: Calculate the product of two integers
- divide: Calculate the quotient of two integers
Please select the appropriate tool based on the user's request. If the request involves multiple steps,
you may call tools step by step。After each calculation,explain the result to the user。"#;
#[tokio::main]
async fn main() -> AgentResult<()> {
dotenv().ok();
let api_key = std::env::var("OPENAI_API_KEY")
.or_else(|_| std::env::var("DASHSCOPE_API_KEY"))
.map_err(|_| AgentError::internal("Please set OPENAI_API_KEY or DASHSCOPE_API_KEY environment variable"))?;
let model = std::env::var("OPENAI_MODEL")
.or_else(|_| std::env::var("DASHSCOPE_MODEL"))
.unwrap_or_else(|_| "gpt-4o-mini".to_string());
let base_url = std::env::var("OPENAI_BASE_URL")
.or_else(|_| std::env::var("DASHSCOPE_BASE_URL"))
.unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
let llm_client = Arc::new(OpenAiClient::new(api_key, model.clone(), Some(base_url)));
let mut runtime = AgentBuilder::new(llm_client)
.system_prompt(SYSTEM_PROMPT)
.enable_thought(false)
.enable_thinking(false)
.register_tool(AddTool)
.register_tool(SubtractTool)
.register_tool(MultiplyTool)
.register_tool(DivideTool)
.tool_policy(Arc::new(ArithmeticToolPolicy))
.approval_handler(Arc::new(CliApprovalHandler))
.build();
let mut session_id = runtime.create_session();
println!("=== agent-base REPL (arithmetic Demo) ===");
println!("model: {}", model);
println!("Input 'exit' quit, 'reset' recreate session");
println!();
loop {
print!("> ");
io::stdout().flush().unwrap();
let mut input = String::new();
io::stdin()
.read_line(&mut input)
.map_err(|e| AgentError::internal(format!("Failed to read input: {e}")))?;
let input = input.trim().to_string();
if input.is_empty() {
continue;
}
if matches!(input.as_str(), "exit" | "quit") {
break;
}
if input == "reset" {
session_id = runtime.create_session();
println!("Created new session");
continue;
}
match runtime
.run_turn_with_handler(session_id.clone(), &input, |event| EventPrinter::handle(event))
.await
{
Ok(_outcome) => {}
Err(e) => {
if e.is_cancelled() {
println!("Cancelled");
} else {
println!("Error: {}", e);
}
}
}
}
Ok(())
}