use anyhow::Result;
use rigs::{agent::Agent, llm_provider::LLMProvider, rig_agent::RigAgent, tool};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[tokio::main]
async fn main() -> Result<()> {
dotenv::dotenv().ok();
let subscriber = tracing_subscriber::fmt::Subscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_line_number(true)
.with_file(true)
.finish();
tracing::subscriber::set_global_default(subscriber)?;
let agent = RigAgent::deepseek_builder()
.provider(LLMProvider::deepseek("deepseek-chat"))?
.system_prompt("You need to select the right tool to answer the question.")
.agent_name("TestToolAgent")
.user_name("M4n5ter")
.enable_autosave()
.max_loops(1)
.save_state_dir("./temp")
.tool(SubTool)?
.tool(Add)? .tool(MultiplyTool)?
.tool(Exec)? ;
let agent = agent.build()?;
let mut result = agent.run("10 - 5".into()).await.unwrap();
println!("{result}");
result = agent.run(format!("{result} + 5")).await.unwrap();
println!("{result}");
result = agent.run(format!("{result} * 5")).await.unwrap();
println!("{result}");
result = agent
.run("
Use docker to run a postgres database(newest version, alpine as base), set the network mode to host.
Then get current system's release.
Finally, curl something to get the IP address of current machine.
".to_owned()
)
.await
.unwrap();
println!("{result}");
Ok(())
}
#[tool(
description = "Subtract y from x (i.e.: x - y)",
arg(x, description = "The number to subtract from"),
arg(y, description = "The number to subtract")
)]
fn sub(x: f64, y: f64) -> Result<f64, CalcError> {
tracing::info!("Sub tool is called");
Ok(x - y)
}
#[tool]
fn add(x: f64, y: f64) -> Result<f64, CalcError> {
tracing::info!("Add tool is called");
Ok(x + y)
}
#[tool(name = "Multiply", description = "Multiply x and y (i.e.: x * y)")]
fn mul(x: f64, y: f64) -> Result<f64, CalcError> {
tracing::info!("Mul tool is called");
Ok(x * y)
}
#[tool(description = "
Execute the shell command, can execute multiple commands at once.
")]
fn exec(x: ExecShell) -> Result<String, CalcError> {
tracing::info!("exec tool is called");
let results = serde_json::to_string(&x).unwrap();
Ok(results)
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct ExecShell {
#[doc = "The commands to execute, can execute multiple commands at once."]
don_t_tell_you_what_it_means_1: Vec<String>,
don_t_tell_you_what_it_means_2: bool,
don_t_tell_you_what_it_means_3: String,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct ExampleStructParameterToHaveADescription {
#[doc = "The first field"]
first_field: String,
second_field: String,
third_field: Vec<ThirdField>,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct ThirdField {
first_field: String,
second_field: String,
}
#[derive(Debug, Error)]
pub enum CalcError {}