function_call/
function_call.rs

1use openai_rst::{
2    chat_completion::{
3        ChatCompletionMessage, ChatCompletionRequest, Content, FinishReason, Function,
4        FunctionParameters, JSONSchemaDefine, JSONSchemaType, Tool, ToolChoiceType, ToolType,
5    },
6    client::Client,
7    common::MessageRole,
8    models::{Model, GPT3},
9};
10use serde::{Deserialize, Serialize};
11use std::{collections::HashMap, vec};
12
13fn get_coin_price(coin: &str) -> f64 {
14    let coin = coin.to_lowercase();
15    match coin.as_str() {
16        "btc" | "bitcoin" => 10000.0,
17        "eth" | "ethereum" => 1000.0,
18        _ => 0.0,
19    }
20}
21
22#[tokio::main]
23async fn main() -> Result<(), Box<dyn std::error::Error>> {
24    let client = Client::from_env().unwrap();
25
26    let mut properties = HashMap::new();
27    properties.insert(
28        "coin".to_string(),
29        Box::new(JSONSchemaDefine {
30            schema_type: Some(JSONSchemaType::String),
31            description: Some("The cryptocurrency to get the price of".to_string()),
32            ..Default::default()
33        }),
34    );
35
36    let req = ChatCompletionRequest::new_multi(
37        Model::GPT3(GPT3::GPT35Turbo),
38        vec![ChatCompletionMessage {
39            role: MessageRole::User,
40            content: Content::Text(String::from("What is the price of Ethereum?")),
41            name: None,
42        }],
43    )
44    .tools(vec![Tool {
45        r#type: ToolType::Function,
46        function: Function {
47            name: String::from("get_coin_price"),
48            description: Some(String::from("Get the price of a cryptocurrency")),
49            parameters: FunctionParameters {
50                schema_type: JSONSchemaType::Object,
51                properties: Some(properties),
52                required: Some(vec![String::from("coin")]),
53            },
54        },
55    }])
56    .tool_choice(ToolChoiceType::Auto);
57
58    let result = client.chat_completion(req).await?;
59
60    match result.choices[0].finish_reason {
61        None => {
62            println!("No finish_reason");
63            println!("{:?}", result.choices[0].message.content);
64        }
65        Some(FinishReason::stop) => {
66            println!("Stop");
67            println!("{:?}", result.choices[0].message.content);
68        }
69        Some(FinishReason::length) => {
70            println!("Length");
71        }
72        Some(FinishReason::tool_calls) => {
73            println!("ToolCalls");
74            #[derive(Deserialize, Serialize)]
75            struct Currency {
76                coin: String,
77            }
78            let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap();
79            for tool_call in tool_calls {
80                let name = tool_call.function.name.clone().unwrap();
81                let arguments = tool_call.function.arguments.clone().unwrap();
82                let c: Currency = serde_json::from_str(&arguments)?;
83                let coin = c.coin;
84                if name == "get_coin_price" {
85                    let price = get_coin_price(&coin);
86                    println!("{} price: {}", coin, price);
87                }
88            }
89        }
90        Some(FinishReason::content_filter) => {
91            println!("ContentFilter");
92        }
93        Some(FinishReason::null) => {
94            println!("Null");
95        }
96    }
97    Ok(())
98}