use clap::Parser;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use oaapi::chat::ChatModel;
use oaapi::chat::CompletionsRequestBody;
use oaapi::chat::Function;
use oaapi::chat::SystemMessage;
use oaapi::chat::TooChoiceOption;
use oaapi::chat::UserMessage;
use oaapi::Client;
#[derive(Parser)]
struct Arguments {
#[arg(short, long)]
message: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct GetCurrentWeather {
location: String,
#[serde(skip_serializing_if = "Option::is_none")]
unit: Option<Unit>,
}
impl Display for GetCurrentWeather {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
write!(f, "Location: {}", self.location)?;
if let Some(unit) = &self.unit {
write!(f, ", Unit: {}", unit)?;
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
enum Unit {
Celsius,
Fahrenheit,
}
impl Display for Unit {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
write!(
f,
"{}",
match self {
| Unit::Celsius => "celsius",
| Unit::Fahrenheit => "fahrenheit",
}
)
}
}
impl Serialize for Unit {
fn serialize<S: serde::Serializer>(
&self,
serializer: S,
) -> Result<S::Ok, S::Error> {
serializer.serialize_str(match self {
| Unit::Celsius => "celsius",
| Unit::Fahrenheit => "fahrenheit",
})
}
}
impl<'de> Deserialize<'de> for Unit {
fn deserialize<D: serde::Deserializer<'de>>(
deserializer: D
) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
match s.as_str() {
| "celsius" => Ok(Unit::Celsius),
| "fahrenheit" => Ok(Unit::Fahrenheit),
| _ => Err(serde::de::Error::custom(format!(
"invalid unit: {}",
s
))),
}
}
}
const GET_CURRENT_WEATHER_SCHEMA: &str = r#"
{
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}"#;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let arguments = Arguments::parse();
let client = Client::from_env()?;
let prompt = r#"Respond for user message by using some tool."#;
let function = Function {
description: Some(
"Get the current weather in a given location".to_string(),
),
name: "get_current_weather".to_string(),
parameters: Some(serde_json::from_str(
GET_CURRENT_WEATHER_SCHEMA,
)?),
};
let tools = Some(vec![function.into()]); let tool_choice = Some(TooChoiceOption::Auto.into());
let request_body = CompletionsRequestBody {
messages: vec![
SystemMessage::new(prompt, None).into(),
UserMessage::new(arguments.message.into(), None).into(),
],
model: ChatModel::Gpt35Turbo1106,
tools, tool_choice, ..Default::default()
};
let response = client
.chat_complete(request_body)
.await?;
let arguments = response
.choices
.first()
.unwrap()
.message
.tool_calls
.as_ref()
.unwrap()
.first()
.unwrap()
.function
.arguments
.as_ref();
let called_function =
serde_json::from_str::<GetCurrentWeather>(&arguments)?;
println!("Result:\n{}", called_function);
Ok(())
}