function_call/
function_call.rs

1use openai_api_rs::v1::api::OpenAIClient;
2use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
3use openai_api_rs::v1::common::GPT4_O;
4use openai_api_rs::v1::types;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::{env, vec};
8
9fn get_coin_price(coin: &str) -> f64 {
10    let coin = coin.to_lowercase();
11    match coin.as_str() {
12        "btc" | "bitcoin" => 10000.0,
13        "eth" | "ethereum" => 1000.0,
14        _ => 0.0,
15    }
16}
17
18#[tokio::main]
19async fn main() -> Result<(), Box<dyn std::error::Error>> {
20    let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
21    let mut client = OpenAIClient::builder().with_api_key(api_key).build()?;
22
23    let mut properties = HashMap::new();
24    properties.insert(
25        "coin".to_string(),
26        Box::new(types::JSONSchemaDefine {
27            schema_type: Some(types::JSONSchemaType::String),
28            description: Some("The cryptocurrency to get the price of".to_string()),
29            ..Default::default()
30        }),
31    );
32
33    let req = ChatCompletionRequest::new(
34        GPT4_O.to_string(),
35        vec![chat_completion::ChatCompletionMessage {
36            role: chat_completion::MessageRole::user,
37            content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")),
38            name: None,
39            tool_calls: None,
40            tool_call_id: None,
41        }],
42    )
43    .tools(vec![chat_completion::Tool {
44        r#type: chat_completion::ToolType::Function,
45        function: types::Function {
46            name: String::from("get_coin_price"),
47            description: Some(String::from("Get the price of a cryptocurrency")),
48            parameters: types::FunctionParameters {
49                schema_type: types::JSONSchemaType::Object,
50                properties: Some(properties),
51                required: Some(vec![String::from("coin")]),
52            },
53        },
54    }])
55    .tool_choice(chat_completion::ToolChoiceType::Auto);
56
57    // debug request json
58    // let serialized = serde_json::to_string(&req).unwrap();
59    // println!("{}", serialized);
60
61    let result = client.chat_completion(req).await?;
62
63    match result.choices[0].finish_reason {
64        None => {
65            println!("No finish_reason");
66            println!("{:?}", result.choices[0].message.content);
67        }
68        Some(chat_completion::FinishReason::stop) => {
69            println!("Stop");
70            println!("{:?}", result.choices[0].message.content);
71        }
72        Some(chat_completion::FinishReason::length) => {
73            println!("Length");
74        }
75        Some(chat_completion::FinishReason::tool_calls) => {
76            println!("ToolCalls");
77            #[derive(Deserialize, Serialize)]
78            struct Currency {
79                coin: String,
80            }
81            let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap();
82            for tool_call in tool_calls {
83                let name = tool_call.function.name.clone().unwrap();
84                let arguments = tool_call.function.arguments.clone().unwrap();
85                let c: Currency = serde_json::from_str(&arguments)?;
86                let coin = c.coin;
87                if name == "get_coin_price" {
88                    let price = get_coin_price(&coin);
89                    println!("{} price: {}", coin, price);
90                }
91            }
92        }
93        Some(chat_completion::FinishReason::content_filter) => {
94            println!("ContentFilter");
95        }
96        Some(chat_completion::FinishReason::null) => {
97            println!("Null");
98        }
99    }
100    Ok(())
101}
102
103// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example function_call