1use std::collections::BTreeSet;
2
3use chrono::NaiveDate;
4use darling::{
5 ast::{self, NestedMeta},
6 FromDeriveInput, FromField, FromMeta,
7};
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::{parse_macro_input, punctuated::Punctuated, DeriveInput, ItemStruct};
11
12#[proc_macro_attribute]
21pub fn inference_server_config(_attr: TokenStream, item: TokenStream) -> TokenStream {
22 let input_struct = parse_macro_input!(item as ItemStruct);
23 match InferenceServerConfigReceiver::from_item_struct(&input_struct) {
24 Ok(receiver) => receiver.expand(),
25 Err(e) => e.write_errors().into(),
26 }
27}
28
29#[derive(FromDeriveInput)]
30#[darling(
31 attributes(config),
32 supports(struct_named),
33 forward_attrs(allow, doc, cfg)
34)]
35struct InferenceServerConfigReceiver {
36 ident: syn::Ident,
37 vis: syn::Visibility,
38 generics: syn::Generics,
39 attrs: Vec<syn::Attribute>,
40 data: ast::Data<(), InferenceServerConfigField>,
41}
42
43#[derive(FromField)]
44#[darling(attributes(config), forward_attrs(doc))]
45struct InferenceServerConfigField {
46 ident: Option<syn::Ident>,
47 ty: syn::Type,
48 attrs: Vec<syn::Attribute>,
49 #[darling(default)]
50 default: Option<syn::Lit>,
51 openwebui_param: Option<syn::LitStr>,
52}
53
54impl InferenceServerConfigReceiver {
55 fn from_item_struct(item: &syn::ItemStruct) -> darling::Result<Self> {
58 let di = syn::DeriveInput {
59 attrs: item.attrs.clone(),
60 vis: item.vis.clone(),
61 ident: item.ident.clone(),
62 generics: item.generics.clone(),
63 data: syn::Data::Struct(syn::DataStruct {
64 fields: item.fields.clone(),
65 struct_token: item.struct_token,
66 semi_token: item.semi_token,
67 }),
68 };
69 InferenceServerConfigReceiver::from_derive_input(&di)
71 }
72
73 fn expand(&self) -> TokenStream {
74 let struct_name = &self.ident;
75 let vis = &self.vis;
76 let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
77 let struct_attrs = &self.attrs;
78 let fields = match &self.data {
80 ast::Data::Struct(fields) => &fields.fields,
81 _ => unreachable!("Should only be a named struct."),
82 };
83
84 let field_defs = fields.iter().map(|f| {
86 let field_ident = f.ident.as_ref().unwrap();
87 let field_ty = &f.ty;
88 let docs = &f.attrs;
89 let default_fn_name =
92 syn::Ident::new(&format!("default_{field_ident}"), field_ident.span());
93 let serde_default_string = format!("{struct_name}::{default_fn_name}");
95 let serde_default_lit_str =
96 syn::LitStr::new(&serde_default_string, proc_macro2::Span::call_site());
97 let serde_rename = match &f.openwebui_param {
99 Some(lit) => lit.clone(),
100 None => syn::LitStr::new(&field_ident.to_string(), proc_macro2::Span::call_site()),
101 };
102 quote! {
104 #(#docs)*
105 #[arg(long, default_value_t = #struct_name::#default_fn_name())]
106 #[serde(default = #serde_default_lit_str, rename = #serde_rename)]
107 pub #field_ident: #field_ty,
108 }
109 });
110
111 let default_fns = fields.iter().map(|f| {
113 let field_ident = f.ident.as_ref().unwrap();
114 let fn_name = syn::Ident::new(&format!("default_{field_ident}"), field_ident.span());
115 let field_ty = &f.ty;
116
117 if let Some(lit) = &f.default {
118 quote! {
123 fn #fn_name() -> #field_ty {
124 #lit
125 }
126 }
127 } else {
128 quote! {
130 fn #fn_name() -> #field_ty {
131 <#field_ty as ::std::default::Default>::default()
132 }
133 }
134 }
135 });
136
137 let default_inits = fields.iter().map(|f| {
139 let field_ident = f.ident.as_ref().unwrap();
140 let fn_name = syn::Ident::new(&format!("default_{field_ident}"), field_ident.span());
141 quote! {
142 #field_ident: Self::#fn_name(),
143 }
144 });
145
146 let expanded = quote! {
148 #[derive(Clone, Parser, Deserialize, ::std::fmt::Debug)]
150 #(#struct_attrs)*
151 #vis struct #struct_name #impl_generics #where_clause {
152 #(#field_defs)*
153 }
154 impl #impl_generics InferenceServerConfig for #struct_name #ty_generics #where_clause {}
156 impl #impl_generics #struct_name #ty_generics #where_clause {
158 #(#default_fns)*
159 }
160 impl #impl_generics ::std::default::Default for #struct_name #ty_generics #where_clause {
162 fn default() -> Self {
163 Self {
164 #(#default_inits)*
165 }
166 }
167 }
168 };
169 expanded.into()
170 }
171}
172
173#[derive(FromDeriveInput)]
176#[darling(attributes(inference_server))]
177struct InferenceServerData {
178 model_name: Option<String>,
179 model_cli_param_name: Option<String>,
180 model_creation_date: Option<String>,
181 owned_by: Option<String>,
182 data: darling::ast::Data<darling::util::Ignored, InferenceServerField>,
183}
184
185#[derive(FromField)]
186struct InferenceServerField {
187 ident: Option<syn::Ident>,
188 ty: syn::Type,
189}
190
191#[proc_macro_derive(InferenceServer, attributes(inference_server))]
192pub fn inference_server(input: TokenStream) -> TokenStream {
193 let input = parse_macro_input!(input as DeriveInput);
194 let input_ident = &input.ident;
195 let (input_generics_impl, input_generics_type, input_generics_where_clause) =
196 &input.generics.split_for_impl();
197 let receiver = match InferenceServerData::from_derive_input(&input) {
198 Ok(r) => r,
199 Err(e) => return e.write_errors().into(),
200 };
201
202 let config_ty = match receiver.data {
204 ast::Data::Struct(fields) => {
205 let config = fields
206 .fields
207 .iter()
208 .find(|f| f.ident.as_ref().is_some_and(|i| *i == "config"));
209 match config {
210 Some(field) => field.ty.clone(),
211 None => {
212 let err_msg = "The server struct must have a field named 'config'.";
213 return TokenStream::from(quote! { compile_error!(#err_msg) });
214 }
215 }
216 }
217 _ => {
218 let err_msg = "The server type must be a struct and not an enum.";
219 return TokenStream::from(quote! { compile_error!(#err_msg) });
220 }
221 };
222
223 let model_name = match receiver.model_name {
228 Some(value) => value,
229 None => {
230 let err_msg = "You must provide a 'model_name' using '#[inference_server(model_name=\"MyModel\")]'";
231 return TokenStream::from(quote! { compile_error!(#err_msg) });
232 }
233 };
234 let model_cli_param_name = match receiver.model_cli_param_name {
236 Some(ref param_name) => {
237 let param_name = param_name.to_lowercase().replace(" ", "-");
238 quote! { #param_name }
239 }
240 None => {
241 let param_name = model_name.to_lowercase().replace(" ", "-");
242 quote! { #param_name }
243 }
244 };
245 let model_creation_date = match receiver.model_creation_date {
247 Some(ref date_str) => {
248 if NaiveDate::parse_from_str(date_str, "%m/%d/%Y").is_err() {
249 let err_msg = format!(
250 "Invalid 'model_creation_date': {date_str}. Must be in MM/DD/YYYY format."
251 );
252 return TokenStream::from(quote! { compile_error!(#err_msg) });
253 }
254 quote! { #date_str }
255 }
256 None => {
257 let err_msg = "You must provide a 'model_creation_date' using '#[inference_server(model_creation_date=\"MM/DD/YYYY\")]'";
258 return TokenStream::from(quote! { compile_error!(#err_msg) });
259 }
260 };
261 let owned_by = match receiver.owned_by {
263 Some(ref owner) => quote! { #owner },
264 None => {
265 let err_msg = "You must provide an 'owned_by' attribute using '#[inference_server(owned_by=\"OwnerName\")]'";
266 return TokenStream::from(quote! { compile_error!(#err_msg) });
267 }
268 };
269
270 let expanded = quote! {
271 impl #input_generics_impl #input_ident #input_generics_type #input_generics_where_clause {
272 pub const fn model_name() -> &'static str { #model_name }
273 pub const fn model_cli_param_name() -> &'static str { #model_cli_param_name }
274 pub const fn model_creation_date() -> &'static str { #model_creation_date }
275 pub const fn owned_by() -> &'static str { #owned_by }
276 }
277
278 impl #input_generics_impl ServerConfigParsing for #input_ident #input_generics_type #input_generics_where_clause {
279 type Config = #config_ty;
280
281 fn parse_cli_config(&mut self, args: &clap::ArgMatches) {
282 self.config = Self::Config::from_arg_matches(args)
283 .expect("Should be able to parse arguments from CLI");
284 }
285
286 fn parse_json_config(&mut self, json: &str) {
287 self.config = serde_json::from_str(json)
288 .expect("Should be able to parse JSON");
289 }
290 }
291 };
292 TokenStream::from(expanded)
293}
294
295#[derive(Debug, FromMeta)]
298struct InferenceServerEntry {
299 crate_namespace: String,
300 #[darling(rename = "server_type")]
301 server_ty: String,
302}
303
304#[derive(Debug, Default, FromMeta)]
305#[darling(default)]
306struct InferenceServerEntries {
307 #[darling(default, rename = "server", multiple)]
308 servers: Vec<InferenceServerEntry>,
309}
310
311#[proc_macro_attribute]
318pub fn inference_server_registry(attr: TokenStream, item: TokenStream) -> TokenStream {
319 let parsed_attr =
320 parse_macro_input!(attr with Punctuated::<NestedMeta, syn::Token![,]>::parse_terminated);
321 let attributes_meta: Vec<NestedMeta> = parsed_attr.into_iter().collect();
322 let registry_args = match InferenceServerEntries::from_list(&attributes_meta) {
324 Ok(args) => args,
325 Err(e) => {
326 return e.write_errors().into();
327 }
328 };
329 let input_struct = parse_macro_input!(item as ItemStruct);
330 let struct_ident = &input_struct.ident;
331 let mut registry_entries = Vec::new();
333 let mut crate_namespaces = BTreeSet::new();
334 for server in ®istry_args.servers {
335 crate_namespaces.insert(&server.crate_namespace);
337 let server_ty_str = &server.server_ty;
339 let server_ty: syn::Type = match syn::parse_str(server_ty_str) {
340 Ok(ty) => ty,
341 Err(e) => {
342 let msg = format!("Invalid server_type `{server_ty_str}`: {e}");
344 return syn::Error::new_spanned(
345 syn::Lit::Str(syn::LitStr::new(
346 server_ty_str,
347 proc_macro2::Span::call_site(),
348 )),
349 msg,
350 )
351 .to_compile_error()
352 .into();
353 }
354 };
355 let registry_entry = quote! {
356 {
357 type S = #server_ty;
358 type C = InferenceClient<#server_ty, Channel<#server_ty>>;
359 map.insert(
360 S::model_name(),
361 Box::new(C::new(
362 S::model_name(),
363 S::model_cli_param_name(),
364 S::model_creation_date(),
365 S::owned_by(),
366 <S as ServerConfigParsing>::Config::command,
367 Channel::<S>::new(),
368 )),
369 );
370 }
371 };
372 registry_entries.push(registry_entry);
373 }
374 let mut crate_imports = Vec::new();
376 for namespace in crate_namespaces {
377 let crate_path: syn::Path =
378 syn::parse_str(namespace).expect("crate namespace should be a valid path");
379 let use_crate = quote! {
380 pub use #crate_path::*;
381 };
382 crate_imports.push(use_crate);
383 }
384
385 let output = quote! {
387 #(#crate_imports)*
389 #input_struct
391 impl #struct_ident {
393 pub fn new() -> Self {
394 let mut map: DynClients = ::std::collections::HashMap::new();
395 #(#registry_entries)*
396 Self {
397 clients: ::std::sync::Arc::new(map),
398 }
399 }
400 pub fn get(&self) -> &DynClients {
401 &self.clients
402 }
403 }
404
405 impl ::std::default::Default for #struct_ident {
407 fn default() -> Self {
408 Self::new()
409 }
410 }
411 };
412
413 output.into()
414}