function_call/
function_call.rs

1use openai_api_rs::v1::api::OpenAIClient;
2use openai_api_rs::v1::chat_completion::{
3    chat_completion::ChatCompletionRequest, ChatCompletionMessage,
4};
5use openai_api_rs::v1::chat_completion::{
6    Content, FinishReason, MessageRole, Tool, ToolChoiceType, ToolType,
7};
8use openai_api_rs::v1::common::GPT4_O;
9use openai_api_rs::v1::types;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::{env, vec};
13
14fn get_coin_price(coin: &str) -> f64 {
15    let coin = coin.to_lowercase();
16    match coin.as_str() {
17        "btc" | "bitcoin" => 10000.0,
18        "eth" | "ethereum" => 1000.0,
19        _ => 0.0,
20    }
21}
22
23#[tokio::main]
24async fn main() -> Result<(), Box<dyn std::error::Error>> {
25    let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
26    let mut client = OpenAIClient::builder().with_api_key(api_key).build()?;
27
28    let mut properties = HashMap::new();
29    properties.insert(
30        "coin".to_string(),
31        Box::new(types::JSONSchemaDefine {
32            schema_type: Some(types::JSONSchemaType::String),
33            description: Some("The cryptocurrency to get the price of".to_string()),
34            ..Default::default()
35        }),
36    );
37
38    let req = ChatCompletionRequest::new(
39        GPT4_O.to_string(),
40        vec![ChatCompletionMessage {
41            role: MessageRole::user,
42            content: Content::Text(String::from("What is the price of Ethereum?")),
43            name: None,
44            tool_calls: None,
45            tool_call_id: None,
46        }],
47    )
48    .tools(vec![Tool {
49        r#type: ToolType::Function,
50        function: types::Function {
51            name: String::from("get_coin_price"),
52            description: Some(String::from("Get the price of a cryptocurrency")),
53            parameters: types::FunctionParameters {
54                schema_type: types::JSONSchemaType::Object,
55                properties: Some(properties),
56                required: Some(vec![String::from("coin")]),
57            },
58        },
59    }])
60    .tool_choice(ToolChoiceType::Auto);
61
62    // debug request json
63    // let serialized = serde_json::to_string(&req).unwrap();
64    // println!("{}", serialized);
65
66    let result = client.chat_completion(req).await?;
67
68    match result.choices[0].finish_reason {
69        None => {
70            println!("No finish_reason");
71            println!("{:?}", result.choices[0].message.content);
72        }
73        Some(FinishReason::stop) => {
74            println!("Stop");
75            println!("{:?}", result.choices[0].message.content);
76        }
77        Some(FinishReason::length) => {
78            println!("Length");
79        }
80        Some(FinishReason::tool_calls) => {
81            println!("ToolCalls");
82            #[derive(Deserialize, Serialize)]
83            struct Currency {
84                coin: String,
85            }
86            let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap();
87            for tool_call in tool_calls {
88                let name = tool_call.function.name.clone().unwrap();
89                let arguments = tool_call.function.arguments.clone().unwrap();
90                let c: Currency = serde_json::from_str(&arguments)?;
91                let coin = c.coin;
92                if name == "get_coin_price" {
93                    let price = get_coin_price(&coin);
94                    println!("{coin} price: {price}");
95                }
96            }
97        }
98        Some(FinishReason::content_filter) => {
99            println!("ContentFilter");
100        }
101        Some(FinishReason::null) => {
102            println!("Null");
103        }
104    }
105    Ok(())
106}
107
108// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example function_call