intelli_shell/service/
ai.rs

1use std::{fmt::Write, sync::LazyLock};
2
3use futures_util::{Stream, stream};
4use itertools::Itertools;
5use regex::{Captures, Regex};
6use tokio::io::{AsyncRead, AsyncReadExt};
7use tokio_util::sync::CancellationToken;
8use tracing::instrument;
9
10use super::IntelliShellService;
11use crate::{
12    ai::CommandFix,
13    errors::{AppError, Result, UserFacingError},
14    model::{CATEGORY_USER, Command, SOURCE_AI, SearchMode},
15    utils::{
16        add_tags_to_description, execute_shell_command_capture, generate_working_dir_tree, get_executable_version,
17        get_os_info, get_shell_info,
18    },
19};
20
21/// Maximum depth level to include in the working directory tree
22const WD_MAX_DEPTH: usize = 5;
23/// Maximum number of entries displayed on the working directory tree
24const WD_ENTRY_LIMIT: usize = 30;
25
26/// Progress events for AI fix command
27#[derive(Debug)]
28pub enum AiFixProgress {
29    /// The command has already been executed and the AI is now processing the request
30    Thinking,
31}
32
33impl IntelliShellService {
34    /// Tries to fix a failing command by using an AI model.
35    ///
36    /// If the command was successfully executed, this method will return [None].
37    #[instrument(skip_all)]
38    pub async fn fix_command<F>(
39        &self,
40        command: &str,
41        history: Option<&str>,
42        mut on_progress: F,
43        cancellation_token: CancellationToken,
44    ) -> Result<Option<CommandFix>>
45    where
46        F: FnMut(AiFixProgress),
47    {
48        // Check if ai is enabled
49        if !self.ai.enabled {
50            return Err(UserFacingError::AiRequired.into());
51        }
52
53        // Make sure we've got a command to fix
54        if command.trim().is_empty() {
55            return Err(UserFacingError::AiEmptyCommand.into());
56        }
57
58        // Execute the command and capture its output
59        let (status, output, terminated_by_ctrl_c) =
60            execute_shell_command_capture(command, true, cancellation_token.clone()).await?;
61
62        // If the command was interrupted by Ctrl+C, skip the fix
63        if terminated_by_ctrl_c {
64            tracing::info!("Command execution was interrupted by user (Ctrl+C), skipping fix");
65            return Ok(None);
66        }
67
68        // If the command succeeded, return without fix
69        if status.success() {
70            tracing::info!("The command to fix was successfully executed, skipping fix");
71            return Ok(None);
72        }
73
74        on_progress(AiFixProgress::Thinking);
75
76        // Prepare prompts and call the AI provider
77        let root_cmd = command.split_whitespace().next();
78        let sys_prompt = replace_prompt_placeholders(&self.ai.prompts.fix, root_cmd, history);
79        let user_prompt = format!(
80            "I've run a command but it failed, help me fix it.\n\ncommand: \
81             {command}\n{status}\noutput:\n```\n{output}\n```"
82        );
83
84        tracing::trace!("System Prompt:\n{sys_prompt}");
85        tracing::trace!("User Prompt:\n{user_prompt}");
86
87        // Call provider
88        let fix = self
89            .ai
90            .fix_client()?
91            .generate_command_fix(&sys_prompt, &user_prompt, cancellation_token)
92            .await?;
93
94        Ok(Some(fix))
95    }
96
97    /// Suggest command templates from an user prompt using an AI model
98    #[instrument(skip_all)]
99    pub async fn suggest_commands(&self, prompt: &str, cancellation_token: CancellationToken) -> Result<Vec<Command>> {
100        // Check if ai is enabled
101        if !self.ai.enabled {
102            return Err(UserFacingError::AiRequired.into());
103        }
104
105        // Prepare prompts and call the AI provider
106        let sys_prompt = replace_prompt_placeholders(&self.ai.prompts.suggest, None, None);
107
108        tracing::trace!("System Prompt:\n{sys_prompt}");
109
110        // Call provider
111        let res = self
112            .ai
113            .suggest_client()?
114            .generate_command_suggestions(&sys_prompt, prompt, cancellation_token)
115            .await?;
116
117        Ok(res
118            .suggestions
119            .into_iter()
120            .map(|s| Command::new(CATEGORY_USER, SOURCE_AI, s.command).with_description(Some(s.description)))
121            .collect())
122    }
123
124    /// Suggest a command template from a command and description using an AI model
125    #[instrument(skip_all)]
126    pub async fn suggest_command(
127        &self,
128        cmd: impl AsRef<str>,
129        description: impl AsRef<str>,
130        cancellation_token: CancellationToken,
131    ) -> Result<Option<Command>> {
132        // Check if ai is enabled
133        if !self.ai.enabled {
134            return Err(UserFacingError::AiRequired.into());
135        }
136
137        let cmd = Some(cmd.as_ref().trim()).filter(|c| !c.is_empty());
138        let description = Some(description.as_ref().trim()).filter(|d| !d.is_empty());
139
140        // Prepare prompts and call the AI provider
141        let intro = "Output a single suggestion, with just one command template.";
142        let sys_prompt = replace_prompt_placeholders(&self.ai.prompts.suggest, None, None);
143        let user_prompt = match (cmd, description) {
144            (Some(cmd), Some(desc)) => format!("{intro}\nGoal: {desc}\nYou can use this as the base: {cmd}"),
145            (Some(prompt), None) | (None, Some(prompt)) => format!("{intro}\nGoal: {prompt}"),
146            (None, None) => return Ok(None),
147        };
148
149        tracing::trace!("System Prompt:\n{sys_prompt}");
150        tracing::trace!("User Prompt:\n{user_prompt}");
151
152        // Call provider
153        let res = self
154            .ai
155            .suggest_client()?
156            .generate_command_suggestions(&sys_prompt, &user_prompt, cancellation_token)
157            .await?;
158
159        Ok(res
160            .suggestions
161            .into_iter()
162            .next()
163            .map(|s| Command::new(CATEGORY_USER, SOURCE_AI, s.command).with_description(Some(s.description))))
164    }
165
166    /// Extracts command templates from a given content using an AI model
167    #[instrument(skip_all)]
168    pub(super) async fn prompt_commands_import(
169        &self,
170        mut content: impl AsyncRead + Unpin + Send,
171        tags: Vec<String>,
172        category: impl Into<String>,
173        source: impl Into<String>,
174        cancellation_token: CancellationToken,
175    ) -> Result<impl Stream<Item = Result<Command>> + Send + 'static> {
176        // Check if ai is enabled
177        if !self.ai.enabled {
178            return Err(UserFacingError::AiRequired.into());
179        }
180
181        // Read the content
182        let mut prompt = String::new();
183        content.read_to_string(&mut prompt).await?;
184
185        let suggestions = if prompt.is_empty() {
186            Vec::new()
187        } else {
188            // Prepare prompts and call the AI provider
189            let sys_prompt = replace_prompt_placeholders(&self.ai.prompts.import, None, None);
190
191            tracing::trace!("System Prompt:\n{sys_prompt}");
192
193            // Call provider
194            let res = self
195                .ai
196                .suggest_client()?
197                .generate_command_suggestions(&sys_prompt, &prompt, cancellation_token)
198                .await?;
199
200            res.suggestions
201        };
202
203        // Return commands
204        let category = category.into();
205        let source = source.into();
206        Ok(stream::iter(
207            suggestions
208                .into_iter()
209                .map(move |s| {
210                    let mut description = s.description;
211                    if !tags.is_empty() {
212                        description = add_tags_to_description(&tags, description);
213                    }
214                    Command::new(category.clone(), source.clone(), s.command).with_description(Some(description))
215                })
216                .map(Ok),
217        ))
218    }
219
220    /// Suggest a command for a dynamic completion using an AI model
221    #[instrument(skip_all)]
222    pub async fn suggest_completion(
223        &self,
224        root_cmd: impl AsRef<str>,
225        variable: impl AsRef<str>,
226        description: impl AsRef<str>,
227        cancellation_token: CancellationToken,
228    ) -> Result<String> {
229        // Check if ai is enabled
230        if !self.ai.enabled {
231            return Err(UserFacingError::AiRequired.into());
232        }
233
234        // Prepare variables
235        let root_cmd = Some(root_cmd.as_ref().trim()).filter(|c| !c.is_empty());
236        let variable = Some(variable.as_ref().trim()).filter(|v| !v.is_empty());
237        let description = Some(description.as_ref().trim()).filter(|d| !d.is_empty());
238        let Some(variable) = variable else {
239            return Err(UserFacingError::CompletionEmptyVariable.into());
240        };
241
242        // Build a regex to match commands that would use the required completion
243        let escaped_variable = regex::escape(variable);
244        let variable_pattern = format!(r"\{{\{{(?:[^}}]+[|:])?{escaped_variable}(?:[|:][^}}]+)?\}}\}}");
245        let cmd_regex = if let Some(root_cmd) = root_cmd {
246            let root_cmd = regex::escape(root_cmd);
247            format!(r"^{root_cmd}\s.*{variable_pattern}.*$")
248        } else {
249            format!(r"^.*{variable_pattern}.*$")
250        };
251
252        // Find those commands
253        let (commands, _) = self
254            .search_commands(SearchMode::Regex, false, &cmd_regex)
255            .await
256            .map_err(AppError::into_report)?;
257        let commands_str = commands.into_iter().map(|c| c.cmd).join("\n");
258
259        // Prepare prompts and call the AI provider
260        let sys_prompt = replace_prompt_placeholders(&self.ai.prompts.completion, None, None);
261        let mut user_prompt = String::new();
262        writeln!(
263            user_prompt,
264            "Write a shell command that generates completion suggestions for the `{variable}` variable."
265        )
266        .unwrap();
267        if let Some(rc) = root_cmd {
268            writeln!(
269                user_prompt,
270                "This completion will be used only for commands starting with `{rc}`."
271            )
272            .unwrap();
273        }
274        if !commands_str.is_empty() {
275            writeln!(
276                user_prompt,
277                "\nFor context, here are some existing command templates that use this \
278                 variable:\n---\n{commands_str}\n---"
279            )
280            .unwrap();
281        }
282        if let Some(d) = description {
283            writeln!(user_prompt, "\n{d}").unwrap();
284        }
285
286        tracing::trace!("System Prompt:\n{sys_prompt}");
287        tracing::trace!("User Prompt:\n{user_prompt}");
288
289        // Call provider
290        let res = self
291            .ai
292            .completion_client()?
293            .generate_completion_suggestion(&sys_prompt, &user_prompt, cancellation_token)
294            .await?;
295
296        Ok(res.command)
297    }
298}
299
300/// Replace placeholders present on the prompt for its value
301fn replace_prompt_placeholders(prompt: &str, root_cmd: Option<&str>, history: Option<&str>) -> String {
302    // Regex to find placeholders like ##VAR_NAME##
303    static PROMPT_PLACEHOLDER_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"##([A-Z_]+)##").unwrap());
304
305    PROMPT_PLACEHOLDER_RE
306        .replace_all(prompt, |caps: &Captures| match &caps[1] {
307            "OS_SHELL_INFO" => {
308                let shell_info = get_shell_info();
309                let os_info = get_os_info();
310                format!(
311                    "### Context:\n- {os_info}\n- {}{}\n",
312                    shell_info
313                        .version
314                        .clone()
315                        .unwrap_or_else(|| shell_info.kind.to_string()),
316                    root_cmd
317                        .and_then(get_executable_version)
318                        .map(|v| format!("\n- {v}"))
319                        .unwrap_or_default(),
320                )
321            }
322            "WORKING_DIR" => generate_working_dir_tree(WD_MAX_DEPTH, WD_ENTRY_LIMIT).unwrap_or_default(),
323            "SHELL_HISTORY" => history
324                .map(|h| format!("### User Shell History (oldest to newest):\n{h}\n"))
325                .unwrap_or_default(),
326            _ => {
327                tracing::warn!("Prompt placeholder '{}' not recognized", &caps[0]);
328                String::default()
329            }
330        })
331        .to_string()
332}