1mod api;
2mod cache;
3mod data;
4mod generator;
5
6use std::fs::read_to_string;
7
8use api::generic_chat_completion;
9use cache::ResolvedSourceCode;
10use dotenv::dotenv;
11
12use ignore::Walk;
13use proc_macro::TokenStream;
14use syn::{ItemFn, __private::ToTokens};
17
18use crate::generator::{generate_body_function_from_head, minimal_llm_function};
21
22#[proc_macro]
24pub fn auto_generate(item: TokenStream) -> TokenStream {
25 dotenv().ok();
26
27 let res = minimal_llm_function(item.to_string());
28 res.parse().unwrap()
31}
32
33struct RawSourceCode {
34 path: String,
35 language: String,
36 content: String,
37}
38
39#[proc_macro_attribute]
40pub fn llm_tool(args: TokenStream, input: TokenStream) -> TokenStream {
41 dotenv().ok();
42
43 println!("args: {:?}", args);
44
45 let is_live = args.to_string().contains("live");
48
49 let ast: ItemFn = syn::parse(input).expect("Failed to parse input as a function");
50
51 let cargo_toml_path = std::env::var("CARGO_MANIFEST_DIR").unwrap_or("".to_string());
52
53 println!("{:?}", cargo_toml_path);
54
55 let mut source_code = vec![];
56
57 for result in Walk::new(cargo_toml_path) {
58 match result {
59 Ok(entry) => {
60 if entry.path().is_file() {
61 if let Ok(Some(kind)) = hyperpolyglot::detect(entry.path()) {
62 let path = format!("{}", entry.path().display());
63
64 let content = read_to_string(path.clone()).unwrap();
65 if content.lines().count() > 500 {
66 continue;
67 }
68
69 println!("{}: {:?}", path, kind);
70
71 let language = kind.language().to_string();
72
73 source_code.push(RawSourceCode {
74 path,
75 content,
76 language,
77 });
78 }
79 }
80 }
81 Err(err) => println!("ERROR: {}", err),
82 }
83 }
84
85 let source_code_context = source_code
86 .iter()
87 .map(|x| {
88 format!(
89 "## {}\n```{}\n{}\n```\n",
90 x.path,
91 x.language.to_lowercase(),
92 x.content
93 )
94 })
95 .collect::<Vec<String>>()
96 .join("\n");
97
98 let system_message = format!("
101 You are an advanced AI, trained on the most modern architecture, with expertise in Rust programming. Your task is to generate the body of a Rust function based on its signature. Please adhere to these guidelines:
102
103 1. Receive the Function Signature: The signature will be provided in a standard Rust format, e.g., 'fn calculate_pi_with_n_iterations(n: u64) -> f64'. Focus on understanding the function's name, parameters, and return type.
104 2. Generate Only the Function Body: You are required to write Rust code that fulfills the requirements of the function signature. This code should be the function body only, without including the function signature or any other wrapping code.
105 3. Exclude Non-Essential Content: Your response must strictly contain valid Rust code applicable within the function's curly braces. Do not include comments, attributes, nested functions, or any redundant repetitions of the function signature. Do not include any explanation or additional text outside of the function body.
106 4. Maintain Simplicity and Clarity: Avoid external crates, unnecessary imports, or extra features like feature flags. Use standard Rust libraries and functionalities. The code should be clear, maintainable, and compile-ready.
107 5. Adhere to Rust Best Practices: Ensure that the generated code is idiomatic, efficient, and adheres to Rust standards and best practices.
108
109 Example:
110 INPUT SIGNATURE: 'fn calculate_pi_with_n_iterations(n: u64) -> f64'
111 EXPECTED OUTPUT (Function Body Only):
112 let mut pi = 0.0;
113 let mut sign = 1.0;
114 for i in 0..n {{
115 pi += sign / (2 * i + 1) as f64;
116 sign = -sign;
117 }}
118 4.0 * pi
119
120 Don't forget only respond with the function body. Don't include nature language text or explanation in your response.
121
122 Global Context:
123 {}
124 ", source_code_context);
125
126 let mut prompt_input = String::new();
127
128 let fn_header = ast.sig.to_token_stream().to_string();
129
130 for attr in ast.attrs {
131 let data = attr.to_token_stream().to_string();
132
133 prompt_input.push_str(&data);
134 prompt_input.push('\n');
135 }
136
137 prompt_input.push_str(&fn_header);
138
139 println!("prompt_input: {}", prompt_input);
140
141 let hash = md5::compute(prompt_input.as_bytes());
142
143 let hash_string = format!("{:x}", hash);
144
145 let existing_resolution = ResolvedSourceCode::load(&hash_string);
146
147 if !is_live {
148 if let Some(resolution) = existing_resolution {
149 return resolution.implementation.parse().unwrap();
150 }
151 }
152
153 let res = generic_chat_completion(system_message, prompt_input.clone()).unwrap();
154
155 println!("res: {:?}", res);
156
157 let body_str = res
158 .choices
159 .first()
160 .unwrap()
161 .message
162 .content
163 .trim()
164 .trim_matches('`')
165 .trim_matches('\'')
166 .to_string()
167 .lines()
168 .skip_while(|line| line.starts_with("rust") || line.starts_with("#["))
169 .collect::<Vec<&str>>()
170 .join("\n");
171
172 let implementation = format!(
173 "{} {{
174 {}
175 }}",
176 prompt_input.clone(),
177 body_str
178 );
179
180 println!("impl:\n {}", implementation);
181
182 let new_resolution = ResolvedSourceCode {
183 implementation: implementation.clone(),
184 hash: format!("{:x}", hash),
185 prompt_input,
186 };
187
188 new_resolution.save();
189
190 implementation.parse().unwrap()
191}
192
193#[proc_macro_attribute]
194pub fn auto_implement(args: TokenStream, input: TokenStream) -> TokenStream {
195 let ast: ItemFn = syn::parse(input).expect("Failed to parse input as a function");
196
197 let context = args.to_string();
198
199 let mut prompt_input = String::new();
200
201 let fn_header = ast.sig.to_token_stream().to_string();
202
203 for attr in ast.attrs {
204 let data = attr.to_token_stream().to_string();
205
206 prompt_input.push_str(&data);
207 prompt_input.push('\n');
208 }
209
210 prompt_input.push_str(&fn_header);
211
212 dotenv().ok();
213
214 let implemented_fn = generate_body_function_from_head(prompt_input, Some(context)).unwrap();
215
216 implemented_fn.parse().unwrap()
217}