burn_lm_macros/
lib.rs

1use std::collections::BTreeSet;
2
3use chrono::NaiveDate;
4use darling::{
5    ast::{self, NestedMeta},
6    FromDeriveInput, FromField, FromMeta,
7};
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::{parse_macro_input, punctuated::Punctuated, DeriveInput, ItemStruct};
11
12// InferenceSeverConfig ------------------------------------------------------
13
14/// This macro consumes the struct, extracts any `#[config(default = ...)]` attributes
15/// and regenerates a brand-new struct with:
16///   - `#[derive(Parser, Deserialize, Debug)]` derive macros
17///   - each field gets the passed default value of the config field attribute for clap and serde with `#[arg(...)]` and `#[serde(...)]`
18///   - add marker trait `impl InferenceServerConfig for ... {}`
19///   - generated implementation for default values compatible with both clap and serde with the Default trait using generated `fn default_<field>()` functions
20#[proc_macro_attribute]
21pub fn inference_server_config(_attr: TokenStream, item: TokenStream) -> TokenStream {
22    let input_struct = parse_macro_input!(item as ItemStruct);
23    match InferenceServerConfigReceiver::from_item_struct(&input_struct) {
24        Ok(receiver) => receiver.expand(),
25        Err(e) => e.write_errors().into(),
26    }
27}
28
29#[derive(FromDeriveInput)]
30#[darling(
31    attributes(config),
32    supports(struct_named),
33    forward_attrs(allow, doc, cfg)
34)]
35struct InferenceServerConfigReceiver {
36    ident: syn::Ident,
37    vis: syn::Visibility,
38    generics: syn::Generics,
39    attrs: Vec<syn::Attribute>,
40    data: ast::Data<(), InferenceServerConfigField>,
41}
42
43#[derive(FromField)]
44#[darling(attributes(config), forward_attrs(doc))]
45struct InferenceServerConfigField {
46    ident: Option<syn::Ident>,
47    ty: syn::Type,
48    attrs: Vec<syn::Attribute>,
49    #[darling(default)]
50    default: Option<syn::Lit>,
51    openwebui_param: Option<syn::LitStr>,
52}
53
54impl InferenceServerConfigReceiver {
55    /// Darling works with derived structure only
56    /// so we convert the struct AST form the attribute macro to a derive input AST
57    fn from_item_struct(item: &syn::ItemStruct) -> darling::Result<Self> {
58        let di = syn::DeriveInput {
59            attrs: item.attrs.clone(),
60            vis: item.vis.clone(),
61            ident: item.ident.clone(),
62            generics: item.generics.clone(),
63            data: syn::Data::Struct(syn::DataStruct {
64                fields: item.fields.clone(),
65                struct_token: item.struct_token,
66                semi_token: item.semi_token,
67            }),
68        };
69        // now we can call Darling
70        InferenceServerConfigReceiver::from_derive_input(&di)
71    }
72
73    fn expand(&self) -> TokenStream {
74        let struct_name = &self.ident;
75        let vis = &self.vis;
76        let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
77        let struct_attrs = &self.attrs;
78        // Extract the named fields
79        let fields = match &self.data {
80            ast::Data::Struct(fields) => &fields.fields,
81            _ => unreachable!("Should only be a named struct."),
82        };
83
84        // Add clap and serde attributes for default value
85        let field_defs = fields.iter().map(|f| {
86            let field_ident = f.ident.as_ref().unwrap();
87            let field_ty = &f.ty;
88            let docs = &f.attrs;
89            // We need to wrap the default values into functions so that we can set the default value
90            // for both clap and serde
91            let default_fn_name =
92                syn::Ident::new(&format!("default_{field_ident}"), field_ident.span());
93            // we need the serde default value as a string
94            let serde_default_string = format!("{struct_name}::{default_fn_name}");
95            let serde_default_lit_str =
96                syn::LitStr::new(&serde_default_string, proc_macro2::Span::call_site());
97            // map config to open webui parameter name
98            let serde_rename = match &f.openwebui_param {
99                Some(lit) => lit.clone(),
100                None => syn::LitStr::new(&field_ident.to_string(), proc_macro2::Span::call_site()),
101            };
102            // rewritten field
103            quote! {
104                #(#docs)*
105                #[arg(long, default_value_t = #struct_name::#default_fn_name())]
106                #[serde(default = #serde_default_lit_str, rename = #serde_rename)]
107                pub #field_ident: #field_ty,
108            }
109        });
110
111        // Generate the wrapper functions for default values
112        let default_fns = fields.iter().map(|f| {
113            let field_ident = f.ident.as_ref().unwrap();
114            let fn_name = syn::Ident::new(&format!("default_{field_ident}"), field_ident.span());
115            let field_ty = &f.ty;
116
117            if let Some(lit) = &f.default {
118                // Example:
119                //   #[config(default = 0.9)]
120                // generates
121                //   fn default_top_p() -> f64 { 0.9 }
122                quote! {
123                    fn #fn_name() -> #field_ty {
124                        #lit
125                    }
126                }
127            } else {
128                // fallback to the type's default implementation if there is no config attribute
129                quote! {
130                    fn #fn_name() -> #field_ty {
131                        <#field_ty as ::std::default::Default>::default()
132                    }
133                }
134            }
135        });
136
137        // Generate Default trait implementation by making use of the function wrappers
138        let default_inits = fields.iter().map(|f| {
139            let field_ident = f.ident.as_ref().unwrap();
140            let fn_name = syn::Ident::new(&format!("default_{field_ident}"), field_ident.span());
141            quote! {
142                #field_ident: Self::#fn_name(),
143            }
144        });
145
146        // Output
147        let expanded = quote! {
148            // Rewritten struct
149            #[derive(Clone, Parser, Deserialize, ::std::fmt::Debug)]
150            #(#struct_attrs)*
151            #vis struct #struct_name #impl_generics #where_clause {
152                #(#field_defs)*
153            }
154            // Marker trait
155            impl #impl_generics InferenceServerConfig for #struct_name #ty_generics #where_clause {}
156            // Function wrappers for default values
157            impl #impl_generics #struct_name #ty_generics #where_clause {
158                #(#default_fns)*
159            }
160            // Default trait implementation
161            impl #impl_generics ::std::default::Default for #struct_name #ty_generics #where_clause {
162                fn default() -> Self {
163                    Self {
164                        #(#default_inits)*
165                    }
166                }
167            }
168        };
169        expanded.into()
170    }
171}
172
173// InferenceServer -----------------------------------------------------------
174
175#[derive(FromDeriveInput)]
176#[darling(attributes(inference_server))]
177struct InferenceServerData {
178    model_name: Option<String>,
179    model_cli_param_name: Option<String>,
180    model_creation_date: Option<String>,
181    owned_by: Option<String>,
182    data: darling::ast::Data<darling::util::Ignored, InferenceServerField>,
183}
184
185#[derive(FromField)]
186struct InferenceServerField {
187    ident: Option<syn::Ident>,
188    ty: syn::Type,
189}
190
191#[proc_macro_derive(InferenceServer, attributes(inference_server))]
192pub fn inference_server(input: TokenStream) -> TokenStream {
193    let input = parse_macro_input!(input as DeriveInput);
194    let input_ident = &input.ident;
195    let (input_generics_impl, input_generics_type, input_generics_where_clause) =
196        &input.generics.split_for_impl();
197    let receiver = match InferenceServerData::from_derive_input(&input) {
198        Ok(r) => r,
199        Err(e) => return e.write_errors().into(),
200    };
201
202    // Verify that the struct has a 'config' field and retrieve its type
203    let config_ty = match receiver.data {
204        ast::Data::Struct(fields) => {
205            let config = fields
206                .fields
207                .iter()
208                .find(|f| f.ident.as_ref().is_some_and(|i| *i == "config"));
209            match config {
210                Some(field) => field.ty.clone(),
211                None => {
212                    let err_msg = "The server struct must have a field named 'config'.";
213                    return TokenStream::from(quote! { compile_error!(#err_msg) });
214                }
215            }
216        }
217        _ => {
218            let err_msg = "The server type must be a struct and not an enum.";
219            return TokenStream::from(quote! { compile_error!(#err_msg) });
220        }
221    };
222
223    // Retrieve the config type
224
225    // retrieve plugin info
226    // handle model_name
227    let model_name = match receiver.model_name {
228        Some(value) => value,
229        None => {
230            let err_msg = "You must provide a 'model_name' using '#[inference_server(model_name=\"MyModel\")]'";
231            return TokenStream::from(quote! { compile_error!(#err_msg) });
232        }
233    };
234    // handle model CLI param name
235    let model_cli_param_name = match receiver.model_cli_param_name {
236        Some(ref param_name) => {
237            let param_name = param_name.to_lowercase().replace(" ", "-");
238            quote! { #param_name }
239        }
240        None => {
241            let param_name = model_name.to_lowercase().replace(" ", "-");
242            quote! { #param_name }
243        }
244    };
245    // handle model_creation_date
246    let model_creation_date = match receiver.model_creation_date {
247        Some(ref date_str) => {
248            if NaiveDate::parse_from_str(date_str, "%m/%d/%Y").is_err() {
249                let err_msg = format!(
250                    "Invalid 'model_creation_date': {date_str}. Must be in MM/DD/YYYY format."
251                );
252                return TokenStream::from(quote! { compile_error!(#err_msg) });
253            }
254            quote! { #date_str }
255        }
256        None => {
257            let err_msg = "You must provide a 'model_creation_date' using '#[inference_server(model_creation_date=\"MM/DD/YYYY\")]'";
258            return TokenStream::from(quote! { compile_error!(#err_msg) });
259        }
260    };
261    // handle owned_by
262    let owned_by = match receiver.owned_by {
263        Some(ref owner) => quote! { #owner },
264        None => {
265            let err_msg = "You must provide an 'owned_by' attribute using '#[inference_server(owned_by=\"OwnerName\")]'";
266            return TokenStream::from(quote! { compile_error!(#err_msg) });
267        }
268    };
269
270    let expanded = quote! {
271        impl #input_generics_impl #input_ident #input_generics_type #input_generics_where_clause {
272            pub const fn model_name() -> &'static str { #model_name }
273            pub const fn model_cli_param_name() -> &'static str { #model_cli_param_name }
274            pub const fn model_creation_date() -> &'static str { #model_creation_date }
275            pub const fn owned_by() -> &'static str { #owned_by }
276        }
277
278        impl #input_generics_impl ServerConfigParsing for #input_ident #input_generics_type #input_generics_where_clause {
279            type Config = #config_ty;
280
281            fn parse_cli_config(&mut self, args: &clap::ArgMatches) {
282                self.config = Self::Config::from_arg_matches(args)
283                    .expect("Should be able to parse arguments from CLI");
284            }
285
286            fn parse_json_config(&mut self, json: &str) {
287                self.config = serde_json::from_str(json)
288                    .expect("Should be able to parse JSON");
289            }
290        }
291    };
292    TokenStream::from(expanded)
293}
294
295// Register inference servers ------------------------------------------------
296
297#[derive(Debug, FromMeta)]
298struct InferenceServerEntry {
299    crate_namespace: String,
300    #[darling(rename = "server_type")]
301    server_ty: String,
302}
303
304#[derive(Debug, Default, FromMeta)]
305#[darling(default)]
306struct InferenceServerEntries {
307    #[darling(default, rename = "server", multiple)]
308    servers: Vec<InferenceServerEntry>,
309}
310
311/// This macro implements the new() function for the Registry struct given a
312/// list of Inference Server entries.
313/// This macro also defines some type aliases to make the generated code more readable.
314/// For instance for "MyModel" server_name the macro will define the following types:
315///   - MyModelS = "type of the server passed as server_type"
316///   - MyModelC = InferenceClient<MyModelS, Channel<MyModelS>
317#[proc_macro_attribute]
318pub fn inference_server_registry(attr: TokenStream, item: TokenStream) -> TokenStream {
319    let parsed_attr =
320        parse_macro_input!(attr with Punctuated::<NestedMeta, syn::Token![,]>::parse_terminated);
321    let attributes_meta: Vec<NestedMeta> = parsed_attr.into_iter().collect();
322    // Use Darling to parse the attributes meta into InferenceServerEntries
323    let registry_args = match InferenceServerEntries::from_list(&attributes_meta) {
324        Ok(args) => args,
325        Err(e) => {
326            return e.write_errors().into();
327        }
328    };
329    let input_struct = parse_macro_input!(item as ItemStruct);
330    let struct_ident = &input_struct.ident;
331    // generate hash map entry for each server entry
332    let mut registry_entries = Vec::new();
333    let mut crate_namespaces = BTreeSet::new();
334    for server in &registry_args.servers {
335        // crate namespaces set
336        crate_namespaces.insert(&server.crate_namespace);
337        // registry hash map entry
338        let server_ty_str = &server.server_ty;
339        let server_ty: syn::Type = match syn::parse_str(server_ty_str) {
340            Ok(ty) => ty,
341            Err(e) => {
342                // Server type is unknown
343                let msg = format!("Invalid server_type `{server_ty_str}`: {e}");
344                return syn::Error::new_spanned(
345                    syn::Lit::Str(syn::LitStr::new(
346                        server_ty_str,
347                        proc_macro2::Span::call_site(),
348                    )),
349                    msg,
350                )
351                .to_compile_error()
352                .into();
353            }
354        };
355        let registry_entry = quote! {
356            {
357                type S = #server_ty;
358                type C = InferenceClient<#server_ty, Channel<#server_ty>>;
359                map.insert(
360                    S::model_name(),
361                    Box::new(C::new(
362                        S::model_name(),
363                        S::model_cli_param_name(),
364                        S::model_creation_date(),
365                        S::owned_by(),
366                        <S as ServerConfigParsing>::Config::command,
367                        Channel::<S>::new(),
368                    )),
369                );
370            }
371        };
372        registry_entries.push(registry_entry);
373    }
374    // Imports
375    let mut crate_imports = Vec::new();
376    for namespace in crate_namespaces {
377        let crate_path: syn::Path =
378            syn::parse_str(namespace).expect("crate namespace should be a valid path");
379        let use_crate = quote! {
380            pub use #crate_path::*;
381        };
382        crate_imports.push(use_crate);
383    }
384
385    // Output
386    let output = quote! {
387        // imports
388        #(#crate_imports)*
389        // Original struct
390        #input_struct
391        // new() implementation
392        impl #struct_ident {
393            pub fn new() -> Self {
394                let mut map: DynClients = ::std::collections::HashMap::new();
395                #(#registry_entries)*
396                Self {
397                    clients: ::std::sync::Arc::new(map),
398                }
399            }
400            pub fn get(&self) -> &DynClients {
401                &self.clients
402            }
403        }
404
405        // Implement default to make clippy happy
406        impl ::std::default::Default for #struct_ident {
407            fn default() -> Self {
408                Self::new()
409            }
410        }
411    };
412
413    output.into()
414}