Skip to main content

agent/
agent.rs

1use chrono::Local;
2use glob::glob;
3use regex::Regex;
4use rustyline::DefaultEditor;
5use rustyline::error::ReadlineError;
6use serde::{Deserialize, Serialize};
7use std::env;
8use std::fs;
9use std::io::{self, Write};
10use std::path::PathBuf;
11use std::process::Command;
12use std::sync::{
13    Arc,
14    atomic::{AtomicBool, Ordering},
15};
16use std::thread;
17use std::time::Duration;
18use walkdir::WalkDir;
19
20use menta::{GenerateTextRequest, ModelMessage, Tool, ToolChoice, ToolExecute, generate_text};
21
22const RESET: &str = "\x1b[0m";
23const DIM: &str = "\x1b[2m";
24const BLUE: &str = "\x1b[34m";
25const CYAN: &str = "\x1b[36m";
26const GREEN: &str = "\x1b[32m";
27const YELLOW: &str = "\x1b[33m";
28const RED: &str = "\x1b[31m";
29
30fn paint(color: &str, value: impl AsRef<str>) -> String {
31    format!("{color}{}{RESET}", value.as_ref())
32}
33
34struct SpinnerState {
35    active: AtomicBool,
36    paused: AtomicBool,
37}
38
39struct LoadingIndicator {
40    state: Arc<SpinnerState>,
41    handle: Option<thread::JoinHandle<()>>,
42}
43
44impl LoadingIndicator {
45    fn start(message: &'static str) -> Self {
46        let state = Arc::new(SpinnerState {
47            active: AtomicBool::new(true),
48            paused: AtomicBool::new(false),
49        });
50        let state_for_thread = Arc::clone(&state);
51
52        let handle = thread::spawn(move || {
53            let frames = ['|', '/', '-', '\\'];
54            let mut index = 0;
55
56            while state_for_thread.active.load(Ordering::Relaxed) {
57                if state_for_thread.paused.load(Ordering::Relaxed) {
58                    thread::sleep(Duration::from_millis(25));
59                    continue;
60                }
61
62                print!(
63                    "\r{} {}",
64                    paint(CYAN, frames[index % frames.len()].to_string()),
65                    paint(DIM, message),
66                );
67                let _ = io::stdout().flush();
68                index += 1;
69                thread::sleep(Duration::from_millis(100));
70            }
71        });
72
73        Self {
74            state,
75            handle: Some(handle),
76        }
77    }
78
79    fn state(&self) -> Arc<SpinnerState> {
80        Arc::clone(&self.state)
81    }
82
83    fn stop(mut self) {
84        self.state.active.store(false, Ordering::Relaxed);
85        if let Some(handle) = self.handle.take() {
86            let _ = handle.join();
87        }
88        print!("\r{}\r", " ".repeat(40));
89        let _ = io::stdout().flush();
90    }
91}
92
93struct SpinnerPause {
94    state: Arc<SpinnerState>,
95}
96
97impl SpinnerPause {
98    fn new(state: Arc<SpinnerState>) -> Self {
99        state.paused.store(true, Ordering::Relaxed);
100        print!("\r{}\r", " ".repeat(60));
101        let _ = io::stdout().flush();
102        Self { state }
103    }
104}
105
106impl Drop for SpinnerPause {
107    fn drop(&mut self) {
108        self.state.paused.store(false, Ordering::Relaxed);
109    }
110}
111
112fn logged_tool<T>(verbose: bool, spinner_state: Arc<SpinnerState>) -> Tool
113where
114    T: ToolExecute,
115{
116    let definition = T::definition();
117    let name = definition.name.clone();
118
119    Tool::new_async(
120        definition.name,
121        definition.description,
122        definition.input_schema,
123        definition.output_schema,
124        move |input| {
125            let name = name.clone();
126            let spinner_state = Arc::clone(&spinner_state);
127            async move {
128                let _pause = SpinnerPause::new(spinner_state);
129                println!("{} {} {}", paint(YELLOW, "tool>"), paint(YELLOW, &name), paint(DIM, &input));
130
131                let parsed = serde_json::from_str::<T>(&input)
132                    .map_err(|error| format!("invalid tool input for {name}: {error}"))?;
133                let output = parsed.execute().await?;
134                let output = serde_json::to_string(&output)
135                    .map_err(|error| format!("invalid tool output for {name}: {error}"))?;
136
137                if verbose {
138                    println!("{} {} {}", paint(YELLOW, "tool<"), paint(YELLOW, &name), paint(DIM, &output));
139                }
140                Ok(output)
141            }
142        },
143    )
144}
145
146#[derive(Deserialize, Tool)]
147#[tool(description = "Get the weather in a location")]
148struct WeatherTool {
149    #[description = "The location to get the weather for"]
150    location: String,
151}
152
153#[derive(Serialize)]
154struct WeatherOutput {
155    temperature: i32,
156    conditions: String,
157}
158
159impl ToolExecute for WeatherTool {
160    type Output = WeatherOutput;
161
162    async fn execute(&self) -> std::result::Result<Self::Output, String> {
163        Ok(WeatherOutput {
164            temperature: 72,
165            conditions: format!("sunny in {}", self.location),
166        })
167    }
168}
169
170#[derive(Deserialize, Tool)]
171#[tool(description = "Run a local shell command")]
172struct BashTool {
173    #[description = "The shell command to run"]
174    command: String,
175}
176
177#[derive(Serialize)]
178struct BashOutput {
179    success: bool,
180    status: i32,
181    stdout: String,
182    stderr: String,
183}
184
185impl ToolExecute for BashTool {
186    type Output = BashOutput;
187
188    async fn execute(&self) -> std::result::Result<Self::Output, String> {
189        let output = Command::new("sh")
190            .arg("-c")
191            .arg(&self.command)
192            .output()
193            .map_err(|error| error.to_string())?;
194
195        Ok(BashOutput {
196            success: output.status.success(),
197            status: output.status.code().unwrap_or(-1),
198            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
199            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
200        })
201    }
202}
203
204#[derive(Deserialize, Tool)]
205#[tool(description = "Write content to a local file")]
206struct WriteTool {
207    #[description = "Path to the file to write"]
208    path: String,
209    #[description = "Full file content to write"]
210    content: String,
211}
212
213#[derive(Serialize)]
214struct WriteOutput {
215    path: String,
216    bytes_written: usize,
217}
218
219impl ToolExecute for WriteTool {
220    type Output = WriteOutput;
221
222    async fn execute(&self) -> std::result::Result<Self::Output, String> {
223        if let Some(parent) = PathBuf::from(&self.path).parent() {
224            if !parent.as_os_str().is_empty() {
225                fs::create_dir_all(parent).map_err(|error| error.to_string())?;
226            }
227        }
228
229        fs::write(&self.path, &self.content).map_err(|error| error.to_string())?;
230
231        Ok(WriteOutput {
232            path: self.path.clone(),
233            bytes_written: self.content.len(),
234        })
235    }
236}
237
238#[derive(Deserialize, Tool)]
239#[tool(description = "Read a local file")]
240struct ReadTool {
241    #[description = "Path to the file to read"]
242    path: String,
243}
244
245#[derive(Serialize)]
246struct ReadOutput {
247    path: String,
248    content: String,
249}
250
251impl ToolExecute for ReadTool {
252    type Output = ReadOutput;
253
254    async fn execute(&self) -> std::result::Result<Self::Output, String> {
255        let content = fs::read_to_string(&self.path).map_err(|error| error.to_string())?;
256
257        Ok(ReadOutput {
258            path: self.path.clone(),
259            content,
260        })
261    }
262}
263
264#[derive(Deserialize, Tool)]
265#[tool(description = "Find files by glob pattern")]
266struct GlobTool {
267    #[description = "Glob pattern like **/*.rs or src/**/*.rs"]
268    pattern: String,
269}
270
271#[derive(Serialize)]
272struct GlobOutput {
273    matches: Vec<String>,
274}
275
276impl ToolExecute for GlobTool {
277    type Output = GlobOutput;
278
279    async fn execute(&self) -> std::result::Result<Self::Output, String> {
280        let matches = glob(&self.pattern)
281            .map_err(|error| error.to_string())?
282            .filter_map(|entry| entry.ok())
283            .map(|path| path.display().to_string())
284            .collect::<Vec<_>>();
285
286        Ok(GlobOutput { matches })
287    }
288}
289
290#[derive(Deserialize, Tool)]
291#[tool(description = "Search file contents with a regex pattern")]
292struct GrepTool {
293    #[description = "Regex pattern to search for"]
294    pattern: String,
295    #[description = "Root path to search from, defaults to current directory"]
296    path: Option<String>,
297}
298
299#[derive(Serialize)]
300struct GrepOutput {
301    matches: Vec<String>,
302}
303
304impl ToolExecute for GrepTool {
305    type Output = GrepOutput;
306
307    async fn execute(&self) -> std::result::Result<Self::Output, String> {
308        let regex = Regex::new(&self.pattern).map_err(|error| error.to_string())?;
309        let root = self.path.as_deref().unwrap_or(".");
310        let mut matches = Vec::new();
311
312        for entry in WalkDir::new(root) {
313            let entry = entry.map_err(|error| error.to_string())?;
314            if !entry.file_type().is_file() {
315                continue;
316            }
317
318            let path = entry.path();
319            let Ok(content) = fs::read_to_string(path) else {
320                continue;
321            };
322
323            for (index, line) in content.lines().enumerate() {
324                if regex.is_match(line) {
325                    matches.push(format!("{}:{}:{}", path.display(), index + 1, line));
326                }
327            }
328        }
329
330        Ok(GrepOutput { matches })
331    }
332}
333
334#[tokio::main]
335async fn main() {
336    let verbose = env::args().skip(1).any(|arg| arg == "--verbose");
337
338    println!("{}", paint(BLUE, "menta agent example"));
339    println!("{}", paint(DIM, "commands: /new, /model [model-id], exit, quit, !<shell-command>"));
340    if verbose {
341        println!("{}", paint(DIM, "verbose tool logging enabled"));
342    }
343
344    let mut model = String::from("openai/gpt-4.1-mini");
345    let mut history = vec![ModelMessage::system(system_prompt())];
346    let mut editor = DefaultEditor::new().expect("failed to initialize line editor");
347
348    loop {
349        let prompt = match editor.readline(&format!("[{model}] > ")) {
350            Ok(prompt) => prompt,
351            Err(ReadlineError::Interrupted) => {
352                println!();
353                continue;
354            }
355            Err(ReadlineError::Eof) => break,
356            Err(error) => {
357                eprintln!("error: failed to read input: {error}");
358                break;
359            }
360        };
361
362        let prompt = prompt.trim();
363        if prompt.is_empty() {
364            continue;
365        }
366
367        let _ = editor.add_history_entry(prompt);
368
369        if matches!(prompt, "exit" | "quit") {
370            break;
371        }
372
373        if prompt == "/new" {
374            history = vec![ModelMessage::system(system_prompt())];
375            println!("started new session");
376            continue;
377        }
378
379        if let Some(next_model) = prompt.strip_prefix("/model") {
380            let next_model = next_model.trim();
381            if next_model.is_empty() {
382                println!("{} {}", paint(BLUE, "current model:"), paint(CYAN, &model));
383            } else {
384                model = next_model.to_string();
385                println!("{} {}", paint(BLUE, "switched model to"), paint(CYAN, &model));
386            }
387            continue;
388        }
389
390        if let Some(command) = prompt.strip_prefix('!') {
391            let command = command.trim();
392            if command.is_empty() {
393                eprintln!("{} missing shell command after !", paint(RED, "error:"));
394                continue;
395            }
396
397            match Command::new("sh").arg("-c").arg(command).output() {
398                Ok(output) => {
399                    let stdout = String::from_utf8_lossy(&output.stdout);
400                    let stderr = String::from_utf8_lossy(&output.stderr);
401
402                    if !stdout.trim().is_empty() {
403                        print!("{stdout}");
404                    }
405
406                    if !stderr.trim().is_empty() {
407                        eprint!("{stderr}");
408                    }
409
410                    if !output.status.success() {
411                        eprintln!("{} command exited with status {}", paint(RED, "error:"), output.status);
412                    }
413                }
414                Err(error) => eprintln!("{} failed to run command: {error}", paint(RED, "error:")),
415            }
416            continue;
417        }
418
419        let mut messages = history.clone();
420        messages.push(ModelMessage::user(prompt));
421
422        let loading = LoadingIndicator::start("thinking...");
423        let spinner_state = loading.state();
424
425        let result = generate_text(
426            GenerateTextRequest::new()
427                .model(model.clone())
428                .messages(messages)
429                .tools(vec![
430                    logged_tool::<WeatherTool>(verbose, Arc::clone(&spinner_state)),
431                    logged_tool::<BashTool>(verbose, Arc::clone(&spinner_state)),
432                    logged_tool::<WriteTool>(verbose, Arc::clone(&spinner_state)),
433                    logged_tool::<ReadTool>(verbose, Arc::clone(&spinner_state)),
434                    logged_tool::<GlobTool>(verbose, Arc::clone(&spinner_state)),
435                    logged_tool::<GrepTool>(verbose, Arc::clone(&spinner_state)),
436                ])
437                .tool_choice(ToolChoice::Auto)
438                .max_steps(4),
439        )
440        .await;
441
442        loading.stop();
443
444        match result {
445            Ok(result) => {
446                println!("{} {}", paint(GREEN, "assistant>"), result.output);
447                history.push(ModelMessage::user(prompt));
448                history.push(ModelMessage::assistant_text(result.text));
449            }
450            Err(error) => eprintln!("{} {error}", paint(RED, "error:")),
451        }
452    }
453}
454
455fn system_prompt() -> String {
456    let date = Local::now().format("%Y-%m-%d").to_string();
457    let pwd = env::current_dir()
458        .map(|path| path.display().to_string())
459        .unwrap_or_else(|_| String::from("unknown"));
460
461    format!(
462        "You are a local coding assistant. Current date: {date}. Current working directory: {pwd}. Use tools when needed: weather, bash, write, read, glob, grep. Prefer tools for filesystem and shell tasks."
463    )
464}