auto_rust/
lib.rs

1mod api;
2mod cache;
3mod data;
4mod generator;
5
6use std::fs::read_to_string;
7
8use api::generic_chat_completion;
9use cache::ResolvedSourceCode;
10use dotenv::dotenv;
11
12use ignore::Walk;
13use proc_macro::TokenStream;
14// use quote::quote;
15
16use syn::{ItemFn, __private::ToTokens};
17
18// use git2::Repository;
19
20use crate::generator::{generate_body_function_from_head, minimal_llm_function};
21
22/// This macro gets an input like "String, "This is a llm generated function"" and returns a function that returns a String
23#[proc_macro]
24pub fn auto_generate(item: TokenStream) -> TokenStream {
25    dotenv().ok();
26
27    let res = minimal_llm_function(item.to_string());
28    // println!("{:?}", res);
29
30    res.parse().unwrap()
31}
32
33struct RawSourceCode {
34    path: String,
35    language: String,
36    content: String,
37}
38
39#[proc_macro_attribute]
40pub fn llm_tool(args: TokenStream, input: TokenStream) -> TokenStream {
41    dotenv().ok();
42
43    println!("args: {:?}", args);
44
45    // detect #[llm_tool(live)] arg
46
47    let is_live = args.to_string().contains("live");
48
49    let ast: ItemFn = syn::parse(input).expect("Failed to parse input as a function");
50
51    let cargo_toml_path = std::env::var("CARGO_MANIFEST_DIR").unwrap_or("".to_string());
52
53    println!("{:?}", cargo_toml_path);
54
55    let mut source_code = vec![];
56
57    for result in Walk::new(cargo_toml_path) {
58        match result {
59            Ok(entry) => {
60                if entry.path().is_file() {
61                    if let Ok(Some(kind)) = hyperpolyglot::detect(entry.path()) {
62                        let path = format!("{}", entry.path().display());
63
64                        let content = read_to_string(path.clone()).unwrap();
65                        if content.lines().count() > 500 {
66                            continue;
67                        }
68
69                        println!("{}: {:?}", path, kind);
70
71                        let language = kind.language().to_string();
72
73                        source_code.push(RawSourceCode {
74                            path,
75                            content,
76                            language,
77                        });
78                    }
79                }
80            }
81            Err(err) => println!("ERROR: {}", err),
82        }
83    }
84
85    let source_code_context = source_code
86        .iter()
87        .map(|x| {
88            format!(
89                "## {}\n```{}\n{}\n```\n",
90                x.path,
91                x.language.to_lowercase(),
92                x.content
93            )
94        })
95        .collect::<Vec<String>>()
96        .join("\n");
97
98    // println!("{}", source_code_context);
99
100    let system_message = format!("
101    You are an advanced AI, trained on the most modern  architecture, with expertise in Rust programming. Your task is to generate the body of a Rust function based on its signature. Please adhere to these guidelines:
102    
103    1. Receive the Function Signature: The signature will be provided in a standard Rust format, e.g., 'fn calculate_pi_with_n_iterations(n: u64) -> f64'. Focus on understanding the function's name, parameters, and return type.
104    2. Generate Only the Function Body: You are required to write Rust code that fulfills the requirements of the function signature. This code should be the function body only, without including the function signature or any other wrapping code.
105    3. Exclude Non-Essential Content: Your response must strictly contain valid Rust code applicable within the function's curly braces. Do not include comments, attributes, nested functions, or any redundant repetitions of the function signature. Do not include any explanation or additional text outside of the function body.
106    4. Maintain Simplicity and Clarity: Avoid external crates, unnecessary imports, or extra features like feature flags. Use standard Rust libraries and functionalities. The code should be clear, maintainable, and compile-ready.
107    5. Adhere to Rust Best Practices: Ensure that the generated code is idiomatic, efficient, and adheres to Rust standards and best practices.
108    
109    Example:
110    INPUT SIGNATURE: 'fn calculate_pi_with_n_iterations(n: u64) -> f64'
111    EXPECTED OUTPUT (Function Body Only):
112        let mut pi = 0.0;
113        let mut sign = 1.0;
114        for i in 0..n {{
115            pi += sign / (2 * i + 1) as f64;
116            sign = -sign;
117        }}
118        4.0 * pi
119    
120    Don't forget only respond with the function body. Don't include nature language text or explanation in your response.
121
122    Global Context:
123    {}
124    ", source_code_context);
125
126    let mut prompt_input = String::new();
127
128    let fn_header = ast.sig.to_token_stream().to_string();
129
130    for attr in ast.attrs {
131        let data = attr.to_token_stream().to_string();
132
133        prompt_input.push_str(&data);
134        prompt_input.push('\n');
135    }
136
137    prompt_input.push_str(&fn_header);
138
139    println!("prompt_input: {}", prompt_input);
140
141    let hash = md5::compute(prompt_input.as_bytes());
142
143    let hash_string = format!("{:x}", hash);
144
145    let existing_resolution = ResolvedSourceCode::load(&hash_string);
146
147    if !is_live {
148        if let Some(resolution) = existing_resolution {
149            return resolution.implementation.parse().unwrap();
150        }
151    }
152
153    let res = generic_chat_completion(system_message, prompt_input.clone()).unwrap();
154
155    println!("res: {:?}", res);
156
157    let body_str = res
158        .choices
159        .first()
160        .unwrap()
161        .message
162        .content
163        .trim()
164        .trim_matches('`')
165        .trim_matches('\'')
166        .to_string()
167        .lines()
168        .skip_while(|line| line.starts_with("rust") || line.starts_with("#["))
169        .collect::<Vec<&str>>()
170        .join("\n");
171
172    let implementation = format!(
173        "{} {{
174            {}
175        }}",
176        prompt_input.clone(),
177        body_str
178    );
179
180    println!("impl:\n {}", implementation);
181
182    let new_resolution = ResolvedSourceCode {
183        implementation: implementation.clone(),
184        hash: format!("{:x}", hash),
185        prompt_input,
186    };
187
188    new_resolution.save();
189
190    implementation.parse().unwrap()
191}
192
193#[proc_macro_attribute]
194pub fn auto_implement(args: TokenStream, input: TokenStream) -> TokenStream {
195    let ast: ItemFn = syn::parse(input).expect("Failed to parse input as a function");
196
197    let context = args.to_string();
198
199    let mut prompt_input = String::new();
200
201    let fn_header = ast.sig.to_token_stream().to_string();
202
203    for attr in ast.attrs {
204        let data = attr.to_token_stream().to_string();
205
206        prompt_input.push_str(&data);
207        prompt_input.push('\n');
208    }
209
210    prompt_input.push_str(&fn_header);
211
212    dotenv().ok();
213
214    let implemented_fn = generate_body_function_from_head(prompt_input, Some(context)).unwrap();
215
216    implemented_fn.parse().unwrap()
217}