kernelx_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro_crate::{crate_name, FoundCrate};
3use quote::quote;
4use syn::{parse::Parse, parse::ParseStream, parse_macro_input, DeriveInput, Ident, Token, Type};
5
6fn get_crate_name() -> proc_macro2::TokenStream {
7    let found_crate =
8        crate_name("kernelx-core").expect("kernelx-core should be present in `Cargo.toml`");
9    match found_crate {
10        FoundCrate::Itself => quote!(crate),
11        FoundCrate::Name(name) => {
12            let ident = proc_macro2::Ident::new(&name, proc_macro2::Span::call_site());
13            quote!(#ident)
14        }
15    }
16}
17
18// Schema derive macro
19#[proc_macro_attribute]
20pub fn lmp_schema(_attr: TokenStream, input: TokenStream) -> TokenStream {
21    let input = parse_macro_input!(input as DeriveInput);
22    let expanded = quote! {
23        #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
24        #[serde(rename_all = "camelCase")]
25        #input
26    };
27    TokenStream::from(expanded)
28}
29
30// Provider configuration macro
31#[proc_macro]
32pub fn provider(input: TokenStream) -> TokenStream {
33    let input = parse_macro_input!(input as ProviderDef);
34    // let crate_name = get_crate_name();
35
36    let provider_type = &input.provider_type;
37    let api_base = &input.api_base;
38    let api_key = &input.api_key;
39    let models = &input.models;
40
41    // Check if the type is a path with segments
42    let provider_path = match provider_type {
43        Type::Path(type_path) if type_path.path.segments.len() > 1 => {
44            quote!(#provider_type)
45        }
46        _ => {
47            quote!(crate::#provider_type)
48        }
49    };
50
51    let expanded = quote! {
52        thread_local! {
53            static PROVIDER: std::sync::OnceLock<#provider_path> = {
54                let cell = std::sync::OnceLock::new();
55                let provider = #provider_path::builder()
56                    .api_base(#api_base)
57                    .api_key(#api_key)
58                    .models(#models)
59                    .build()
60                    .expect("Failed to create provider");
61                cell.set(provider).unwrap();
62                cell
63            };
64        }
65
66        fn get_provider() -> #provider_path {
67            PROVIDER.with(|p| p.get().unwrap().clone())
68        }
69    };
70
71    TokenStream::from(expanded)
72}
73
74// Provider definition parser
75struct ProviderDef {
76    provider_type: syn::Type,
77    api_base: syn::LitStr,
78    api_key: syn::LitStr,
79    models: syn::Expr, // Using Expr to capture the entire models![] macro invocation
80}
81
82impl Parse for ProviderDef {
83    fn parse(input: ParseStream) -> syn::Result<Self> {
84        let provider_type: syn::Type = input.parse()?;
85
86        input.parse::<Token![,]>()?;
87        input.parse::<syn::Ident>()?; // parse "api_base:"
88        input.parse::<Token![:]>()?;
89        let api_base: syn::LitStr = input.parse()?;
90
91        input.parse::<Token![,]>()?;
92        input.parse::<syn::Ident>()?; // parse "api_key:"
93        input.parse::<Token![:]>()?;
94        let api_key: syn::LitStr = input.parse()?;
95
96        input.parse::<Token![,]>()?;
97        input.parse::<syn::Ident>()?; // parse "models:"
98        input.parse::<Token![:]>()?;
99        let models: syn::Expr = input.parse()?;
100
101        Ok(ProviderDef {
102            provider_type,
103            api_base,
104            api_key,
105            models,
106        })
107    }
108}
109
110// LMP function macro
111#[proc_macro_attribute]
112pub fn lmp(args: TokenStream, input: TokenStream) -> TokenStream {
113    let args = parse_macro_input!(args as LmpArgs);
114    let input_fn = parse_macro_input!(input as syn::ItemFn);
115
116    let system_prompt = input_fn
117        .attrs
118        .iter()
119        .filter(|attr| attr.path().is_ident("doc"))
120        .filter_map(|attr| {
121            if let syn::Meta::NameValue(meta) = &attr.meta {
122                if let syn::Expr::Lit(syn::ExprLit {
123                    lit: syn::Lit::Str(s),
124                    ..
125                }) = &meta.value
126                {
127                    Some(s.value())
128                } else {
129                    None
130                }
131            } else {
132                None
133            }
134        })
135        .collect::<Vec<_>>()
136        .join("\n");
137
138    let system_prompt_literal = syn::LitStr::new(&system_prompt, proc_macro2::Span::call_site());
139    let model = args.model;
140    let temperature = args.temperature.unwrap_or(0.0);
141    let max_tokens = args.max_tokens.unwrap_or(512);
142
143    let fn_vis = &input_fn.vis;
144    let fn_sig = &input_fn.sig;
145    let original_body = &input_fn.block;
146    let crate_name = get_crate_name();
147
148    let expanded = if let Some(rf_type) = args.response_format {
149        quote! {
150            #fn_vis #fn_sig {
151                use #crate_name::{StructuredLM, Result, Error};
152
153                let prompt = { #original_body };
154                let model = get_provider().get_model::<dyn StructuredLM>(#model)?
155                    .system_prompt(#system_prompt_literal)
156                    .temperature(#temperature)
157                    .max_tokens(#max_tokens);
158
159                let schema = schemars::schema_for!(#rf_type);
160                let json = model.structured_complete(&prompt, &serde_json::to_value(schema.schema)?)
161                    .await?;
162                serde_json::from_value(json).map_err(Error::Serialization)
163            }
164        }
165    } else {
166        quote! {
167            #fn_vis #fn_sig {
168                use #crate_name::{LLM, Result};
169
170                let prompt = { #original_body };
171                let model = get_provider().get_model::<dyn LLM>(#model)?
172                    .system_prompt(#system_prompt_literal)
173                    .temperature(#temperature)
174                    .max_tokens(#max_tokens);
175
176                model.complete(&prompt).await
177            }
178        }
179    };
180
181    TokenStream::from(expanded)
182}
183
184// Args parser for LMP
185#[derive(Default)]
186struct LmpArgs {
187    model: String,
188    temperature: Option<f32>,
189    max_tokens: Option<u32>,
190    response_format: Option<Type>,
191}
192
193impl Parse for LmpArgs {
194    fn parse(input: ParseStream) -> syn::Result<Self> {
195        let mut args = LmpArgs::default();
196        while !input.is_empty() {
197            let key: Ident = input.parse()?;
198            input.parse::<Token![=]>()?;
199            match key.to_string().as_str() {
200                "model" => args.model = input.parse::<syn::LitStr>()?.value(),
201                "temperature" => {
202                    args.temperature = Some(input.parse::<syn::LitFloat>()?.base10_parse()?)
203                }
204                "max_tokens" => {
205                    args.max_tokens = Some(input.parse::<syn::LitInt>()?.base10_parse()?)
206                }
207                "response_format" => args.response_format = Some(input.parse()?),
208                _ => return Err(syn::Error::new(key.span(), format!("Unknown key: {}", key))),
209            }
210            if input.peek(Token![,]) {
211                input.parse::<Token![,]>()?;
212            }
213        }
214        Ok(args)
215    }
216}
217
218#[proc_macro_attribute]
219pub fn json_schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
220    let input = parse_macro_input!(item as DeriveInput);
221    let name = &input.ident;
222
223    let expanded = quote! {
224        #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
225        #input
226
227        impl #name {
228            pub fn schema() -> serde_json::Value {
229                let schema = schemars::schema_for!(Self);
230                serde_json::to_value(schema.schema).expect("Failed to serialize schema")
231            }
232
233            pub fn from_json(value: serde_json::Value) -> Result<Self> {
234                serde_json::from_value(value).map_err(Error::Serialization)
235            }
236        }
237    };
238
239    TokenStream::from(expanded)
240}