Skip to main content

aster/agents/
execute_commands.rs

1use std::collections::HashMap;
2
3use anyhow::{anyhow, Result};
4
5use crate::context_mgmt::compact_messages;
6use crate::conversation::message::{Message, SystemNotificationType};
7use crate::recipe::build_recipe::build_recipe_from_template_with_positional_params;
8use crate::session::SessionManager;
9
10use super::Agent;
11
12pub const COMPACT_TRIGGERS: &[&str] =
13    &["/compact", "Please compact this conversation", "/summarize"];
14
15pub struct CommandDef {
16    pub name: &'static str,
17    pub description: &'static str,
18}
19
20static COMMANDS: &[CommandDef] = &[
21    CommandDef {
22        name: "prompts",
23        description: "List available prompts, optionally filtered by extension",
24    },
25    CommandDef {
26        name: "prompt",
27        description: "Execute a prompt or show its info with --info",
28    },
29    CommandDef {
30        name: "compact",
31        description: "Compact the conversation history",
32    },
33    CommandDef {
34        name: "clear",
35        description: "Clear the conversation history",
36    },
37];
38
39pub fn list_commands() -> &'static [CommandDef] {
40    COMMANDS
41}
42
43impl Agent {
44    pub async fn execute_command(
45        &self,
46        message_text: &str,
47        session_id: &str,
48    ) -> Result<Option<Message>> {
49        let mut trimmed = message_text.trim().to_string();
50
51        if COMPACT_TRIGGERS.contains(&trimmed.as_str()) {
52            trimmed = COMPACT_TRIGGERS[0].to_string();
53        }
54
55        if !trimmed.starts_with('/') {
56            return Ok(None);
57        }
58
59        let command_str = trimmed.strip_prefix('/').unwrap_or(&trimmed);
60        let (command, params_str) = command_str
61            .split_once(' ')
62            .map(|(cmd, p)| (cmd, p.trim()))
63            .unwrap_or((command_str, ""));
64
65        let params: Vec<&str> = if params_str.is_empty() {
66            vec![]
67        } else {
68            params_str.split_whitespace().collect()
69        };
70
71        match command {
72            "prompts" => self.handle_prompts_command(&params, session_id).await,
73            "prompt" => self.handle_prompt_command(&params, session_id).await,
74            "compact" => self.handle_compact_command(session_id).await,
75            "clear" => self.handle_clear_command(session_id).await,
76            _ => {
77                self.handle_recipe_command(command, params_str, session_id)
78                    .await
79            }
80        }
81    }
82
83    async fn handle_compact_command(&self, session_id: &str) -> Result<Option<Message>> {
84        let session = self.store_get_session(session_id, true).await?;
85        let conversation = session
86            .conversation
87            .ok_or_else(|| anyhow!("Session has no conversation"))?;
88
89        let (compacted_conversation, _usage) = compact_messages(
90            self.provider().await?.as_ref(),
91            &conversation,
92            true, // is_manual_compact
93        )
94        .await?;
95
96        self.store_replace_conversation(session_id, &compacted_conversation)
97            .await?;
98
99        Ok(Some(Message::assistant().with_system_notification(
100            SystemNotificationType::InlineMessage,
101            "Compaction complete",
102        )))
103    }
104
105    async fn handle_clear_command(&self, session_id: &str) -> Result<Option<Message>> {
106        use crate::conversation::Conversation;
107
108        self.store_replace_conversation(session_id, &Conversation::default())
109            .await?;
110
111        if let Some(store) = &self.session_store {
112            use crate::session::TokenStatsUpdate;
113            store
114                .update_token_stats(
115                    session_id,
116                    TokenStatsUpdate {
117                        schedule_id: None,
118                        total_tokens: Some(0),
119                        input_tokens: Some(0),
120                        output_tokens: Some(0),
121                        accumulated_total: None,
122                        accumulated_input: None,
123                        accumulated_output: None,
124                    },
125                )
126                .await?;
127        } else {
128            SessionManager::update_session(session_id)
129                .total_tokens(Some(0))
130                .input_tokens(Some(0))
131                .output_tokens(Some(0))
132                .apply()
133                .await?;
134        }
135
136        Ok(Some(Message::assistant().with_system_notification(
137            SystemNotificationType::InlineMessage,
138            "Conversation cleared",
139        )))
140    }
141
142    async fn handle_prompts_command(
143        &self,
144        params: &[&str],
145        _session_id: &str,
146    ) -> Result<Option<Message>> {
147        let extension_filter = params.first().map(|s| s.to_string());
148
149        let prompts = self.list_extension_prompts().await;
150
151        if let Some(filter) = &extension_filter {
152            if !prompts.contains_key(filter) {
153                let error_msg = format!("Extension '{}' not found", filter);
154                return Ok(Some(Message::assistant().with_text(error_msg)));
155            }
156        }
157
158        let filtered_prompts: HashMap<String, Vec<String>> = prompts
159            .into_iter()
160            .filter(|(ext, _)| extension_filter.as_ref().is_none_or(|f| f == ext))
161            .map(|(extension, prompt_list)| {
162                let names = prompt_list.into_iter().map(|p| p.name).collect();
163                (extension, names)
164            })
165            .collect();
166
167        let mut output = String::new();
168        if filtered_prompts.is_empty() {
169            output.push_str("No prompts available.\n");
170        } else {
171            output.push_str("Available prompts:\n\n");
172            for (extension, prompt_names) in filtered_prompts {
173                output.push_str(&format!("**{}**:\n", extension));
174                for name in prompt_names {
175                    output.push_str(&format!("  - {}\n", name));
176                }
177                output.push('\n');
178            }
179        }
180
181        Ok(Some(Message::assistant().with_text(output)))
182    }
183
184    async fn handle_prompt_command(
185        &self,
186        params: &[&str],
187        session_id: &str,
188    ) -> Result<Option<Message>> {
189        if params.is_empty() {
190            return Ok(Some(
191                Message::assistant().with_text("Prompt name argument is required"),
192            ));
193        }
194
195        let prompt_name = params[0].to_string();
196        let is_info = params.get(1).map(|s| *s == "--info").unwrap_or(false);
197
198        if is_info {
199            let prompts = self.list_extension_prompts().await;
200            let mut prompt_info = None;
201
202            for (extension, prompt_list) in prompts {
203                if let Some(prompt) = prompt_list.iter().find(|p| p.name == prompt_name) {
204                    let mut output = format!("**Prompt: {}**\n\n", prompt.name);
205                    if let Some(desc) = &prompt.description {
206                        output.push_str(&format!("Description: {}\n\n", desc));
207                    }
208                    output.push_str(&format!("Extension: {}\n\n", extension));
209
210                    if let Some(args) = &prompt.arguments {
211                        output.push_str("Arguments:\n");
212                        for arg in args {
213                            output.push_str(&format!("  - {}", arg.name));
214                            if let Some(desc) = &arg.description {
215                                output.push_str(&format!(": {}", desc));
216                            }
217                            output.push('\n');
218                        }
219                    }
220
221                    prompt_info = Some(output);
222                    break;
223                }
224            }
225
226            return Ok(Some(Message::assistant().with_text(
227                prompt_info.unwrap_or_else(|| format!("Prompt '{}' not found", prompt_name)),
228            )));
229        }
230
231        let mut arguments = HashMap::new();
232        for param in params.iter().skip(1) {
233            if let Some((key, value)) = param.split_once('=') {
234                let value = value.trim_matches('"');
235                arguments.insert(key.to_string(), value.to_string());
236            }
237        }
238
239        let arguments_value = serde_json::to_value(arguments)
240            .map_err(|e| anyhow!("Failed to serialize arguments: {}", e))?;
241
242        match self.get_prompt(&prompt_name, arguments_value).await {
243            Ok(prompt_result) => {
244                for (i, prompt_message) in prompt_result.messages.into_iter().enumerate() {
245                    let msg = Message::from(prompt_message);
246
247                    let expected_role = if i % 2 == 0 {
248                        rmcp::model::Role::User
249                    } else {
250                        rmcp::model::Role::Assistant
251                    };
252
253                    if msg.role != expected_role {
254                        let error_msg = format!(
255                            "Expected {:?} message at position {}, but found {:?}",
256                            expected_role, i, msg.role
257                        );
258                        return Ok(Some(Message::assistant().with_text(error_msg)));
259                    }
260
261                    self.store_add_message(session_id, &msg).await?;
262                }
263
264                let last_message = self
265                    .store_get_session(session_id, true)
266                    .await?
267                    .conversation
268                    .ok_or_else(|| anyhow!("No conversation found"))?
269                    .messages()
270                    .last()
271                    .cloned()
272                    .ok_or_else(|| anyhow!("No messages in conversation"))?;
273
274                Ok(Some(last_message))
275            }
276            Err(e) => Ok(Some(
277                Message::assistant().with_text(format!("Error getting prompt: {}", e)),
278            )),
279        }
280    }
281
282    async fn handle_recipe_command(
283        &self,
284        command: &str,
285        params_str: &str,
286        _session_id: &str,
287    ) -> Result<Option<Message>> {
288        let full_command = format!("/{}", command);
289        let recipe_path = match crate::slash_commands::get_recipe_for_command(&full_command) {
290            Some(path) => path,
291            None => return Ok(None),
292        };
293
294        if !recipe_path.exists() {
295            return Ok(None);
296        }
297
298        let recipe_content = std::fs::read_to_string(&recipe_path)
299            .map_err(|e| anyhow!("Failed to read recipe file: {}", e))?;
300
301        let recipe_dir = recipe_path
302            .parent()
303            .ok_or_else(|| anyhow!("Recipe path has no parent directory"))?;
304
305        let recipe_dir_str = recipe_dir.display().to_string();
306        let validation_result =
307            crate::recipe::validate_recipe::validate_recipe_template_from_content(
308                &recipe_content,
309                Some(recipe_dir_str),
310            )
311            .map_err(|e| anyhow!("Failed to parse recipe: {}", e))?;
312
313        let param_values: Vec<String> = if params_str.is_empty() {
314            vec![]
315        } else {
316            let params_without_default = validation_result
317                .parameters
318                .as_ref()
319                .map(|params| params.iter().filter(|p| p.default.is_none()).count())
320                .unwrap_or(0);
321
322            if params_without_default <= 1 {
323                vec![params_str.to_string()]
324            } else {
325                let param_names: Vec<String> = validation_result
326                    .parameters
327                    .as_ref()
328                    .map(|params| {
329                        params
330                            .iter()
331                            .filter(|p| p.default.is_none())
332                            .map(|p| p.key.clone())
333                            .collect()
334                    })
335                    .unwrap_or_default();
336
337                let error_message = format!(
338                    "The /{} recipe requires {} parameters: {}.\n\n\
339                    Slash command recipes only support 1 parameter.\n\n\
340                    **To use this recipe:**\n\
341                    • **CLI:** `aster run --recipe {} {}`\n\
342                    • **Desktop:** Launch from the recipes sidebar to fill in parameters",
343                    command,
344                    params_without_default,
345                    param_names
346                        .iter()
347                        .map(|name| format!("**{}**", name))
348                        .collect::<Vec<_>>()
349                        .join(", "),
350                    command,
351                    param_names
352                        .iter()
353                        .map(|name| format!("--params {}=\"...\"", name))
354                        .collect::<Vec<_>>()
355                        .join(" ")
356                );
357
358                return Err(anyhow!(error_message));
359            }
360        };
361
362        let param_values_len = param_values.len();
363
364        let recipe = match build_recipe_from_template_with_positional_params(
365            recipe_content,
366            recipe_dir,
367            param_values,
368            None::<fn(&str, &str) -> Result<String>>,
369        ) {
370            Ok(recipe) => recipe,
371            Err(crate::recipe::build_recipe::RecipeError::MissingParams { parameters }) => {
372                return Ok(Some(Message::assistant().with_text(format!(
373                    "Recipe requires {} parameter(s): {}. Provided: {}",
374                    parameters.len(),
375                    parameters.join(", "),
376                    param_values_len
377                ))));
378            }
379            Err(e) => return Err(anyhow!("Failed to build recipe: {}", e)),
380        };
381
382        let prompt = [recipe.instructions.as_deref(), recipe.prompt.as_deref()]
383            .into_iter()
384            .flatten()
385            .collect::<Vec<_>>()
386            .join("\n\n");
387
388        Ok(Some(Message::user().with_text(prompt)))
389    }
390}