1use proc_macro::TokenStream;
2use proc_macro_crate::{crate_name, FoundCrate};
3use quote::quote;
4use syn::{parse::Parse, parse::ParseStream, parse_macro_input, DeriveInput, Ident, Token, Type};
5
6fn get_crate_name() -> proc_macro2::TokenStream {
7 let found_crate =
8 crate_name("kernelx-core").expect("kernelx-core should be present in `Cargo.toml`");
9 match found_crate {
10 FoundCrate::Itself => quote!(crate),
11 FoundCrate::Name(name) => {
12 let ident = proc_macro2::Ident::new(&name, proc_macro2::Span::call_site());
13 quote!(#ident)
14 }
15 }
16}
17
18#[proc_macro_attribute]
20pub fn lmp_schema(_attr: TokenStream, input: TokenStream) -> TokenStream {
21 let input = parse_macro_input!(input as DeriveInput);
22 let expanded = quote! {
23 #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
24 #[serde(rename_all = "camelCase")]
25 #input
26 };
27 TokenStream::from(expanded)
28}
29
30#[proc_macro]
32pub fn provider(input: TokenStream) -> TokenStream {
33 let input = parse_macro_input!(input as ProviderDef);
34 let provider_type = &input.provider_type;
37 let api_base = &input.api_base;
38 let api_key = &input.api_key;
39 let models = &input.models;
40
41 let provider_path = match provider_type {
43 Type::Path(type_path) if type_path.path.segments.len() > 1 => {
44 quote!(#provider_type)
45 }
46 _ => {
47 quote!(crate::#provider_type)
48 }
49 };
50
51 let expanded = quote! {
52 thread_local! {
53 static PROVIDER: std::sync::OnceLock<#provider_path> = {
54 let cell = std::sync::OnceLock::new();
55 let provider = #provider_path::builder()
56 .api_base(#api_base)
57 .api_key(#api_key)
58 .models(#models)
59 .build()
60 .expect("Failed to create provider");
61 cell.set(provider).unwrap();
62 cell
63 };
64 }
65
66 fn get_provider() -> #provider_path {
67 PROVIDER.with(|p| p.get().unwrap().clone())
68 }
69 };
70
71 TokenStream::from(expanded)
72}
73
74struct ProviderDef {
76 provider_type: syn::Type,
77 api_base: syn::LitStr,
78 api_key: syn::LitStr,
79 models: syn::Expr, }
81
82impl Parse for ProviderDef {
83 fn parse(input: ParseStream) -> syn::Result<Self> {
84 let provider_type: syn::Type = input.parse()?;
85
86 input.parse::<Token![,]>()?;
87 input.parse::<syn::Ident>()?; input.parse::<Token![:]>()?;
89 let api_base: syn::LitStr = input.parse()?;
90
91 input.parse::<Token![,]>()?;
92 input.parse::<syn::Ident>()?; input.parse::<Token![:]>()?;
94 let api_key: syn::LitStr = input.parse()?;
95
96 input.parse::<Token![,]>()?;
97 input.parse::<syn::Ident>()?; input.parse::<Token![:]>()?;
99 let models: syn::Expr = input.parse()?;
100
101 Ok(ProviderDef {
102 provider_type,
103 api_base,
104 api_key,
105 models,
106 })
107 }
108}
109
110#[proc_macro_attribute]
112pub fn lmp(args: TokenStream, input: TokenStream) -> TokenStream {
113 let args = parse_macro_input!(args as LmpArgs);
114 let input_fn = parse_macro_input!(input as syn::ItemFn);
115
116 let system_prompt = input_fn
117 .attrs
118 .iter()
119 .filter(|attr| attr.path().is_ident("doc"))
120 .filter_map(|attr| {
121 if let syn::Meta::NameValue(meta) = &attr.meta {
122 if let syn::Expr::Lit(syn::ExprLit {
123 lit: syn::Lit::Str(s),
124 ..
125 }) = &meta.value
126 {
127 Some(s.value())
128 } else {
129 None
130 }
131 } else {
132 None
133 }
134 })
135 .collect::<Vec<_>>()
136 .join("\n");
137
138 let system_prompt_literal = syn::LitStr::new(&system_prompt, proc_macro2::Span::call_site());
139 let model = args.model;
140 let temperature = args.temperature.unwrap_or(0.0);
141 let max_tokens = args.max_tokens.unwrap_or(512);
142
143 let fn_vis = &input_fn.vis;
144 let fn_sig = &input_fn.sig;
145 let original_body = &input_fn.block;
146 let crate_name = get_crate_name();
147
148 let expanded = if let Some(rf_type) = args.response_format {
149 quote! {
150 #fn_vis #fn_sig {
151 use #crate_name::{StructuredLM, Result, Error};
152
153 let prompt = { #original_body };
154 let model = get_provider().get_model::<dyn StructuredLM>(#model)?
155 .system_prompt(#system_prompt_literal)
156 .temperature(#temperature)
157 .max_tokens(#max_tokens);
158
159 let schema = schemars::schema_for!(#rf_type);
160 let json = model.structured_complete(&prompt, &serde_json::to_value(schema.schema)?)
161 .await?;
162 serde_json::from_value(json).map_err(Error::Serialization)
163 }
164 }
165 } else {
166 quote! {
167 #fn_vis #fn_sig {
168 use #crate_name::{LLM, Result};
169
170 let prompt = { #original_body };
171 let model = get_provider().get_model::<dyn LLM>(#model)?
172 .system_prompt(#system_prompt_literal)
173 .temperature(#temperature)
174 .max_tokens(#max_tokens);
175
176 model.complete(&prompt).await
177 }
178 }
179 };
180
181 TokenStream::from(expanded)
182}
183
184#[derive(Default)]
186struct LmpArgs {
187 model: String,
188 temperature: Option<f32>,
189 max_tokens: Option<u32>,
190 response_format: Option<Type>,
191}
192
193impl Parse for LmpArgs {
194 fn parse(input: ParseStream) -> syn::Result<Self> {
195 let mut args = LmpArgs::default();
196 while !input.is_empty() {
197 let key: Ident = input.parse()?;
198 input.parse::<Token![=]>()?;
199 match key.to_string().as_str() {
200 "model" => args.model = input.parse::<syn::LitStr>()?.value(),
201 "temperature" => {
202 args.temperature = Some(input.parse::<syn::LitFloat>()?.base10_parse()?)
203 }
204 "max_tokens" => {
205 args.max_tokens = Some(input.parse::<syn::LitInt>()?.base10_parse()?)
206 }
207 "response_format" => args.response_format = Some(input.parse()?),
208 _ => return Err(syn::Error::new(key.span(), format!("Unknown key: {}", key))),
209 }
210 if input.peek(Token![,]) {
211 input.parse::<Token![,]>()?;
212 }
213 }
214 Ok(args)
215 }
216}
217
218#[proc_macro_attribute]
219pub fn json_schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
220 let input = parse_macro_input!(item as DeriveInput);
221 let name = &input.ident;
222
223 let expanded = quote! {
224 #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
225 #input
226
227 impl #name {
228 pub fn schema() -> serde_json::Value {
229 let schema = schemars::schema_for!(Self);
230 serde_json::to_value(schema.schema).expect("Failed to serialize schema")
231 }
232
233 pub fn from_json(value: serde_json::Value) -> Result<Self> {
234 serde_json::from_value(value).map_err(Error::Serialization)
235 }
236 }
237 };
238
239 TokenStream::from(expanded)
240}