function_call/
function_call.rs1use openai_api_rs::v1::api::OpenAIClient;
2use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
3use openai_api_rs::v1::common::GPT4_O;
4use openai_api_rs::v1::types;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::{env, vec};
8
9fn get_coin_price(coin: &str) -> f64 {
10 let coin = coin.to_lowercase();
11 match coin.as_str() {
12 "btc" | "bitcoin" => 10000.0,
13 "eth" | "ethereum" => 1000.0,
14 _ => 0.0,
15 }
16}
17
18#[tokio::main]
19async fn main() -> Result<(), Box<dyn std::error::Error>> {
20 let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
21 let mut client = OpenAIClient::builder().with_api_key(api_key).build()?;
22
23 let mut properties = HashMap::new();
24 properties.insert(
25 "coin".to_string(),
26 Box::new(types::JSONSchemaDefine {
27 schema_type: Some(types::JSONSchemaType::String),
28 description: Some("The cryptocurrency to get the price of".to_string()),
29 ..Default::default()
30 }),
31 );
32
33 let req = ChatCompletionRequest::new(
34 GPT4_O.to_string(),
35 vec![chat_completion::ChatCompletionMessage {
36 role: chat_completion::MessageRole::user,
37 content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")),
38 name: None,
39 tool_calls: None,
40 tool_call_id: None,
41 }],
42 )
43 .tools(vec![chat_completion::Tool {
44 r#type: chat_completion::ToolType::Function,
45 function: types::Function {
46 name: String::from("get_coin_price"),
47 description: Some(String::from("Get the price of a cryptocurrency")),
48 parameters: types::FunctionParameters {
49 schema_type: types::JSONSchemaType::Object,
50 properties: Some(properties),
51 required: Some(vec![String::from("coin")]),
52 },
53 },
54 }])
55 .tool_choice(chat_completion::ToolChoiceType::Auto);
56
57 let result = client.chat_completion(req).await?;
62
63 match result.choices[0].finish_reason {
64 None => {
65 println!("No finish_reason");
66 println!("{:?}", result.choices[0].message.content);
67 }
68 Some(chat_completion::FinishReason::stop) => {
69 println!("Stop");
70 println!("{:?}", result.choices[0].message.content);
71 }
72 Some(chat_completion::FinishReason::length) => {
73 println!("Length");
74 }
75 Some(chat_completion::FinishReason::tool_calls) => {
76 println!("ToolCalls");
77 #[derive(Deserialize, Serialize)]
78 struct Currency {
79 coin: String,
80 }
81 let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap();
82 for tool_call in tool_calls {
83 let name = tool_call.function.name.clone().unwrap();
84 let arguments = tool_call.function.arguments.clone().unwrap();
85 let c: Currency = serde_json::from_str(&arguments)?;
86 let coin = c.coin;
87 if name == "get_coin_price" {
88 let price = get_coin_price(&coin);
89 println!("{} price: {}", coin, price);
90 }
91 }
92 }
93 Some(chat_completion::FinishReason::content_filter) => {
94 println!("ContentFilter");
95 }
96 Some(chat_completion::FinishReason::null) => {
97 println!("Null");
98 }
99 }
100 Ok(())
101}
102
103