openai_proc_macros/
lib.rs

1use convert_case::{Case, Casing};
2use openai_bootstrap::{authorization, ApiResponse, BASE_URL};
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use reqwest::blocking::Client;
6use serde::Deserialize;
7
8#[derive(Deserialize)]
9struct Models {
10    data: Vec<Model>,
11}
12
13#[derive(Deserialize)]
14struct Model {
15    id: String,
16}
17
18#[proc_macro]
19pub fn generate_model_id_enum(_input: TokenStream) -> TokenStream {
20    let client = Client::new();
21    let request = client.get(BASE_URL.to_owned() + "models");
22    let api_response: ApiResponse<Models> = authorization!(request)
23        .send()
24        .unwrap_or_else(|error| panic!("{error}"))
25        .json()
26        .unwrap();
27
28    match api_response {
29        ApiResponse::Ok(models) => {
30            let mut model_id_idents = Vec::new();
31            let mut model_ids = Vec::new();
32            let mut model_indexes = Vec::new();
33            let mut index: u32 = 0;
34
35            for model in models.data {
36                if model.id.contains(':') || model.id.contains("deprecated") {
37                    continue;
38                }
39
40                model_id_idents.push(format_ident!(
41                    "{}",
42                    model.id.to_case(Case::Pascal).replace('.', "_")
43                ));
44                model_ids.push(model.id);
45                model_indexes.push(index);
46
47                index += 1;
48            }
49
50            quote! {
51                use serde::{ Serialize, de };
52
53                #[derive(Debug, PartialEq, Default, Clone)]
54                pub enum ModelID {
55                    #[default]
56                    #(#model_id_idents),*,
57                    Custom(String),
58                }
59
60                impl Serialize for ModelID {
61                    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
62                    where
63                        S: serde::Serializer,
64                    {
65                        match *self {
66                            #( ModelID::#model_id_idents => serializer.serialize_unit_variant("ModelID", #model_indexes, #model_ids) ),*,
67                            ModelID::Custom(ref string) => serializer.serialize_str(string),
68                        }
69                    }
70                }
71
72                impl<'de> Deserialize<'de> for ModelID {
73                    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
74                    where
75                        D: serde::Deserializer<'de>,
76                    {
77                        struct ModelIDVisitor;
78
79                        impl<'de> de::Visitor<'de> for ModelIDVisitor {
80                            type Value = ModelID;
81
82                            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
83                                write!(formatter, "one of {}", "".to_owned() + #( " `" + #model_ids + "`" )+*)
84                            }
85
86                            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
87                            where
88                                E: de::Error,
89                            {
90                                match v {
91                                    #( #model_ids => Ok(ModelID::#model_id_idents) ),*,
92                                    _ => Ok(ModelID::Custom(v.to_string())),
93                                }
94                            }
95                        }
96
97                        deserializer.deserialize_identifier(ModelIDVisitor)
98                    }
99                }
100            }.into()
101        }
102        ApiResponse::Err { error } => panic!("{error}"),
103    }
104}