openai_proc_macros/
lib.rs1use convert_case::{Case, Casing};
2use openai_bootstrap::{authorization, ApiResponse, BASE_URL};
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use reqwest::blocking::Client;
6use serde::Deserialize;
7
8#[derive(Deserialize)]
9struct Models {
10 data: Vec<Model>,
11}
12
13#[derive(Deserialize)]
14struct Model {
15 id: String,
16}
17
18#[proc_macro]
19pub fn generate_model_id_enum(_input: TokenStream) -> TokenStream {
20 let client = Client::new();
21 let request = client.get(BASE_URL.to_owned() + "models");
22 let api_response: ApiResponse<Models> = authorization!(request)
23 .send()
24 .unwrap_or_else(|error| panic!("{error}"))
25 .json()
26 .unwrap();
27
28 match api_response {
29 ApiResponse::Ok(models) => {
30 let mut model_id_idents = Vec::new();
31 let mut model_ids = Vec::new();
32 let mut model_indexes = Vec::new();
33 let mut index: u32 = 0;
34
35 for model in models.data {
36 if model.id.contains(':') || model.id.contains("deprecated") {
37 continue;
38 }
39
40 model_id_idents.push(format_ident!(
41 "{}",
42 model.id.to_case(Case::Pascal).replace('.', "_")
43 ));
44 model_ids.push(model.id);
45 model_indexes.push(index);
46
47 index += 1;
48 }
49
50 quote! {
51 use serde::{ Serialize, de };
52
53 #[derive(Debug, PartialEq, Default, Clone)]
54 pub enum ModelID {
55 #[default]
56 #(#model_id_idents),*,
57 Custom(String),
58 }
59
60 impl Serialize for ModelID {
61 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
62 where
63 S: serde::Serializer,
64 {
65 match *self {
66 #( ModelID::#model_id_idents => serializer.serialize_unit_variant("ModelID", #model_indexes, #model_ids) ),*,
67 ModelID::Custom(ref string) => serializer.serialize_str(string),
68 }
69 }
70 }
71
72 impl<'de> Deserialize<'de> for ModelID {
73 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
74 where
75 D: serde::Deserializer<'de>,
76 {
77 struct ModelIDVisitor;
78
79 impl<'de> de::Visitor<'de> for ModelIDVisitor {
80 type Value = ModelID;
81
82 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
83 write!(formatter, "one of {}", "".to_owned() + #( " `" + #model_ids + "`" )+*)
84 }
85
86 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
87 where
88 E: de::Error,
89 {
90 match v {
91 #( #model_ids => Ok(ModelID::#model_id_idents) ),*,
92 _ => Ok(ModelID::Custom(v.to_string())),
93 }
94 }
95 }
96
97 deserializer.deserialize_identifier(ModelIDVisitor)
98 }
99 }
100 }.into()
101 }
102 ApiResponse::Err { error } => panic!("{error}"),
103 }
104}