use std::time::Duration;
use anyhow::Result;
use opentelemetry::trace::TracerProvider;
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::Resource;
use opentelemetry_sdk::trace::SdkTracerProvider;
use rig::prelude::*;
use rig::{
completion::{Prompt, ToolDefinition},
providers,
tool::Tool,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::Level;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
},
"required": ["x", "y"],
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
println!("[tool-call] Adding {} and {}", args.x, args.y);
let result = args.x + args.y;
Ok(result)
}
}
#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to subtract from"
},
"y": {
"type": "number",
"description": "The number to subtract"
}
},
"required": ["x", "y"],
},
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
println!("[tool-call] Subtracting {} from {}", args.y, args.x);
let result = args.x - args.y;
tokio::time::sleep(Duration::from_micros(1)).await;
Ok(result)
}
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let exporter = opentelemetry_otlp::SpanExporter::builder()
.with_http()
.with_protocol(opentelemetry_otlp::Protocol::HttpBinary)
.build()?;
let provider = SdkTracerProvider::builder()
.with_batch_exporter(exporter)
.with_resource(Resource::builder().with_service_name("rig-demo").build())
.build();
let tracer = provider.tracer("readme_example");
let otel_layer = tracing_opentelemetry::layer().with_tracer(tracer);
let filter_layer = tracing_subscriber::filter::EnvFilter::builder()
.with_default_directive(Level::INFO.into())
.from_env_lossy();
let fmt_layer = tracing_subscriber::fmt::layer().pretty();
tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.with(otel_layer)
.init();
let openai_client = providers::openai::Client::from_env();
let calculator_agent = openai_client
.agent(providers::openai::GPT_4O)
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();
println!("Calculate 2 - 5");
println!(
"OpenAI Calculator Agent: {}",
calculator_agent.prompt("Calculate 2 - 5").await?
);
let _ = provider.shutdown();
Ok(())
}