1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
mod api;
mod data;
mod generator;

use std::fs::read_to_string;

use api::generic_chat_completion;
use dotenv::dotenv;

use ignore::Walk;
use proc_macro::TokenStream;
// use quote::quote;

use syn::{ItemFn, __private::ToTokens};

// use git2::Repository;

use crate::generator::{generate_body_function_from_head, minimal_llm_function};

/// This macro gets an input like "String, "This is a llm generated function"" and returns a function that returns a String
#[proc_macro]
pub fn auto_generate(item: TokenStream) -> TokenStream {
    dotenv().ok();

    let res = minimal_llm_function(item.to_string());
    // println!("{:?}", res);

    res.parse().unwrap()
}

struct RawSourceCode {
    path: String,
    language: String,
    content: String,
}

#[proc_macro_attribute]
pub fn llm_tool(_args: TokenStream, input: TokenStream) -> TokenStream {
    dotenv().ok();

    let ast: ItemFn = syn::parse(input).expect("Failed to parse input as a function");

    let cargo_toml_path = std::env::var("CARGO_MANIFEST_DIR").unwrap_or("".to_string());

    println!("{:?}", cargo_toml_path);

    let mut source_code = vec![];

    for result in Walk::new(cargo_toml_path) {
        match result {
            Ok(entry) => {
                if entry.path().is_file() {
                    if let Ok(Some(kind)) = hyperpolyglot::detect(entry.path()) {
                        let path = format!("{}", entry.path().display());

                        let content = read_to_string(path.clone()).unwrap();
                        if content.lines().count() > 500 {
                            continue;
                        }

                        println!("{}: {:?}", path, kind);

                        let language = kind.language().to_string();

                        source_code.push(RawSourceCode {
                            path,
                            content,
                            language,
                        });
                    }
                }
            }
            Err(err) => println!("ERROR: {}", err),
        }
    }

    let source_code_context = source_code
        .iter()
        .map(|x| {
            format!(
                "## {}\n```{}\n{}\n```\n",
                x.path,
                x.language.to_lowercase(),
                x.content
            )
        })
        .collect::<Vec<String>>()
        .join("\n");

    // println!("{}", source_code_context);

    let system_message = format!("
    You are an advanced AI, trained on the GPT-4 architecture, with expertise in Rust programming. Your task is to generate the body of a Rust function based on its signature. Please adhere to these guidelines:
    
    1. Receive the Function Signature: The signature will be provided in a standard Rust format, e.g., 'fn calculate_pi_with_n_iterations(n: u64) -> f64'. Focus on understanding the function's name, parameters, and return type.
    2. Generate Only the Function Body: You are required to write Rust code that fulfills the requirements of the function signature. This code should be the function body only, without including the function signature or any other wrapping code.
    3. Exclude Non-Essential Content: Your response must strictly contain valid Rust code applicable within the function's curly braces. Do not include comments, attributes, nested functions, or any redundant repetitions of the function signature. Do not include any explanation or additional text outside of the function body.
    4. Maintain Simplicity and Clarity: Avoid external crates, unnecessary imports, or extra features like feature flags. Use standard Rust libraries and functionalities. The code should be clear, maintainable, and compile-ready.
    5. Adhere to Rust Best Practices: Ensure that the generated code is idiomatic, efficient, and adheres to Rust standards and best practices.
    
    Example:
    INPUT SIGNATURE: 'fn calculate_pi_with_n_iterations(n: u64) -> f64'
    EXPECTED OUTPUT (Function Body Only):
        let mut pi = 0.0;
        let mut sign = 1.0;
        for i in 0..n {{
            pi += sign / (2 * i + 1) as f64;
            sign = -sign;
        }}
        4.0 * pi
    
    Global Context:
    {}
    ", source_code_context);

    let mut prompt_input = String::new();

    let fn_header = ast.sig.to_token_stream().to_string();

    for attr in ast.attrs {
        let data = attr.to_token_stream().to_string();

        prompt_input.push_str(&data);
        prompt_input.push('\n');
    }

    prompt_input.push_str(&fn_header);

    println!("prompt_input: {}", prompt_input);

    let res = generic_chat_completion(system_message, prompt_input.clone()).unwrap();

    println!("res: {:?}", res);

    let body_str = res
        .choices
        .first()
        .unwrap()
        .message
        .content
        .trim()
        .trim_matches('`')
        .to_string()
        .lines()
        .skip_while(|line| line.starts_with("rust") || line.starts_with("#["))
        .collect::<Vec<&str>>()
        .join("\n");

    let implementation = format!(
        "{} {{
            {}
        }}",
        prompt_input, body_str
    );

    println!("impl:\n {}", implementation);

    implementation.parse().unwrap()
}

#[proc_macro_attribute]
pub fn auto_implement(args: TokenStream, input: TokenStream) -> TokenStream {
    let ast: ItemFn = syn::parse(input).expect("Failed to parse input as a function");

    let context = args.to_string();

    let mut prompt_input = String::new();

    let fn_header = ast.sig.to_token_stream().to_string();

    for attr in ast.attrs {
        let data = attr.to_token_stream().to_string();

        prompt_input.push_str(&data);
        prompt_input.push('\n');
    }

    prompt_input.push_str(&fn_header);

    dotenv().ok();

    let implemented_fn = generate_body_function_from_head(prompt_input, Some(context)).unwrap();

    implemented_fn.parse().unwrap()
}