function_call/
function_call.rs1use openai_rst::{
2 chat_completion::{
3 ChatCompletionMessage, ChatCompletionRequest, Content, FinishReason, Function,
4 FunctionParameters, JSONSchemaDefine, JSONSchemaType, Tool, ToolChoiceType, ToolType,
5 },
6 client::Client,
7 common::MessageRole,
8 models::{Model, GPT3},
9};
10use serde::{Deserialize, Serialize};
11use std::{collections::HashMap, vec};
12
13fn get_coin_price(coin: &str) -> f64 {
14 let coin = coin.to_lowercase();
15 match coin.as_str() {
16 "btc" | "bitcoin" => 10000.0,
17 "eth" | "ethereum" => 1000.0,
18 _ => 0.0,
19 }
20}
21
22#[tokio::main]
23async fn main() -> Result<(), Box<dyn std::error::Error>> {
24 let client = Client::from_env().unwrap();
25
26 let mut properties = HashMap::new();
27 properties.insert(
28 "coin".to_string(),
29 Box::new(JSONSchemaDefine {
30 schema_type: Some(JSONSchemaType::String),
31 description: Some("The cryptocurrency to get the price of".to_string()),
32 ..Default::default()
33 }),
34 );
35
36 let req = ChatCompletionRequest::new_multi(
37 Model::GPT3(GPT3::GPT35Turbo),
38 vec![ChatCompletionMessage {
39 role: MessageRole::User,
40 content: Content::Text(String::from("What is the price of Ethereum?")),
41 name: None,
42 }],
43 )
44 .tools(vec![Tool {
45 r#type: ToolType::Function,
46 function: Function {
47 name: String::from("get_coin_price"),
48 description: Some(String::from("Get the price of a cryptocurrency")),
49 parameters: FunctionParameters {
50 schema_type: JSONSchemaType::Object,
51 properties: Some(properties),
52 required: Some(vec![String::from("coin")]),
53 },
54 },
55 }])
56 .tool_choice(ToolChoiceType::Auto);
57
58 let result = client.chat_completion(req).await?;
59
60 match result.choices[0].finish_reason {
61 None => {
62 println!("No finish_reason");
63 println!("{:?}", result.choices[0].message.content);
64 }
65 Some(FinishReason::stop) => {
66 println!("Stop");
67 println!("{:?}", result.choices[0].message.content);
68 }
69 Some(FinishReason::length) => {
70 println!("Length");
71 }
72 Some(FinishReason::tool_calls) => {
73 println!("ToolCalls");
74 #[derive(Deserialize, Serialize)]
75 struct Currency {
76 coin: String,
77 }
78 let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap();
79 for tool_call in tool_calls {
80 let name = tool_call.function.name.clone().unwrap();
81 let arguments = tool_call.function.arguments.clone().unwrap();
82 let c: Currency = serde_json::from_str(&arguments)?;
83 let coin = c.coin;
84 if name == "get_coin_price" {
85 let price = get_coin_price(&coin);
86 println!("{} price: {}", coin, price);
87 }
88 }
89 }
90 Some(FinishReason::content_filter) => {
91 println!("ContentFilter");
92 }
93 Some(FinishReason::null) => {
94 println!("Null");
95 }
96 }
97 Ok(())
98}