openai_func_enums/
lib.rs

1use async_openai::error::OpenAIError;
2use async_openai::types::{
3    ChatCompletionTool, ChatCompletionToolArgs, ChatCompletionToolType, FunctionObject,
4    FunctionObjectArgs,
5};
6use async_trait::async_trait;
7pub use openai_func_embeddings::*;
8pub use openai_func_enums_macros::*;
9use serde_json::Value;
10use std::error::Error;
11use std::fmt::{self, Debug};
12use std::sync::Arc;
13use tokio::sync::mpsc;
14
15/// A trait to provide a descriptor for an enumeration.
16/// This includes the name of the enum and the count of tokens in its name.
17pub trait EnumDescriptor {
18    /// Returns the name of the enum and the count of tokens in its name.
19    ///
20    /// # Returns
21    ///
22    /// A tuple where the first element is a `String` representing the name of the enum,
23    /// and the second element is a `usize` representing the count of tokens in the enum's name.
24    fn name_with_token_count() -> &'static (&'static str, usize);
25
26    fn arg_description_with_token_count() -> &'static (&'static str, usize);
27}
28
29pub trait ToolSet {}
30
31/// A trait to provide descriptors for the variants of an enumeration.
32/// This includes the names of the variants and the count of tokens in their names.
33pub trait VariantDescriptors {
34    /// Returns the names of the variants of the enum and the count of tokens in each variant's name.
35    ///
36    /// # Returns
37    ///
38    /// A `Vec` of tuples where each tuple's first element is a `String` representing the name of a variant,
39    /// and the second element is a `usize` representing the count of tokens in the variant's name.
40    fn variant_names_with_token_counts(
41    ) -> &'static (&'static [&'static str], &'static [usize], usize, usize);
42
43    /// Returns the name of a variant and the count of tokens in its name.
44    ///
45    /// # Returns
46    ///
47    /// A tuple where the first element is a `String` representing the name of the variant,
48    /// and the second element is a `usize` representing the count of tokens in the variant's name.
49    fn variant_name_with_token_count(&self) -> (&'static str, usize);
50}
51
52#[derive(Clone, Debug)]
53pub enum ToolCallExecutionStrategy {
54    Parallel,
55    Async,
56    Synchronous,
57}
58
59#[derive(Debug)]
60pub struct CommandError {
61    details: String,
62}
63
64impl CommandError {
65    pub fn new(msg: &str) -> CommandError {
66        CommandError {
67            details: msg.to_string(),
68        }
69    }
70}
71
72impl fmt::Display for CommandError {
73    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
74        write!(f, "{}", self.details)
75    }
76}
77
78impl Error for CommandError {}
79
80impl From<OpenAIError> for CommandError {
81    fn from(error: OpenAIError) -> Self {
82        CommandError::new(&format!("OpenAI Error: {}", error))
83    }
84}
85
86pub struct Logger {
87    pub sender: mpsc::Sender<String>,
88}
89
90impl Logger {
91    pub async fn log(&self, message: String) {
92        let _ = self.sender.send(message).await;
93    }
94}
95
96pub async fn logger_task(mut receiver: mpsc::Receiver<String>) {
97    while let Some(message) = receiver.recv().await {
98        println!("{}", message);
99    }
100}
101
102// There is a better way than to keep adding return types.
103// Trying to determine which road to go down on other issues first.
104#[async_trait]
105pub trait RunCommand: Sync + Send {
106    async fn run(
107        &self,
108        execution_strategy: ToolCallExecutionStrategy,
109        arguments: Option<Vec<String>>,
110        logger: Arc<Logger>,
111        system_message: Option<(String, usize)>,
112    ) -> Result<
113        (Option<String>, Option<Vec<String>>),
114        Box<dyn std::error::Error + Send + Sync + 'static>,
115    >;
116}
117
118/// A macro to parse a function call into a specified type.
119/// If the parsing fails, it prints an error message and returns `None`.
120///
121/// # Arguments
122///
123/// * `$func_call` - An expression representing the function call to parse.
124/// * `$type` - The target type to parse the function call into.
125#[macro_export]
126macro_rules! parse_function_call {
127    ($func_call:expr, $type:ty) => {
128        match serde_json::from_str::<$type>($func_call.arguments.as_str()) {
129            Ok(response) => Some(response),
130            Err(e) => {
131                println!("Failed to parse function call: {}", e);
132                None
133            }
134        }
135    };
136}
137
138/// A function to get the chat completion arguments for a function.
139///
140/// # Arguments
141///
142/// * `func` - A function that returns a JSON representation of a function and the count of tokens in the representation.
143///
144/// # Returns
145///
146/// * A `Result` which is `Ok` if the function chat completion arguments were successfully obtained, and `Err` otherwise.
147///   The `Ok` variant contains a tuple where the first element is a `ChatCompletionFunctions` representing the chat completion arguments for the function,
148///   and the second element is a `usize` representing the total count of tokens in the function's JSON representation.
149pub fn get_function_chat_completion_args(
150    func: impl Fn() -> (Value, usize),
151) -> Result<(Vec<FunctionObject>, usize), OpenAIError> {
152    let (func_json, total_tokens) = func();
153
154    let mut chat_completion_functions_vec = Vec::new();
155
156    let values = match func_json {
157        Value::Object(_) => vec![func_json],
158        Value::Array(arr) => arr,
159        _ => {
160            return Err(OpenAIError::InvalidArgument(String::from(
161                "Something went wrong parsing the json",
162            )))
163        }
164    };
165
166    for value in values {
167        let parameters = value.get("parameters").cloned();
168
169        let description = value
170            .get("description")
171            .and_then(|v| v.as_str())
172            .map(|s| s.to_string());
173
174        let name = value.get("name").unwrap().as_str().unwrap().to_string();
175        let chat_completion_args = match description {
176            Some(desc) => FunctionObjectArgs::default()
177                .name(name)
178                .description(desc)
179                .parameters(parameters)
180                .build()?,
181            None => FunctionObjectArgs::default()
182                .name(name)
183                .parameters(parameters)
184                .build()?,
185        };
186        chat_completion_functions_vec.push(chat_completion_args);
187    }
188
189    Ok((chat_completion_functions_vec, total_tokens))
190}
191
192/// A function to get the chat completion arguments for a tool.
193///
194/// # Arguments
195///
196/// * `tool_func` - A function that returns a JSON representation of a tool and the count of tokens in the representation.
197///
198/// # Returns
199///
200/// * A `Result` which is `Ok` if the tool chat completion arguments were successfully obtained, and `Err` otherwise.
201///   The `Ok` variant contains a tuple where the first element is a `ChatCompletionTool` representing the chat completion arguments for the tool,
202///   and the second element is a `usize` representing the total count of tokens in the tool's JSON representation.
203pub fn get_tool_chat_completion_args(
204    tool_func: impl Fn() -> (Value, usize),
205) -> Result<(Vec<ChatCompletionTool>, usize), OpenAIError> {
206    let (tool_json, total_tokens) = tool_func();
207
208    let mut chat_completion_tool_vec = Vec::new();
209
210    let values = match tool_json {
211        Value::Object(_) => vec![tool_json],
212        Value::Array(arr) => arr,
213        _ => {
214            return Err(OpenAIError::InvalidArgument(String::from(
215                "Something went wrong parsing the json",
216            )))
217        }
218    };
219
220    for value in values {
221        let parameters = value.get("parameters").cloned();
222
223        let description = value
224            .get("description")
225            .and_then(|v| v.as_str())
226            .map(|s| s.to_string());
227
228        let name = value.get("name").unwrap().as_str().unwrap().to_string();
229
230        if name != "GPT" {
231            let chat_completion_functions_args = match description {
232                Some(desc) => FunctionObjectArgs::default()
233                    .name(name)
234                    .description(desc)
235                    .parameters(parameters)
236                    .build()?,
237                None => FunctionObjectArgs::default()
238                    .name(name)
239                    .parameters(parameters)
240                    .build()?,
241            };
242
243            let chat_completion_tool = ChatCompletionToolArgs::default()
244                .r#type(ChatCompletionToolType::Function)
245                .function(chat_completion_functions_args)
246                .build()?;
247
248            chat_completion_tool_vec.push(chat_completion_tool);
249        }
250    }
251
252    Ok((chat_completion_tool_vec, total_tokens))
253}
254
255/// This function will get called if an "allowed_functions" argument is passed to the
256/// run function. If it is passed, then the presense or absence of the function_filtering
257/// feature flag will dictate what happens. If function_filtering is on, then the required
258/// functions (if some) will get included, then your ranked functions will get added until the
259/// token limit is reached. Without function_filtering feature enabled, all functions listed in
260/// allowed_func_names and required_func_names will get sent.
261
262/// Performs selective inclusion of tools based on the provided `allowed_func_names` and the state
263/// of the `function_filtering` feature flag. When `function_filtering` is enabled and `required_func_names`
264/// is specified, required functions are prioritized, followed by ranked functions until a token limit is reached.
265/// Without the `function_filtering` feature, all functions in `allowed_func_names` and `required_func_names`
266/// are included, irrespective of the token limit.
267///
268/// # Arguments
269/// - `tool_func`: A function that takes allowed and required function names, and returns tool JSON and total token count.
270/// - `allowed_func_names`: A list of function names allowed for inclusion.
271/// - `required_func_names`: An optional list of function names required for inclusion.
272///
273/// # Returns
274/// A result containing a vector of `ChatCompletionTool` objects and the total token count, or an `OpenAIError` on failure.
275///
276/// # Errors
277/// Returns an `OpenAIError::InvalidArgument` if there's an issue parsing the tool JSON.
278pub fn get_tools_limited(
279    tool_func: impl Fn(Vec<String>, Option<Vec<String>>) -> (Value, usize),
280    allowed_func_names: Vec<String>,
281    required_func_names: Option<Vec<String>>,
282) -> Result<(Vec<ChatCompletionTool>, usize), OpenAIError> {
283    let (tool_json, total_tokens) = tool_func(allowed_func_names, required_func_names);
284
285    let mut chat_completion_tool_vec = Vec::new();
286
287    let values = match tool_json {
288        Value::Object(_) => vec![tool_json],
289        Value::Array(arr) => arr,
290        _ => {
291            return Err(OpenAIError::InvalidArgument(String::from(
292                "Something went wrong parsing the json",
293            )))
294        }
295    };
296
297    for value in values {
298        let parameters = value.get("parameters").cloned();
299
300        let description = value
301            .get("description")
302            .and_then(|v| v.as_str())
303            .map(|s| s.to_string());
304
305        let name = value.get("name").unwrap().as_str().unwrap().to_string();
306
307        if name != "GPT" {
308            let chat_completion_functions_args = match description {
309                Some(desc) => FunctionObjectArgs::default()
310                    .name(name)
311                    .description(desc)
312                    .parameters(parameters)
313                    .build()?,
314                None => FunctionObjectArgs::default()
315                    .name(name)
316                    .parameters(parameters)
317                    .build()?,
318            };
319
320            let chat_completion_tool = ChatCompletionToolArgs::default()
321                .r#type(ChatCompletionToolType::Function)
322                .function(chat_completion_functions_args)
323                .build()?;
324
325            chat_completion_tool_vec.push(chat_completion_tool);
326        }
327    }
328
329    Ok((chat_completion_tool_vec, total_tokens))
330}