function_call/
function_call.rs1use openai_api_rs::v1::api::OpenAIClient;
2use openai_api_rs::v1::chat_completion::{
3 chat_completion::ChatCompletionRequest, ChatCompletionMessage,
4};
5use openai_api_rs::v1::chat_completion::{
6 Content, FinishReason, MessageRole, Tool, ToolChoiceType, ToolType,
7};
8use openai_api_rs::v1::common::GPT4_O;
9use openai_api_rs::v1::types;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::{env, vec};
13
14fn get_coin_price(coin: &str) -> f64 {
15 let coin = coin.to_lowercase();
16 match coin.as_str() {
17 "btc" | "bitcoin" => 10000.0,
18 "eth" | "ethereum" => 1000.0,
19 _ => 0.0,
20 }
21}
22
23#[tokio::main]
24async fn main() -> Result<(), Box<dyn std::error::Error>> {
25 let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
26 let mut client = OpenAIClient::builder().with_api_key(api_key).build()?;
27
28 let mut properties = HashMap::new();
29 properties.insert(
30 "coin".to_string(),
31 Box::new(types::JSONSchemaDefine {
32 schema_type: Some(types::JSONSchemaType::String),
33 description: Some("The cryptocurrency to get the price of".to_string()),
34 ..Default::default()
35 }),
36 );
37
38 let req = ChatCompletionRequest::new(
39 GPT4_O.to_string(),
40 vec![ChatCompletionMessage {
41 role: MessageRole::user,
42 content: Content::Text(String::from("What is the price of Ethereum?")),
43 name: None,
44 tool_calls: None,
45 tool_call_id: None,
46 }],
47 )
48 .tools(vec![Tool {
49 r#type: ToolType::Function,
50 function: types::Function {
51 name: String::from("get_coin_price"),
52 description: Some(String::from("Get the price of a cryptocurrency")),
53 parameters: types::FunctionParameters {
54 schema_type: types::JSONSchemaType::Object,
55 properties: Some(properties),
56 required: Some(vec![String::from("coin")]),
57 },
58 },
59 }])
60 .tool_choice(ToolChoiceType::Auto);
61
62 let result = client.chat_completion(req).await?;
67
68 match result.choices[0].finish_reason {
69 None => {
70 println!("No finish_reason");
71 println!("{:?}", result.choices[0].message.content);
72 }
73 Some(FinishReason::stop) => {
74 println!("Stop");
75 println!("{:?}", result.choices[0].message.content);
76 }
77 Some(FinishReason::length) => {
78 println!("Length");
79 }
80 Some(FinishReason::tool_calls) => {
81 println!("ToolCalls");
82 #[derive(Deserialize, Serialize)]
83 struct Currency {
84 coin: String,
85 }
86 let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap();
87 for tool_call in tool_calls {
88 let name = tool_call.function.name.clone().unwrap();
89 let arguments = tool_call.function.arguments.clone().unwrap();
90 let c: Currency = serde_json::from_str(&arguments)?;
91 let coin = c.coin;
92 if name == "get_coin_price" {
93 let price = get_coin_price(&coin);
94 println!("{coin} price: {price}");
95 }
96 }
97 }
98 Some(FinishReason::content_filter) => {
99 println!("ContentFilter");
100 }
101 Some(FinishReason::null) => {
102 println!("Null");
103 }
104 }
105 Ok(())
106}
107
108