1use proc_macro::{TokenStream, TokenTree};
2use quote::{format_ident, quote, ToTokens};
3use syn::{parse_macro_input, Attribute, Data, DeriveInput, Expr, Ident, Lit, Meta};
4
5#[cfg(any(
6 feature = "compile_embeddings_all",
7 feature = "compile_embeddings_update"
8))]
9use async_openai::{types::CreateEmbeddingRequestArgs, Client};
10
11#[cfg(any(
12 feature = "compile_embeddings_all",
13 feature = "compile_embeddings_update"
14))]
15use std::io::Write;
16
17#[proc_macro_attribute]
39pub fn arg_description(_args: TokenStream, input: TokenStream) -> TokenStream {
40 input
41}
42
43#[proc_macro_derive(EnumDescriptor, attributes(arg_description))]
79pub fn enum_descriptor_derive(input: TokenStream) -> TokenStream {
80 let DeriveInput { ident, attrs, .. } = parse_macro_input!(input as DeriveInput);
81
82 let name_str = ident.to_string();
83 let name_token_count = calculate_token_count(&name_str);
84
85 let mut description: &'static str = "";
86 let mut desc_tokens = 0_usize;
87
88 for attr in &attrs {
89 if attr.path().is_ident("arg_description") {
90 let _result = attr.parse_nested_meta(|meta| {
91 let content = meta.input;
92
93 if !content.is_empty() {
94 if meta.path.is_ident("description") {
95 let value = meta.value()?;
96 if let Ok(Lit::Str(value)) = value.parse() {
97 description = Box::leak(value.value().into_boxed_str());
98 desc_tokens = calculate_token_count(description);
99 }
100 }
101 return Ok(());
102 }
103
104 Err(meta.error("unrecognized my_attribute"))
105 });
106
107 if _result.is_err() {
108 println!("Error parsing attribute: {:#?}", _result);
109 }
110 }
111 }
112
113 let expanded = quote! {
114 impl openai_func_enums::EnumDescriptor for #ident {
115 fn name_with_token_count() -> &'static (&'static str, usize) {
116 static NAME_DATA: (&'static str, usize) = (stringify!(#ident), #name_token_count);
117 &NAME_DATA
118 }
119
120 fn arg_description_with_token_count() -> &'static (&'static str, usize) {
121 static DESC_DATA: (&'static str, usize) = (#description, #desc_tokens);
122 &DESC_DATA
123 }
124 }
125 };
126
127 TokenStream::from(expanded)
128}
129
130#[proc_macro_derive(VariantDescriptors)]
176pub fn variant_descriptors_derive(input: TokenStream) -> TokenStream {
177 let ast = parse_macro_input!(input as DeriveInput);
178
179 let enum_name = &ast.ident;
180
181 let variants = if let syn::Data::Enum(ref e) = ast.data {
182 e.variants
183 .iter()
184 .map(|v| {
185 let variant_name = &v.ident;
186 let token_count = calculate_token_count(&variant_name.to_string());
187
188 (variant_name, token_count)
189 })
190 .collect::<Vec<_>>()
191 } else {
192 panic!("VariantDescriptors can only be used with enums");
193 };
194
195 let variant_name_with_token_count: Vec<_> = variants
196 .iter()
197 .map(|(variant_name, token_count)| {
198 quote! { Self::#variant_name => (stringify!(#variant_name), #token_count) }
199 })
200 .collect();
201
202 let variant_names: Vec<_> = variants
203 .iter()
204 .map(|(variant_name, _)| quote! { stringify!(#variant_name) })
205 .collect();
206
207 let variant_name_additional_tokens = variant_names.len() * 3;
208
209 let token_counts: Vec<_> = variants
210 .iter()
211 .map(|(_, token_count)| quote! { #token_count })
212 .collect();
213
214 let total_token_count = variants
215 .iter()
216 .map(|(_, token_count)| *token_count)
217 .sum::<usize>();
218
219 let expanded = quote! {
220 impl VariantDescriptors for #enum_name {
221 fn variant_names_with_token_counts() -> &'static (&'static [&'static str], &'static [usize], usize, usize) {
222 static VARIANT_DATA: (&'static [&'static str], &'static [usize], usize, usize) = (
223 &[#(#variant_names),*],
224 &[#(#token_counts),*],
225 #total_token_count,
226 #variant_name_additional_tokens
227 );
228 &VARIANT_DATA
229 }
230
231 fn variant_name_with_token_count(&self) -> (&'static str, usize) {
232 match self {
233 #(#variant_name_with_token_count,)*
234 }
235 }
236 }
237 };
238
239 TokenStream::from(expanded)
240}
241
242#[proc_macro]
306pub fn generate_enum_info(input: TokenStream) -> TokenStream {
307 let enum_ident = parse_macro_input!(input as Ident);
308
309 let output = quote! {
310 {
311 let ARG_DESC_AND_TOKENS: &'static (&'static str, usize) = <#enum_ident as openai_func_enums::EnumDescriptor>::arg_description_with_token_count();
312 let ENUM_NAME_AND_TOKENS: &'static (&'static str, usize) = <#enum_ident as openai_func_enums::EnumDescriptor>::name_with_token_count();
313 let ENUM_VARIANTS_INFO: &'static (&'static [&'static str], &'static [usize], usize, usize) = <#enum_ident as openai_func_enums::VariantDescriptors>::variant_names_with_token_counts();
314
315 let token_count = 6 + ARG_DESC_AND_TOKENS.1 + 6 + ENUM_NAME_AND_TOKENS.1 + 7 + 7 + ENUM_VARIANTS_INFO.2 + ENUM_VARIANTS_INFO.3;
316
317 let json_enum: Value = serde_json::json!({
318 ENUM_NAME_AND_TOKENS.0: {
319 "type": "string",
320 "enum": ENUM_VARIANTS_INFO.0.iter().map(|name| *name).collect::<Vec<_>>(),
321 "description": ARG_DESC_AND_TOKENS.0,
322 }
323 });
324
325 (json_enum, token_count)
326 }
327 };
328
329 output.into()
330}
331
332#[proc_macro]
333pub fn generate_value_arg_info(input: TokenStream) -> TokenStream {
334 let mut type_and_name_values = Vec::new();
335
336 let tokens = input.into_iter().collect::<Vec<TokenTree>>();
337 for token in tokens {
338 if let TokenTree::Ident(ident) = &token {
339 type_and_name_values.push(ident.to_string());
340 }
341 }
342
343 let output = if type_and_name_values.len() == 2 {
344 let name = &type_and_name_values[1];
345 let type_name = &type_and_name_values[0];
346
347 let name_tokens = calculate_token_count(name);
348 let type_name_tokens = calculate_token_count(type_name);
349 let mut total_tokens = name_tokens + type_name_tokens;
350
351 let json_string = if type_name == "array" {
352 total_tokens += 22;
353 format!(
354 r#"{{"{}": {{"type": "array", "items": {{"type": "string"}}}}}}"#,
355 name
356 )
357 } else {
358 total_tokens += 11;
359 format!(r#"{{"{}": {{"type": "{}"}}}}"#, name, type_name)
360 };
361
362 quote! {
363 {
364 static JSON_STR: &str = #json_string;
365 let json_enum: serde_json::Value = serde_json::from_str(JSON_STR).unwrap();
366 (json_enum, #total_tokens)
367 }
368 }
369 } else {
370 quote! {}
371 };
372
373 output.into()
374}
375
376#[deprecated(since = "0.3.0", note = "Use a doc string --> '///'.")]
395#[proc_macro_attribute]
396pub fn func_description(_args: TokenStream, input: TokenStream) -> TokenStream {
397 input
398}
399
400#[proc_macro_derive(ToolSet)]
436pub fn derive_subcommand_gpt(input: TokenStream) -> TokenStream {
437 let input = parse_macro_input!(input as DeriveInput);
438
439 let name = input.ident;
440
441 let data = match input.data {
442 Data::Enum(data) => data,
443 _ => panic!("ToolSet can only be implemented for enums"),
444 };
445
446 let mut generated_structs = Vec::new();
447 let mut json_generator_functions = Vec::new();
448
449 let mut generated_clap_gpt_enum: Vec<proc_macro2::TokenStream> = Vec::new();
450 let mut generated_struct_names = Vec::new();
451
452 #[cfg(any(
453 feature = "compile_embeddings_all",
454 feature = "compile_embeddings_update"
455 ))]
456 let rt = tokio::runtime::Runtime::new().unwrap();
457
458 #[cfg(any(
459 feature = "compile_embeddings_all",
460 feature = "compile_embeddings_update",
461 feature = "function_filtering"
462 ))]
463 let embed_path = std::env::var("FUNC_ENUMS_EMBED_PATH")
464 .expect("Functionality for embeddings requires environment variable FUNC_ENUMS_EMBED_PATH to be set.");
465
466 #[cfg(not(any(
467 feature = "compile_embeddings_all",
468 feature = "compile_embeddings_update",
469 feature = "function_filtering"
470 )))]
471 let embed_path = "";
472
473 #[cfg(any(
474 feature = "compile_embeddings_all",
475 feature = "compile_embeddings_update",
476 feature = "function_filtering"
477 ))]
478 let embed_model = std::env::var("FUNC_ENUMS_EMBED_MODEL")
479 .expect("Functionality for embeddings requires environment variable FUNC_ENUMS_EMBED_MODEL to be set.");
480
481 #[cfg(not(any(
485 feature = "compile_embeddings_all",
486 feature = "compile_embeddings_update",
487 feature = "function_filtering"
488 )))]
489 let embed_model = "";
490
491 let max_response_tokens: u16 = std::env::var("FUNC_ENUMS_MAX_RESPONSE_TOKENS")
492 .expect("Environment variable FUNC_ENUMS_MAX_RESPONSE_TOKENS is required. See build.rs files in the examples.")
493 .parse()
494 .expect("Failed to parse u16 value from FUNC_ENUMS_MAX_RESPONSE_TOKENS");
495
496 let max_request_tokens: usize = std::env::var("FUNC_ENUMS_MAX_REQUEST_TOKENS")
497 .expect("Environment variable FUNC_ENUMS_MAX_REQUEST_TOKENS is required. See build.rs files in the examples.")
498 .parse()
499 .expect("Failed to parse usize value from FUNC_ENUMS_MAX_REQUEST_TOKENS");
500
501 let max_func_tokens: u16 =std::env::var("FUNC_ENUMS_MAX_FUNC_TOKENS")
502 .expect("Environment variable FUNC_ENUMS_MAX_FUNC_TOKENS is required. See build.rs files in the examples.")
503 .parse()
504 .expect("Failed to parse u16 value from FUNC_ENUMS_MAX_FUNC_TOKENS");
505
506 let max_single_arg_tokens: u16 = std::env::var("FUNC_ENUMS_MAX_SINGLE_ARG_TOKENS")
507 .expect("Environment variable FUNC_ENUMS_MAX_SINGLE_ARG_TOKENS is required. See build.rs files in the examples.")
508 .parse()
509 .expect("Failed to parse u16 value from FUNC_ENUMS_MAX_SINGLE_ARG_TOKENS");
510
511 #[cfg(any(
512 feature = "compile_embeddings_all",
513 feature = "compile_embeddings_update"
514 ))]
515 let mut embeddings: Vec<openai_func_embeddings::FuncEmbedding> = Vec::new();
516
517 #[cfg(feature = "compile_embeddings_update")]
518 {
519 if Path::new(&embed_path).exists() {
520 let mut file = std::fs::File::open(&embed_path).unwrap();
521 let mut bytes = Vec::new();
522 file.read_to_end(&mut bytes).unwrap();
523 let archived_data = rkyv::check_archived_root::<Vec<FuncEmbedding>>(&bytes).unwrap();
524 embeddings = archived_data.deserialize(&mut rkyv::Infallible).unwrap();
525 }
526 }
527
528 let mut has_gpt_variant = false;
529 let gpt_variant_name = "GPT";
531 for variant in data.variants.iter() {
532 let variant_name = &variant.ident;
533 if variant_name.to_string() == gpt_variant_name {
534 has_gpt_variant = true;
535 }
536
537 let struct_name = format_ident!("{}", variant_name);
538 let struct_name_tokens = calculate_token_count(struct_name.to_string().as_str());
539 generated_struct_names.push(struct_name.clone());
540 let mut variant_desc = String::new();
541 let mut variant_desc_tokens = 0_usize;
542
543 for variant_attrs in &variant.attrs {
544 let description = get_comment_from_attr(variant_attrs);
545 if let Some(description) = description {
546 variant_desc = description;
547 variant_desc_tokens = calculate_token_count(variant_desc.as_str());
548
549 #[cfg(feature = "compile_embeddings_all")]
551 {
552 println!("Writing embeddings");
553 let mut name_and_desc = variant_name.to_string();
554 name_and_desc.push(':');
555 name_and_desc.push_str(&variant_desc);
556
557 rt.block_on(async {
558 let embedding = get_single_embedding(&name_and_desc, &embed_model).await;
559 if let Ok(embedding) = embedding {
560 let data = openai_func_embeddings::FuncEmbedding {
561 name: variant_name.to_string(),
562 description: variant_desc.clone(),
563 embedding,
564 };
565
566 embeddings.push(data);
567 }
568 });
569 }
570
571 #[cfg(feature = "compile_embeddings_update")]
572 {
573 let mut name_and_desc = variant_name.to_string();
574 name_and_desc.push(':');
575 name_and_desc.push_str(&variant_desc);
576
577 rt.block_on(async {
578 let mut existing = embeddings.iter().find(|x| x.name == name);
579
580 if let Some(existing) = existing {
581 if existing.description != variant_desc {
582 let embedding =
583 get_single_embedding(&name_and_desc, &embed_model).await;
584
585 if let Ok(embedding) = embedding {
586 existing.description = variant_desc.clone();
587 existing.embedding = embedding;
588 }
589 }
590 } else {
591 let embedding =
592 get_single_embedding(&name_and_desc, &embed_model).await;
593 if let Ok(embedding) = embedding {
594 let data = FuncEmbedding {
595 name: variant_name.to_string(),
596 description: variant_desc.clone(),
597 embedding,
598 };
599
600 embeddings.push(data);
601 }
602 }
603 });
604 }
605 }
606 }
607
608 #[cfg(any(
609 feature = "compile_embeddings_all",
610 feature = "compile_embeddings_update"
611 ))]
612 {
613 let serialized_data = rkyv::to_bytes::<_, 256>(&embeddings).unwrap();
614 let mut file = std::fs::File::create(&embed_path).unwrap();
615 file.write_all(&serialized_data).unwrap();
616 }
617
618 let fields: Vec<_> = variant
619 .fields
620 .iter()
621 .map(|f| {
622 let field_name = if let Some(ident) = &f.ident {
625 format_ident!("{}", ident)
626 } else {
627 format_ident!("{}", to_snake_case(&f.ty.to_token_stream().to_string()))
628 };
629 let field_type = &f.ty;
630 quote! {
631 pub #field_name: #field_type,
632 }
633 })
634 .collect();
635
636 let execute_command_parameters: Vec<_> = variant
637 .fields
638 .iter()
639 .map(|field| {
640 let field_name = &field.ident;
641 quote! { #field_name: self.#field_name.clone() }
642 })
643 .collect();
644
645 let number_type = "number";
646 let number_ident = format_ident!("{}", number_type);
647 let integer_type = "integer";
648 let integer_ident = format_ident!("{}", integer_type);
649 let string_type = "string";
650 let string_ident = format_ident!("{}", string_type);
651 let array_type = "array";
652 let array_ident = format_ident!("{}", array_type);
653
654 let field_info: Vec<_> = variant
655 .fields
656 .iter()
657 .map(|f| {
658 let field_name = if let Some(ident) = &f.ident {
659 format_ident!("{}", ident)
660 } else {
661 format_ident!("{}", to_snake_case(&f.ty.to_token_stream().to_string()))
662 };
663 let field_type = &f.ty;
664
665 match field_type {
666 syn::Type::Path(typepath) if typepath.qself.is_none() => {
667 let type_ident = &typepath.path.segments.last().unwrap().ident;
668
669 match type_ident.to_string().as_str() {
670 "f32" | "f64" => {
671 return quote! {
672 generate_value_arg_info!(#number_ident, #field_name)
673 };
674 }
675 "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | "i8" | "i16"
676 | "i32" | "i64" | "i128" | "isize" => {
677 return quote! {
678 generate_value_arg_info!(#integer_ident, #field_name)
679 };
680 }
681 "String" | "&str" => {
682 return quote! {
683 generate_value_arg_info!(#string_ident, #field_name)
684 };
685 }
686 "Vec" => {
687 return quote! {
688 generate_value_arg_info!(#array_ident, #field_name)
689 };
690 }
691 _ => {
692 return quote! {
693 openai_func_enums::generate_enum_info!(#field_type)
694 };
695 }
696 }
697 }
698 syn::Type::Tuple(_) => {
699 println!("Field {} is of tuple type", field_name);
700 }
701 syn::Type::Array(_) => {
702 println!("Field {} is of array type", field_name);
703 return quote! {
704 generate_value_arg_info!(#array_ident, #field_name)
705 };
706 }
707 _ => {
708 println!("Field {} is of another type.", field_name);
709 }
710 }
711 quote! {}
712 })
713 .collect();
714
715 json_generator_functions.push(quote! {
716 impl #struct_name {
717 pub fn name() -> String {
718 stringify!(#struct_name).to_string()
719 }
720
721 pub fn to_function_call() -> ChatCompletionFunctionCall {
722 ChatCompletionFunctionCall::Function {
723 name: stringify!(#struct_name).to_string(),
724 }
725 }
726
727 pub fn to_tool_choice() -> ChatCompletionToolChoiceOption {
728 ChatCompletionToolChoiceOption::Named(ChatCompletionNamedToolChoice {
729 r#type: ChatCompletionToolType::Function,
730 function: FunctionName { name: stringify!(#struct_name).to_string() }
731 })
732 }
733
734 pub fn execute_command(&self) -> #name {
735 #name::#variant_name {
736 #(#execute_command_parameters),*
737 }
738 }
739
740 pub fn get_function_json() -> (serde_json::Value, usize) {
742 let mut parameters = serde_json::Map::new();
743 let mut total_tokens = 0;
744
745 for (arg_json, arg_tokens) in vec![#(#field_info),*] {
746 total_tokens += arg_tokens;
747 total_tokens += 3;
748
749 parameters.insert(
750 arg_json.as_object().unwrap().keys().next().unwrap().clone(),
751 arg_json
752 .as_object()
753 .unwrap()
754 .values()
755 .next()
756 .unwrap()
757 .clone(),
758 );
759 }
760
761 let function_json = serde_json::json!({
762 "name": stringify!(#struct_name),
763 "description": #variant_desc,
764 "parameters": {
765 "type": "object",
766 "properties": parameters,
767 "required": parameters.keys().collect::<Vec<_>>()
768 }
769 });
770
771 total_tokens += 43;
772 total_tokens += #struct_name_tokens;
773 total_tokens += #variant_desc_tokens;
774
775 (function_json, total_tokens)
776 }
777 }
778 });
779
780 generated_structs.push(quote! {
781 #[derive(Clone, serde::Deserialize, Debug)]
782 pub struct #struct_name {
783 #(#fields)*
784 }
785 });
786 }
787
788 if !has_gpt_variant {
789 panic!("Enums that derive ToolSet must define a variant called 'GPT'.")
790 }
791
792 let all_function_calls = quote! {
793 pub fn all_function_jsons() -> (serde_json::Value, usize) {
794 let results = vec![#(#generated_struct_names::get_function_json(),)*];
795 let combined_json = serde_json::Value::Array(results.iter().map(|(json, _)| json.clone()).collect());
796 let total_tokens = results.iter().map(|(_, tokens)| tokens).sum();
797 (combined_json, total_tokens)
798 }
799
800 pub fn function_jsons_under_limit(ranked_func_names: Vec<String>) -> (serde_json::Value, usize) {
801 let results = vec![#(#generated_struct_names::get_function_json(),)*];
802
803 let limit = #max_func_tokens as usize;
804 let (functions_to_present, total_tokens) = results.into_iter().fold(
805 (vec![], 0_usize),
806 |(mut acc, token_count), (json, tokens)| {
807 if token_count + tokens <= limit {
808 acc.push((json.clone(), tokens));
809 (acc, token_count + tokens)
810 } else {
811 (acc, token_count)
812 }
813 },
814 );
815
816 let combined_json = serde_json::Value::Array(functions_to_present.iter().map(|(json, _)| json.clone()).collect());
817 (combined_json, total_tokens)
818 }
819
820 pub fn function_jsons_allowed_with_required(
821 allowed_func_names: Vec<String>,
822 required_func_names: Option<Vec<String>>
823 ) -> (serde_json::Value, usize) {
824 let results = vec![#(#generated_struct_names::get_function_json(),)*];
825 let required_func_names = required_func_names.unwrap_or_default();
826
827 let updated_func_names = required_func_names.iter()
830 .chain(allowed_func_names.iter().filter(|name| !required_func_names.contains(name)))
831 .cloned()
832 .collect::<Vec<String>>();
833
834 let (functions_to_present, total_tokens) = updated_func_names.iter()
835 .filter_map(|name| results.iter().find(|(json, _)| json["name"] == *name))
836 .fold((vec![], 0_usize), |(mut acc, token_count), (json, tokens)| {
837 acc.push((json.clone(), tokens));
838 (acc, token_count + tokens)
839 });
840
841 let combined_json = serde_json::Value::Array(functions_to_present.iter().map(|(json, _)| json.clone()).collect());
842 (combined_json, total_tokens)
843 }
844
845 pub fn function_jsons_with_required_under_limit(
846 ranked_func_names: Vec<String>,
847 required_func_names: Option<Vec<String>>
848 ) -> (serde_json::Value, usize) {
849 let results = vec![#(#generated_struct_names::get_function_json(),)*];
850 let required_func_names = required_func_names.unwrap_or_default();
851
852 let updated_func_names = required_func_names.iter()
855 .chain(ranked_func_names.iter().filter(|name| !required_func_names.contains(name)))
856 .cloned()
857 .collect::<Vec<String>>();
858
859 let limit = #max_func_tokens as usize;
860
861 let (functions_to_present, total_tokens) = updated_func_names.iter()
862 .filter_map(|name| results.iter().find(|(json, _)| json["name"] == *name))
863 .fold((vec![], 0_usize), |(mut acc, token_count), (json, tokens)| {
864 if token_count + tokens <= limit {
865 acc.push((json.clone(), tokens));
866 (acc, token_count + tokens)
867 } else {
868 (acc, token_count)
869 }
870 });
871
872 let combined_json = serde_json::Value::Array(functions_to_present.iter().map(|(json, _)| json.clone()).collect());
873 (combined_json, total_tokens)
874 }
875 };
876
877 {
878 generated_clap_gpt_enum.push(quote! {
879 pub enum CommandsGPT {
880 GPT { a: String },
881 }
882 });
883 }
884
885 let struct_names: Vec<String> = generated_struct_names
886 .iter()
887 .map(|name| format!("{}", name))
888 .collect();
889
890 let match_arms: Vec<_> = generated_struct_names
891 .iter()
892 .map(|struct_name| {
893 let response_name = format_ident!("{}", struct_name);
894
895 quote! {
896 Ok(FunctionResponse::#response_name(response)) => {
897 let result = response.execute_command();
898 let command_clone = command.clone();
899 let custom_system_message_clone = custom_system_message.clone();
900 let logger_clone = logger.clone();
901 let command_lock = command_clone.lock().await;
902 let command_inner_value = command_lock.as_ref().cloned();
903 drop(command_lock);
904
905 let run_result = result.run(execution_strategy_clone, command_inner_value, logger_clone, custom_system_message_clone).await;
906 match run_result {
907 Ok(run_result) => {
908 {
909 let prior_result_clone = prior_result.clone();
910 let mut prior_result_lock = prior_result_clone.lock().await;
911 *prior_result_lock = run_result.0;
912
913 let command_clone = command.clone();
914 let mut command_lock = command_clone.lock().await;
915 *command_lock = run_result.1;
916
917 let custom_system_message_clone = custom_system_message.clone();
918 }
919 return Ok(());
920 }
921 Err(e) => {
922 println!("{:#?}", e);
923 }
924 }
925 }
926 }
927 })
928 .collect();
929
930 let match_arms_no_return: Vec<_> = generated_struct_names
932 .iter()
933 .map(|struct_name| {
934 let response_name = format_ident!("{}", struct_name);
935
936 quote! {
937 Ok(FunctionResponse::#response_name(response)) => {
938 let result = response.execute_command();
939 let run_result = result.run(execution_strategy_clone, None, logger_clone, custom_system_message_clone).await;
940 match run_result {
941 Ok(run_result) => {
942 {
943 let mut prior_result_lock = prior_result_clone.lock().await;
946 *prior_result_lock = run_result.0;
947
948 let mut command_lock = command_clone.lock().await;
949 *command_lock = run_result.1;
950 }
951 }
952 Err(e) => {
953 println!("{:#?}", e);
954 }
955 }
956 }
957 }
958 })
959 .collect();
960
961 #[cfg(feature = "function_filtering")]
962 let filtering_delegate = quote! {
963 openai_func_enums::get_tools_limited(CommandsGPT::function_jsons_with_required_under_limit, allowed_functions, required_functions)?
964 };
965
966 #[cfg(not(feature = "function_filtering"))]
967 let filtering_delegate = quote! {
968 openai_func_enums::get_tools_limited(CommandsGPT::function_jsons_allowed_with_required, allowed_functions, required_functions)?
969 };
970
971 let commands_gpt_impl = quote! {
972 #[derive(Clone, Debug, serde::Deserialize)]
973 pub enum FunctionResponse {
974 #(
975 #generated_struct_names(#generated_struct_names),
976 )*
977 }
978
979 impl CommandsGPT {
980 #all_function_calls
981
982 fn to_snake_case(camel_case: &str) -> String {
983 let mut snake_case = String::new();
984 for (i, ch) in camel_case.char_indices() {
985 if i > 0 && ch.is_uppercase() {
986 snake_case.push('_');
987 }
988 snake_case.extend(ch.to_lowercase());
989 }
990 snake_case
991 }
992
993 pub fn parse_gpt_function_call(function_call: &FunctionCall) -> Result<FunctionResponse, Box<dyn std::error::Error + Send + Sync + 'static>> {
994 match function_call.name.as_str() {
995 #(
996 #struct_names => {
997 match serde_json::from_str::<#generated_struct_names>(&function_call.arguments) {
998 Ok(arguments) => Ok(FunctionResponse::#generated_struct_names(arguments)),
999 Err(_) => {
1000 let snake_case_args = function_call.arguments
1001 .as_str()
1002 .split(',')
1003 .map(|s| {
1004 let mut parts = s.split(':');
1005 match (parts.next(), parts.next()) {
1006 (Some(key), Some(value)) => {
1007 let key_trimmed = key.trim_matches(|c: char| !c.is_alphanumeric()).trim();
1008 let key_snake_case = Self::to_snake_case(key_trimmed);
1009 format!("\"{}\":{}", key_snake_case, value)
1010 },
1011 _ => s.to_string()
1012 }
1013 })
1014 .collect::<Vec<String>>()
1015 .join(",");
1016
1017 let snake_case_args = format!("{{{}", snake_case_args);
1018
1019 match serde_json::from_str::<#generated_struct_names>(&snake_case_args) {
1020 Ok(arguments) => {
1021 Ok(FunctionResponse::#generated_struct_names(arguments))
1022 }
1023 Err(e) => {
1024 Err(Box::new(openai_func_enums::CommandError::new("There was an issue deserializing function arguments.")))
1025 }
1026 }
1027 }
1028 }
1029 },
1030 )*
1031 _ => {
1032 println!("{:#?}", function_call);
1033 Err(Box::new(openai_func_enums::CommandError::new("Unknown function name")))
1034 }
1035 }
1036 }
1037
1038 fn calculate_token_count(text: &str) -> usize {
1039 let bpe = tiktoken_rs::cl100k_base().unwrap();
1040 bpe.encode_ordinary(&text).len()
1041 }
1042
1043 #[allow(clippy::too_many_arguments)]
1044 pub async fn run(
1045 prompt: &String,
1046 model_name: &str,
1047 request_token_limit: Option<usize>,
1048 max_response_tokens: Option<u16>,
1049 custom_system_message: Option<(String, usize)>,
1050 prior_result: std::sync::Arc<tokio::sync::Mutex<Option<String>>>,
1051 execution_strategy: ToolCallExecutionStrategy,
1052 command: std::sync::Arc<tokio::sync::Mutex<Option<Vec<String>>>>,
1053 allowed_functions: Option<Vec<String>>,
1054 required_functions: Option<Vec<String>>,
1055 logger: std::sync::Arc<openai_func_enums::Logger>,
1056 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
1057
1058 let tool_args: (Vec<async_openai::types::ChatCompletionTool>, usize) = if let Some(allowed_functions) = allowed_functions {
1059 if !allowed_functions.is_empty() {
1060 #filtering_delegate
1061 } else {
1062 get_tool_chat_completion_args(CommandsGPT::all_function_jsons)?
1063 }
1064
1065 } else {
1066 get_tool_chat_completion_args(CommandsGPT::all_function_jsons)?
1067 };
1068
1069 let custom_system_message_clone = custom_system_message.clone();
1070 let (this_system_message, system_message_tokens) = match custom_system_message_clone {
1071 Some((message, tokens)) => {
1072 (message.clone(), tokens)
1073 }
1074 None => (String::from("You are a helpful function calling bot."), 7)
1075 };
1076
1077 let word_count = prompt.split_whitespace().count();
1078
1079 let request_token_total = tool_args.1 + system_message_tokens + if word_count < 200 {
1080 ((word_count as f64 / 0.75).round() as usize)
1081 } else {
1082 Self::calculate_token_count(prompt.as_str())
1083 };
1084
1085 if request_token_total > request_token_limit.unwrap_or(FUNC_ENUMS_MAX_REQUEST_TOKENS) {
1086 return Err(Box::new(openai_func_enums::CommandError::new("Request token count is too high")));
1087 }
1088
1089 let this_system_message_clone = this_system_message.clone();
1090
1091 let request = CreateChatCompletionRequestArgs::default()
1092 .max_tokens(max_response_tokens.unwrap_or(FUNC_ENUMS_MAX_RESPONSE_TOKENS))
1093 .model(model_name)
1094 .temperature(0.0)
1095 .messages([ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessageArgs::default()
1096 .content(this_system_message_clone)
1097 .build()?),
1098 ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessageArgs::default()
1099 .content(prompt.to_string())
1100 .build()?)])
1101 .tools(tool_args.0)
1102 .tool_choice("auto")
1103 .build()?;
1104
1105 let client = Client::new();
1106 let response_message = client
1107 .chat()
1108 .create(request)
1109 .await?
1110 .choices
1111 .get(0)
1112 .unwrap()
1113 .message
1114 .clone();
1115
1116 if let Some(tool_calls) = response_message.tool_calls {
1117 if tool_calls.len() == 1 {
1118 let execution_strategy_clone = execution_strategy.clone();
1119 let custom_system_message_clone = custom_system_message.clone();
1120
1121 match Self::parse_gpt_function_call(&tool_calls.first().unwrap().function) {
1122 #(#match_arms,)*
1123 Err(e) => {
1124 println!("{:#?}", e);
1125 return Err(Box::new(openai_func_enums::CommandError::new("Error running GPT command")));
1126 }
1127 };
1128 } else {
1129 match execution_strategy {
1130 ToolCallExecutionStrategy::Async => {
1131 let mut tasks = Vec::new();
1132
1133 let custom_system_message_clone = custom_system_message.clone();
1134 for tool_call in tool_calls.iter() {
1135 match tool_call.r#type {
1136 ChatCompletionToolType::Function => {
1137 let function = tool_call.function.clone();
1138 let prior_result_clone = prior_result.clone();
1139 let command_clone = command.clone();
1140 let execution_strategy_clone = execution_strategy.clone();
1141 let logger_clone = logger.clone();
1142 let custom_system_message_clone = custom_system_message.clone();
1143
1144 let task = tokio::spawn( async move {
1145 match Self::parse_gpt_function_call(&function) {
1146 #(#match_arms_no_return,)*
1147 Err(e) => {
1148 println!("{:#?}", e);
1149 }
1150 }
1151 });
1152 tasks.push(task);
1153 },
1154 }
1155 }
1156
1157 for task in tasks {
1158 let _ = task.await;
1159 }
1160 },
1161 ToolCallExecutionStrategy::Synchronous => {
1162 for tool_call in tool_calls.iter() {
1163 match tool_call.r#type {
1164 ChatCompletionToolType::Function => {
1165 let prior_result_clone = prior_result.clone();
1166 let command_clone = command.clone();
1167 let execution_strategy_clone = execution_strategy.clone();
1168 let logger_clone = logger.clone();
1169 let custom_system_message_clone = custom_system_message.clone();
1170
1171 match Self::parse_gpt_function_call(&tool_call.function) {
1172 #(#match_arms_no_return,)*
1173 Err(e) => {
1174 println!("{:#?}", e);
1175 }
1176 }
1177 },
1178 }
1179 }
1180 },
1181 ToolCallExecutionStrategy::Parallel => {
1182 let mut handles = Vec::new();
1183
1184 for tool_call in tool_calls.iter() {
1185 match tool_call.r#type {
1186 ChatCompletionToolType::Function => {
1187 let function = tool_call.function.clone();
1188 let prior_result_clone = prior_result.clone();
1189 let command_clone = command.clone();
1190
1191 let execution_strategy_clone = ToolCallExecutionStrategy::Async;
1202 let logger_clone = logger.clone();
1203 let custom_system_message_clone = custom_system_message.clone();
1204
1205 let handle = std::thread::spawn(move || {
1206 let rt = tokio::runtime::Runtime::new().unwrap();
1207 rt.block_on(async {
1208 match Self::parse_gpt_function_call(&function) {
1209 #(#match_arms_no_return,)*
1210 Err(e) => {
1211 println!("{:#?}", e);
1212 }
1213 }
1214
1215 })
1216 });
1217 handles.push(handle);
1218 },
1219 }
1220 }
1221
1222 for handle in handles {
1223 let _ = handle.join();
1224 }
1225 },
1226 }
1227 }
1228 Ok(())
1229 } else {
1230 return Ok(());
1231 }
1232 }
1233 }
1234 };
1235
1236 let embedding_imports = quote! {
1237
1238 #[cfg(any(
1239 feature = "compile_embeddings_all",
1240 feature = "compile_embeddings_update",
1241 feature = "function_filtering"
1242 ))]
1243 use openai_func_enums::FuncEnumsError;
1244
1245 pub const FUNC_ENUMS_EMBED_PATH: &str = #embed_path;
1246
1247 pub const FUNC_ENUMS_EMBED_MODEL: &str = #embed_model;
1248 };
1249
1250 let gen = quote! {
1251 pub const FUNC_ENUMS_MAX_RESPONSE_TOKENS: u16 = #max_response_tokens;
1252 pub const FUNC_ENUMS_MAX_REQUEST_TOKENS: usize = #max_request_tokens;
1253 pub const FUNC_ENUMS_MAX_FUNC_TOKENS: u16 = #max_func_tokens;
1254 pub const FUNC_ENUMS_MAX_SINGLE_ARG_TOKENS: u16 = #max_single_arg_tokens;
1255
1256 use serde::Deserialize;
1257 use serde_json::{json, Value};
1258
1259 use openai_func_enums::{
1260 generate_value_arg_info, get_tool_chat_completion_args,
1261 ArchivedFuncEmbedding,
1262 };
1263
1264 use rkyv::{archived_root, Archived};
1265 use rkyv::vec::ArchivedVec;
1266
1267 use async_trait::async_trait;
1268 use async_openai::{
1269 types::{
1270 ChatCompletionFunctionCall, ChatCompletionNamedToolChoice, ChatCompletionRequestMessage,
1271 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs,
1272 ChatCompletionToolChoiceOption, ChatCompletionToolType, CreateChatCompletionRequestArgs,
1273 CreateEmbeddingRequestArgs, FunctionCall, FunctionName,
1274 },
1275 Client,
1276 };
1277 use tokio::sync::{mpsc};
1278
1279 #embedding_imports
1280
1281 #(#generated_structs)*
1282
1283 #(#json_generator_functions)*
1284
1285 #(#generated_clap_gpt_enum)*
1286
1287 #commands_gpt_impl
1288 };
1289
1290 gen.into()
1291}
1292
1293fn get_comment_from_attr(attr: &Attribute) -> Option<String> {
1294 if attr.path().is_ident("doc") {
1295 if let Meta::NameValue(meta) = &attr.meta {
1296 if meta.path.is_ident("doc") {
1297 let value = meta.value.clone();
1298 match value {
1299 Expr::Lit(value) => match value.lit {
1300 Lit::Str(value) => {
1301 return Some(value.value());
1302 }
1303 _ => {
1304 return None;
1305 }
1306 },
1307 _ => {
1308 return None;
1309 }
1310 }
1311 }
1312 }
1313 }
1314 None
1315}
1316
1317fn calculate_token_count(text: &str) -> usize {
1341 let bpe = tiktoken_rs::cl100k_base().unwrap();
1342 bpe.encode_ordinary(text).len()
1343}
1344
1345fn to_snake_case(camel_case: &str) -> String {
1367 let mut snake_case = String::new();
1368 for (i, ch) in camel_case.char_indices() {
1369 if i > 0 && ch.is_uppercase() {
1370 snake_case.push('_');
1371 }
1372 snake_case.extend(ch.to_lowercase());
1373 }
1374 snake_case
1375}
1376
1377#[cfg(any(
1378 feature = "compile_embeddings_all",
1379 feature = "compile_embeddings_update"
1380))]
1381async fn get_single_embedding(
1382 text: &String,
1383 model: &String,
1384) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
1385 let client = Client::new();
1386 let request = CreateEmbeddingRequestArgs::default()
1387 .model(model)
1388 .input([text])
1389 .build()?;
1390
1391 let response = client.embeddings().create(request).await?;
1392
1393 match response.data.first() {
1394 Some(data) => {
1395 return Ok(data.embedding.to_owned());
1396 }
1397 None => {
1398 let embedding_error = openai_func_embeddings::FuncEnumsError::OpenAIError(
1399 String::from("Didn't get embedding vector back."),
1400 );
1401 let boxed_error: Box<dyn std::error::Error + Send + Sync> = Box::new(embedding_error);
1402 return Err(boxed_error);
1403 }
1404 }
1405}