intelli_shell/service/
ai.rs

1use std::{fmt::Write, sync::LazyLock};
2
3use futures_util::{Stream, stream};
4use itertools::Itertools;
5use regex::{Captures, Regex};
6use tokio::io::{AsyncRead, AsyncReadExt};
7use tracing::instrument;
8
9use super::IntelliShellService;
10use crate::{
11    ai::CommandFix,
12    errors::{AppError, Result, UserFacingError},
13    model::{CATEGORY_USER, Command, SOURCE_AI, SearchMode},
14    utils::{
15        add_tags_to_description, execute_shell_command_capture, generate_working_dir_tree, get_executable_version,
16        get_os_info, get_shell_info,
17    },
18};
19
20/// Maximum depth level to include in the working directory tree
21const WD_MAX_DEPTH: usize = 5;
22/// Maximum number of entries displayed on the working directory tree
23const WD_ENTRY_LIMIT: usize = 30;
24
25/// Progress events for AI fix command
26#[derive(Debug)]
27pub enum AiFixProgress {
28    /// The command has already been executed and the AI is now processing the request
29    Thinking,
30}
31
32impl IntelliShellService {
33    /// Tries to fix a failing command by using an AI model.
34    ///
35    /// If the command was successfully executed, this method will return [None].
36    #[instrument(skip_all)]
37    pub async fn fix_command<F>(
38        &self,
39        command: &str,
40        history: Option<&str>,
41        mut on_progress: F,
42    ) -> Result<Option<CommandFix>>
43    where
44        F: FnMut(AiFixProgress),
45    {
46        // Check if ai is enabled
47        if !self.ai.enabled {
48            return Err(UserFacingError::AiRequired.into());
49        }
50
51        // Make sure we've got a command to fix
52        if command.trim().is_empty() {
53            return Err(UserFacingError::AiEmptyCommand.into());
54        }
55
56        // Execute the command and capture its output
57        let (status, output, terminated_by_ctrl_c) = execute_shell_command_capture(command, true).await?;
58
59        // If the command was interrupted by Ctrl+C, skip the fix
60        if terminated_by_ctrl_c {
61            tracing::info!("Command execution was interrupted by user (Ctrl+C), skipping fix");
62            return Ok(None);
63        }
64
65        // If the command succeeded, return without fix
66        if status.success() {
67            tracing::info!("The command to fix was successfully executed, skipping fix");
68            return Ok(None);
69        }
70
71        on_progress(AiFixProgress::Thinking);
72
73        // Prepare prompts and call the AI provider
74        let root_cmd = command.split_whitespace().next();
75        let sys_prompt = replace_prompt_placeholders(&self.ai.prompts.fix, root_cmd, history);
76        let user_prompt = format!(
77            "I've run a command but it failed, help me fix it.\n\ncommand: \
78             {command}\n{status}\noutput:\n```\n{output}\n```"
79        );
80
81        tracing::trace!("System Prompt:\n{sys_prompt}");
82        tracing::trace!("User Prompt:\n{user_prompt}");
83
84        // Call provider
85        let fix = self
86            .ai
87            .fix_client()?
88            .generate_command_fix(&sys_prompt, &user_prompt)
89            .await?;
90
91        Ok(Some(fix))
92    }
93
94    /// Suggest command templates from an user prompt using an AI model
95    #[instrument(skip_all)]
96    pub async fn suggest_commands(&self, prompt: &str) -> Result<Vec<Command>> {
97        // Check if ai is enabled
98        if !self.ai.enabled {
99            return Err(UserFacingError::AiRequired.into());
100        }
101
102        // Prepare prompts and call the AI provider
103        let sys_prompt = replace_prompt_placeholders(&self.ai.prompts.suggest, None, None);
104
105        tracing::trace!("System Prompt:\n{sys_prompt}");
106
107        // Call provider
108        let res = self
109            .ai
110            .suggest_client()?
111            .generate_command_suggestions(&sys_prompt, prompt)
112            .await?;
113
114        Ok(res
115            .suggestions
116            .into_iter()
117            .map(|s| Command::new(CATEGORY_USER, SOURCE_AI, s.command).with_description(Some(s.description)))
118            .collect())
119    }
120
121    /// Suggest a command template from a command and description using an AI model
122    #[instrument(skip_all)]
123    pub async fn suggest_command(&self, cmd: impl AsRef<str>, description: impl AsRef<str>) -> Result<Option<Command>> {
124        // Check if ai is enabled
125        if !self.ai.enabled {
126            return Err(UserFacingError::AiRequired.into());
127        }
128
129        let cmd = Some(cmd.as_ref().trim()).filter(|c| !c.is_empty());
130        let description = Some(description.as_ref().trim()).filter(|d| !d.is_empty());
131
132        // Prepare prompts and call the AI provider
133        let intro = "Output a single suggestion, with just one command template.";
134        let sys_prompt = replace_prompt_placeholders(&self.ai.prompts.suggest, None, None);
135        let user_prompt = match (cmd, description) {
136            (Some(cmd), Some(desc)) => format!("{intro}\nGoal: {desc}\nYou can use this as the base: {cmd}"),
137            (Some(prompt), None) | (None, Some(prompt)) => format!("{intro}\nGoal: {prompt}"),
138            (None, None) => return Ok(None),
139        };
140
141        tracing::trace!("System Prompt:\n{sys_prompt}");
142        tracing::trace!("User Prompt:\n{user_prompt}");
143
144        // Call provider
145        let res = self
146            .ai
147            .suggest_client()?
148            .generate_command_suggestions(&sys_prompt, &user_prompt)
149            .await?;
150
151        Ok(res
152            .suggestions
153            .into_iter()
154            .next()
155            .map(|s| Command::new(CATEGORY_USER, SOURCE_AI, s.command).with_description(Some(s.description))))
156    }
157
158    /// Extracts command templates from a given content using an AI model
159    #[instrument(skip_all)]
160    pub(super) async fn prompt_commands_import(
161        &self,
162        mut content: impl AsyncRead + Unpin + Send,
163        tags: Vec<String>,
164        category: impl Into<String>,
165        source: impl Into<String>,
166    ) -> Result<impl Stream<Item = Result<Command>> + Send + 'static> {
167        // Check if ai is enabled
168        if !self.ai.enabled {
169            return Err(UserFacingError::AiRequired.into());
170        }
171
172        // Read the content
173        let mut prompt = String::new();
174        content.read_to_string(&mut prompt).await?;
175
176        let suggestions = if prompt.is_empty() {
177            Vec::new()
178        } else {
179            // Prepare prompts and call the AI provider
180            let sys_prompt = replace_prompt_placeholders(&self.ai.prompts.import, None, None);
181
182            tracing::trace!("System Prompt:\n{sys_prompt}");
183
184            // Call provider
185            let res = self
186                .ai
187                .suggest_client()?
188                .generate_command_suggestions(&sys_prompt, &prompt)
189                .await?;
190
191            res.suggestions
192        };
193
194        // Return commands
195        let category = category.into();
196        let source = source.into();
197        Ok(stream::iter(
198            suggestions
199                .into_iter()
200                .map(move |s| {
201                    let mut description = s.description;
202                    if !tags.is_empty() {
203                        description = add_tags_to_description(&tags, description);
204                    }
205                    Command::new(category.clone(), source.clone(), s.command).with_description(Some(description))
206                })
207                .map(Ok),
208        ))
209    }
210
211    /// Suggest a command for a dynamic completion using an AI model
212    #[instrument(skip_all)]
213    pub async fn suggest_completion(
214        &self,
215        root_cmd: impl AsRef<str>,
216        variable: impl AsRef<str>,
217        description: impl AsRef<str>,
218    ) -> Result<String> {
219        // Check if ai is enabled
220        if !self.ai.enabled {
221            return Err(UserFacingError::AiRequired.into());
222        }
223
224        // Prepare variables
225        let root_cmd = Some(root_cmd.as_ref().trim()).filter(|c| !c.is_empty());
226        let variable = Some(variable.as_ref().trim()).filter(|v| !v.is_empty());
227        let description = Some(description.as_ref().trim()).filter(|d| !d.is_empty());
228        let Some(variable) = variable else {
229            return Err(UserFacingError::CompletionEmptyVariable.into());
230        };
231
232        // Build a regex to match commands that would use the required completion
233        let escaped_variable = regex::escape(variable);
234        let variable_pattern = format!(r"\{{\{{(?:[^}}]+[|:])?{escaped_variable}(?:[|:][^}}]+)?\}}\}}");
235        let cmd_regex = if let Some(root_cmd) = root_cmd {
236            let root_cmd = regex::escape(root_cmd);
237            format!(r"^{root_cmd}\s.*{variable_pattern}.*$")
238        } else {
239            format!(r"^.*{variable_pattern}.*$")
240        };
241
242        // Find those commands
243        let (commands, _) = self
244            .search_commands(SearchMode::Regex, false, &cmd_regex)
245            .await
246            .map_err(AppError::into_report)?;
247        let commands_str = commands.into_iter().map(|c| c.cmd).join("\n");
248
249        // Prepare prompts and call the AI provider
250        let sys_prompt = replace_prompt_placeholders(&self.ai.prompts.completion, None, None);
251        let mut user_prompt = String::new();
252        writeln!(
253            user_prompt,
254            "Write a shell command that generates completion suggestions for the `{variable}` variable."
255        )
256        .unwrap();
257        if let Some(rc) = root_cmd {
258            writeln!(
259                user_prompt,
260                "This completion will be used only for commands starting with `{rc}`."
261            )
262            .unwrap();
263        }
264        if !commands_str.is_empty() {
265            writeln!(
266                user_prompt,
267                "\nFor context, here are some existing command templates that use this \
268                 variable:\n---\n{commands_str}\n---"
269            )
270            .unwrap();
271        }
272        if let Some(d) = description {
273            writeln!(user_prompt, "\n{d}").unwrap();
274        }
275
276        tracing::trace!("System Prompt:\n{sys_prompt}");
277        tracing::trace!("User Prompt:\n{user_prompt}");
278
279        // Call provider
280        let res = self
281            .ai
282            .completion_client()?
283            .generate_completion_suggestion(&sys_prompt, &user_prompt)
284            .await?;
285
286        Ok(res.command)
287    }
288}
289
290/// Replace placeholders present on the prompt for its value
291fn replace_prompt_placeholders(prompt: &str, root_cmd: Option<&str>, history: Option<&str>) -> String {
292    // Regex to find placeholders like ##VAR_NAME##
293    static PROMPT_PLACEHOLDER_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"##([A-Z_]+)##").unwrap());
294
295    PROMPT_PLACEHOLDER_RE
296        .replace_all(prompt, |caps: &Captures| match &caps[1] {
297            "OS_SHELL_INFO" => {
298                let shell_info = get_shell_info();
299                let os_info = get_os_info();
300                format!(
301                    "### Context:\n- {os_info}\n- {}{}\n",
302                    shell_info
303                        .version
304                        .clone()
305                        .unwrap_or_else(|| shell_info.kind.to_string()),
306                    root_cmd
307                        .and_then(get_executable_version)
308                        .map(|v| format!("\n- {v}"))
309                        .unwrap_or_default(),
310                )
311            }
312            "WORKING_DIR" => generate_working_dir_tree(WD_MAX_DEPTH, WD_ENTRY_LIMIT).unwrap_or_default(),
313            "SHELL_HISTORY" => history
314                .map(|h| format!("### User Shell History (oldest to newest):\n{h}\n"))
315                .unwrap_or_default(),
316            _ => {
317                tracing::warn!("Prompt placeholder '{}' not recognized", &caps[0]);
318                String::default()
319            }
320        })
321        .to_string()
322}