1mod builder;
2mod completion;
3mod elicitation;
4mod export;
5mod input;
6mod output;
7mod prompt;
8mod task_execution_display;
9mod thinking;
10
11use crate::session::task_execution_display::{
12 format_task_execution_notification, TASK_EXECUTION_NOTIFICATION_TYPE,
13};
14use aster::conversation::Conversation;
15use std::io::Write;
16use std::str::FromStr;
17use tokio::signal::ctrl_c;
18use tokio_util::task::AbortOnDropHandle;
19
20pub use self::export::message_to_markdown;
21use aster::agents::AgentEvent;
22use aster::permission::permission_confirmation::PrincipalType;
23use aster::permission::Permission;
24use aster::permission::PermissionConfirmation;
25use aster::providers::base::Provider;
26use aster::utils::safe_truncate;
27pub use builder::{build_session, SessionBuilderConfig, SessionSettings};
28use console::Color;
29
30use anyhow::{Context, Result};
31use aster::agents::extension::{Envs, ExtensionConfig, PLATFORM_EXTENSIONS};
32use aster::agents::types::RetryConfig;
33use aster::agents::{Agent, SessionConfig, COMPACT_TRIGGERS};
34use aster::config::{AsterMode, Config};
35use aster::session::SessionManager;
36use completion::AsterCompleter;
37use input::InputResult;
38use rmcp::model::PromptMessage;
39use rmcp::model::ServerNotification;
40use rmcp::model::{ErrorCode, ErrorData};
41
42use aster::config::paths::Paths;
43use aster::conversation::message::{ActionRequiredData, Message, MessageContent};
44use rustyline::EditMode;
45use serde::{Deserialize, Serialize};
46use serde_json::Value;
47use std::collections::HashMap;
48use std::path::PathBuf;
49use std::sync::Arc;
50use std::time::Instant;
51use tokio;
52use tokio_util::sync::CancellationToken;
53use tracing::warn;
54
55#[derive(Serialize, Deserialize, Debug)]
56struct JsonOutput {
57 messages: Vec<Message>,
58 metadata: JsonMetadata,
59}
60
61#[derive(Serialize, Deserialize, Debug)]
62struct JsonMetadata {
63 total_tokens: Option<i32>,
64 status: String,
65}
66
67#[derive(Serialize, Debug)]
68#[serde(tag = "type", rename_all = "snake_case")]
69enum StreamEvent {
70 Message {
71 message: Message,
72 },
73 Notification {
74 extension_id: String,
75 #[serde(flatten)]
76 data: NotificationData,
77 },
78 ModelChange {
79 model: String,
80 mode: String,
81 },
82 Error {
83 error: String,
84 },
85 Complete {
86 total_tokens: Option<i32>,
87 },
88}
89
90#[derive(Serialize, Debug)]
91#[serde(rename_all = "snake_case")]
92enum NotificationData {
93 Log {
94 message: String,
95 },
96 Progress {
97 progress: f64,
98 total: Option<f64>,
99 message: Option<String>,
100 },
101}
102
103pub enum RunMode {
104 Normal,
105 Plan,
106}
107
108struct HistoryManager {
109 history_file: PathBuf,
110 old_history_file: PathBuf,
111}
112
113impl HistoryManager {
114 fn new() -> Self {
115 Self {
116 history_file: Paths::state_dir().join("history.txt"),
117 old_history_file: Paths::config_dir().join("history.txt"),
118 }
119 }
120
121 fn load(
122 &self,
123 editor: &mut rustyline::Editor<AsterCompleter, rustyline::history::DefaultHistory>,
124 ) {
125 if let Some(parent) = self.history_file.parent() {
126 if !parent.exists() {
127 if let Err(e) = std::fs::create_dir_all(parent) {
128 eprintln!("Warning: Failed to create history directory: {}", e);
129 }
130 }
131 }
132
133 let history_files = [&self.history_file, &self.old_history_file];
134 if let Some(file) = history_files.iter().find(|f| f.exists()) {
135 if let Err(err) = editor.load_history(file) {
136 eprintln!("Warning: Failed to load command history: {}", err);
137 }
138 }
139 }
140
141 fn save(
142 &self,
143 editor: &mut rustyline::Editor<AsterCompleter, rustyline::history::DefaultHistory>,
144 ) {
145 if let Err(err) = editor.save_history(&self.history_file) {
146 eprintln!("Warning: Failed to save command history: {}", err);
147 } else if self.old_history_file.exists() {
148 if let Err(err) = std::fs::remove_file(&self.old_history_file) {
149 eprintln!("Warning: Failed to remove old history file: {}", err);
150 }
151 }
152 }
153}
154
155pub struct CliSession {
156 agent: Agent,
157 messages: Conversation,
158 session_id: String,
159 completion_cache: Arc<std::sync::RwLock<CompletionCache>>,
160 debug: bool,
161 run_mode: RunMode,
162 scheduled_job_id: Option<String>, max_turns: Option<u32>,
164 edit_mode: Option<EditMode>,
165 retry_config: Option<RetryConfig>,
166 output_format: String,
167}
168
169struct CompletionCache {
171 prompts: HashMap<String, Vec<String>>,
172 prompt_info: HashMap<String, output::PromptInfo>,
173 last_updated: Instant,
174}
175
176impl CompletionCache {
177 fn new() -> Self {
178 Self {
179 prompts: HashMap::new(),
180 prompt_info: HashMap::new(),
181 last_updated: Instant::now(),
182 }
183 }
184}
185
186pub enum PlannerResponseType {
187 Plan,
188 ClarifyingQuestions,
189}
190
191pub async fn classify_planner_response(
197 message_text: String,
198 provider: Arc<dyn Provider>,
199) -> Result<PlannerResponseType> {
200 let prompt = format!("The text below is the output from an AI model which can either provide a plan or list of clarifying questions. Based on the text below, decide if the output is a \"plan\" or \"clarifying questions\".\n---\n{message_text}");
201
202 let message = Message::user().with_text(&prompt);
204 let (result, _usage) = provider
205 .complete(
206 "Reply only with the classification label: \"plan\" or \"clarifying questions\"",
207 &[message],
208 &[],
209 )
210 .await?;
211
212 let predicted = result.as_concat_text();
213 if predicted.to_lowercase().contains("plan") {
214 Ok(PlannerResponseType::Plan)
215 } else {
216 Ok(PlannerResponseType::ClarifyingQuestions)
217 }
218}
219
220impl CliSession {
221 #[allow(clippy::too_many_arguments)]
222 pub async fn new(
223 agent: Agent,
224 session_id: String,
225 debug: bool,
226 scheduled_job_id: Option<String>,
227 max_turns: Option<u32>,
228 edit_mode: Option<EditMode>,
229 retry_config: Option<RetryConfig>,
230 output_format: String,
231 ) -> Self {
232 let messages = SessionManager::get_session(&session_id, true)
233 .await
234 .map(|session| session.conversation.unwrap_or_default())
235 .unwrap();
236
237 CliSession {
238 agent,
239 messages,
240 session_id,
241 completion_cache: Arc::new(std::sync::RwLock::new(CompletionCache::new())),
242 debug,
243 run_mode: RunMode::Normal,
244 scheduled_job_id,
245 max_turns,
246 edit_mode,
247 retry_config,
248 output_format,
249 }
250 }
251
252 pub fn session_id(&self) -> &String {
253 &self.session_id
254 }
255
256 pub async fn add_extension(&mut self, extension_command: String) -> Result<()> {
262 let mut parts: Vec<&str> = extension_command.split_whitespace().collect();
263 let mut envs = HashMap::new();
264
265 while let Some(part) = parts.first() {
266 if !part.contains('=') {
267 break;
268 }
269 let env_part = parts.remove(0);
270 let (key, value) = env_part.split_once('=').unwrap();
271 envs.insert(key.to_string(), value.to_string());
272 }
273
274 if parts.is_empty() {
275 return Err(anyhow::anyhow!("No command provided in extension string"));
276 }
277
278 let cmd = parts.remove(0).to_string();
279
280 let config = ExtensionConfig::Stdio {
281 name: String::new(),
282 cmd,
283 args: parts.iter().map(|s| s.to_string()).collect(),
284 envs: Envs::new(envs),
285 env_keys: Vec::new(),
286 description: aster::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(),
287 timeout: Some(aster::config::DEFAULT_EXTENSION_TIMEOUT),
289 bundled: None,
290 available_tools: Vec::new(),
291 };
292
293 self.agent
294 .add_extension(config)
295 .await
296 .map_err(|e| anyhow::anyhow!("Failed to start extension: {}", e))?;
297
298 self.invalidate_completion_cache().await;
300
301 Ok(())
302 }
303
304 pub async fn add_streamable_http_extension(&mut self, extension_url: String) -> Result<()> {
309 let config = ExtensionConfig::StreamableHttp {
310 name: String::new(),
311 uri: extension_url,
312 envs: Envs::new(HashMap::new()),
313 env_keys: Vec::new(),
314 headers: HashMap::new(),
315 description: aster::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(),
316 timeout: Some(aster::config::DEFAULT_EXTENSION_TIMEOUT),
318 bundled: None,
319 available_tools: Vec::new(),
320 };
321
322 self.agent
323 .add_extension(config)
324 .await
325 .map_err(|e| anyhow::anyhow!("Failed to start extension: {}", e))?;
326
327 self.invalidate_completion_cache().await;
329
330 Ok(())
331 }
332
333 pub async fn add_builtin(&mut self, builtin_name: String) -> Result<()> {
338 for name in builtin_name.split(',') {
339 let extension_name = name.trim();
340
341 let config = if PLATFORM_EXTENSIONS.contains_key(extension_name) {
342 ExtensionConfig::Platform {
343 name: extension_name.to_string(),
344 bundled: None,
345 description: name.to_string(),
346 available_tools: Vec::new(),
347 }
348 } else {
349 ExtensionConfig::Builtin {
350 name: extension_name.to_string(),
351 display_name: None,
352 timeout: None,
353 bundled: None,
354 description: name.to_string(),
355 available_tools: Vec::new(),
356 }
357 };
358 self.agent
359 .add_extension(config)
360 .await
361 .map_err(|e| anyhow::anyhow!("Failed to start builtin extension: {}", e))?;
362 }
363
364 self.invalidate_completion_cache().await;
366
367 Ok(())
368 }
369
370 pub async fn list_prompts(
371 &mut self,
372 extension: Option<String>,
373 ) -> Result<HashMap<String, Vec<String>>> {
374 let prompts = self.agent.list_extension_prompts().await;
375
376 if let Some(filter) = &extension {
378 if !prompts.contains_key(filter) {
379 return Err(anyhow::anyhow!("Extension '{}' not found", filter));
380 }
381 }
382
383 Ok(prompts
385 .into_iter()
386 .filter(|(ext, _)| extension.as_ref().is_none_or(|f| f == ext))
387 .map(|(extension, prompt_list)| {
388 let names = prompt_list.into_iter().map(|p| p.name).collect();
389 (extension, names)
390 })
391 .collect())
392 }
393
394 pub async fn get_prompt_info(&mut self, name: &str) -> Result<Option<output::PromptInfo>> {
395 let prompts = self.agent.list_extension_prompts().await;
396
397 for (extension, prompt_list) in prompts {
399 if let Some(prompt) = prompt_list.iter().find(|p| p.name == name) {
400 return Ok(Some(output::PromptInfo {
401 name: prompt.name.clone(),
402 description: prompt.description.clone(),
403 arguments: prompt.arguments.clone(),
404 extension: Some(extension),
405 }));
406 }
407 }
408
409 Ok(None)
410 }
411
412 pub async fn get_prompt(&mut self, name: &str, arguments: Value) -> Result<Vec<PromptMessage>> {
413 Ok(self.agent.get_prompt(name, arguments).await?.messages)
414 }
415
416 pub(crate) async fn process_message(
418 &mut self,
419 message: Message,
420 cancel_token: CancellationToken,
421 ) -> Result<()> {
422 let cancel_token = cancel_token.clone();
423 self.push_message(message);
424 self.process_agent_response(false, cancel_token).await?;
425 Ok(())
426 }
427
428 pub async fn interactive(&mut self, prompt: Option<String>) -> Result<()> {
430 if let Some(prompt) = prompt {
431 let msg = Message::user().with_text(&prompt);
432 self.process_message(msg, CancellationToken::default())
433 .await?;
434 }
435
436 self.update_completion_cache().await?;
437
438 let mut editor = self.create_editor()?;
439 let history_manager = HistoryManager::new();
440 history_manager.load(&mut editor);
441
442 output::display_greeting();
443 loop {
444 self.display_context_usage().await?;
445
446 let input = input::get_input(&mut editor)?;
447 if matches!(input, InputResult::Exit) {
448 break;
449 }
450 self.handle_input(input, &history_manager, &mut editor)
451 .await?;
452 }
453
454 println!(
455 "Closing session. Session ID: {}",
456 console::style(&self.session_id).cyan()
457 );
458
459 Ok(())
460 }
461
462 fn create_editor(
463 &self,
464 ) -> Result<rustyline::Editor<AsterCompleter, rustyline::history::DefaultHistory>> {
465 let builder =
466 rustyline::Config::builder().completion_type(rustyline::CompletionType::Circular);
467 let builder = match self.edit_mode {
468 Some(mode) => builder.edit_mode(mode),
469 None => builder.edit_mode(EditMode::Emacs),
470 };
471 let config = builder.build();
472 let mut editor =
473 rustyline::Editor::<AsterCompleter, rustyline::history::DefaultHistory>::with_config(
474 config,
475 )?;
476 let completer = AsterCompleter::new(self.completion_cache.clone());
477 editor.set_helper(Some(completer));
478 Ok(editor)
479 }
480
481 async fn handle_input(
482 &mut self,
483 input: InputResult,
484 history: &HistoryManager,
485 editor: &mut rustyline::Editor<AsterCompleter, rustyline::history::DefaultHistory>,
486 ) -> Result<()> {
487 match input {
488 InputResult::Message(content) => {
489 self.handle_message_input(&content, history, editor).await?;
490 }
491 InputResult::Exit => unreachable!("Exit is handled in the main loop"),
492 InputResult::AddExtension(cmd) => {
493 history.save(editor);
494 match self.add_extension(cmd.clone()).await {
495 Ok(_) => output::render_extension_success(&cmd),
496 Err(e) => output::render_extension_error(&cmd, &e.to_string()),
497 }
498 }
499 InputResult::AddBuiltin(names) => {
500 history.save(editor);
501 match self.add_builtin(names.clone()).await {
502 Ok(_) => output::render_builtin_success(&names),
503 Err(e) => output::render_builtin_error(&names, &e.to_string()),
504 }
505 }
506 InputResult::ToggleTheme => {
507 history.save(editor);
508 self.handle_toggle_theme();
509 }
510 InputResult::SelectTheme(theme_name) => {
511 history.save(editor);
512 self.handle_select_theme(&theme_name);
513 }
514 InputResult::Retry => {}
515 InputResult::ListPrompts(extension) => {
516 history.save(editor);
517 match self.list_prompts(extension).await {
518 Ok(prompts) => output::render_prompts(&prompts),
519 Err(e) => output::render_error(&e.to_string()),
520 }
521 }
522 InputResult::AsterMode(mode) => {
523 history.save(editor);
524 self.handle_aster_mode(&mode)?;
525 }
526 InputResult::Plan(options) => {
527 self.handle_plan_mode(options).await?;
528 }
529 InputResult::EndPlan => {
530 self.run_mode = RunMode::Normal;
531 output::render_exit_plan_mode();
532 }
533 InputResult::Clear => {
534 history.save(editor);
535 self.handle_clear().await?;
536 }
537 InputResult::PromptCommand(opts) => {
538 history.save(editor);
539 self.handle_prompt_command(opts).await?;
540 }
541 InputResult::Recipe(filepath_opt) => {
542 history.save(editor);
543 self.handle_recipe(filepath_opt).await;
544 }
545 InputResult::Compact => {
546 history.save(editor);
547 self.handle_compact().await?;
548 }
549 }
550 Ok(())
551 }
552
553 async fn handle_message_input(
554 &mut self,
555 content: &str,
556 history: &HistoryManager,
557 editor: &mut rustyline::Editor<AsterCompleter, rustyline::history::DefaultHistory>,
558 ) -> Result<()> {
559 match self.run_mode {
560 RunMode::Normal => {
561 history.save(editor);
562 self.push_message(Message::user().with_text(content));
563
564 if let Err(e) = crate::project_tracker::update_project_tracker(
565 Some(content),
566 Some(&self.session_id),
567 ) {
568 eprintln!(
569 "Warning: Failed to update project tracker with instruction: {}",
570 e
571 );
572 }
573
574 let _provider = self.agent.provider().await?;
575
576 output::show_thinking();
577 let start_time = Instant::now();
578 self.process_agent_response(true, CancellationToken::default())
579 .await?;
580 output::hide_thinking();
581
582 let elapsed = start_time.elapsed();
583 let elapsed_str = format_elapsed_time(elapsed);
584 println!(
585 "\n{}",
586 console::style(format!("⏱️ Elapsed time: {}", elapsed_str)).dim()
587 );
588 }
589 RunMode::Plan => {
590 let mut plan_messages = self.messages.clone();
591 plan_messages.push(Message::user().with_text(content));
592 let reasoner = get_reasoner().await?;
593 self.plan_with_reasoner_model(plan_messages, reasoner)
594 .await?;
595 }
596 }
597 Ok(())
598 }
599
600 fn handle_toggle_theme(&self) {
601 let current = output::get_theme();
602 let new_theme = match current {
603 output::Theme::Ansi => {
604 println!("Switching to Light theme");
605 output::Theme::Light
606 }
607 output::Theme::Light => {
608 println!("Switching to Dark theme");
609 output::Theme::Dark
610 }
611 output::Theme::Dark => {
612 println!("Switching to Ansi theme");
613 output::Theme::Ansi
614 }
615 };
616 output::set_theme(new_theme);
617 }
618
619 fn handle_select_theme(&self, theme_name: &str) {
620 let new_theme = match theme_name {
621 "light" => {
622 println!("Switching to Light theme");
623 output::Theme::Light
624 }
625 "dark" => {
626 println!("Switching to Dark theme");
627 output::Theme::Dark
628 }
629 "ansi" => {
630 println!("Switching to Ansi theme");
631 output::Theme::Ansi
632 }
633 _ => output::Theme::Dark,
634 };
635 output::set_theme(new_theme);
636 }
637
638 fn handle_aster_mode(&self, mode: &str) -> Result<()> {
639 let config = Config::global();
640 let mode = match AsterMode::from_str(&mode.to_lowercase()) {
641 Ok(mode) => mode,
642 Err(_) => {
643 output::render_error(&format!(
644 "Invalid mode '{}'. Mode must be one of: auto, approve, chat, smart_approve",
645 mode
646 ));
647 return Ok(());
648 }
649 };
650 config.set_aster_mode(mode)?;
651 output::aster_mode_message(&format!("Aster mode set to '{:?}'", mode));
652 Ok(())
653 }
654
655 async fn handle_plan_mode(&mut self, options: input::PlanCommandOptions) -> Result<()> {
656 self.run_mode = RunMode::Plan;
657 output::render_enter_plan_mode();
658
659 if options.message_text.is_empty() {
660 return Ok(());
661 }
662
663 let mut plan_messages = self.messages.clone();
664 plan_messages.push(Message::user().with_text(&options.message_text));
665
666 let reasoner = get_reasoner().await?;
667 self.plan_with_reasoner_model(plan_messages, reasoner).await
668 }
669
670 async fn handle_clear(&mut self) -> Result<()> {
671 if let Err(e) =
672 SessionManager::replace_conversation(&self.session_id, &Conversation::default()).await
673 {
674 output::render_error(&format!("Failed to clear session: {}", e));
675 return Ok(());
676 }
677
678 if let Err(e) = SessionManager::update_session(&self.session_id)
679 .total_tokens(Some(0))
680 .input_tokens(Some(0))
681 .output_tokens(Some(0))
682 .apply()
683 .await
684 {
685 output::render_error(&format!("Failed to reset token counts: {}", e));
686 return Ok(());
687 }
688
689 self.messages.clear();
690 tracing::info!("Chat context cleared by user.");
691 output::render_message(
692 &Message::assistant().with_text("Chat context cleared.\n"),
693 self.debug,
694 );
695 Ok(())
696 }
697
698 async fn handle_recipe(&mut self, filepath_opt: Option<String>) {
699 println!("{}", console::style("Generating Recipe").green());
700
701 output::show_thinking();
702 let recipe = self.agent.create_recipe(self.messages.clone()).await;
703 output::hide_thinking();
704
705 match recipe {
706 Ok(recipe) => {
707 let filepath_str = filepath_opt.as_deref().unwrap_or("recipe.yaml");
708 match self.save_recipe(&recipe, filepath_str) {
709 Ok(path) => println!(
710 "{}",
711 console::style(format!("Saved recipe to {}", path.display())).green()
712 ),
713 Err(e) => println!("{}", console::style(e).red()),
714 }
715 }
716 Err(e) => {
717 println!(
718 "{}: {:?}",
719 console::style("Failed to generate recipe").red(),
720 e
721 );
722 }
723 }
724 }
725
726 async fn handle_compact(&mut self) -> Result<()> {
727 let prompt = "Are you sure you want to compact this conversation? This will condense the message history.";
728 let should_summarize = match cliclack::confirm(prompt).initial_value(true).interact() {
729 Ok(choice) => choice,
730 Err(e) => {
731 if e.kind() == std::io::ErrorKind::Interrupted {
732 false
733 } else {
734 return Err(e.into());
735 }
736 }
737 };
738
739 if should_summarize {
740 self.push_message(Message::user().with_text(COMPACT_TRIGGERS[0]));
741 output::show_thinking();
742 self.process_agent_response(true, CancellationToken::default())
743 .await?;
744 output::hide_thinking();
745 } else {
746 println!("{}", console::style("Compaction cancelled.").yellow());
747 }
748 Ok(())
749 }
750
751 async fn plan_with_reasoner_model(
752 &mut self,
753 plan_messages: Conversation,
754 reasoner: Arc<dyn Provider>,
755 ) -> Result<(), anyhow::Error> {
756 let plan_prompt = self.agent.get_plan_prompt().await?;
757 output::show_thinking();
758 let (plan_response, _usage) = reasoner
759 .complete(&plan_prompt, plan_messages.messages(), &[])
760 .await?;
761 output::render_message(&plan_response, self.debug);
762 output::hide_thinking();
763 let planner_response_type =
764 classify_planner_response(plan_response.as_concat_text(), self.agent.provider().await?)
765 .await?;
766
767 match planner_response_type {
768 PlannerResponseType::Plan => {
769 println!();
770 let should_act = match cliclack::confirm(
771 "Do you want to clear message history & act on this plan?",
772 )
773 .initial_value(true)
774 .interact()
775 {
776 Ok(choice) => choice,
777 Err(e) => {
778 if e.kind() == std::io::ErrorKind::Interrupted {
779 false } else {
781 return Err(e.into());
782 }
783 }
784 };
785 if should_act {
786 output::render_act_on_plan();
787 self.run_mode = RunMode::Normal;
788 let config = Config::global();
790 let curr_aster_mode = config.get_aster_mode().unwrap_or(AsterMode::Auto);
791 if curr_aster_mode != AsterMode::Auto {
792 config.set_aster_mode(AsterMode::Auto).unwrap();
793 }
794
795 self.messages.clear();
797 let plan_message = Message::user().with_text(plan_response.as_concat_text());
799 self.push_message(plan_message);
800 output::show_thinking();
802 self.process_agent_response(true, CancellationToken::default())
803 .await?;
804 output::hide_thinking();
805
806 if curr_aster_mode != AsterMode::Auto {
808 config.set_aster_mode(curr_aster_mode)?;
809 }
810 } else {
811 self.push_message(plan_response);
814 }
815 }
816 PlannerResponseType::ClarifyingQuestions => {
817 self.push_message(plan_response);
820 }
821 }
822
823 Ok(())
824 }
825
826 pub async fn headless(&mut self, prompt: String) -> Result<()> {
828 let message = Message::user().with_text(&prompt);
829 self.process_message(message, CancellationToken::default())
830 .await?;
831 Ok(())
832 }
833
834 async fn process_agent_response(
835 &mut self,
836 interactive: bool,
837 cancel_token: CancellationToken,
838 ) -> Result<()> {
839 let is_json_mode = self.output_format == "json";
840 let is_stream_json_mode = self.output_format == "stream-json";
841
842 let emit_stream_event = |event: &StreamEvent| {
844 if let Ok(json) = serde_json::to_string(event) {
845 println!("{}", json);
846 }
847 };
848
849 let session_config = SessionConfig {
850 id: self.session_id.clone(),
851 schedule_id: self.scheduled_job_id.clone(),
852 max_turns: self.max_turns,
853 retry_config: self.retry_config.clone(),
854 system_prompt: None,
855 };
856 let user_message = self
857 .messages
858 .last()
859 .ok_or_else(|| anyhow::anyhow!("No user message"))?;
860
861 let cancel_token_interrupt = cancel_token.clone();
862 let handle = tokio::spawn(async move {
863 if ctrl_c().await.is_ok() {
864 cancel_token_interrupt.cancel();
865 }
866 });
867 let _drop_handle = AbortOnDropHandle::new(handle);
868
869 let mut stream = self
870 .agent
871 .reply(
872 user_message.clone(),
873 session_config.clone(),
874 Some(cancel_token.clone()),
875 )
876 .await?;
877
878 let mut progress_bars = output::McpSpinners::new();
879 let cancel_token_clone = cancel_token.clone();
880
881 use futures::StreamExt;
882 loop {
883 tokio::select! {
884 result = stream.next() => {
885 match result {
886 Some(Ok(AgentEvent::Message(message))) => {
887 let tool_call_confirmation = message.content.iter().find_map(|content| {
888 if let MessageContent::ActionRequired(action) = content {
889 #[allow(irrefutable_let_patterns)] if let ActionRequiredData::ToolConfirmation { id, tool_name, arguments, prompt } = &action.data {
891 Some((id.clone(), tool_name.clone(), arguments.clone(), prompt.clone()))
892 } else {
893 None
894 }
895 } else {
896 None
897 }
898 });
899
900 let elicitation_request = message.content.iter().find_map(|content| {
901 if let MessageContent::ActionRequired(action) = content {
902 if let ActionRequiredData::Elicitation { id, message, requested_schema } = &action.data {
903 Some((id.clone(), message.clone(), requested_schema.clone()))
904 } else {
905 None
906 }
907 } else {
908 None
909 }
910 });
911
912 if let Some((id, _tool_name, _arguments, security_prompt)) = tool_call_confirmation {
913 output::hide_thinking();
914
915 let prompt = if let Some(security_message) = &security_prompt {
917 println!("\n{}", security_message);
918 "Do you allow this tool call?".to_string()
919 } else {
920 "Aster would like to call the above tool, do you allow?".to_string()
921 };
922
923 let permission_result = if security_prompt.is_none() {
925 cliclack::select(prompt)
927 .item(Permission::AllowOnce, "Allow", "Allow the tool call once")
928 .item(Permission::AlwaysAllow, "Always Allow", "Always allow the tool call")
929 .item(Permission::DenyOnce, "Deny", "Deny the tool call")
930 .item(Permission::Cancel, "Cancel", "Cancel the AI response and tool call")
931 .interact()
932 } else {
933 cliclack::select(prompt)
935 .item(Permission::AllowOnce, "Allow", "Allow the tool call once")
936 .item(Permission::DenyOnce, "Deny", "Deny the tool call")
937 .item(Permission::Cancel, "Cancel", "Cancel the AI response and tool call")
938 .interact()
939 };
940
941 let permission = match permission_result {
942 Ok(p) => p,
943 Err(e) => {
944 if e.kind() == std::io::ErrorKind::Interrupted {
945 Permission::Cancel
946 } else {
947 return Err(e.into());
948 }
949 }
950 };
951
952 if permission == Permission::Cancel {
953 output::render_text("Tool call cancelled. Returning to chat...", Some(Color::Yellow), true);
954
955 let mut response_message = Message::user();
956 response_message.content.push(MessageContent::tool_response(
957 id.clone(),
958 Err(ErrorData { code: ErrorCode::INVALID_REQUEST, message: std::borrow::Cow::from("Tool call cancelled by user".to_string()), data: None })
959 ));
960 self.messages.push(response_message);
961 cancel_token_clone.cancel();
962 drop(stream);
963 break;
964 } else {
965 self.agent.handle_confirmation(id.clone(), PermissionConfirmation {
966 principal_type: PrincipalType::Tool,
967 permission,
968 }).await;
969 }
970 }
971 else if let Some((elicitation_id, elicitation_message, schema)) = elicitation_request {
972 output::hide_thinking();
973 let _ = progress_bars.hide();
974
975 match elicitation::collect_elicitation_input(&elicitation_message, &schema) {
976 Ok(Some(user_data)) => {
977 let user_data_value = serde_json::to_value(user_data)
978 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
979
980 let response_message = Message::user()
981 .with_content(MessageContent::action_required_elicitation_response(
982 elicitation_id.clone(),
983 user_data_value,
984 ))
985 .with_visibility(false, true);
986
987 self.messages.push(response_message.clone());
988 let _ = self
991 .agent
992 .reply(
993 response_message,
994 session_config.clone(),
995 Some(cancel_token.clone()),
996 )
997 .await?;
998 }
999 Ok(None) => {
1000 output::render_text("Information request cancelled.", Some(Color::Yellow), true);
1001 cancel_token_clone.cancel();
1002 drop(stream);
1003 break;
1004 }
1005 Err(e) => {
1006 output::render_error(&format!("Failed to collect input: {}", e));
1007 cancel_token_clone.cancel();
1008 drop(stream);
1009 break;
1010 }
1011 }
1012 }
1013 else {
1014 for content in &message.content {
1015 if let MessageContent::ToolRequest(tool_request) = content {
1016 if let Ok(tool_call) = &tool_request.tool_call {
1017 tracing::info!(counter.aster.tool_calls = 1,
1018 tool_name = %tool_call.name,
1019 "Tool call started"
1020 );
1021 }
1022 }
1023 if let MessageContent::ToolResponse(tool_response) = content {
1024 let tool_name = self.messages
1025 .iter()
1026 .rev()
1027 .find_map(|msg| {
1028 msg.content.iter().find_map(|c| {
1029 if let MessageContent::ToolRequest(req) = c {
1030 if req.id == tool_response.id {
1031 if let Ok(tool_call) = &req.tool_call {
1032 Some(tool_call.name.clone())
1033 } else {
1034 None
1035 }
1036 } else {
1037 None
1038 }
1039 } else {
1040 None
1041 }
1042 })
1043 })
1044 .unwrap_or_else(|| "unknown".to_string().into());
1045
1046 let success = tool_response.tool_result.is_ok();
1047 let result_status = if success { "success" } else { "error" };
1048 tracing::info!(
1049 counter.aster.tool_completions = 1,
1050 tool_name = %tool_name,
1051 result = %result_status,
1052 "Tool call completed"
1053 );
1054 }
1055 }
1056
1057 self.messages.push(message.clone());
1058
1059 if interactive {output::hide_thinking()};
1060 let _ = progress_bars.hide();
1061
1062 if is_stream_json_mode {
1064 emit_stream_event(&StreamEvent::Message { message: message.clone() });
1065 } else if !is_json_mode {
1066 output::render_message(&message, self.debug);
1067 }
1068 }
1069 }
1070 Some(Ok(AgentEvent::McpNotification((extension_id, message)))) => {
1071 match &message {
1072 ServerNotification::LoggingMessageNotification(notification) => {
1073 let data = ¬ification.params.data;
1074 let (formatted_message, subagent_id, message_notification_type) = match data {
1075 Value::String(s) => (s.clone(), None, None),
1076 Value::Object(o) => {
1077 if let Some(Value::String(msg)) = o.get("message") {
1079 let subagent_id = o.get("subagent_id")
1081 .and_then(|v| v.as_str());
1082 let notification_type = o.get("type")
1083 .and_then(|v| v.as_str());
1084
1085 let formatted = match notification_type {
1086 Some("subagent_created") | Some("completed") | Some("terminated") => {
1087 format!("🤖 {}", msg)
1088 }
1089 Some("tool_usage") | Some("tool_completed") | Some("tool_error") => {
1090 format!("🔧 {}", msg)
1091 }
1092 Some("message_processing") | Some("turn_progress") => {
1093 format!("💭 {}", msg)
1094 }
1095 Some("response_generated") => {
1096 let config = Config::global();
1098 let min_priority = config
1099 .get_param::<f32>("ASTER_CLI_MIN_PRIORITY")
1100 .ok()
1101 .unwrap_or(0.5);
1102
1103 if min_priority > 0.1 && !self.debug {
1104 if let Some(response_content) = msg.strip_prefix("Responded: ") {
1106 format!("🤖 Responded: {}", safe_truncate(response_content, 100))
1107 } else {
1108 format!("🤖 {}", msg)
1109 }
1110 } else {
1111 format!("🤖 {}", msg)
1113 }
1114 }
1115 _ => {
1116 msg.to_string()
1117 }
1118 };
1119 (formatted, subagent_id.map(str::to_string), notification_type.map(str::to_string))
1120 } else if let Some(Value::String(output)) = o.get("output") {
1121 let notification_type = o.get("type")
1123 .and_then(|v| v.as_str())
1124 .map(str::to_string);
1125
1126 (output.to_owned(), None, notification_type)
1127 } else if let Some(result) = format_task_execution_notification(data) {
1128 result
1129 } else {
1130 (data.to_string(), None, None)
1131 }
1132 },
1133 v => {
1134 (v.to_string(), None, None)
1135 },
1136 };
1137
1138 if is_stream_json_mode {
1139 emit_stream_event(&StreamEvent::Notification {
1140 extension_id: extension_id.clone(),
1141 data: NotificationData::Log { message: formatted_message.clone() },
1142 });
1143 }
1144 else if let Some(_id) = subagent_id {
1146 if interactive {
1147 let _ = progress_bars.hide();
1148 if !is_json_mode {
1149 println!("{}", console::style(&formatted_message).green().dim());
1150 }
1151 } else if !is_json_mode {
1152 progress_bars.log(&formatted_message);
1153 }
1154 } else if let Some(ref notification_type) = message_notification_type {
1155 if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE {
1156 if interactive {
1157 let _ = progress_bars.hide();
1158 if !is_json_mode {
1159 print!("{}", formatted_message);
1160 std::io::stdout().flush().unwrap();
1161 }
1162 } else if !is_json_mode {
1163 print!("{}", formatted_message);
1164 std::io::stdout().flush().unwrap();
1165 }
1166 } else if notification_type == "shell_output" {
1167 if interactive {
1168 let _ = progress_bars.hide();
1169 }
1170 if !is_json_mode {
1171 println!("{}", formatted_message);
1172 }
1173 }
1174 }
1175 else if output::is_showing_thinking() {
1176 output::set_thinking_message(&formatted_message);
1177 } else {
1178 progress_bars.log(&formatted_message);
1179 }
1180 },
1181 ServerNotification::ProgressNotification(notification) => {
1182 let progress = notification.params.progress;
1183 let text = notification.params.message.as_deref();
1184 let total = notification.params.total;
1185 let token = ¬ification.params.progress_token;
1186
1187 if is_stream_json_mode {
1188 emit_stream_event(&StreamEvent::Notification {
1189 extension_id: extension_id.clone(),
1190 data: NotificationData::Progress {
1191 progress,
1192 total,
1193 message: text.map(String::from),
1194 },
1195 });
1196 } else {
1197 progress_bars.update(
1198 &token.0.to_string(),
1199 progress,
1200 total,
1201 text,
1202 );
1203 }
1204 },
1205 _ => (),
1206 }
1207 }
1208 Some(Ok(AgentEvent::HistoryReplaced(updated_conversation))) => {
1209 self.messages = updated_conversation;
1210 }
1211 Some(Ok(AgentEvent::ModelChange { model, mode })) => {
1212 if is_stream_json_mode {
1213 emit_stream_event(&StreamEvent::ModelChange {
1214 model: model.clone(),
1215 mode: mode.clone(),
1216 });
1217 } else if self.debug {
1218 eprintln!("Model changed to {} in {} mode", model, mode);
1219 }
1220 }
1221
1222 Some(Err(e)) => {
1223 let error_msg = e.to_string();
1224
1225 if is_stream_json_mode {
1226 emit_stream_event(&StreamEvent::Error { error: error_msg.clone() });
1227 }
1228
1229 if e.downcast_ref::<aster::providers::errors::ProviderError>()
1230 .map(|provider_error| matches!(provider_error, aster::providers::errors::ProviderError::ContextLengthExceeded(_)))
1231 .unwrap_or(false) {
1232
1233 if !is_stream_json_mode {
1234 output::render_text(
1235 "Compaction requested. Should have happened in the agent!",
1236 Some(Color::Yellow),
1237 true
1238 );
1239 }
1240 warn!("Compaction requested. Should have happened in the agent!");
1241 }
1242 if !is_stream_json_mode {
1243 eprintln!("Error: {}", error_msg);
1244 }
1245 cancel_token_clone.cancel();
1246 drop(stream);
1247 if let Err(e) = self.handle_interrupted_messages(false).await {
1248 eprintln!("Error handling interruption: {}", e);
1249 } else if !is_stream_json_mode {
1250 output::render_error(
1251 "The error above was an exception we were not able to handle.\n\
1252 These errors are often related to connection or authentication\n\
1253 We've removed the conversation up to the most recent user message\n\
1254 - depending on the error you may be able to continue",
1255 );
1256 }
1257 break;
1258 }
1259 None => break,
1260 }
1261 }
1262 _ = cancel_token_clone.cancelled() => {
1263 drop(stream);
1264 if let Err(e) = self.handle_interrupted_messages(true).await {
1265 eprintln!("Error handling interruption: {}", e);
1266 }
1267 break;
1268 }
1269 }
1270 }
1271
1272 if is_json_mode {
1274 let metadata = match SessionManager::get_session(&self.session_id, false).await {
1275 Ok(session) => JsonMetadata {
1276 total_tokens: session.total_tokens,
1277 status: "completed".to_string(),
1278 },
1279 Err(_) => JsonMetadata {
1280 total_tokens: None,
1281 status: "completed".to_string(),
1282 },
1283 };
1284
1285 let json_output = JsonOutput {
1286 messages: self.messages.messages().to_vec(),
1287 metadata,
1288 };
1289
1290 println!("{}", serde_json::to_string_pretty(&json_output)?);
1291 } else if is_stream_json_mode {
1292 let total_tokens = SessionManager::get_session(&self.session_id, false)
1293 .await
1294 .ok()
1295 .and_then(|s| s.total_tokens);
1296 emit_stream_event(&StreamEvent::Complete { total_tokens });
1297 } else {
1298 println!();
1299 }
1300
1301 Ok(())
1302 }
1303
1304 async fn handle_interrupted_messages(&mut self, interrupt: bool) -> Result<()> {
1305 let tool_requests = self
1307 .messages
1308 .last()
1309 .filter(|msg| msg.role == rmcp::model::Role::Assistant)
1310 .map_or(Vec::new(), |msg| {
1311 msg.content
1312 .iter()
1313 .filter_map(|content| {
1314 if let MessageContent::ToolRequest(req) = content {
1315 Some((req.id.clone(), req.tool_call.clone()))
1316 } else {
1317 None
1318 }
1319 })
1320 .collect()
1321 });
1322
1323 if !tool_requests.is_empty() {
1324 let mut response_message = Message::user();
1327 let last_tool_name = tool_requests
1328 .last()
1329 .and_then(|(_, tool_call)| {
1330 tool_call
1331 .as_ref()
1332 .ok()
1333 .map(|tool| tool.name.to_string().clone())
1334 })
1335 .unwrap_or_else(|| "tool".to_string());
1336
1337 let notification = if interrupt {
1338 "Interrupted by the user to make a correction".to_string()
1339 } else {
1340 "An uncaught error happened during tool use".to_string()
1341 };
1342 for (req_id, _) in &tool_requests {
1343 response_message.content.push(MessageContent::tool_response(
1344 req_id.clone(),
1345 Err(ErrorData {
1346 code: ErrorCode::INTERNAL_ERROR,
1347 message: std::borrow::Cow::from(notification.clone()),
1348 data: None,
1349 }),
1350 ));
1351 }
1352 self.push_message(response_message);
1354 let prompt = format!(
1355 "The existing call to {} was interrupted. How would you like to proceed?",
1356 last_tool_name
1357 );
1358 self.push_message(Message::assistant().with_text(&prompt));
1359 output::render_message(&Message::assistant().with_text(&prompt), self.debug);
1360 } else {
1361 if let Some(last_msg) = self.messages.last() {
1363 if last_msg.role == rmcp::model::Role::User {
1364 match last_msg.content.first() {
1365 Some(MessageContent::ToolResponse(_)) => {
1366 let prompt = "The tool calling loop was interrupted. How would you like to proceed?";
1368 self.push_message(Message::assistant().with_text(prompt));
1369 output::render_message(
1370 &Message::assistant().with_text(prompt),
1371 self.debug,
1372 );
1373 }
1374 Some(_) => {
1375 self.messages.pop();
1377 let prompt = "Interrupted before the model replied and removed the last message.";
1378 output::render_message(
1379 &Message::assistant().with_text(prompt),
1380 self.debug,
1381 );
1382 }
1383 None => panic!("No content in last message"),
1384 }
1385 }
1386 }
1387 }
1388 Ok(())
1389 }
1390
1391 pub async fn update_completion_cache(&mut self) -> Result<()> {
1394 let prompts = self.agent.list_extension_prompts().await;
1396
1397 let mut cache = self.completion_cache.write().unwrap();
1399 cache.prompts.clear();
1400 cache.prompt_info.clear();
1401
1402 for (extension, prompt_list) in prompts {
1403 let names: Vec<String> = prompt_list.iter().map(|p| p.name.clone()).collect();
1404 cache.prompts.insert(extension.clone(), names);
1405
1406 for prompt in prompt_list {
1407 cache.prompt_info.insert(
1408 prompt.name.clone(),
1409 output::PromptInfo {
1410 name: prompt.name.clone(),
1411 description: prompt.description.clone(),
1412 arguments: prompt.arguments.clone(),
1413 extension: Some(extension.clone()),
1414 },
1415 );
1416 }
1417 }
1418
1419 cache.last_updated = Instant::now();
1420 Ok(())
1421 }
1422
1423 async fn invalidate_completion_cache(&self) {
1426 let mut cache = self.completion_cache.write().unwrap();
1427 cache.prompts.clear();
1428 cache.prompt_info.clear();
1429 cache.last_updated = Instant::now();
1430 }
1431
1432 pub fn message_history(&self) -> Conversation {
1433 self.messages.clone()
1434 }
1435
1436 pub fn render_message_history(&self) {
1438 if self.messages.is_empty() {
1439 return;
1440 }
1441
1442 println!(
1444 "\n{} {} messages loaded into context.",
1445 console::style("Session restored:").green().bold(),
1446 console::style(self.messages.len()).green()
1447 );
1448
1449 for message in self.messages.iter() {
1451 output::render_message(message, self.debug);
1452 }
1453
1454 println!(
1456 "\n{}\n",
1457 console::style("──────── New Messages ────────").dim()
1458 );
1459 }
1460
1461 pub async fn get_session(&self) -> Result<aster::session::Session> {
1462 SessionManager::get_session(&self.session_id, false).await
1463 }
1464
1465 pub async fn get_total_token_usage(&self) -> Result<Option<i32>> {
1467 let metadata = self.get_session().await?;
1468 Ok(metadata.total_tokens)
1469 }
1470
1471 pub async fn display_context_usage(&self) -> Result<()> {
1473 let provider = self.agent.provider().await?;
1474 let model_config = provider.get_model_config();
1475 let context_limit = model_config.context_limit();
1476
1477 let config = Config::global();
1478 let show_cost = config
1479 .get_param::<bool>("ASTER_CLI_SHOW_COST")
1480 .unwrap_or(false);
1481
1482 let provider_name = config
1483 .get_aster_provider()
1484 .unwrap_or_else(|_| "unknown".to_string());
1485
1486 match self.get_session().await {
1487 Ok(metadata) => {
1488 let total_tokens = metadata.total_tokens.unwrap_or(0) as usize;
1489
1490 output::display_context_usage(total_tokens, context_limit);
1491
1492 if show_cost {
1493 let input_tokens = metadata.input_tokens.unwrap_or(0) as usize;
1494 let output_tokens = metadata.output_tokens.unwrap_or(0) as usize;
1495 output::display_cost_usage(
1496 &provider_name,
1497 &model_config.model_name,
1498 input_tokens,
1499 output_tokens,
1500 );
1501 }
1502 }
1503 Err(_) => {
1504 output::display_context_usage(0, context_limit);
1505 }
1506 }
1507
1508 Ok(())
1509 }
1510
1511 async fn handle_prompt_command(&mut self, opts: input::PromptCommandOptions) -> Result<()> {
1513 if opts.name.is_empty() {
1515 output::render_error("Prompt name argument is required");
1516 return Ok(());
1517 }
1518
1519 if opts.info {
1520 match self.get_prompt_info(&opts.name).await? {
1521 Some(info) => output::render_prompt_info(&info),
1522 None => output::render_error(&format!("Prompt '{}' not found", opts.name)),
1523 }
1524 } else {
1525 let arguments = serde_json::to_value(opts.arguments)
1527 .map_err(|e| anyhow::anyhow!("Failed to serialize arguments: {}", e))?;
1528
1529 match self.get_prompt(&opts.name, arguments).await {
1530 Ok(messages) => {
1531 let start_len = self.messages.len();
1532 let mut valid = true;
1533 for (i, prompt_message) in messages.into_iter().enumerate() {
1534 let msg = Message::from(prompt_message);
1535 let expected_role = if i % 2 == 0 {
1537 rmcp::model::Role::User
1538 } else {
1539 rmcp::model::Role::Assistant
1540 };
1541
1542 if msg.role != expected_role {
1543 output::render_error(&format!(
1544 "Expected {:?} message at position {}, but found {:?}",
1545 expected_role, i, msg.role
1546 ));
1547 valid = false;
1548 self.messages.truncate(start_len);
1550 break;
1551 }
1552
1553 if msg.role == rmcp::model::Role::User {
1554 output::render_message(&msg, self.debug);
1555 }
1556 self.push_message(msg);
1557 }
1558
1559 if valid {
1560 output::show_thinking();
1561 self.process_agent_response(true, CancellationToken::default())
1562 .await?;
1563 output::hide_thinking();
1564 }
1565 }
1566 Err(e) => output::render_error(&e.to_string()),
1567 }
1568 }
1569
1570 Ok(())
1571 }
1572
1573 fn save_recipe(
1582 &self,
1583 recipe: &aster::recipe::Recipe,
1584 filepath_str: &str,
1585 ) -> anyhow::Result<PathBuf> {
1586 let path_buf = PathBuf::from(filepath_str);
1587 let mut path = path_buf.clone();
1588
1589 if path_buf.is_relative() {
1591 let cwd = std::env::current_dir().context("Failed to get current directory")?;
1593 path = cwd.join(&path_buf);
1594 }
1595
1596 if let Some(parent) = path.parent() {
1598 if !parent.exists() {
1599 return Err(anyhow::anyhow!(
1600 "Directory '{}' does not exist",
1601 parent.display()
1602 ));
1603 }
1604 }
1605
1606 let file = std::fs::File::create(path.as_path())
1608 .context(format!("Failed to create file '{}'", path.display()))?;
1609
1610 serde_yaml::to_writer(file, recipe).context("Failed to save recipe")?;
1612
1613 Ok(path)
1614 }
1615
1616 fn push_message(&mut self, message: Message) {
1617 self.messages.push(message);
1618 }
1619}
1620
1621async fn get_reasoner() -> Result<Arc<dyn Provider>, anyhow::Error> {
1622 use aster::model::ModelConfig;
1623 use aster::providers::create;
1624
1625 let config = Config::global();
1626
1627 let provider = if let Ok(provider) = config.get_param::<String>("ASTER_PLANNER_PROVIDER") {
1629 provider
1630 } else {
1631 println!("WARNING: ASTER_PLANNER_PROVIDER not found. Using default provider...");
1632 config
1633 .get_aster_provider()
1634 .expect("No provider configured. Run 'aster configure' first")
1635 };
1636
1637 let model = if let Ok(model) = config.get_param::<String>("ASTER_PLANNER_MODEL") {
1639 model
1640 } else {
1641 println!("WARNING: ASTER_PLANNER_MODEL not found. Using default model...");
1642 config
1643 .get_aster_model()
1644 .expect("No model configured. Run 'aster configure' first")
1645 };
1646
1647 let model_config =
1648 ModelConfig::new_with_context_env(model, Some("ASTER_PLANNER_CONTEXT_LIMIT"))?;
1649 let reasoner = create(&provider, model_config).await?;
1650
1651 Ok(reasoner)
1652}
1653
1654fn format_elapsed_time(duration: std::time::Duration) -> String {
1657 let total_secs = duration.as_secs();
1658 if total_secs < 60 {
1659 format!("{:.2}s", duration.as_secs_f64())
1660 } else {
1661 let minutes = total_secs / 60;
1662 let seconds = total_secs % 60;
1663 format!("{}m {:02}s", minutes, seconds)
1664 }
1665}
1666
1667#[cfg(test)]
1668mod tests {
1669 use super::*;
1670 use std::time::Duration;
1671
1672 #[test]
1673 fn test_format_elapsed_time_under_60_seconds() {
1674 let duration = Duration::from_millis(500);
1676 assert_eq!(format_elapsed_time(duration), "0.50s");
1677
1678 let duration = Duration::from_secs(1);
1680 assert_eq!(format_elapsed_time(duration), "1.00s");
1681
1682 let duration = Duration::from_millis(45750);
1684 assert_eq!(format_elapsed_time(duration), "45.75s");
1685
1686 let duration = Duration::from_millis(59990);
1688 assert_eq!(format_elapsed_time(duration), "59.99s");
1689 }
1690
1691 #[test]
1692 fn test_format_elapsed_time_minutes() {
1693 let duration = Duration::from_secs(60);
1695 assert_eq!(format_elapsed_time(duration), "1m 00s");
1696
1697 let duration = Duration::from_secs(61);
1699 assert_eq!(format_elapsed_time(duration), "1m 01s");
1700
1701 let duration = Duration::from_secs(90);
1703 assert_eq!(format_elapsed_time(duration), "1m 30s");
1704
1705 let duration = Duration::from_secs(119);
1707 assert_eq!(format_elapsed_time(duration), "1m 59s");
1708
1709 let duration = Duration::from_secs(120);
1711 assert_eq!(format_elapsed_time(duration), "2m 00s");
1712
1713 let duration = Duration::from_secs(605);
1715 assert_eq!(format_elapsed_time(duration), "10m 05s");
1716
1717 let duration = Duration::from_secs(3661);
1719 assert_eq!(format_elapsed_time(duration), "61m 01s");
1720 }
1721
1722 #[test]
1723 fn test_format_elapsed_time_edge_cases() {
1724 let duration = Duration::from_secs(0);
1726 assert_eq!(format_elapsed_time(duration), "0.00s");
1727
1728 let duration = Duration::from_millis(1);
1730 assert_eq!(format_elapsed_time(duration), "0.00s");
1731
1732 let duration = Duration::from_millis(60500);
1735 assert_eq!(format_elapsed_time(duration), "1m 00s");
1736 }
1737}