Skip to main content

aster_cli/session/
mod.rs

1mod builder;
2mod completion;
3mod elicitation;
4mod export;
5mod input;
6mod output;
7mod prompt;
8mod task_execution_display;
9mod thinking;
10
11use crate::session::task_execution_display::{
12    format_task_execution_notification, TASK_EXECUTION_NOTIFICATION_TYPE,
13};
14use aster::conversation::Conversation;
15use std::io::Write;
16use std::str::FromStr;
17use tokio::signal::ctrl_c;
18use tokio_util::task::AbortOnDropHandle;
19
20pub use self::export::message_to_markdown;
21use aster::agents::AgentEvent;
22use aster::permission::permission_confirmation::PrincipalType;
23use aster::permission::Permission;
24use aster::permission::PermissionConfirmation;
25use aster::providers::base::Provider;
26use aster::utils::safe_truncate;
27pub use builder::{build_session, SessionBuilderConfig, SessionSettings};
28use console::Color;
29
30use anyhow::{Context, Result};
31use aster::agents::extension::{Envs, ExtensionConfig, PLATFORM_EXTENSIONS};
32use aster::agents::types::RetryConfig;
33use aster::agents::{Agent, SessionConfig, COMPACT_TRIGGERS};
34use aster::config::{AsterMode, Config};
35use aster::session::SessionManager;
36use completion::AsterCompleter;
37use input::InputResult;
38use rmcp::model::PromptMessage;
39use rmcp::model::ServerNotification;
40use rmcp::model::{ErrorCode, ErrorData};
41
42use aster::config::paths::Paths;
43use aster::conversation::message::{ActionRequiredData, Message, MessageContent};
44use rustyline::EditMode;
45use serde::{Deserialize, Serialize};
46use serde_json::Value;
47use std::collections::HashMap;
48use std::path::PathBuf;
49use std::sync::Arc;
50use std::time::Instant;
51use tokio;
52use tokio_util::sync::CancellationToken;
53use tracing::warn;
54
55#[derive(Serialize, Deserialize, Debug)]
56struct JsonOutput {
57    messages: Vec<Message>,
58    metadata: JsonMetadata,
59}
60
61#[derive(Serialize, Deserialize, Debug)]
62struct JsonMetadata {
63    total_tokens: Option<i32>,
64    status: String,
65}
66
67#[derive(Serialize, Debug)]
68#[serde(tag = "type", rename_all = "snake_case")]
69enum StreamEvent {
70    Message {
71        message: Message,
72    },
73    Notification {
74        extension_id: String,
75        #[serde(flatten)]
76        data: NotificationData,
77    },
78    ModelChange {
79        model: String,
80        mode: String,
81    },
82    Error {
83        error: String,
84    },
85    Complete {
86        total_tokens: Option<i32>,
87    },
88}
89
90#[derive(Serialize, Debug)]
91#[serde(rename_all = "snake_case")]
92enum NotificationData {
93    Log {
94        message: String,
95    },
96    Progress {
97        progress: f64,
98        total: Option<f64>,
99        message: Option<String>,
100    },
101}
102
103pub enum RunMode {
104    Normal,
105    Plan,
106}
107
108struct HistoryManager {
109    history_file: PathBuf,
110    old_history_file: PathBuf,
111}
112
113impl HistoryManager {
114    fn new() -> Self {
115        Self {
116            history_file: Paths::state_dir().join("history.txt"),
117            old_history_file: Paths::config_dir().join("history.txt"),
118        }
119    }
120
121    fn load(
122        &self,
123        editor: &mut rustyline::Editor<AsterCompleter, rustyline::history::DefaultHistory>,
124    ) {
125        if let Some(parent) = self.history_file.parent() {
126            if !parent.exists() {
127                if let Err(e) = std::fs::create_dir_all(parent) {
128                    eprintln!("Warning: Failed to create history directory: {}", e);
129                }
130            }
131        }
132
133        let history_files = [&self.history_file, &self.old_history_file];
134        if let Some(file) = history_files.iter().find(|f| f.exists()) {
135            if let Err(err) = editor.load_history(file) {
136                eprintln!("Warning: Failed to load command history: {}", err);
137            }
138        }
139    }
140
141    fn save(
142        &self,
143        editor: &mut rustyline::Editor<AsterCompleter, rustyline::history::DefaultHistory>,
144    ) {
145        if let Err(err) = editor.save_history(&self.history_file) {
146            eprintln!("Warning: Failed to save command history: {}", err);
147        } else if self.old_history_file.exists() {
148            if let Err(err) = std::fs::remove_file(&self.old_history_file) {
149                eprintln!("Warning: Failed to remove old history file: {}", err);
150            }
151        }
152    }
153}
154
155pub struct CliSession {
156    agent: Agent,
157    messages: Conversation,
158    session_id: String,
159    completion_cache: Arc<std::sync::RwLock<CompletionCache>>,
160    debug: bool,
161    run_mode: RunMode,
162    scheduled_job_id: Option<String>, // ID of the scheduled job that triggered this session
163    max_turns: Option<u32>,
164    edit_mode: Option<EditMode>,
165    retry_config: Option<RetryConfig>,
166    output_format: String,
167}
168
169// Cache structure for completion data
170struct CompletionCache {
171    prompts: HashMap<String, Vec<String>>,
172    prompt_info: HashMap<String, output::PromptInfo>,
173    last_updated: Instant,
174}
175
176impl CompletionCache {
177    fn new() -> Self {
178        Self {
179            prompts: HashMap::new(),
180            prompt_info: HashMap::new(),
181            last_updated: Instant::now(),
182        }
183    }
184}
185
186pub enum PlannerResponseType {
187    Plan,
188    ClarifyingQuestions,
189}
190
191/// Decide if the planner's reponse is a plan or a clarifying question
192///
193/// This function is called after the planner has generated a response
194/// to the user's message. The response is either a plan or a clarifying
195/// question.
196pub async fn classify_planner_response(
197    message_text: String,
198    provider: Arc<dyn Provider>,
199) -> Result<PlannerResponseType> {
200    let prompt = format!("The text below is the output from an AI model which can either provide a plan or list of clarifying questions. Based on the text below, decide if the output is a \"plan\" or \"clarifying questions\".\n---\n{message_text}");
201
202    // Generate the description
203    let message = Message::user().with_text(&prompt);
204    let (result, _usage) = provider
205        .complete(
206            "Reply only with the classification label: \"plan\" or \"clarifying questions\"",
207            &[message],
208            &[],
209        )
210        .await?;
211
212    let predicted = result.as_concat_text();
213    if predicted.to_lowercase().contains("plan") {
214        Ok(PlannerResponseType::Plan)
215    } else {
216        Ok(PlannerResponseType::ClarifyingQuestions)
217    }
218}
219
220impl CliSession {
221    #[allow(clippy::too_many_arguments)]
222    pub async fn new(
223        agent: Agent,
224        session_id: String,
225        debug: bool,
226        scheduled_job_id: Option<String>,
227        max_turns: Option<u32>,
228        edit_mode: Option<EditMode>,
229        retry_config: Option<RetryConfig>,
230        output_format: String,
231    ) -> Self {
232        let messages = SessionManager::get_session(&session_id, true)
233            .await
234            .map(|session| session.conversation.unwrap_or_default())
235            .unwrap();
236
237        CliSession {
238            agent,
239            messages,
240            session_id,
241            completion_cache: Arc::new(std::sync::RwLock::new(CompletionCache::new())),
242            debug,
243            run_mode: RunMode::Normal,
244            scheduled_job_id,
245            max_turns,
246            edit_mode,
247            retry_config,
248            output_format,
249        }
250    }
251
252    pub fn session_id(&self) -> &String {
253        &self.session_id
254    }
255
256    /// Add a stdio extension to the session
257    ///
258    /// # Arguments
259    /// * `extension_command` - Full command string including environment variables
260    ///   Format: "ENV1=val1 ENV2=val2 command args..."
261    pub async fn add_extension(&mut self, extension_command: String) -> Result<()> {
262        let mut parts: Vec<&str> = extension_command.split_whitespace().collect();
263        let mut envs = HashMap::new();
264
265        while let Some(part) = parts.first() {
266            if !part.contains('=') {
267                break;
268            }
269            let env_part = parts.remove(0);
270            let (key, value) = env_part.split_once('=').unwrap();
271            envs.insert(key.to_string(), value.to_string());
272        }
273
274        if parts.is_empty() {
275            return Err(anyhow::anyhow!("No command provided in extension string"));
276        }
277
278        let cmd = parts.remove(0).to_string();
279
280        let config = ExtensionConfig::Stdio {
281            name: String::new(),
282            cmd,
283            args: parts.iter().map(|s| s.to_string()).collect(),
284            envs: Envs::new(envs),
285            env_keys: Vec::new(),
286            description: aster::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(),
287            // TODO: should set timeout
288            timeout: Some(aster::config::DEFAULT_EXTENSION_TIMEOUT),
289            bundled: None,
290            available_tools: Vec::new(),
291        };
292
293        self.agent
294            .add_extension(config)
295            .await
296            .map_err(|e| anyhow::anyhow!("Failed to start extension: {}", e))?;
297
298        // Invalidate the completion cache when a new extension is added
299        self.invalidate_completion_cache().await;
300
301        Ok(())
302    }
303
304    /// Add a streamable HTTP extension to the session
305    ///
306    /// # Arguments
307    /// * `extension_url` - URL of the server
308    pub async fn add_streamable_http_extension(&mut self, extension_url: String) -> Result<()> {
309        let config = ExtensionConfig::StreamableHttp {
310            name: String::new(),
311            uri: extension_url,
312            envs: Envs::new(HashMap::new()),
313            env_keys: Vec::new(),
314            headers: HashMap::new(),
315            description: aster::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(),
316            // TODO: should set timeout
317            timeout: Some(aster::config::DEFAULT_EXTENSION_TIMEOUT),
318            bundled: None,
319            available_tools: Vec::new(),
320        };
321
322        self.agent
323            .add_extension(config)
324            .await
325            .map_err(|e| anyhow::anyhow!("Failed to start extension: {}", e))?;
326
327        // Invalidate the completion cache when a new extension is added
328        self.invalidate_completion_cache().await;
329
330        Ok(())
331    }
332
333    /// Add a builtin extension to the session
334    ///
335    /// # Arguments
336    /// * `builtin_name` - Name of the builtin extension(s), comma separated
337    pub async fn add_builtin(&mut self, builtin_name: String) -> Result<()> {
338        for name in builtin_name.split(',') {
339            let extension_name = name.trim();
340
341            let config = if PLATFORM_EXTENSIONS.contains_key(extension_name) {
342                ExtensionConfig::Platform {
343                    name: extension_name.to_string(),
344                    bundled: None,
345                    description: name.to_string(),
346                    available_tools: Vec::new(),
347                }
348            } else {
349                ExtensionConfig::Builtin {
350                    name: extension_name.to_string(),
351                    display_name: None,
352                    timeout: None,
353                    bundled: None,
354                    description: name.to_string(),
355                    available_tools: Vec::new(),
356                }
357            };
358            self.agent
359                .add_extension(config)
360                .await
361                .map_err(|e| anyhow::anyhow!("Failed to start builtin extension: {}", e))?;
362        }
363
364        // Invalidate the completion cache when a new extension is added
365        self.invalidate_completion_cache().await;
366
367        Ok(())
368    }
369
370    pub async fn list_prompts(
371        &mut self,
372        extension: Option<String>,
373    ) -> Result<HashMap<String, Vec<String>>> {
374        let prompts = self.agent.list_extension_prompts().await;
375
376        // Early validation if filtering by extension
377        if let Some(filter) = &extension {
378            if !prompts.contains_key(filter) {
379                return Err(anyhow::anyhow!("Extension '{}' not found", filter));
380            }
381        }
382
383        // Convert prompts into filtered map of extension names to prompt names
384        Ok(prompts
385            .into_iter()
386            .filter(|(ext, _)| extension.as_ref().is_none_or(|f| f == ext))
387            .map(|(extension, prompt_list)| {
388                let names = prompt_list.into_iter().map(|p| p.name).collect();
389                (extension, names)
390            })
391            .collect())
392    }
393
394    pub async fn get_prompt_info(&mut self, name: &str) -> Result<Option<output::PromptInfo>> {
395        let prompts = self.agent.list_extension_prompts().await;
396
397        // Find which extension has this prompt
398        for (extension, prompt_list) in prompts {
399            if let Some(prompt) = prompt_list.iter().find(|p| p.name == name) {
400                return Ok(Some(output::PromptInfo {
401                    name: prompt.name.clone(),
402                    description: prompt.description.clone(),
403                    arguments: prompt.arguments.clone(),
404                    extension: Some(extension),
405                }));
406            }
407        }
408
409        Ok(None)
410    }
411
412    pub async fn get_prompt(&mut self, name: &str, arguments: Value) -> Result<Vec<PromptMessage>> {
413        Ok(self.agent.get_prompt(name, arguments).await?.messages)
414    }
415
416    /// Process a single message and get the response
417    pub(crate) async fn process_message(
418        &mut self,
419        message: Message,
420        cancel_token: CancellationToken,
421    ) -> Result<()> {
422        let cancel_token = cancel_token.clone();
423        self.push_message(message);
424        self.process_agent_response(false, cancel_token).await?;
425        Ok(())
426    }
427
428    /// Start an interactive session, optionally with an initial message
429    pub async fn interactive(&mut self, prompt: Option<String>) -> Result<()> {
430        if let Some(prompt) = prompt {
431            let msg = Message::user().with_text(&prompt);
432            self.process_message(msg, CancellationToken::default())
433                .await?;
434        }
435
436        self.update_completion_cache().await?;
437
438        let mut editor = self.create_editor()?;
439        let history_manager = HistoryManager::new();
440        history_manager.load(&mut editor);
441
442        output::display_greeting();
443        loop {
444            self.display_context_usage().await?;
445
446            let input = input::get_input(&mut editor)?;
447            if matches!(input, InputResult::Exit) {
448                break;
449            }
450            self.handle_input(input, &history_manager, &mut editor)
451                .await?;
452        }
453
454        println!(
455            "Closing session. Session ID: {}",
456            console::style(&self.session_id).cyan()
457        );
458
459        Ok(())
460    }
461
462    fn create_editor(
463        &self,
464    ) -> Result<rustyline::Editor<AsterCompleter, rustyline::history::DefaultHistory>> {
465        let builder =
466            rustyline::Config::builder().completion_type(rustyline::CompletionType::Circular);
467        let builder = match self.edit_mode {
468            Some(mode) => builder.edit_mode(mode),
469            None => builder.edit_mode(EditMode::Emacs),
470        };
471        let config = builder.build();
472        let mut editor =
473            rustyline::Editor::<AsterCompleter, rustyline::history::DefaultHistory>::with_config(
474                config,
475            )?;
476        let completer = AsterCompleter::new(self.completion_cache.clone());
477        editor.set_helper(Some(completer));
478        Ok(editor)
479    }
480
481    async fn handle_input(
482        &mut self,
483        input: InputResult,
484        history: &HistoryManager,
485        editor: &mut rustyline::Editor<AsterCompleter, rustyline::history::DefaultHistory>,
486    ) -> Result<()> {
487        match input {
488            InputResult::Message(content) => {
489                self.handle_message_input(&content, history, editor).await?;
490            }
491            InputResult::Exit => unreachable!("Exit is handled in the main loop"),
492            InputResult::AddExtension(cmd) => {
493                history.save(editor);
494                match self.add_extension(cmd.clone()).await {
495                    Ok(_) => output::render_extension_success(&cmd),
496                    Err(e) => output::render_extension_error(&cmd, &e.to_string()),
497                }
498            }
499            InputResult::AddBuiltin(names) => {
500                history.save(editor);
501                match self.add_builtin(names.clone()).await {
502                    Ok(_) => output::render_builtin_success(&names),
503                    Err(e) => output::render_builtin_error(&names, &e.to_string()),
504                }
505            }
506            InputResult::ToggleTheme => {
507                history.save(editor);
508                self.handle_toggle_theme();
509            }
510            InputResult::SelectTheme(theme_name) => {
511                history.save(editor);
512                self.handle_select_theme(&theme_name);
513            }
514            InputResult::Retry => {}
515            InputResult::ListPrompts(extension) => {
516                history.save(editor);
517                match self.list_prompts(extension).await {
518                    Ok(prompts) => output::render_prompts(&prompts),
519                    Err(e) => output::render_error(&e.to_string()),
520                }
521            }
522            InputResult::AsterMode(mode) => {
523                history.save(editor);
524                self.handle_aster_mode(&mode)?;
525            }
526            InputResult::Plan(options) => {
527                self.handle_plan_mode(options).await?;
528            }
529            InputResult::EndPlan => {
530                self.run_mode = RunMode::Normal;
531                output::render_exit_plan_mode();
532            }
533            InputResult::Clear => {
534                history.save(editor);
535                self.handle_clear().await?;
536            }
537            InputResult::PromptCommand(opts) => {
538                history.save(editor);
539                self.handle_prompt_command(opts).await?;
540            }
541            InputResult::Recipe(filepath_opt) => {
542                history.save(editor);
543                self.handle_recipe(filepath_opt).await;
544            }
545            InputResult::Compact => {
546                history.save(editor);
547                self.handle_compact().await?;
548            }
549        }
550        Ok(())
551    }
552
553    async fn handle_message_input(
554        &mut self,
555        content: &str,
556        history: &HistoryManager,
557        editor: &mut rustyline::Editor<AsterCompleter, rustyline::history::DefaultHistory>,
558    ) -> Result<()> {
559        match self.run_mode {
560            RunMode::Normal => {
561                history.save(editor);
562                self.push_message(Message::user().with_text(content));
563
564                if let Err(e) = crate::project_tracker::update_project_tracker(
565                    Some(content),
566                    Some(&self.session_id),
567                ) {
568                    eprintln!(
569                        "Warning: Failed to update project tracker with instruction: {}",
570                        e
571                    );
572                }
573
574                let _provider = self.agent.provider().await?;
575
576                output::show_thinking();
577                let start_time = Instant::now();
578                self.process_agent_response(true, CancellationToken::default())
579                    .await?;
580                output::hide_thinking();
581
582                let elapsed = start_time.elapsed();
583                let elapsed_str = format_elapsed_time(elapsed);
584                println!(
585                    "\n{}",
586                    console::style(format!("⏱️  Elapsed time: {}", elapsed_str)).dim()
587                );
588            }
589            RunMode::Plan => {
590                let mut plan_messages = self.messages.clone();
591                plan_messages.push(Message::user().with_text(content));
592                let reasoner = get_reasoner().await?;
593                self.plan_with_reasoner_model(plan_messages, reasoner)
594                    .await?;
595            }
596        }
597        Ok(())
598    }
599
600    fn handle_toggle_theme(&self) {
601        let current = output::get_theme();
602        let new_theme = match current {
603            output::Theme::Ansi => {
604                println!("Switching to Light theme");
605                output::Theme::Light
606            }
607            output::Theme::Light => {
608                println!("Switching to Dark theme");
609                output::Theme::Dark
610            }
611            output::Theme::Dark => {
612                println!("Switching to Ansi theme");
613                output::Theme::Ansi
614            }
615        };
616        output::set_theme(new_theme);
617    }
618
619    fn handle_select_theme(&self, theme_name: &str) {
620        let new_theme = match theme_name {
621            "light" => {
622                println!("Switching to Light theme");
623                output::Theme::Light
624            }
625            "dark" => {
626                println!("Switching to Dark theme");
627                output::Theme::Dark
628            }
629            "ansi" => {
630                println!("Switching to Ansi theme");
631                output::Theme::Ansi
632            }
633            _ => output::Theme::Dark,
634        };
635        output::set_theme(new_theme);
636    }
637
638    fn handle_aster_mode(&self, mode: &str) -> Result<()> {
639        let config = Config::global();
640        let mode = match AsterMode::from_str(&mode.to_lowercase()) {
641            Ok(mode) => mode,
642            Err(_) => {
643                output::render_error(&format!(
644                    "Invalid mode '{}'. Mode must be one of: auto, approve, chat, smart_approve",
645                    mode
646                ));
647                return Ok(());
648            }
649        };
650        config.set_aster_mode(mode)?;
651        output::aster_mode_message(&format!("Aster mode set to '{:?}'", mode));
652        Ok(())
653    }
654
655    async fn handle_plan_mode(&mut self, options: input::PlanCommandOptions) -> Result<()> {
656        self.run_mode = RunMode::Plan;
657        output::render_enter_plan_mode();
658
659        if options.message_text.is_empty() {
660            return Ok(());
661        }
662
663        let mut plan_messages = self.messages.clone();
664        plan_messages.push(Message::user().with_text(&options.message_text));
665
666        let reasoner = get_reasoner().await?;
667        self.plan_with_reasoner_model(plan_messages, reasoner).await
668    }
669
670    async fn handle_clear(&mut self) -> Result<()> {
671        if let Err(e) =
672            SessionManager::replace_conversation(&self.session_id, &Conversation::default()).await
673        {
674            output::render_error(&format!("Failed to clear session: {}", e));
675            return Ok(());
676        }
677
678        if let Err(e) = SessionManager::update_session(&self.session_id)
679            .total_tokens(Some(0))
680            .input_tokens(Some(0))
681            .output_tokens(Some(0))
682            .apply()
683            .await
684        {
685            output::render_error(&format!("Failed to reset token counts: {}", e));
686            return Ok(());
687        }
688
689        self.messages.clear();
690        tracing::info!("Chat context cleared by user.");
691        output::render_message(
692            &Message::assistant().with_text("Chat context cleared.\n"),
693            self.debug,
694        );
695        Ok(())
696    }
697
698    async fn handle_recipe(&mut self, filepath_opt: Option<String>) {
699        println!("{}", console::style("Generating Recipe").green());
700
701        output::show_thinking();
702        let recipe = self.agent.create_recipe(self.messages.clone()).await;
703        output::hide_thinking();
704
705        match recipe {
706            Ok(recipe) => {
707                let filepath_str = filepath_opt.as_deref().unwrap_or("recipe.yaml");
708                match self.save_recipe(&recipe, filepath_str) {
709                    Ok(path) => println!(
710                        "{}",
711                        console::style(format!("Saved recipe to {}", path.display())).green()
712                    ),
713                    Err(e) => println!("{}", console::style(e).red()),
714                }
715            }
716            Err(e) => {
717                println!(
718                    "{}: {:?}",
719                    console::style("Failed to generate recipe").red(),
720                    e
721                );
722            }
723        }
724    }
725
726    async fn handle_compact(&mut self) -> Result<()> {
727        let prompt = "Are you sure you want to compact this conversation? This will condense the message history.";
728        let should_summarize = match cliclack::confirm(prompt).initial_value(true).interact() {
729            Ok(choice) => choice,
730            Err(e) => {
731                if e.kind() == std::io::ErrorKind::Interrupted {
732                    false
733                } else {
734                    return Err(e.into());
735                }
736            }
737        };
738
739        if should_summarize {
740            self.push_message(Message::user().with_text(COMPACT_TRIGGERS[0]));
741            output::show_thinking();
742            self.process_agent_response(true, CancellationToken::default())
743                .await?;
744            output::hide_thinking();
745        } else {
746            println!("{}", console::style("Compaction cancelled.").yellow());
747        }
748        Ok(())
749    }
750
751    async fn plan_with_reasoner_model(
752        &mut self,
753        plan_messages: Conversation,
754        reasoner: Arc<dyn Provider>,
755    ) -> Result<(), anyhow::Error> {
756        let plan_prompt = self.agent.get_plan_prompt().await?;
757        output::show_thinking();
758        let (plan_response, _usage) = reasoner
759            .complete(&plan_prompt, plan_messages.messages(), &[])
760            .await?;
761        output::render_message(&plan_response, self.debug);
762        output::hide_thinking();
763        let planner_response_type =
764            classify_planner_response(plan_response.as_concat_text(), self.agent.provider().await?)
765                .await?;
766
767        match planner_response_type {
768            PlannerResponseType::Plan => {
769                println!();
770                let should_act = match cliclack::confirm(
771                    "Do you want to clear message history & act on this plan?",
772                )
773                .initial_value(true)
774                .interact()
775                {
776                    Ok(choice) => choice,
777                    Err(e) => {
778                        if e.kind() == std::io::ErrorKind::Interrupted {
779                            false // If interrupted, set should_act to false
780                        } else {
781                            return Err(e.into());
782                        }
783                    }
784                };
785                if should_act {
786                    output::render_act_on_plan();
787                    self.run_mode = RunMode::Normal;
788                    // set aster mode: auto if that isn't already the case
789                    let config = Config::global();
790                    let curr_aster_mode = config.get_aster_mode().unwrap_or(AsterMode::Auto);
791                    if curr_aster_mode != AsterMode::Auto {
792                        config.set_aster_mode(AsterMode::Auto).unwrap();
793                    }
794
795                    // clear the messages before acting on the plan
796                    self.messages.clear();
797                    // add the plan response as a user message
798                    let plan_message = Message::user().with_text(plan_response.as_concat_text());
799                    self.push_message(plan_message);
800                    // act on the plan
801                    output::show_thinking();
802                    self.process_agent_response(true, CancellationToken::default())
803                        .await?;
804                    output::hide_thinking();
805
806                    // Reset run & aster mode
807                    if curr_aster_mode != AsterMode::Auto {
808                        config.set_aster_mode(curr_aster_mode)?;
809                    }
810                } else {
811                    // add the plan response (assistant message) & carry the conversation forward
812                    // in the next round, the user might wanna slightly modify the plan
813                    self.push_message(plan_response);
814                }
815            }
816            PlannerResponseType::ClarifyingQuestions => {
817                // add the plan response (assistant message) & carry the conversation forward
818                // in the next round, the user will answer the clarifying questions
819                self.push_message(plan_response);
820            }
821        }
822
823        Ok(())
824    }
825
826    /// Process a single message and exit
827    pub async fn headless(&mut self, prompt: String) -> Result<()> {
828        let message = Message::user().with_text(&prompt);
829        self.process_message(message, CancellationToken::default())
830            .await?;
831        Ok(())
832    }
833
834    async fn process_agent_response(
835        &mut self,
836        interactive: bool,
837        cancel_token: CancellationToken,
838    ) -> Result<()> {
839        let is_json_mode = self.output_format == "json";
840        let is_stream_json_mode = self.output_format == "stream-json";
841
842        // Helper to emit a streaming JSON event
843        let emit_stream_event = |event: &StreamEvent| {
844            if let Ok(json) = serde_json::to_string(event) {
845                println!("{}", json);
846            }
847        };
848
849        let session_config = SessionConfig {
850            id: self.session_id.clone(),
851            schedule_id: self.scheduled_job_id.clone(),
852            max_turns: self.max_turns,
853            retry_config: self.retry_config.clone(),
854            system_prompt: None,
855        };
856        let user_message = self
857            .messages
858            .last()
859            .ok_or_else(|| anyhow::anyhow!("No user message"))?;
860
861        let cancel_token_interrupt = cancel_token.clone();
862        let handle = tokio::spawn(async move {
863            if ctrl_c().await.is_ok() {
864                cancel_token_interrupt.cancel();
865            }
866        });
867        let _drop_handle = AbortOnDropHandle::new(handle);
868
869        let mut stream = self
870            .agent
871            .reply(
872                user_message.clone(),
873                session_config.clone(),
874                Some(cancel_token.clone()),
875            )
876            .await?;
877
878        let mut progress_bars = output::McpSpinners::new();
879        let cancel_token_clone = cancel_token.clone();
880
881        use futures::StreamExt;
882        loop {
883            tokio::select! {
884                result = stream.next() => {
885                    match result {
886                        Some(Ok(AgentEvent::Message(message))) => {
887                            let tool_call_confirmation = message.content.iter().find_map(|content| {
888                                if let MessageContent::ActionRequired(action) = content {
889                                    #[allow(irrefutable_let_patterns)] // this is a one variant enum right now but it will have more
890                                    if let ActionRequiredData::ToolConfirmation { id, tool_name, arguments, prompt } = &action.data {
891                                        Some((id.clone(), tool_name.clone(), arguments.clone(), prompt.clone()))
892                                    } else {
893                                        None
894                                    }
895                                } else {
896                                    None
897                                }
898                            });
899
900                            let elicitation_request = message.content.iter().find_map(|content| {
901                                if let MessageContent::ActionRequired(action) = content {
902                                    if let ActionRequiredData::Elicitation { id, message, requested_schema } = &action.data {
903                                        Some((id.clone(), message.clone(), requested_schema.clone()))
904                                    } else {
905                                        None
906                                    }
907                                } else {
908                                    None
909                                }
910                            });
911
912                            if let Some((id, _tool_name, _arguments, security_prompt)) = tool_call_confirmation {
913                                output::hide_thinking();
914
915                                // Format the confirmation prompt - use security message if present, otherwise use generic message
916                                let prompt = if let Some(security_message) = &security_prompt {
917                                    println!("\n{}", security_message);
918                                    "Do you allow this tool call?".to_string()
919                                } else {
920                                    "Aster would like to call the above tool, do you allow?".to_string()
921                                };
922
923                                // Get confirmation from user
924                                let permission_result = if security_prompt.is_none() {
925                                    // No security message - show all options including "Always Allow"
926                                    cliclack::select(prompt)
927                                        .item(Permission::AllowOnce, "Allow", "Allow the tool call once")
928                                        .item(Permission::AlwaysAllow, "Always Allow", "Always allow the tool call")
929                                        .item(Permission::DenyOnce, "Deny", "Deny the tool call")
930                                        .item(Permission::Cancel, "Cancel", "Cancel the AI response and tool call")
931                                        .interact()
932                                } else {
933                                    // Security message present - don't show "Always Allow"
934                                    cliclack::select(prompt)
935                                        .item(Permission::AllowOnce, "Allow", "Allow the tool call once")
936                                        .item(Permission::DenyOnce, "Deny", "Deny the tool call")
937                                        .item(Permission::Cancel, "Cancel", "Cancel the AI response and tool call")
938                                        .interact()
939                                };
940
941                                let permission = match permission_result {
942                                    Ok(p) => p,
943                                    Err(e) => {
944                                        if e.kind() == std::io::ErrorKind::Interrupted {
945                                            Permission::Cancel
946                                        } else {
947                                            return Err(e.into());
948                                        }
949                                    }
950                                };
951
952                                if permission == Permission::Cancel {
953                                    output::render_text("Tool call cancelled. Returning to chat...", Some(Color::Yellow), true);
954
955                                    let mut response_message = Message::user();
956                                    response_message.content.push(MessageContent::tool_response(
957                                        id.clone(),
958                                        Err(ErrorData { code: ErrorCode::INVALID_REQUEST, message: std::borrow::Cow::from("Tool call cancelled by user".to_string()), data: None })
959                                    ));
960                                    self.messages.push(response_message);
961                                    cancel_token_clone.cancel();
962                                    drop(stream);
963                                    break;
964                                } else {
965                                    self.agent.handle_confirmation(id.clone(), PermissionConfirmation {
966                                        principal_type: PrincipalType::Tool,
967                                        permission,
968                                    }).await;
969                                }
970                            }
971                            else if let Some((elicitation_id, elicitation_message, schema)) = elicitation_request {
972                                output::hide_thinking();
973                                let _ = progress_bars.hide();
974
975                                match elicitation::collect_elicitation_input(&elicitation_message, &schema) {
976                                    Ok(Some(user_data)) => {
977                                        let user_data_value = serde_json::to_value(user_data)
978                                            .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
979
980                                        let response_message = Message::user()
981                                            .with_content(MessageContent::action_required_elicitation_response(
982                                                elicitation_id.clone(),
983                                                user_data_value,
984                                            ))
985                                            .with_visibility(false, true);
986
987                                        self.messages.push(response_message.clone());
988                                        // Elicitation responses return an empty stream - the response
989                                        // unblocks the waiting tool call via ActionRequiredManager
990                                        let _ = self
991                                            .agent
992                                            .reply(
993                                                response_message,
994                                                session_config.clone(),
995                                                Some(cancel_token.clone()),
996                                            )
997                                            .await?;
998                                    }
999                                    Ok(None) => {
1000                                        output::render_text("Information request cancelled.", Some(Color::Yellow), true);
1001                                        cancel_token_clone.cancel();
1002                                        drop(stream);
1003                                        break;
1004                                    }
1005                                    Err(e) => {
1006                                        output::render_error(&format!("Failed to collect input: {}", e));
1007                                        cancel_token_clone.cancel();
1008                                        drop(stream);
1009                                        break;
1010                                    }
1011                                }
1012                            }
1013                            else {
1014                                for content in &message.content {
1015                                    if let MessageContent::ToolRequest(tool_request) = content {
1016                                        if let Ok(tool_call) = &tool_request.tool_call {
1017                                            tracing::info!(counter.aster.tool_calls = 1,
1018                                                tool_name = %tool_call.name,
1019                                                "Tool call started"
1020                                            );
1021                                        }
1022                                    }
1023                                    if let MessageContent::ToolResponse(tool_response) = content {
1024                                        let tool_name = self.messages
1025                                            .iter()
1026                                            .rev()
1027                                            .find_map(|msg| {
1028                                                msg.content.iter().find_map(|c| {
1029                                                    if let MessageContent::ToolRequest(req) = c {
1030                                                        if req.id == tool_response.id {
1031                                                            if let Ok(tool_call) = &req.tool_call {
1032                                                                Some(tool_call.name.clone())
1033                                                            } else {
1034                                                                None
1035                                                            }
1036                                                        } else {
1037                                                            None
1038                                                        }
1039                                                    } else {
1040                                                        None
1041                                                    }
1042                                                })
1043                                            })
1044                                            .unwrap_or_else(|| "unknown".to_string().into());
1045
1046                                        let success = tool_response.tool_result.is_ok();
1047                                        let result_status = if success { "success" } else { "error" };
1048                                        tracing::info!(
1049                                            counter.aster.tool_completions = 1,
1050                                            tool_name = %tool_name,
1051                                            result = %result_status,
1052                                            "Tool call completed"
1053                                        );
1054                                    }
1055                                }
1056
1057                                self.messages.push(message.clone());
1058
1059                                if interactive {output::hide_thinking()};
1060                                let _ = progress_bars.hide();
1061
1062                                // Handle different output formats
1063                                if is_stream_json_mode {
1064                                    emit_stream_event(&StreamEvent::Message { message: message.clone() });
1065                                } else if !is_json_mode {
1066                                    output::render_message(&message, self.debug);
1067                                }
1068                            }
1069                        }
1070                        Some(Ok(AgentEvent::McpNotification((extension_id, message)))) => {
1071                            match &message {
1072                                ServerNotification::LoggingMessageNotification(notification) => {
1073                                    let data = &notification.params.data;
1074                                    let (formatted_message, subagent_id, message_notification_type) = match data {
1075                                        Value::String(s) => (s.clone(), None, None),
1076                                        Value::Object(o) => {
1077                                            // Check for subagent notification structure first
1078                                            if let Some(Value::String(msg)) = o.get("message") {
1079                                                // Extract subagent info for better display
1080                                                let subagent_id = o.get("subagent_id")
1081                                                    .and_then(|v| v.as_str());
1082                                                let notification_type = o.get("type")
1083                                                    .and_then(|v| v.as_str());
1084
1085                                                let formatted = match notification_type {
1086                                                    Some("subagent_created") | Some("completed") | Some("terminated") => {
1087                                                        format!("🤖 {}", msg)
1088                                                    }
1089                                                    Some("tool_usage") | Some("tool_completed") | Some("tool_error") => {
1090                                                        format!("🔧 {}", msg)
1091                                                    }
1092                                                    Some("message_processing") | Some("turn_progress") => {
1093                                                        format!("💭 {}", msg)
1094                                                    }
1095                                                    Some("response_generated") => {
1096                                                        // Check verbosity setting for subagent response content
1097                                                        let config = Config::global();
1098                                                        let min_priority = config
1099                                                            .get_param::<f32>("ASTER_CLI_MIN_PRIORITY")
1100                                                            .ok()
1101                                                            .unwrap_or(0.5);
1102
1103                                                        if min_priority > 0.1 && !self.debug {
1104                                                            // High/Medium verbosity: show truncated response
1105                                                            if let Some(response_content) = msg.strip_prefix("Responded: ") {
1106                                                                format!("🤖 Responded: {}", safe_truncate(response_content, 100))
1107                                                            } else {
1108                                                                format!("🤖 {}", msg)
1109                                                            }
1110                                                        } else {
1111                                                            // All verbosity or debug: show full response
1112                                                            format!("🤖 {}", msg)
1113                                                        }
1114                                                    }
1115                                                    _ => {
1116                                                        msg.to_string()
1117                                                    }
1118                                                };
1119                                                (formatted, subagent_id.map(str::to_string), notification_type.map(str::to_string))
1120                                            } else if let Some(Value::String(output)) = o.get("output") {
1121                                                // Extract type if present (e.g., "shell_output")
1122                                                let notification_type = o.get("type")
1123                                                    .and_then(|v| v.as_str())
1124                                                    .map(str::to_string);
1125
1126                                                (output.to_owned(), None, notification_type)
1127                                            } else if let Some(result) = format_task_execution_notification(data) {
1128                                                result
1129                                            } else {
1130                                                (data.to_string(), None, None)
1131                                            }
1132                                        },
1133                                        v => {
1134                                            (v.to_string(), None, None)
1135                                        },
1136                                    };
1137
1138                                    if is_stream_json_mode {
1139                                        emit_stream_event(&StreamEvent::Notification {
1140                                            extension_id: extension_id.clone(),
1141                                            data: NotificationData::Log { message: formatted_message.clone() },
1142                                        });
1143                                    }
1144                                    // Handle subagent notifications - show immediately
1145                                    else if let Some(_id) = subagent_id {
1146                                        if interactive {
1147                                            let _ = progress_bars.hide();
1148                                            if !is_json_mode {
1149                                                println!("{}", console::style(&formatted_message).green().dim());
1150                                            }
1151                                        } else if !is_json_mode {
1152                                            progress_bars.log(&formatted_message);
1153                                        }
1154                                    } else if let Some(ref notification_type) = message_notification_type {
1155                                        if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE {
1156                                            if interactive {
1157                                                let _ = progress_bars.hide();
1158                                                if !is_json_mode {
1159                                                    print!("{}", formatted_message);
1160                                                    std::io::stdout().flush().unwrap();
1161                                                }
1162                                            } else if !is_json_mode {
1163                                                print!("{}", formatted_message);
1164                                                std::io::stdout().flush().unwrap();
1165                                            }
1166                                        } else if notification_type == "shell_output" {
1167                                            if interactive {
1168                                                let _ = progress_bars.hide();
1169                                            }
1170                                            if !is_json_mode {
1171                                                println!("{}", formatted_message);
1172                                            }
1173                                        }
1174                                    }
1175                                    else if output::is_showing_thinking() {
1176                                        output::set_thinking_message(&formatted_message);
1177                                    } else {
1178                                        progress_bars.log(&formatted_message);
1179                                    }
1180                                },
1181                                ServerNotification::ProgressNotification(notification) => {
1182                                    let progress = notification.params.progress;
1183                                    let text = notification.params.message.as_deref();
1184                                    let total = notification.params.total;
1185                                    let token = &notification.params.progress_token;
1186
1187                                    if is_stream_json_mode {
1188                                        emit_stream_event(&StreamEvent::Notification {
1189                                            extension_id: extension_id.clone(),
1190                                            data: NotificationData::Progress {
1191                                                progress,
1192                                                total,
1193                                                message: text.map(String::from),
1194                                            },
1195                                        });
1196                                    } else {
1197                                        progress_bars.update(
1198                                            &token.0.to_string(),
1199                                            progress,
1200                                            total,
1201                                            text,
1202                                        );
1203                                    }
1204                                },
1205                                _ => (),
1206                            }
1207                        }
1208                        Some(Ok(AgentEvent::HistoryReplaced(updated_conversation))) => {
1209                            self.messages = updated_conversation;
1210                        }
1211                        Some(Ok(AgentEvent::ModelChange { model, mode })) => {
1212                            if is_stream_json_mode {
1213                                emit_stream_event(&StreamEvent::ModelChange {
1214                                    model: model.clone(),
1215                                    mode: mode.clone(),
1216                                });
1217                            } else if self.debug {
1218                                eprintln!("Model changed to {} in {} mode", model, mode);
1219                            }
1220                        }
1221
1222                        Some(Err(e)) => {
1223                            let error_msg = e.to_string();
1224
1225                            if is_stream_json_mode {
1226                                emit_stream_event(&StreamEvent::Error { error: error_msg.clone() });
1227                            }
1228
1229                            if e.downcast_ref::<aster::providers::errors::ProviderError>()
1230                                .map(|provider_error| matches!(provider_error, aster::providers::errors::ProviderError::ContextLengthExceeded(_)))
1231                                .unwrap_or(false) {
1232
1233                                if !is_stream_json_mode {
1234                                    output::render_text(
1235                                        "Compaction requested. Should have happened in the agent!",
1236                                        Some(Color::Yellow),
1237                                        true
1238                                    );
1239                                }
1240                                warn!("Compaction requested. Should have happened in the agent!");
1241                            }
1242                            if !is_stream_json_mode {
1243                                eprintln!("Error: {}", error_msg);
1244                            }
1245                            cancel_token_clone.cancel();
1246                            drop(stream);
1247                            if let Err(e) = self.handle_interrupted_messages(false).await {
1248                                eprintln!("Error handling interruption: {}", e);
1249                            } else if !is_stream_json_mode {
1250                                output::render_error(
1251                                    "The error above was an exception we were not able to handle.\n\
1252                                    These errors are often related to connection or authentication\n\
1253                                    We've removed the conversation up to the most recent user message\n\
1254                                    - depending on the error you may be able to continue",
1255                                );
1256                            }
1257                            break;
1258                        }
1259                        None => break,
1260                    }
1261                }
1262                _ = cancel_token_clone.cancelled() => {
1263                    drop(stream);
1264                    if let Err(e) = self.handle_interrupted_messages(true).await {
1265                        eprintln!("Error handling interruption: {}", e);
1266                    }
1267                    break;
1268                }
1269            }
1270        }
1271
1272        // Output based on format
1273        if is_json_mode {
1274            let metadata = match SessionManager::get_session(&self.session_id, false).await {
1275                Ok(session) => JsonMetadata {
1276                    total_tokens: session.total_tokens,
1277                    status: "completed".to_string(),
1278                },
1279                Err(_) => JsonMetadata {
1280                    total_tokens: None,
1281                    status: "completed".to_string(),
1282                },
1283            };
1284
1285            let json_output = JsonOutput {
1286                messages: self.messages.messages().to_vec(),
1287                metadata,
1288            };
1289
1290            println!("{}", serde_json::to_string_pretty(&json_output)?);
1291        } else if is_stream_json_mode {
1292            let total_tokens = SessionManager::get_session(&self.session_id, false)
1293                .await
1294                .ok()
1295                .and_then(|s| s.total_tokens);
1296            emit_stream_event(&StreamEvent::Complete { total_tokens });
1297        } else {
1298            println!();
1299        }
1300
1301        Ok(())
1302    }
1303
1304    async fn handle_interrupted_messages(&mut self, interrupt: bool) -> Result<()> {
1305        // First, get any tool requests from the last message if it exists
1306        let tool_requests = self
1307            .messages
1308            .last()
1309            .filter(|msg| msg.role == rmcp::model::Role::Assistant)
1310            .map_or(Vec::new(), |msg| {
1311                msg.content
1312                    .iter()
1313                    .filter_map(|content| {
1314                        if let MessageContent::ToolRequest(req) = content {
1315                            Some((req.id.clone(), req.tool_call.clone()))
1316                        } else {
1317                            None
1318                        }
1319                    })
1320                    .collect()
1321            });
1322
1323        if !tool_requests.is_empty() {
1324            // Interrupted during a tool request
1325            // Create tool responses for all interrupted tool requests
1326            let mut response_message = Message::user();
1327            let last_tool_name = tool_requests
1328                .last()
1329                .and_then(|(_, tool_call)| {
1330                    tool_call
1331                        .as_ref()
1332                        .ok()
1333                        .map(|tool| tool.name.to_string().clone())
1334                })
1335                .unwrap_or_else(|| "tool".to_string());
1336
1337            let notification = if interrupt {
1338                "Interrupted by the user to make a correction".to_string()
1339            } else {
1340                "An uncaught error happened during tool use".to_string()
1341            };
1342            for (req_id, _) in &tool_requests {
1343                response_message.content.push(MessageContent::tool_response(
1344                    req_id.clone(),
1345                    Err(ErrorData {
1346                        code: ErrorCode::INTERNAL_ERROR,
1347                        message: std::borrow::Cow::from(notification.clone()),
1348                        data: None,
1349                    }),
1350                ));
1351            }
1352            // TODO(Douwe): update also db
1353            self.push_message(response_message);
1354            let prompt = format!(
1355                "The existing call to {} was interrupted. How would you like to proceed?",
1356                last_tool_name
1357            );
1358            self.push_message(Message::assistant().with_text(&prompt));
1359            output::render_message(&Message::assistant().with_text(&prompt), self.debug);
1360        } else {
1361            // An interruption occurred outside of a tool request-response.
1362            if let Some(last_msg) = self.messages.last() {
1363                if last_msg.role == rmcp::model::Role::User {
1364                    match last_msg.content.first() {
1365                        Some(MessageContent::ToolResponse(_)) => {
1366                            // Interruption occurred after a tool had completed but not assistant reply
1367                            let prompt = "The tool calling loop was interrupted. How would you like to proceed?";
1368                            self.push_message(Message::assistant().with_text(prompt));
1369                            output::render_message(
1370                                &Message::assistant().with_text(prompt),
1371                                self.debug,
1372                            );
1373                        }
1374                        Some(_) => {
1375                            // A real users message
1376                            self.messages.pop();
1377                            let prompt = "Interrupted before the model replied and removed the last message.";
1378                            output::render_message(
1379                                &Message::assistant().with_text(prompt),
1380                                self.debug,
1381                            );
1382                        }
1383                        None => panic!("No content in last message"),
1384                    }
1385                }
1386            }
1387        }
1388        Ok(())
1389    }
1390
1391    /// Update the completion cache with fresh data
1392    /// This should be called before the interactive session starts
1393    pub async fn update_completion_cache(&mut self) -> Result<()> {
1394        // Get fresh data
1395        let prompts = self.agent.list_extension_prompts().await;
1396
1397        // Update the cache with write lock
1398        let mut cache = self.completion_cache.write().unwrap();
1399        cache.prompts.clear();
1400        cache.prompt_info.clear();
1401
1402        for (extension, prompt_list) in prompts {
1403            let names: Vec<String> = prompt_list.iter().map(|p| p.name.clone()).collect();
1404            cache.prompts.insert(extension.clone(), names);
1405
1406            for prompt in prompt_list {
1407                cache.prompt_info.insert(
1408                    prompt.name.clone(),
1409                    output::PromptInfo {
1410                        name: prompt.name.clone(),
1411                        description: prompt.description.clone(),
1412                        arguments: prompt.arguments.clone(),
1413                        extension: Some(extension.clone()),
1414                    },
1415                );
1416            }
1417        }
1418
1419        cache.last_updated = Instant::now();
1420        Ok(())
1421    }
1422
1423    /// Invalidate the completion cache
1424    /// This should be called when extensions are added or removed
1425    async fn invalidate_completion_cache(&self) {
1426        let mut cache = self.completion_cache.write().unwrap();
1427        cache.prompts.clear();
1428        cache.prompt_info.clear();
1429        cache.last_updated = Instant::now();
1430    }
1431
1432    pub fn message_history(&self) -> Conversation {
1433        self.messages.clone()
1434    }
1435
1436    /// Render all past messages from the session history
1437    pub fn render_message_history(&self) {
1438        if self.messages.is_empty() {
1439            return;
1440        }
1441
1442        // Print session restored message
1443        println!(
1444            "\n{} {} messages loaded into context.",
1445            console::style("Session restored:").green().bold(),
1446            console::style(self.messages.len()).green()
1447        );
1448
1449        // Render each message
1450        for message in self.messages.iter() {
1451            output::render_message(message, self.debug);
1452        }
1453
1454        // Add a visual separator after restored messages
1455        println!(
1456            "\n{}\n",
1457            console::style("──────── New Messages ────────").dim()
1458        );
1459    }
1460
1461    pub async fn get_session(&self) -> Result<aster::session::Session> {
1462        SessionManager::get_session(&self.session_id, false).await
1463    }
1464
1465    // Get the session's total token usage
1466    pub async fn get_total_token_usage(&self) -> Result<Option<i32>> {
1467        let metadata = self.get_session().await?;
1468        Ok(metadata.total_tokens)
1469    }
1470
1471    /// Display enhanced context usage with session totals
1472    pub async fn display_context_usage(&self) -> Result<()> {
1473        let provider = self.agent.provider().await?;
1474        let model_config = provider.get_model_config();
1475        let context_limit = model_config.context_limit();
1476
1477        let config = Config::global();
1478        let show_cost = config
1479            .get_param::<bool>("ASTER_CLI_SHOW_COST")
1480            .unwrap_or(false);
1481
1482        let provider_name = config
1483            .get_aster_provider()
1484            .unwrap_or_else(|_| "unknown".to_string());
1485
1486        match self.get_session().await {
1487            Ok(metadata) => {
1488                let total_tokens = metadata.total_tokens.unwrap_or(0) as usize;
1489
1490                output::display_context_usage(total_tokens, context_limit);
1491
1492                if show_cost {
1493                    let input_tokens = metadata.input_tokens.unwrap_or(0) as usize;
1494                    let output_tokens = metadata.output_tokens.unwrap_or(0) as usize;
1495                    output::display_cost_usage(
1496                        &provider_name,
1497                        &model_config.model_name,
1498                        input_tokens,
1499                        output_tokens,
1500                    );
1501                }
1502            }
1503            Err(_) => {
1504                output::display_context_usage(0, context_limit);
1505            }
1506        }
1507
1508        Ok(())
1509    }
1510
1511    /// Handle prompt command execution
1512    async fn handle_prompt_command(&mut self, opts: input::PromptCommandOptions) -> Result<()> {
1513        // name is required
1514        if opts.name.is_empty() {
1515            output::render_error("Prompt name argument is required");
1516            return Ok(());
1517        }
1518
1519        if opts.info {
1520            match self.get_prompt_info(&opts.name).await? {
1521                Some(info) => output::render_prompt_info(&info),
1522                None => output::render_error(&format!("Prompt '{}' not found", opts.name)),
1523            }
1524        } else {
1525            // Convert the arguments HashMap to a Value
1526            let arguments = serde_json::to_value(opts.arguments)
1527                .map_err(|e| anyhow::anyhow!("Failed to serialize arguments: {}", e))?;
1528
1529            match self.get_prompt(&opts.name, arguments).await {
1530                Ok(messages) => {
1531                    let start_len = self.messages.len();
1532                    let mut valid = true;
1533                    for (i, prompt_message) in messages.into_iter().enumerate() {
1534                        let msg = Message::from(prompt_message);
1535                        // ensure we get a User - Assistant - User type pattern
1536                        let expected_role = if i % 2 == 0 {
1537                            rmcp::model::Role::User
1538                        } else {
1539                            rmcp::model::Role::Assistant
1540                        };
1541
1542                        if msg.role != expected_role {
1543                            output::render_error(&format!(
1544                                "Expected {:?} message at position {}, but found {:?}",
1545                                expected_role, i, msg.role
1546                            ));
1547                            valid = false;
1548                            // get rid of everything we added to messages
1549                            self.messages.truncate(start_len);
1550                            break;
1551                        }
1552
1553                        if msg.role == rmcp::model::Role::User {
1554                            output::render_message(&msg, self.debug);
1555                        }
1556                        self.push_message(msg);
1557                    }
1558
1559                    if valid {
1560                        output::show_thinking();
1561                        self.process_agent_response(true, CancellationToken::default())
1562                            .await?;
1563                        output::hide_thinking();
1564                    }
1565                }
1566                Err(e) => output::render_error(&e.to_string()),
1567            }
1568        }
1569
1570        Ok(())
1571    }
1572
1573    /// Save a recipe to a file
1574    ///
1575    /// # Arguments
1576    /// * `recipe` - The recipe to save
1577    /// * `filepath_str` - The path to save the recipe to
1578    ///
1579    /// # Returns
1580    /// * `Result<PathBuf, String>` - The path the recipe was saved to or an error message
1581    fn save_recipe(
1582        &self,
1583        recipe: &aster::recipe::Recipe,
1584        filepath_str: &str,
1585    ) -> anyhow::Result<PathBuf> {
1586        let path_buf = PathBuf::from(filepath_str);
1587        let mut path = path_buf.clone();
1588
1589        // Update the final path if it's relative
1590        if path_buf.is_relative() {
1591            // If the path is relative, resolve it relative to the current working directory
1592            let cwd = std::env::current_dir().context("Failed to get current directory")?;
1593            path = cwd.join(&path_buf);
1594        }
1595
1596        // Check if parent directory exists
1597        if let Some(parent) = path.parent() {
1598            if !parent.exists() {
1599                return Err(anyhow::anyhow!(
1600                    "Directory '{}' does not exist",
1601                    parent.display()
1602                ));
1603            }
1604        }
1605
1606        // Try creating the file
1607        let file = std::fs::File::create(path.as_path())
1608            .context(format!("Failed to create file '{}'", path.display()))?;
1609
1610        // Write YAML
1611        serde_yaml::to_writer(file, recipe).context("Failed to save recipe")?;
1612
1613        Ok(path)
1614    }
1615
1616    fn push_message(&mut self, message: Message) {
1617        self.messages.push(message);
1618    }
1619}
1620
1621async fn get_reasoner() -> Result<Arc<dyn Provider>, anyhow::Error> {
1622    use aster::model::ModelConfig;
1623    use aster::providers::create;
1624
1625    let config = Config::global();
1626
1627    // Try planner-specific provider first, fallback to default provider
1628    let provider = if let Ok(provider) = config.get_param::<String>("ASTER_PLANNER_PROVIDER") {
1629        provider
1630    } else {
1631        println!("WARNING: ASTER_PLANNER_PROVIDER not found. Using default provider...");
1632        config
1633            .get_aster_provider()
1634            .expect("No provider configured. Run 'aster configure' first")
1635    };
1636
1637    // Try planner-specific model first, fallback to default model
1638    let model = if let Ok(model) = config.get_param::<String>("ASTER_PLANNER_MODEL") {
1639        model
1640    } else {
1641        println!("WARNING: ASTER_PLANNER_MODEL not found. Using default model...");
1642        config
1643            .get_aster_model()
1644            .expect("No model configured. Run 'aster configure' first")
1645    };
1646
1647    let model_config =
1648        ModelConfig::new_with_context_env(model, Some("ASTER_PLANNER_CONTEXT_LIMIT"))?;
1649    let reasoner = create(&provider, model_config).await?;
1650
1651    Ok(reasoner)
1652}
1653
1654/// Format elapsed time duration
1655/// Shows seconds if less than 60, otherwise shows minutes:seconds
1656fn format_elapsed_time(duration: std::time::Duration) -> String {
1657    let total_secs = duration.as_secs();
1658    if total_secs < 60 {
1659        format!("{:.2}s", duration.as_secs_f64())
1660    } else {
1661        let minutes = total_secs / 60;
1662        let seconds = total_secs % 60;
1663        format!("{}m {:02}s", minutes, seconds)
1664    }
1665}
1666
1667#[cfg(test)]
1668mod tests {
1669    use super::*;
1670    use std::time::Duration;
1671
1672    #[test]
1673    fn test_format_elapsed_time_under_60_seconds() {
1674        // Test sub-second duration
1675        let duration = Duration::from_millis(500);
1676        assert_eq!(format_elapsed_time(duration), "0.50s");
1677
1678        // Test exactly 1 second
1679        let duration = Duration::from_secs(1);
1680        assert_eq!(format_elapsed_time(duration), "1.00s");
1681
1682        // Test 45.75 seconds
1683        let duration = Duration::from_millis(45750);
1684        assert_eq!(format_elapsed_time(duration), "45.75s");
1685
1686        // Test 59.99 seconds
1687        let duration = Duration::from_millis(59990);
1688        assert_eq!(format_elapsed_time(duration), "59.99s");
1689    }
1690
1691    #[test]
1692    fn test_format_elapsed_time_minutes() {
1693        // Test exactly 60 seconds (1 minute)
1694        let duration = Duration::from_secs(60);
1695        assert_eq!(format_elapsed_time(duration), "1m 00s");
1696
1697        // Test 61 seconds (1 minute 1 second)
1698        let duration = Duration::from_secs(61);
1699        assert_eq!(format_elapsed_time(duration), "1m 01s");
1700
1701        // Test 90 seconds (1 minute 30 seconds)
1702        let duration = Duration::from_secs(90);
1703        assert_eq!(format_elapsed_time(duration), "1m 30s");
1704
1705        // Test 119 seconds (1 minute 59 seconds)
1706        let duration = Duration::from_secs(119);
1707        assert_eq!(format_elapsed_time(duration), "1m 59s");
1708
1709        // Test 120 seconds (2 minutes)
1710        let duration = Duration::from_secs(120);
1711        assert_eq!(format_elapsed_time(duration), "2m 00s");
1712
1713        // Test 605 seconds (10 minutes 5 seconds)
1714        let duration = Duration::from_secs(605);
1715        assert_eq!(format_elapsed_time(duration), "10m 05s");
1716
1717        // Test 3661 seconds (61 minutes 1 second)
1718        let duration = Duration::from_secs(3661);
1719        assert_eq!(format_elapsed_time(duration), "61m 01s");
1720    }
1721
1722    #[test]
1723    fn test_format_elapsed_time_edge_cases() {
1724        // Test zero duration
1725        let duration = Duration::from_secs(0);
1726        assert_eq!(format_elapsed_time(duration), "0.00s");
1727
1728        // Test very small duration (1 millisecond)
1729        let duration = Duration::from_millis(1);
1730        assert_eq!(format_elapsed_time(duration), "0.00s");
1731
1732        // Test fractional seconds are truncated for minute display
1733        // 60.5 seconds should still show as 1m 00s (not 1m 00.5s)
1734        let duration = Duration::from_millis(60500);
1735        assert_eq!(format_elapsed_time(duration), "1m 00s");
1736    }
1737}