openai_func_enums_macros/
lib.rs

1use proc_macro::{TokenStream, TokenTree};
2use quote::{format_ident, quote, ToTokens};
3use syn::{parse_macro_input, Attribute, Data, DeriveInput, Expr, Ident, Lit, Meta};
4
5#[cfg(any(
6    feature = "compile_embeddings_all",
7    feature = "compile_embeddings_update"
8))]
9use async_openai::{types::CreateEmbeddingRequestArgs, Client};
10
11#[cfg(any(
12    feature = "compile_embeddings_all",
13    feature = "compile_embeddings_update"
14))]
15use std::io::Write;
16
17/// The `arg_description` attribute is a procedural macro used to provide additional description for an enum.
18///
19/// This attribute does not modify the code it annotates but instead attaches metadata in the form of a description.
20/// This isn't required. It will be passed as a description if present, which can result in better
21/// argument selection.
22///
23/// # Usage
24///
25/// ```rust
26/// #[arg_description(description = "This is a sample enum.", tokens = 5)]
27/// #[derive(EnumDescriptor)]
28/// pub enum SampleEnum {
29///     Variant1,
30///     Variant2,
31/// }
32/// ```
33///
34/// Note: The actual usage of the description and tokens provided through this attribute happens
35/// in the `EnumDescriptor` derive macro and is retrieved in the `enum_descriptor_derive` function.
36///
37/// The `arg_description` attribute takes one argument, `description`, which is a string literal.
38#[proc_macro_attribute]
39pub fn arg_description(_args: TokenStream, input: TokenStream) -> TokenStream {
40    input
41}
42
43/// A derive procedural macro for the `EnumDescriptor` trait.
44///
45/// The `EnumDescriptor` trait should have a function `name_with_token_count`
46/// that returns a tuple with the name of the enum type as a string and the
47/// token count for the name as an `usize`.
48///
49/// This procedural macro generates an implementation of `EnumDescriptor` for
50/// the type on which it's applied. The `name_with_token_count` function, in the
51/// generated implementation, returns the name of the type and its token count.
52///
53/// # Usage
54///
55/// Use the `#[derive(EnumDescriptor)]` attribute on an enum to derive the
56/// `EnumDescriptor` trait for it.
57///
58/// ```
59/// #[derive(EnumDescriptor)]
60/// enum MyEnum {
61///     Variant1,
62///     Variant2,
63/// }
64/// ```
65///
66/// This will generate:
67///
68/// ```
69/// impl EnumDescriptor for MyEnum {
70///     fn name_with_token_count() -> (String, usize) {
71///         (String::from("MyEnum"), /* token count of "MyEnum" */)
72///     }
73/// }
74/// ```
75///
76/// The actual token count is computed during compile time using the
77/// `calculate_token_count` function.
78#[proc_macro_derive(EnumDescriptor, attributes(arg_description))]
79pub fn enum_descriptor_derive(input: TokenStream) -> TokenStream {
80    let DeriveInput { ident, attrs, .. } = parse_macro_input!(input as DeriveInput);
81
82    let name_str = ident.to_string();
83    let name_token_count = calculate_token_count(&name_str);
84
85    let mut description: &'static str = "";
86    let mut desc_tokens = 0_usize;
87
88    for attr in &attrs {
89        if attr.path().is_ident("arg_description") {
90            let _result = attr.parse_nested_meta(|meta| {
91                let content = meta.input;
92
93                if !content.is_empty() {
94                    if meta.path.is_ident("description") {
95                        let value = meta.value()?;
96                        if let Ok(Lit::Str(value)) = value.parse() {
97                            description = Box::leak(value.value().into_boxed_str());
98                            desc_tokens = calculate_token_count(description);
99                        }
100                    }
101                    return Ok(());
102                }
103
104                Err(meta.error("unrecognized my_attribute"))
105            });
106
107            if _result.is_err() {
108                println!("Error parsing attribute:   {:#?}", _result);
109            }
110        }
111    }
112
113    let expanded = quote! {
114        impl openai_func_enums::EnumDescriptor for #ident {
115            fn name_with_token_count() -> &'static (&'static str, usize) {
116                static NAME_DATA: (&'static str, usize) = (stringify!(#ident), #name_token_count);
117                &NAME_DATA
118            }
119
120            fn arg_description_with_token_count() -> &'static (&'static str, usize) {
121                static DESC_DATA: (&'static str, usize) = (#description, #desc_tokens);
122                &DESC_DATA
123            }
124        }
125    };
126
127    TokenStream::from(expanded)
128}
129
130/// A derive procedural macro for the `VariantDescriptors` trait.
131///
132/// This macro generates an implementation of the `VariantDescriptors` trait for
133/// an enum. The trait provides two methods:
134///
135/// 1. `variant_names_with_token_counts`: Returns a `Vec` containing tuples,
136/// each with a string representation of a variant's name and its token count.
137///
138/// 2. `variant_name_with_token_count`: Takes an enum variant as input and
139/// returns a tuple with the variant's name as a string and its token count.
140///
141/// Note: This macro will panic if it is used on anything other than an enum.
142///
143/// # Usage
144///
145/// ```
146/// #[derive(VariantDescriptors)]
147/// enum MyEnum {
148///     Variant1,
149///     Variant2,
150/// }
151/// ```
152///
153/// This will generate the following:
154///
155/// ```
156/// impl VariantDescriptors for MyEnum {
157///     fn variant_names_with_token_counts() -> Vec<(String, usize)> {
158///         vec![
159///             (String::from("Variant1"), /* token count of "Variant1" */),
160///             (String::from("Variant2"), /* token count of "Variant2" */),
161///         ]
162///     }
163///
164///     fn variant_name_with_token_count(&self) -> (String, usize) {
165///         match self {
166///             Self::Variant1 => (String::from("Variant1"), /* token count of "Variant1" */),
167///             Self::Variant2 => (String::from("Variant2"), /* token count of "Variant2" */),
168///         }
169///     }
170/// }
171/// ```
172///
173/// The actual token count is computed during compile time using the
174/// `calculate_token_count` function.
175#[proc_macro_derive(VariantDescriptors)]
176pub fn variant_descriptors_derive(input: TokenStream) -> TokenStream {
177    let ast = parse_macro_input!(input as DeriveInput);
178
179    let enum_name = &ast.ident;
180
181    let variants = if let syn::Data::Enum(ref e) = ast.data {
182        e.variants
183            .iter()
184            .map(|v| {
185                let variant_name = &v.ident;
186                let token_count = calculate_token_count(&variant_name.to_string());
187
188                (variant_name, token_count)
189            })
190            .collect::<Vec<_>>()
191    } else {
192        panic!("VariantDescriptors can only be used with enums");
193    };
194
195    let variant_name_with_token_count: Vec<_> = variants
196        .iter()
197        .map(|(variant_name, token_count)| {
198            quote! { Self::#variant_name => (stringify!(#variant_name), #token_count) }
199        })
200        .collect();
201
202    let variant_names: Vec<_> = variants
203        .iter()
204        .map(|(variant_name, _)| quote! { stringify!(#variant_name) })
205        .collect();
206
207    let variant_name_additional_tokens = variant_names.len() * 3;
208
209    let token_counts: Vec<_> = variants
210        .iter()
211        .map(|(_, token_count)| quote! { #token_count })
212        .collect();
213
214    let total_token_count = variants
215        .iter()
216        .map(|(_, token_count)| *token_count)
217        .sum::<usize>();
218
219    let expanded = quote! {
220        impl VariantDescriptors for #enum_name {
221            fn variant_names_with_token_counts() -> &'static (&'static [&'static str], &'static [usize], usize, usize) {
222                static VARIANT_DATA: (&'static [&'static str], &'static [usize], usize, usize) = (
223                    &[#(#variant_names),*],
224                    &[#(#token_counts),*],
225                    #total_token_count,
226                    #variant_name_additional_tokens
227                );
228                &VARIANT_DATA
229            }
230
231            fn variant_name_with_token_count(&self) -> (&'static str, usize) {
232                match self {
233                    #(#variant_name_with_token_count,)*
234                }
235            }
236        }
237    };
238
239    TokenStream::from(expanded)
240}
241
242/// A procedural macro to generate JSON information about an enum, including its name,
243/// variant names, and descriptions, along with a total token count.
244///
245/// This macro leverages the `EnumDescriptor` and `VariantDescriptors` traits to extract
246/// details about an enum. It compiles these details into a JSON format and calculates
247/// an associated token count based on the structure of the generated JSON. The token count
248/// is an estimation of how many tokens are needed to represent the enum information in a
249/// serialized format, considering the syntax and spacing of JSON.
250///
251/// The macro returns a tuple containing the generated JSON object and the estimated total
252/// token count.
253///
254/// # Usage
255///
256/// When applied to an enum, the macro generates code similar to the following example:
257///
258/// ```rust
259/// {
260///     use serde_json::Value;
261///     let mut token_count = 0;
262///
263///     // Description and token count for the enum's argument (if applicable)
264///     let (arg_desc, arg_tokens) = <MyEnum as EnumDescriptor>::arg_description_with_token_count();
265///     token_count += 6; // Base tokens for argument declaration
266///     token_count += arg_tokens; // Tokens for the argument description
267///
268///     // Enum name and its token count
269///     let enum_name = <MyEnum as EnumDescriptor>::name_with_token_count();
270///     token_count += 6; // Base tokens for enum name declaration
271///     token_count += enum_name.1; // Tokens for the enum name
272///
273///     // Base tokens for enum and type declarations
274///     token_count += 7; // Enum declaration
275///     token_count += 7; // Type declaration
276///
277///     // Variant names and their token counts
278///     let enum_variants = <MyEnum as VariantDescriptors>::variant_names_with_token_counts();
279///     // Adding 3 tokens for each variant for proper JSON formatting
280///     token_count += enum_variants.iter().map(|(_, token_count_i)| *token_count_i + 3).sum::<usize>();
281///
282///     // Constructing the JSON object with enum details
283///     let json_enum = serde_json::json!({
284///         enum_name.0: {
285///             "type": "string",
286///             "enum": enum_variants.iter().map(|(name, _)| name.clone()).collect::<Vec<_>>(),
287///             "description": arg_desc,
288///         }
289///     });
290///
291///     (json_enum, token_count)
292/// }
293/// ```
294///
295/// ## Token Count Estimation Details
296///
297/// The estimation of tokens for the generated JSON includes:
298/// - **Base tokens for argument and enum declarations**: A fixed count to account for the JSON structure around the enum and its arguments.
299/// - **Dynamic tokens for the enum name and argument descriptions**: Calculated based on the length and structure of the enum name and argument descriptions.
300/// - **Tokens for each enum variant**: Includes a fixed addition for JSON formatting alongside the variant names.
301///
302/// This approach ensures a precise estimation of the token count required to represent the enum information in JSON, facilitating accurate serialization.
303///
304/// Note: The enum must implement the `EnumDescriptor` and `VariantDescriptors` traits for the macro to function correctly. The actual token count is computed at compile time using these traits' methods.
305#[proc_macro]
306pub fn generate_enum_info(input: TokenStream) -> TokenStream {
307    let enum_ident = parse_macro_input!(input as Ident);
308
309    let output = quote! {
310        {
311            let ARG_DESC_AND_TOKENS: &'static (&'static str, usize) = <#enum_ident as openai_func_enums::EnumDescriptor>::arg_description_with_token_count();
312            let ENUM_NAME_AND_TOKENS: &'static (&'static str, usize) = <#enum_ident as openai_func_enums::EnumDescriptor>::name_with_token_count();
313            let ENUM_VARIANTS_INFO: &'static (&'static [&'static str], &'static [usize], usize, usize) = <#enum_ident as openai_func_enums::VariantDescriptors>::variant_names_with_token_counts();
314
315            let token_count = 6 + ARG_DESC_AND_TOKENS.1 + 6 + ENUM_NAME_AND_TOKENS.1 + 7 + 7 + ENUM_VARIANTS_INFO.2 + ENUM_VARIANTS_INFO.3;
316
317            let json_enum: Value = serde_json::json!({
318                ENUM_NAME_AND_TOKENS.0: {
319                    "type": "string",
320                    "enum": ENUM_VARIANTS_INFO.0.iter().map(|name| *name).collect::<Vec<_>>(),
321                    "description": ARG_DESC_AND_TOKENS.0,
322                }
323            });
324
325            (json_enum, token_count)
326        }
327    };
328
329    output.into()
330}
331
332#[proc_macro]
333pub fn generate_value_arg_info(input: TokenStream) -> TokenStream {
334    let mut type_and_name_values = Vec::new();
335
336    let tokens = input.into_iter().collect::<Vec<TokenTree>>();
337    for token in tokens {
338        if let TokenTree::Ident(ident) = &token {
339            type_and_name_values.push(ident.to_string());
340        }
341    }
342
343    let output = if type_and_name_values.len() == 2 {
344        let name = &type_and_name_values[1];
345        let type_name = &type_and_name_values[0];
346
347        let name_tokens = calculate_token_count(name);
348        let type_name_tokens = calculate_token_count(type_name);
349        let mut total_tokens = name_tokens + type_name_tokens;
350
351        let json_string = if type_name == "array" {
352            total_tokens += 22;
353            format!(
354                r#"{{"{}": {{"type": "array", "items": {{"type": "string"}}}}}}"#,
355                name
356            )
357        } else {
358            total_tokens += 11;
359            format!(r#"{{"{}": {{"type": "{}"}}}}"#, name, type_name)
360        };
361
362        quote! {
363            {
364                static JSON_STR: &str = #json_string;
365                let json_enum: serde_json::Value = serde_json::from_str(JSON_STR).unwrap();
366                (json_enum, #total_tokens)
367            }
368        }
369    } else {
370        quote! {}
371    };
372
373    output.into()
374}
375
376/// This procedural macro attribute is used to specify a description for an enum variant.
377///
378/// The `func_description` attribute does not modify the input it is given.
379/// It's only used to attach metadata (i.e., a description) to enum variants.
380///
381/// # Usage
382///
383/// ```rust
384/// enum MyEnum {
385///     #[func_description(description="This function does a thing.")]
386///     DoAThing,
387///     #[func_description(description="This function does another thing.")]
388///     DoAnotherThing,
389/// }
390/// ```
391///
392/// Note: The actual usage of the description provided through this attribute happens
393/// in the `FunctionCallResponse` derive macro and is retrieved in the `impl_function_call_response` function.
394#[deprecated(since = "0.3.0", note = "Use a doc string --> '///'.")]
395#[proc_macro_attribute]
396pub fn func_description(_args: TokenStream, input: TokenStream) -> TokenStream {
397    input
398}
399
400/// The `ToolSet` procedural macro is used to derive a structure
401/// which encapsulates various chat completion commands.
402///
403/// This macro should be applied to an enum. It generates various supporting
404/// structures and methods, including structures representing the command arguments,
405/// methods for converting between the argument structures and the original enum,
406/// JSON conversion methods, and an implementation of the original enum that provides
407/// methods for executing the commands and dealing with the responses.
408///
409/// Each variant of the original enum will be converted into a corresponding structure,
410/// and each field in the variant will become a field in the generated structure.
411/// The generated structures will derive `serde::Deserialize` and `Debug` automatically.
412///
413/// This macro also generates methods for calculating the token count of a string and
414/// for executing commands based on function calls received from the chat API.
415///
416/// The types of fields in the enum variants determine how the corresponding fields in the
417/// generated structures are treated. For example, fields of type `String` or `&str` are
418/// converted to JSON value arguments with type `"string"`, while fields of type `u8`, `u16`,
419/// `u32`, `u64`, `usize`, `i8`, `i16`, `i32`, `i64`, `isize`, `f32` or `f64` are converted
420/// to JSON value arguments with type `"integer"` or `"number"` respectively.
421/// For fields with a tuple type, currently this macro simply prints that the field is of a tuple type.
422/// For fields with an array type, they are converted to JSON value arguments with type `"array"`.
423///
424/// When running the chat command, a custom system message can be optionally provided.
425/// If provided, this message will be used as the system message in the chat request.
426/// If not provided, a default system message will be used.
427///
428/// If the total token count of the request exceeds a specified limit, an error will be returned.
429///
430/// The `derive_subcommand_gpt` function consumes a `TokenStream` representing the enum
431/// to which the macro is applied and produces a `TokenStream` representing the generated code.
432///
433/// # Panics
434/// This macro will panic (only at compile time) if it is applied to a non-enum item.
435#[proc_macro_derive(ToolSet)]
436pub fn derive_subcommand_gpt(input: TokenStream) -> TokenStream {
437    let input = parse_macro_input!(input as DeriveInput);
438
439    let name = input.ident;
440
441    let data = match input.data {
442        Data::Enum(data) => data,
443        _ => panic!("ToolSet can only be implemented for enums"),
444    };
445
446    let mut generated_structs = Vec::new();
447    let mut json_generator_functions = Vec::new();
448
449    let mut generated_clap_gpt_enum: Vec<proc_macro2::TokenStream> = Vec::new();
450    let mut generated_struct_names = Vec::new();
451
452    #[cfg(any(
453        feature = "compile_embeddings_all",
454        feature = "compile_embeddings_update"
455    ))]
456    let rt = tokio::runtime::Runtime::new().unwrap();
457
458    #[cfg(any(
459        feature = "compile_embeddings_all",
460        feature = "compile_embeddings_update",
461        feature = "function_filtering"
462    ))]
463    let embed_path = std::env::var("FUNC_ENUMS_EMBED_PATH")
464        .expect("Functionality for embeddings requires environment variable FUNC_ENUMS_EMBED_PATH to be set.");
465
466    #[cfg(not(any(
467        feature = "compile_embeddings_all",
468        feature = "compile_embeddings_update",
469        feature = "function_filtering"
470    )))]
471    let embed_path = "";
472
473    #[cfg(any(
474        feature = "compile_embeddings_all",
475        feature = "compile_embeddings_update",
476        feature = "function_filtering"
477    ))]
478    let embed_model = std::env::var("FUNC_ENUMS_EMBED_MODEL")
479        .expect("Functionality for embeddings requires environment variable FUNC_ENUMS_EMBED_MODEL to be set.");
480
481    // We can set const values, we can have feature flags, but
482    // the compiler will not allow us to maybe set a const behind
483    // a feature flag.
484    #[cfg(not(any(
485        feature = "compile_embeddings_all",
486        feature = "compile_embeddings_update",
487        feature = "function_filtering"
488    )))]
489    let embed_model = "";
490
491    let max_response_tokens: u16 = std::env::var("FUNC_ENUMS_MAX_RESPONSE_TOKENS")
492        .expect("Environment variable FUNC_ENUMS_MAX_RESPONSE_TOKENS is required. See build.rs files in the examples.")
493        .parse()
494        .expect("Failed to parse u16 value from FUNC_ENUMS_MAX_RESPONSE_TOKENS");
495
496    let max_request_tokens: usize = std::env::var("FUNC_ENUMS_MAX_REQUEST_TOKENS")
497        .expect("Environment variable FUNC_ENUMS_MAX_REQUEST_TOKENS is required. See build.rs files in the examples.")
498        .parse()
499        .expect("Failed to parse usize value from FUNC_ENUMS_MAX_REQUEST_TOKENS");
500
501    let max_func_tokens: u16 =std::env::var("FUNC_ENUMS_MAX_FUNC_TOKENS")
502        .expect("Environment variable FUNC_ENUMS_MAX_FUNC_TOKENS is required. See build.rs files in the examples.")
503        .parse()
504        .expect("Failed to parse u16 value from FUNC_ENUMS_MAX_FUNC_TOKENS");
505
506    let max_single_arg_tokens: u16 = std::env::var("FUNC_ENUMS_MAX_SINGLE_ARG_TOKENS") 
507        .expect("Environment variable FUNC_ENUMS_MAX_SINGLE_ARG_TOKENS is required. See build.rs files in the examples.")
508        .parse()
509        .expect("Failed to parse u16 value from FUNC_ENUMS_MAX_SINGLE_ARG_TOKENS");
510
511    #[cfg(any(
512        feature = "compile_embeddings_all",
513        feature = "compile_embeddings_update"
514    ))]
515    let mut embeddings: Vec<openai_func_embeddings::FuncEmbedding> = Vec::new();
516
517    #[cfg(feature = "compile_embeddings_update")]
518    {
519        if Path::new(&embed_path).exists() {
520            let mut file = std::fs::File::open(&embed_path).unwrap();
521            let mut bytes = Vec::new();
522            file.read_to_end(&mut bytes).unwrap();
523            let archived_data = rkyv::check_archived_root::<Vec<FuncEmbedding>>(&bytes).unwrap();
524            embeddings = archived_data.deserialize(&mut rkyv::Infallible).unwrap();
525        }
526    }
527
528    let mut has_gpt_variant = false;
529    // TODO: make this setable:
530    let gpt_variant_name = "GPT";
531    for variant in data.variants.iter() {
532        let variant_name = &variant.ident;
533        if variant_name.to_string() == gpt_variant_name {
534            has_gpt_variant = true;
535        }
536
537        let struct_name = format_ident!("{}", variant_name);
538        let struct_name_tokens = calculate_token_count(struct_name.to_string().as_str());
539        generated_struct_names.push(struct_name.clone());
540        let mut variant_desc = String::new();
541        let mut variant_desc_tokens = 0_usize;
542
543        for variant_attrs in &variant.attrs {
544            let description = get_comment_from_attr(variant_attrs);
545            if let Some(description) = description {
546                variant_desc = description;
547                variant_desc_tokens = calculate_token_count(variant_desc.as_str());
548
549                // TODO: Do a default, show a helpful error message, do something, you will forget
550                #[cfg(feature = "compile_embeddings_all")]
551                {
552                    println!("Writing embeddings");
553                    let mut name_and_desc = variant_name.to_string();
554                    name_and_desc.push(':');
555                    name_and_desc.push_str(&variant_desc);
556
557                    rt.block_on(async {
558                        let embedding = get_single_embedding(&name_and_desc, &embed_model).await;
559                        if let Ok(embedding) = embedding {
560                            let data = openai_func_embeddings::FuncEmbedding {
561                                name: variant_name.to_string(),
562                                description: variant_desc.clone(),
563                                embedding,
564                            };
565
566                            embeddings.push(data);
567                        }
568                    });
569                }
570
571                #[cfg(feature = "compile_embeddings_update")]
572                {
573                    let mut name_and_desc = variant_name.to_string();
574                    name_and_desc.push(':');
575                    name_and_desc.push_str(&variant_desc);
576
577                    rt.block_on(async {
578                        let mut existing = embeddings.iter().find(|x| x.name == name);
579
580                        if let Some(existing) = existing {
581                            if existing.description != variant_desc {
582                                let embedding =
583                                    get_single_embedding(&name_and_desc, &embed_model).await;
584
585                                if let Ok(embedding) = embedding {
586                                    existing.description = variant_desc.clone();
587                                    existing.embedding = embedding;
588                                }
589                            }
590                        } else {
591                            let embedding =
592                                get_single_embedding(&name_and_desc, &embed_model).await;
593                            if let Ok(embedding) = embedding {
594                                let data = FuncEmbedding {
595                                    name: variant_name.to_string(),
596                                    description: variant_desc.clone(),
597                                    embedding,
598                                };
599
600                                embeddings.push(data);
601                            }
602                        }
603                    });
604                }
605            }
606        }
607
608        #[cfg(any(
609            feature = "compile_embeddings_all",
610            feature = "compile_embeddings_update"
611        ))]
612        {
613            let serialized_data = rkyv::to_bytes::<_, 256>(&embeddings).unwrap();
614            let mut file = std::fs::File::create(&embed_path).unwrap();
615            file.write_all(&serialized_data).unwrap();
616        }
617
618        let fields: Vec<_> = variant
619            .fields
620            .iter()
621            .map(|f| {
622                // If the field has an identifier (i.e., it is a named field),
623                // use it. Otherwise, use the type as the name.
624                let field_name = if let Some(ident) = &f.ident {
625                    format_ident!("{}", ident)
626                } else {
627                    format_ident!("{}", to_snake_case(&f.ty.to_token_stream().to_string()))
628                };
629                let field_type = &f.ty;
630                quote! {
631                    pub #field_name: #field_type,
632                }
633            })
634            .collect();
635
636        let execute_command_parameters: Vec<_> = variant
637            .fields
638            .iter()
639            .map(|field| {
640                let field_name = &field.ident;
641                quote! { #field_name: self.#field_name.clone() }
642            })
643            .collect();
644
645        let number_type = "number";
646        let number_ident = format_ident!("{}", number_type);
647        let integer_type = "integer";
648        let integer_ident = format_ident!("{}", integer_type);
649        let string_type = "string";
650        let string_ident = format_ident!("{}", string_type);
651        let array_type = "array";
652        let array_ident = format_ident!("{}", array_type);
653
654        let field_info: Vec<_> = variant
655            .fields
656            .iter()
657            .map(|f| {
658                let field_name = if let Some(ident) = &f.ident {
659                    format_ident!("{}", ident)
660                } else {
661                    format_ident!("{}", to_snake_case(&f.ty.to_token_stream().to_string()))
662                };
663                let field_type = &f.ty;
664
665                match field_type {
666                    syn::Type::Path(typepath) if typepath.qself.is_none() => {
667                        let type_ident = &typepath.path.segments.last().unwrap().ident;
668
669                        match type_ident.to_string().as_str() {
670                            "f32" | "f64" => {
671                                return quote! {
672                                    generate_value_arg_info!(#number_ident, #field_name)
673                                };
674                            }
675                            "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | "i8" | "i16"
676                            | "i32" | "i64" | "i128" | "isize" => {
677                                return quote! {
678                                    generate_value_arg_info!(#integer_ident, #field_name)
679                                };
680                            }
681                            "String" | "&str" => {
682                                return quote! {
683                                    generate_value_arg_info!(#string_ident, #field_name)
684                                };
685                            }
686                            "Vec" => {
687                                return quote! {
688                                    generate_value_arg_info!(#array_ident, #field_name)
689                                };
690                            }
691                            _ => {
692                                return quote! {
693                                    openai_func_enums::generate_enum_info!(#field_type)
694                                };
695                            }
696                        }
697                    }
698                    syn::Type::Tuple(_) => {
699                        println!("Field {} is of tuple type", field_name);
700                    }
701                    syn::Type::Array(_) => {
702                        println!("Field {} is of array type", field_name);
703                        return quote! {
704                            generate_value_arg_info!(#array_ident, #field_name)
705                        };
706                    }
707                    _ => {
708                        println!("Field {} is of another type.", field_name);
709                    }
710                }
711                quote! {}
712            })
713            .collect();
714
715        json_generator_functions.push(quote! {
716            impl #struct_name {
717                pub fn name() -> String {
718                    stringify!(#struct_name).to_string()
719                }
720
721                pub fn to_function_call() -> ChatCompletionFunctionCall {
722                    ChatCompletionFunctionCall::Function {
723                        name: stringify!(#struct_name).to_string(),
724                    }
725                }
726
727                pub fn to_tool_choice() -> ChatCompletionToolChoiceOption {
728                    ChatCompletionToolChoiceOption::Named(ChatCompletionNamedToolChoice {
729                        r#type: ChatCompletionToolType::Function,
730                        function: FunctionName { name: stringify!(#struct_name).to_string() }
731                    })
732                }
733
734                pub fn execute_command(&self) -> #name {
735                    #name::#variant_name {
736                        #(#execute_command_parameters),*
737                    }
738                }
739
740                // Bake this in. Can be much faster.
741                pub fn get_function_json() -> (serde_json::Value, usize) {
742                    let mut parameters = serde_json::Map::new();
743                    let mut total_tokens = 0;
744
745                    for (arg_json, arg_tokens) in vec![#(#field_info),*] {
746                        total_tokens += arg_tokens;
747                        total_tokens += 3;
748
749                        parameters.insert(
750                            arg_json.as_object().unwrap().keys().next().unwrap().clone(),
751                            arg_json
752                                .as_object()
753                                .unwrap()
754                                .values()
755                                .next()
756                                .unwrap()
757                                .clone(),
758                        );
759                    }
760
761                    let function_json = serde_json::json!({
762                        "name": stringify!(#struct_name),
763                        "description": #variant_desc,
764                        "parameters": {
765                            "type": "object",
766                            "properties": parameters,
767                            "required": parameters.keys().collect::<Vec<_>>()
768                        }
769                    });
770
771                    total_tokens += 43;
772                    total_tokens += #struct_name_tokens;
773                    total_tokens += #variant_desc_tokens;
774
775                    (function_json, total_tokens)
776                }
777            }
778        });
779
780        generated_structs.push(quote! {
781            #[derive(Clone, serde::Deserialize, Debug)]
782            pub struct #struct_name {
783                #(#fields)*
784            }
785        });
786    }
787
788    if !has_gpt_variant {
789        panic!("Enums that derive ToolSet must define a variant called 'GPT'.")
790    }
791
792    let all_function_calls = quote! {
793        pub fn all_function_jsons() -> (serde_json::Value, usize) {
794            let results = vec![#(#generated_struct_names::get_function_json(),)*];
795            let combined_json = serde_json::Value::Array(results.iter().map(|(json, _)| json.clone()).collect());
796            let total_tokens = results.iter().map(|(_, tokens)| tokens).sum();
797            (combined_json, total_tokens)
798        }
799
800        pub fn function_jsons_under_limit(ranked_func_names: Vec<String>) -> (serde_json::Value, usize) {
801            let results = vec![#(#generated_struct_names::get_function_json(),)*];
802
803            let limit = #max_func_tokens as usize;
804            let (functions_to_present, total_tokens) = results.into_iter().fold(
805                (vec![], 0_usize),
806                |(mut acc, token_count), (json, tokens)| {
807                    if token_count + tokens <= limit {
808                        acc.push((json.clone(), tokens));
809                        (acc, token_count + tokens)
810                    } else {
811                        (acc, token_count)
812                    }
813                },
814            );
815
816            let combined_json = serde_json::Value::Array(functions_to_present.iter().map(|(json, _)| json.clone()).collect());
817            (combined_json, total_tokens)
818        }
819
820        pub fn function_jsons_allowed_with_required(
821            allowed_func_names: Vec<String>,
822            required_func_names: Option<Vec<String>>
823        ) -> (serde_json::Value, usize) {
824            let results = vec![#(#generated_struct_names::get_function_json(),)*];
825            let required_func_names = required_func_names.unwrap_or_default();
826
827            // Take the vector of what has to be there just for it to function and add the ranked
828            // functions to it, skipping ranked ones if it is already in the required list.
829            let updated_func_names = required_func_names.iter()
830                .chain(allowed_func_names.iter().filter(|name| !required_func_names.contains(name)))
831                .cloned()
832                .collect::<Vec<String>>();
833
834            let (functions_to_present, total_tokens) = updated_func_names.iter()
835                .filter_map(|name| results.iter().find(|(json, _)| json["name"] == *name))
836                .fold((vec![], 0_usize), |(mut acc, token_count), (json, tokens)| {
837                    acc.push((json.clone(), tokens));
838                    (acc, token_count + tokens)
839                });
840
841            let combined_json = serde_json::Value::Array(functions_to_present.iter().map(|(json, _)| json.clone()).collect());
842            (combined_json, total_tokens)
843        }
844
845        pub fn function_jsons_with_required_under_limit(
846            ranked_func_names: Vec<String>,
847            required_func_names: Option<Vec<String>>
848        ) -> (serde_json::Value, usize) {
849            let results = vec![#(#generated_struct_names::get_function_json(),)*];
850            let required_func_names = required_func_names.unwrap_or_default();
851
852            // Take the vector of what has to be there just for it to function and add the ranked
853            // functions to it, skipping ranked ones if it is already in the required list.
854            let updated_func_names = required_func_names.iter()
855                .chain(ranked_func_names.iter().filter(|name| !required_func_names.contains(name)))
856                .cloned()
857                .collect::<Vec<String>>();
858
859            let limit = #max_func_tokens as usize;
860
861            let (functions_to_present, total_tokens) = updated_func_names.iter()
862                .filter_map(|name| results.iter().find(|(json, _)| json["name"] == *name))
863                .fold((vec![], 0_usize), |(mut acc, token_count), (json, tokens)| {
864                    if token_count + tokens <= limit {
865                        acc.push((json.clone(), tokens));
866                        (acc, token_count + tokens)
867                    } else {
868                        (acc, token_count)
869                    }
870                });
871
872            let combined_json = serde_json::Value::Array(functions_to_present.iter().map(|(json, _)| json.clone()).collect());
873            (combined_json, total_tokens)
874        }
875    };
876
877    {
878        generated_clap_gpt_enum.push(quote! {
879            pub enum CommandsGPT {
880                GPT { a: String },
881            }
882        });
883    }
884
885    let struct_names: Vec<String> = generated_struct_names
886        .iter()
887        .map(|name| format!("{}", name))
888        .collect();
889
890    let match_arms: Vec<_> = generated_struct_names
891        .iter()
892        .map(|struct_name| {
893            let response_name = format_ident!("{}", struct_name);
894
895            quote! {
896                Ok(FunctionResponse::#response_name(response)) => {
897                    let result = response.execute_command();
898                    let command_clone = command.clone();
899                    let custom_system_message_clone = custom_system_message.clone();
900                    let logger_clone = logger.clone();
901                    let command_lock = command_clone.lock().await;
902                    let command_inner_value = command_lock.as_ref().cloned();
903                    drop(command_lock);
904
905                    let run_result = result.run(execution_strategy_clone, command_inner_value, logger_clone, custom_system_message_clone).await;
906                    match run_result {
907                        Ok(run_result) => {
908                            {
909                                let prior_result_clone = prior_result.clone();
910                                let mut prior_result_lock = prior_result_clone.lock().await;
911                                *prior_result_lock = run_result.0;
912
913                                let command_clone = command.clone();
914                                let mut command_lock = command_clone.lock().await;
915                                *command_lock = run_result.1;
916
917                                let custom_system_message_clone = custom_system_message.clone();
918                            }
919                            return Ok(());
920                        }
921                        Err(e) => {
922                            println!("{:#?}", e);
923                        }
924                    }
925                }
926            }
927        })
928        .collect();
929
930    // TODO: reload this shit into your head.
931    let match_arms_no_return: Vec<_> = generated_struct_names
932        .iter()
933        .map(|struct_name| {
934            let response_name = format_ident!("{}", struct_name);
935
936            quote! {
937                Ok(FunctionResponse::#response_name(response)) => {
938                    let result = response.execute_command();
939                    let run_result = result.run(execution_strategy_clone, None, logger_clone, custom_system_message_clone).await;
940                    match run_result {
941                        Ok(run_result) => {
942                            {
943                                // Feels like this is a dead lock.
944                                // Update: isn't.
945                                let mut prior_result_lock = prior_result_clone.lock().await;
946                                *prior_result_lock = run_result.0;
947
948                                let mut command_lock = command_clone.lock().await;
949                                *command_lock = run_result.1;
950                            }
951                        }
952                        Err(e) => {
953                            println!("{:#?}", e);
954                        }
955                    }
956                }
957            }
958        })
959        .collect();
960
961    #[cfg(feature = "function_filtering")]
962    let filtering_delegate = quote! {
963        openai_func_enums::get_tools_limited(CommandsGPT::function_jsons_with_required_under_limit, allowed_functions, required_functions)?
964    };
965
966    #[cfg(not(feature = "function_filtering"))]
967    let filtering_delegate = quote! {
968        openai_func_enums::get_tools_limited(CommandsGPT::function_jsons_allowed_with_required, allowed_functions, required_functions)?
969    };
970
971    let commands_gpt_impl = quote! {
972        #[derive(Clone, Debug, serde::Deserialize)]
973        pub enum FunctionResponse {
974            #(
975                #generated_struct_names(#generated_struct_names),
976            )*
977        }
978
979        impl CommandsGPT {
980            #all_function_calls
981
982            fn to_snake_case(camel_case: &str) -> String {
983                let mut snake_case = String::new();
984                for (i, ch) in camel_case.char_indices() {
985                    if i > 0 && ch.is_uppercase() {
986                        snake_case.push('_');
987                    }
988                    snake_case.extend(ch.to_lowercase());
989                }
990                snake_case
991            }
992
993            pub fn parse_gpt_function_call(function_call: &FunctionCall) -> Result<FunctionResponse, Box<dyn std::error::Error + Send + Sync + 'static>> {
994                match function_call.name.as_str() {
995                    #(
996                    #struct_names => {
997                        match serde_json::from_str::<#generated_struct_names>(&function_call.arguments) {
998                            Ok(arguments) => Ok(FunctionResponse::#generated_struct_names(arguments)),
999                            Err(_) => {
1000                                let snake_case_args = function_call.arguments
1001                                    .as_str()
1002                                    .split(',')
1003                                    .map(|s| {
1004                                        let mut parts = s.split(':');
1005                                        match (parts.next(), parts.next()) {
1006                                            (Some(key), Some(value)) => {
1007                                                let key_trimmed = key.trim_matches(|c: char| !c.is_alphanumeric()).trim();
1008                                                let key_snake_case = Self::to_snake_case(key_trimmed);
1009                                                format!("\"{}\":{}", key_snake_case, value)
1010                                            },
1011                                            _ => s.to_string()
1012                                        }
1013                                    })
1014                                    .collect::<Vec<String>>()
1015                                    .join(",");
1016
1017                                let snake_case_args = format!("{{{}", snake_case_args);
1018
1019                                match serde_json::from_str::<#generated_struct_names>(&snake_case_args) {
1020                                    Ok(arguments) => {
1021                                        Ok(FunctionResponse::#generated_struct_names(arguments))
1022                                    }
1023                                    Err(e) => {
1024                                        Err(Box::new(openai_func_enums::CommandError::new("There was an issue deserializing function arguments.")))
1025                                    }
1026                                }
1027                            }
1028                        }
1029                    },
1030                    )*
1031                    _ => {
1032                        println!("{:#?}", function_call);
1033                        Err(Box::new(openai_func_enums::CommandError::new("Unknown function name")))
1034                    }
1035                }
1036            }
1037
1038            fn calculate_token_count(text: &str) -> usize {
1039                let bpe = tiktoken_rs::cl100k_base().unwrap();
1040                bpe.encode_ordinary(&text).len()
1041            }
1042
1043            #[allow(clippy::too_many_arguments)]
1044            pub async fn run(
1045                prompt: &String,
1046                model_name: &str,
1047                request_token_limit: Option<usize>,
1048                max_response_tokens: Option<u16>,
1049                custom_system_message: Option<(String, usize)>,
1050                prior_result: std::sync::Arc<tokio::sync::Mutex<Option<String>>>,
1051                execution_strategy: ToolCallExecutionStrategy,
1052                command: std::sync::Arc<tokio::sync::Mutex<Option<Vec<String>>>>,
1053                allowed_functions: Option<Vec<String>>,
1054                required_functions: Option<Vec<String>>,
1055                logger: std::sync::Arc<openai_func_enums::Logger>,
1056            ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
1057
1058                let tool_args: (Vec<async_openai::types::ChatCompletionTool>, usize) = if let Some(allowed_functions) = allowed_functions {
1059                    if !allowed_functions.is_empty() {
1060                        #filtering_delegate
1061                    } else {
1062                        get_tool_chat_completion_args(CommandsGPT::all_function_jsons)?
1063                    }
1064
1065                } else {
1066                    get_tool_chat_completion_args(CommandsGPT::all_function_jsons)?
1067                };
1068
1069                let custom_system_message_clone = custom_system_message.clone();
1070                let (this_system_message, system_message_tokens) = match custom_system_message_clone {
1071                    Some((message, tokens)) => {
1072                        (message.clone(), tokens)
1073                    }
1074                    None => (String::from("You are a helpful function calling bot."), 7)
1075                };
1076
1077                let word_count = prompt.split_whitespace().count();
1078
1079                let request_token_total = tool_args.1 + system_message_tokens + if word_count < 200 {
1080                    ((word_count as f64 / 0.75).round() as usize)
1081                } else {
1082                    Self::calculate_token_count(prompt.as_str())
1083                };
1084
1085                if request_token_total > request_token_limit.unwrap_or(FUNC_ENUMS_MAX_REQUEST_TOKENS)  {
1086                    return Err(Box::new(openai_func_enums::CommandError::new("Request token count is too high")));
1087                }
1088
1089                let this_system_message_clone = this_system_message.clone();
1090
1091                let request = CreateChatCompletionRequestArgs::default()
1092                    .max_tokens(max_response_tokens.unwrap_or(FUNC_ENUMS_MAX_RESPONSE_TOKENS))
1093                    .model(model_name)
1094                    .temperature(0.0)
1095                    .messages([ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessageArgs::default()
1096                        .content(this_system_message_clone)
1097                        .build()?),
1098                    ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessageArgs::default()
1099                        .content(prompt.to_string())
1100                        .build()?)])
1101                    .tools(tool_args.0)
1102                    .tool_choice("auto")
1103                    .build()?;
1104
1105                let client = Client::new();
1106                let response_message = client
1107                    .chat()
1108                    .create(request)
1109                    .await?
1110                    .choices
1111                    .get(0)
1112                    .unwrap()
1113                    .message
1114                    .clone();
1115
1116                if let Some(tool_calls) = response_message.tool_calls {
1117                    if tool_calls.len() == 1 {
1118                        let execution_strategy_clone = execution_strategy.clone();
1119                        let custom_system_message_clone = custom_system_message.clone();
1120
1121                        match Self::parse_gpt_function_call(&tool_calls.first().unwrap().function) {
1122                            #(#match_arms,)*
1123                            Err(e) => {
1124                                println!("{:#?}", e);
1125                                return Err(Box::new(openai_func_enums::CommandError::new("Error running GPT command")));
1126                            }
1127                        };
1128                    } else {
1129                        match execution_strategy {
1130                            ToolCallExecutionStrategy::Async => {
1131                                let mut tasks = Vec::new();
1132
1133                                let custom_system_message_clone = custom_system_message.clone();
1134                                for tool_call in tool_calls.iter() {
1135                                    match tool_call.r#type {
1136                                        ChatCompletionToolType::Function => {
1137                                            let function = tool_call.function.clone();
1138                                            let prior_result_clone = prior_result.clone();
1139                                            let command_clone = command.clone();
1140                                            let execution_strategy_clone = execution_strategy.clone();
1141                                            let logger_clone = logger.clone();
1142                                            let custom_system_message_clone = custom_system_message.clone();
1143
1144                                            let task = tokio::spawn( async move {
1145                                                match Self::parse_gpt_function_call(&function) {
1146                                                    #(#match_arms_no_return,)*
1147                                                    Err(e) => {
1148                                                        println!("{:#?}", e);
1149                                                    }
1150                                                }
1151                                            });
1152                                            tasks.push(task);
1153                                        },
1154                                    }
1155                                }
1156
1157                                for task in tasks {
1158                                    let _ = task.await;
1159                                }
1160                            },
1161                            ToolCallExecutionStrategy::Synchronous => {
1162                                for tool_call in tool_calls.iter() {
1163                                    match tool_call.r#type {
1164                                        ChatCompletionToolType::Function => {
1165                                            let prior_result_clone = prior_result.clone();
1166                                            let command_clone = command.clone();
1167                                            let execution_strategy_clone = execution_strategy.clone();
1168                                            let logger_clone = logger.clone();
1169                                            let custom_system_message_clone = custom_system_message.clone();
1170
1171                                            match Self::parse_gpt_function_call(&tool_call.function) {
1172                                                #(#match_arms_no_return,)*
1173                                                Err(e) => {
1174                                                    println!("{:#?}", e);
1175                                                }
1176                                            }
1177                                        },
1178                                    }
1179                                }
1180                            },
1181                            ToolCallExecutionStrategy::Parallel => {
1182                                let mut handles = Vec::new();
1183
1184                                for tool_call in tool_calls.iter() {
1185                                    match tool_call.r#type {
1186                                        ChatCompletionToolType::Function => {
1187                                            let function = tool_call.function.clone();
1188                                            let prior_result_clone = prior_result.clone();
1189                                            let command_clone = command.clone();
1190
1191                                            // TODO: Think through. There's a lot of overhead to
1192                                            // make os threads this way. For now assume that if
1193                                            // strategy is set to "Parallel" that we only want to
1194                                            // put the intially returned tool calls on threads, and
1195                                            // if they themselves contain something multi-step we
1196                                            // will run those as if they are io-bound. Potentially
1197                                            // makes sense to support letting variants get
1198                                            // decorated with a execution strategy preference like
1199                                            // "this is io bound" or "this is cpu bound".
1200                                            // This will rarely matter.
1201                                            let execution_strategy_clone = ToolCallExecutionStrategy::Async;
1202                                            let logger_clone = logger.clone();
1203                                            let custom_system_message_clone = custom_system_message.clone();
1204
1205                                            let handle = std::thread::spawn(move || {
1206                                                let rt = tokio::runtime::Runtime::new().unwrap();
1207                                                rt.block_on(async {
1208                                                    match Self::parse_gpt_function_call(&function) {
1209                                                        #(#match_arms_no_return,)*
1210                                                        Err(e) => {
1211                                                            println!("{:#?}", e);
1212                                                        }
1213                                                    }
1214
1215                                                })
1216                                            });
1217                                            handles.push(handle);
1218                                        },
1219                                    }
1220                                }
1221
1222                                for handle in handles {
1223                                    let _ = handle.join();
1224                                }
1225                            },
1226                        }
1227                    }
1228                    Ok(())
1229                } else {
1230                    return Ok(());
1231                }
1232            }
1233        }
1234    };
1235
1236    let embedding_imports = quote! {
1237
1238        #[cfg(any(
1239            feature = "compile_embeddings_all",
1240            feature = "compile_embeddings_update",
1241            feature = "function_filtering"
1242        ))]
1243        use openai_func_enums::FuncEnumsError;
1244
1245        pub const FUNC_ENUMS_EMBED_PATH: &str = #embed_path;
1246
1247        pub const FUNC_ENUMS_EMBED_MODEL: &str = #embed_model;
1248    };
1249
1250    let gen = quote! {
1251        pub const FUNC_ENUMS_MAX_RESPONSE_TOKENS: u16 = #max_response_tokens;
1252        pub const FUNC_ENUMS_MAX_REQUEST_TOKENS: usize = #max_request_tokens;
1253        pub const FUNC_ENUMS_MAX_FUNC_TOKENS: u16 = #max_func_tokens;
1254        pub const FUNC_ENUMS_MAX_SINGLE_ARG_TOKENS: u16 = #max_single_arg_tokens;
1255
1256        use serde::Deserialize;
1257        use serde_json::{json, Value};
1258
1259        use openai_func_enums::{
1260            generate_value_arg_info, get_tool_chat_completion_args,
1261            ArchivedFuncEmbedding,
1262        };
1263
1264        use rkyv::{archived_root, Archived};
1265        use rkyv::vec::ArchivedVec;
1266
1267        use async_trait::async_trait;
1268        use async_openai::{
1269            types::{
1270                ChatCompletionFunctionCall, ChatCompletionNamedToolChoice, ChatCompletionRequestMessage,
1271                ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs,
1272                ChatCompletionToolChoiceOption, ChatCompletionToolType, CreateChatCompletionRequestArgs,
1273                CreateEmbeddingRequestArgs, FunctionCall, FunctionName,
1274            },
1275            Client,
1276        };
1277        use tokio::sync::{mpsc};
1278
1279        #embedding_imports
1280
1281        #(#generated_structs)*
1282
1283        #(#json_generator_functions)*
1284
1285        #(#generated_clap_gpt_enum)*
1286
1287        #commands_gpt_impl
1288    };
1289
1290    gen.into()
1291}
1292
1293fn get_comment_from_attr(attr: &Attribute) -> Option<String> {
1294    if attr.path().is_ident("doc") {
1295        if let Meta::NameValue(meta) = &attr.meta {
1296            if meta.path.is_ident("doc") {
1297                let value = meta.value.clone();
1298                match value {
1299                    Expr::Lit(value) => match value.lit {
1300                        Lit::Str(value) => {
1301                            return Some(value.value());
1302                        }
1303                        _ => {
1304                            return None;
1305                        }
1306                    },
1307                    _ => {
1308                        return None;
1309                    }
1310                }
1311            }
1312        }
1313    }
1314    None
1315}
1316
1317/// Calculate the token count of a given text string using the Byte Pair Encoding (BPE) tokenizer.
1318///
1319/// This function utilizes the BPE tokenizer from the `cl100k_base` library. It tokenizes the given text and
1320/// returns the count of the tokens. This can be used to measure how many tokens a particular text string
1321/// consumes, which is often relevant in the context of natural language processing tasks.
1322///
1323/// # Arguments
1324///
1325/// * `text` - A string slice that holds the text to tokenize.
1326///
1327/// # Returns
1328///
1329/// * `usize` - The count of tokens in the text.
1330///
1331/// # Example
1332///
1333/// ```
1334/// let text = "Hello, world!";
1335/// let token_count = calculate_token_count(text);
1336/// println!("Token count: {}", token_count);
1337/// ```
1338///
1339/// Note: This function can fail if the `cl100k_base` tokenizer is not properly initialized or the text cannot be tokenized.
1340fn calculate_token_count(text: &str) -> usize {
1341    let bpe = tiktoken_rs::cl100k_base().unwrap();
1342    bpe.encode_ordinary(text).len()
1343}
1344
1345/// Convert a camelCase or PascalCase string into a snake_case string.
1346///
1347/// This function iterates over each character in the input string. If the character is an uppercase letter, it adds an
1348/// underscore before it (except if it's the first character) and then appends the lowercase version of the character
1349/// to the output string.
1350///
1351/// # Arguments
1352///
1353/// * `camel_case` - A string slice that holds the camelCase or PascalCase string to convert.
1354///
1355/// # Returns
1356///
1357/// * `String` - The converted snake_case string.
1358///
1359/// # Example
1360///
1361/// ```
1362/// let camel_case = "HelloWorld";
1363/// let snake_case = to_snake_case(camel_case);
1364/// assert_eq!(snake_case, "hello_world");
1365/// ```
1366fn to_snake_case(camel_case: &str) -> String {
1367    let mut snake_case = String::new();
1368    for (i, ch) in camel_case.char_indices() {
1369        if i > 0 && ch.is_uppercase() {
1370            snake_case.push('_');
1371        }
1372        snake_case.extend(ch.to_lowercase());
1373    }
1374    snake_case
1375}
1376
1377#[cfg(any(
1378    feature = "compile_embeddings_all",
1379    feature = "compile_embeddings_update"
1380))]
1381async fn get_single_embedding(
1382    text: &String,
1383    model: &String,
1384) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
1385    let client = Client::new();
1386    let request = CreateEmbeddingRequestArgs::default()
1387        .model(model)
1388        .input([text])
1389        .build()?;
1390
1391    let response = client.embeddings().create(request).await?;
1392
1393    match response.data.first() {
1394        Some(data) => {
1395            return Ok(data.embedding.to_owned());
1396        }
1397        None => {
1398            let embedding_error = openai_func_embeddings::FuncEnumsError::OpenAIError(
1399                String::from("Didn't get embedding vector back."),
1400            );
1401            let boxed_error: Box<dyn std::error::Error + Send + Sync> = Box::new(embedding_error);
1402            return Err(boxed_error);
1403        }
1404    }
1405}