1use super::output;
2use super::CliSession;
3use aster::agents::types::{RetryConfig, SessionConfig};
4use aster::agents::Agent;
5use aster::config::{
6 extensions::get_extension_by_name, get_all_extensions, get_enabled_extensions, Config,
7 ExtensionConfig,
8};
9use aster::providers::create;
10use aster::recipe::{Response, SubRecipe};
11use console::style;
12
13use aster::agents::extension::PlatformExtensionContext;
14use aster::session::session_manager::SessionType;
15use aster::session::SessionManager;
16use aster::session::{EnabledExtensionsState, ExtensionState};
17use rustyline::EditMode;
18use std::collections::HashSet;
19use std::process;
20use std::sync::Arc;
21use tokio::task::JoinSet;
22
23#[derive(Clone, Debug)]
28pub struct SessionBuilderConfig {
29 pub session_id: Option<String>,
31 pub resume: bool,
33 pub no_session: bool,
35 pub extensions: Vec<String>,
37 pub streamable_http_extensions: Vec<String>,
39 pub builtins: Vec<String>,
41 pub extensions_override: Option<Vec<ExtensionConfig>>,
43 pub additional_system_prompt: Option<String>,
45 pub settings: Option<SessionSettings>,
47 pub provider: Option<String>,
49 pub model: Option<String>,
51 pub debug: bool,
53 pub max_tool_repetitions: Option<u32>,
55 pub max_turns: Option<u32>,
57 pub scheduled_job_id: Option<String>,
59 pub interactive: bool,
61 pub quiet: bool,
63 pub sub_recipes: Option<Vec<SubRecipe>>,
65 pub final_output_response: Option<Response>,
67 pub retry_config: Option<RetryConfig>,
69 pub output_format: String,
71}
72
73impl Default for SessionBuilderConfig {
76 fn default() -> Self {
77 SessionBuilderConfig {
78 session_id: None,
79 resume: false,
80 no_session: false,
81 extensions: Vec::new(),
82 streamable_http_extensions: Vec::new(),
83 builtins: Vec::new(),
84 extensions_override: None,
85 additional_system_prompt: None,
86 settings: None,
87 provider: None,
88 model: None,
89 debug: false,
90 max_tool_repetitions: None,
91 max_turns: None,
92 scheduled_job_id: None,
93 interactive: false,
94 quiet: false,
95 sub_recipes: None,
96 final_output_response: None,
97 retry_config: None,
98 output_format: "text".to_string(),
99 }
100 }
101}
102
103async fn offer_extension_debugging_help(
105 extension_name: &str,
106 error_message: &str,
107 provider: Arc<dyn aster::providers::base::Provider>,
108 interactive: bool,
109) -> Result<(), anyhow::Error> {
110 if !interactive {
112 return Ok(());
113 }
114
115 let help_prompt = format!(
116 "Would you like me to help debug the '{}' extension failure?",
117 extension_name
118 );
119
120 let should_help = match cliclack::confirm(help_prompt)
121 .initial_value(false)
122 .interact()
123 {
124 Ok(choice) => choice,
125 Err(e) => {
126 if e.kind() == std::io::ErrorKind::Interrupted {
127 return Ok(());
128 } else {
129 return Err(e.into());
130 }
131 }
132 };
133
134 if !should_help {
135 return Ok(());
136 }
137
138 println!("{}", style("🔧 Starting debugging session...").cyan());
139
140 let debug_prompt = format!(
142 "I'm having trouble starting an extension called '{}'. Here's the error I encountered:\n\n{}\n\nCan you help me diagnose what might be wrong and suggest how to fix it? Please consider common issues like:\n- Missing dependencies or tools\n- Configuration problems\n- Network connectivity (for remote extensions)\n- Permission issues\n- Path or environment variable problems",
143 extension_name,
144 error_message
145 );
146
147 let debug_agent = Agent::new();
149
150 let session = SessionManager::create_session(
151 std::env::current_dir()?,
152 "CLI Session".to_string(),
153 SessionType::Hidden,
154 )
155 .await?;
156
157 debug_agent.update_provider(provider, &session.id).await?;
158
159 let extensions = get_all_extensions();
161 for ext_wrapper in extensions {
162 if ext_wrapper.enabled && ext_wrapper.config.name() == "developer" {
163 if let Err(e) = debug_agent.add_extension(ext_wrapper.config).await {
164 eprintln!(
166 "Note: Could not load developer extension for debugging: {}",
167 e
168 );
169 }
170 break;
171 }
172 }
173
174 let mut debug_session = CliSession::new(
175 debug_agent,
176 session.id,
177 false,
178 None,
179 None,
180 None,
181 None,
182 "text".to_string(),
183 )
184 .await;
185
186 println!("{}", style("Analyzing the extension failure...").yellow());
188 match debug_session.headless(debug_prompt).await {
189 Ok(_) => {
190 println!(
191 "{}",
192 style("✅ Debugging session completed. Check the suggestions above.").green()
193 );
194 }
195 Err(e) => {
196 eprintln!(
197 "{}",
198 style(format!("❌ Debugging session failed: {}", e)).red()
199 );
200 }
201 }
202 Ok(())
203}
204
205fn check_missing_extensions_or_exit(saved_extensions: &[ExtensionConfig]) {
206 let missing: Vec<_> = saved_extensions
207 .iter()
208 .filter(|ext| get_extension_by_name(&ext.name()).is_none())
209 .cloned()
210 .collect();
211
212 if !missing.is_empty() {
213 let names = missing
214 .iter()
215 .map(|e| e.name())
216 .collect::<Vec<_>>()
217 .join(", ");
218
219 if !cliclack::confirm(format!(
220 "Extension(s) {} from previous session are no longer available. Restore for this session?",
221 names
222 ))
223 .initial_value(true)
224 .interact()
225 .unwrap_or(false)
226 {
227 println!("{}", style("Resume cancelled.").yellow());
228 process::exit(0);
229 }
230 }
231}
232
233#[derive(Clone, Debug, Default)]
234pub struct SessionSettings {
235 pub aster_model: Option<String>,
236 pub aster_provider: Option<String>,
237 pub temperature: Option<f32>,
238}
239
240pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession {
241 aster::posthog::set_session_context("cli", session_config.resume);
242
243 let config = Config::global();
244
245 let (saved_provider, saved_model_config) = if session_config.resume {
246 if let Some(ref session_id) = session_config.session_id {
247 match SessionManager::get_session(session_id, false).await {
248 Ok(session_data) => (session_data.provider_name, session_data.model_config),
249 Err(_) => (None, None),
250 }
251 } else {
252 (None, None)
253 }
254 } else {
255 (None, None)
256 };
257
258 let provider_name = session_config
259 .provider
260 .or(saved_provider)
261 .or_else(|| {
262 session_config
263 .settings
264 .as_ref()
265 .and_then(|s| s.aster_provider.clone())
266 })
267 .or_else(|| config.get_aster_provider().ok())
268 .expect("No provider configured. Run 'aster configure' first");
269
270 let model_name = session_config
271 .model
272 .or_else(|| saved_model_config.as_ref().map(|mc| mc.model_name.clone()))
273 .or_else(|| {
274 session_config
275 .settings
276 .as_ref()
277 .and_then(|s| s.aster_model.clone())
278 })
279 .or_else(|| config.get_aster_model().ok())
280 .expect("No model configured. Run 'aster configure' first");
281
282 let model_config = if session_config.resume
283 && saved_model_config
284 .as_ref()
285 .is_some_and(|mc| mc.model_name == model_name)
286 {
287 let mut config = saved_model_config.unwrap();
288 if let Some(temp) = session_config.settings.as_ref().and_then(|s| s.temperature) {
289 config = config.with_temperature(Some(temp));
290 }
291 config
292 } else {
293 let temperature = session_config.settings.as_ref().and_then(|s| s.temperature);
294 aster::model::ModelConfig::new(&model_name)
295 .unwrap_or_else(|e| {
296 output::render_error(&format!("Failed to create model configuration: {}", e));
297 process::exit(1);
298 })
299 .with_temperature(temperature)
300 };
301
302 let agent: Agent = Agent::new();
303
304 agent
305 .apply_recipe_components(
306 session_config.sub_recipes,
307 session_config.final_output_response,
308 true,
309 )
310 .await;
311
312 let new_provider = match create(&provider_name, model_config).await {
313 Ok(provider) => provider,
314 Err(e) => {
315 output::render_error(&format!(
316 "Error {}.\n\
317 Please check your system keychain and run 'aster configure' again.\n\
318 If your system is unable to use the keyring, please try setting secret key(s) via environment variables.\n\
319 For more info, see: https://astercloud.github.io/aster-rust/docs/troubleshooting/#keychainkeyring-errors",
320 e
321 ));
322 process::exit(1);
323 }
324 };
325 let provider_for_display = Arc::clone(&new_provider);
326
327 if let Some(lead_worker) = new_provider.as_lead_worker() {
328 let (lead_model, worker_model) = lead_worker.get_model_info();
329 tracing::info!(
330 "🤖 Lead/Worker Mode Enabled: Lead model (first 3 turns): {}, Worker model (turn 4+): {}, Auto-fallback on failures: Enabled",
331 lead_model,
332 worker_model
333 );
334 } else {
335 tracing::info!("🤖 Using model: {}", model_name);
336 }
337
338 let session_id: String = if session_config.no_session {
339 let working_dir = std::env::current_dir().expect("Could not get working directory");
340 let session = SessionManager::create_session(
341 working_dir,
342 "CLI Session".to_string(),
343 SessionType::Hidden,
344 )
345 .await
346 .expect("Could not create session");
347 session.id
348 } else if session_config.resume {
349 if let Some(session_id) = session_config.session_id {
350 match SessionManager::get_session(&session_id, false).await {
351 Ok(_) => session_id,
352 Err(_) => {
353 output::render_error(&format!(
354 "Cannot resume session {} - no such session exists",
355 style(&session_id).cyan()
356 ));
357 process::exit(1);
358 }
359 }
360 } else {
361 match SessionManager::list_sessions().await {
362 Ok(sessions) if !sessions.is_empty() => sessions[0].id.clone(),
363 _ => {
364 output::render_error("Cannot resume - no previous sessions found");
365 process::exit(1);
366 }
367 }
368 }
369 } else {
370 session_config.session_id.unwrap()
371 };
372
373 agent
374 .update_provider(new_provider, &session_id)
375 .await
376 .unwrap_or_else(|e| {
377 output::render_error(&format!("Failed to initialize agent: {}", e));
378 process::exit(1);
379 });
380
381 agent
382 .extension_manager
383 .set_context(PlatformExtensionContext {
384 session_id: Some(session_id.clone()),
385 extension_manager: Some(Arc::downgrade(&agent.extension_manager)),
386 })
387 .await;
388
389 if session_config.resume {
390 let session = SessionManager::get_session(&session_id, false)
391 .await
392 .unwrap_or_else(|e| {
393 output::render_error(&format!("Failed to read session metadata: {}", e));
394 process::exit(1);
395 });
396
397 let current_workdir =
398 std::env::current_dir().expect("Failed to get current working directory");
399 if current_workdir != session.working_dir {
400 let change_workdir = cliclack::confirm(format!("{} The original working directory of this session was set to {}. Your current directory is {}. Do you want to switch back to the original working directory?", style("WARNING:").yellow(), style(session.working_dir.display()).cyan(), style(current_workdir.display()).cyan()))
401 .initial_value(true)
402 .interact().expect("Failed to get user input");
403
404 if change_workdir {
405 if !session.working_dir.exists() {
406 output::render_error(&format!(
407 "Cannot switch to original working directory - {} no longer exists",
408 style(session.working_dir.display()).cyan()
409 ));
410 } else if let Err(e) = std::env::set_current_dir(&session.working_dir) {
411 output::render_error(&format!(
412 "Failed to switch to original working directory: {}",
413 e
414 ));
415 }
416 }
417 }
418 }
419
420 for warning in aster::config::get_warnings() {
424 eprintln!("{}", style(format!("Warning: {}", warning)).yellow());
425 }
426
427 let extensions_to_run: Vec<_> = if let Some(extensions) = session_config.extensions_override {
429 extensions.into_iter().collect()
430 } else if session_config.resume {
431 match SessionManager::get_session(&session_id, false).await {
432 Ok(session_data) => {
433 if let Some(saved_state) =
434 EnabledExtensionsState::from_extension_data(&session_data.extension_data)
435 {
436 check_missing_extensions_or_exit(&saved_state.extensions);
437 saved_state.extensions
438 } else {
439 get_enabled_extensions()
440 }
441 }
442 _ => get_enabled_extensions(),
443 }
444 } else {
445 get_enabled_extensions()
446 };
447
448 let mut set = JoinSet::new();
449 let agent_ptr = Arc::new(agent);
450
451 let mut waiting_on = HashSet::new();
452 for extension in extensions_to_run {
453 waiting_on.insert(extension.name());
454 let agent_ptr = agent_ptr.clone();
455 set.spawn(async move {
456 (
457 extension.name(),
458 agent_ptr.add_extension(extension.clone()).await,
459 )
460 });
461 }
462
463 let get_message = |waiting_on: &HashSet<String>| {
464 let mut names: Vec<_> = waiting_on.iter().cloned().collect();
465 names.sort();
466 format!("starting {} extensions: {}", names.len(), names.join(", "))
467 };
468
469 let spinner = cliclack::spinner();
470 spinner.start(get_message(&waiting_on));
471
472 let mut offer_debug = Vec::new();
473 while let Some(result) = set.join_next().await {
474 match result {
475 Ok((name, Ok(_))) => {
476 waiting_on.remove(&name);
477 spinner.set_message(get_message(&waiting_on));
478 }
479 Ok((name, Err(e))) => offer_debug.push((name, e)),
480 Err(e) => tracing::error!("failed to add extension: {}", e),
481 }
482 }
483
484 spinner.clear();
485
486 for (name, err) in offer_debug {
487 if let Err(debug_err) = offer_extension_debugging_help(
488 &name,
489 &err.to_string(),
490 Arc::clone(&provider_for_display),
491 session_config.interactive,
492 )
493 .await
494 {
495 eprintln!("Note: Could not start debugging session: {}", debug_err);
496 }
497 }
498
499 let edit_mode = config
501 .get_param::<String>("EDIT_MODE")
502 .ok()
503 .and_then(|edit_mode| match edit_mode.to_lowercase().as_str() {
504 "emacs" => Some(EditMode::Emacs),
505 "vi" => Some(EditMode::Vi),
506 _ => {
507 eprintln!("Invalid EDIT_MODE specified, defaulting to Emacs");
508 None
509 }
510 });
511
512 let debug_mode = session_config.debug || config.get_param("ASTER_DEBUG").unwrap_or(false);
513
514 let mut session = CliSession::new(
516 Arc::try_unwrap(agent_ptr).unwrap_or_else(|_| panic!("There should be no more references")),
517 session_id.clone(),
518 debug_mode,
519 session_config.scheduled_job_id.clone(),
520 session_config.max_turns,
521 edit_mode,
522 session_config.retry_config.clone(),
523 session_config.output_format.clone(),
524 )
525 .await;
526
527 for extension_str in session_config.extensions {
529 if let Err(e) = session.add_extension(extension_str.clone()).await {
530 eprintln!(
531 "{}",
532 style(format!(
533 "Warning: Failed to start stdio extension '{}' ({}), continuing without it",
534 extension_str, e
535 ))
536 .yellow()
537 );
538
539 if let Err(debug_err) = offer_extension_debugging_help(
541 &extension_str,
542 &e.to_string(),
543 Arc::clone(&provider_for_display),
544 session_config.interactive,
545 )
546 .await
547 {
548 eprintln!("Note: Could not start debugging session: {}", debug_err);
549 }
550 }
551 }
552
553 for extension_str in session_config.streamable_http_extensions {
555 if let Err(e) = session
556 .add_streamable_http_extension(extension_str.clone())
557 .await
558 {
559 eprintln!(
560 "{}",
561 style(format!(
562 "Warning: Failed to start streamable HTTP extension '{}' ({}), continuing without it",
563 extension_str, e
564 ))
565 .yellow()
566 );
567
568 if let Err(debug_err) = offer_extension_debugging_help(
570 &extension_str,
571 &e.to_string(),
572 Arc::clone(&provider_for_display),
573 session_config.interactive,
574 )
575 .await
576 {
577 eprintln!("Note: Could not start debugging session: {}", debug_err);
578 }
579 }
580 }
581
582 for builtin in session_config.builtins {
584 if let Err(e) = session.add_builtin(builtin.clone()).await {
585 eprintln!(
586 "{}",
587 style(format!(
588 "Warning: Failed to start builtin extension '{}' ({}), continuing without it",
589 builtin, e
590 ))
591 .yellow()
592 );
593
594 if let Err(debug_err) = offer_extension_debugging_help(
596 &builtin,
597 &e.to_string(),
598 Arc::clone(&provider_for_display),
599 session_config.interactive,
600 )
601 .await
602 {
603 eprintln!("Note: Could not start debugging session: {}", debug_err);
604 }
605 }
606 }
607
608 let session_config_for_save = SessionConfig {
609 id: session_id.clone(),
610 schedule_id: None,
611 max_turns: None,
612 retry_config: None,
613 system_prompt: None,
614 };
615
616 if let Err(e) = session
617 .agent
618 .save_extension_state(&session_config_for_save)
619 .await
620 {
621 tracing::warn!("Failed to save initial extension state: {}", e);
622 }
623
624 session
626 .agent
627 .extend_system_prompt(super::prompt::get_cli_prompt())
628 .await;
629
630 if let Some(additional_prompt) = session_config.additional_system_prompt {
631 session.agent.extend_system_prompt(additional_prompt).await;
632 }
633
634 let system_prompt_file: Option<String> = config.get_param("ASTER_SYSTEM_PROMPT_FILE_PATH").ok();
636 if let Some(ref path) = system_prompt_file {
637 let override_prompt =
638 std::fs::read_to_string(path).expect("Failed to read system prompt file");
639 session.agent.override_system_prompt(override_prompt).await;
640 }
641
642 if !session_config.quiet {
644 output::display_session_info(
645 session_config.resume,
646 &provider_name,
647 &model_name,
648 &Some(session_id),
649 Some(&provider_for_display),
650 );
651 }
652 session
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658
659 #[test]
660 fn test_session_builder_config_creation() {
661 let config = SessionBuilderConfig {
662 session_id: None,
663 resume: false,
664 no_session: false,
665 extensions: vec!["echo test".to_string()],
666 streamable_http_extensions: vec!["http://localhost:8080/mcp".to_string()],
667 builtins: vec!["developer".to_string()],
668 extensions_override: None,
669 additional_system_prompt: Some("Test prompt".to_string()),
670 settings: None,
671 provider: None,
672 model: None,
673 debug: true,
674 max_tool_repetitions: Some(5),
675 max_turns: None,
676 scheduled_job_id: None,
677 interactive: true,
678 quiet: false,
679 sub_recipes: None,
680 final_output_response: None,
681 retry_config: None,
682 output_format: "text".to_string(),
683 };
684
685 assert_eq!(config.extensions.len(), 1);
686 assert_eq!(config.streamable_http_extensions.len(), 1);
687 assert_eq!(config.builtins.len(), 1);
688 assert!(config.debug);
689 assert_eq!(config.max_tool_repetitions, Some(5));
690 assert!(config.max_turns.is_none());
691 assert!(config.scheduled_job_id.is_none());
692 assert!(config.interactive);
693 assert!(!config.quiet);
694 }
695
696 #[test]
697 fn test_session_builder_config_default() {
698 let config = SessionBuilderConfig::default();
699
700 assert!(config.session_id.is_none());
701 assert!(!config.resume);
702 assert!(!config.no_session);
703 assert!(config.extensions.is_empty());
704 assert!(config.streamable_http_extensions.is_empty());
705 assert!(config.builtins.is_empty());
706 assert!(config.extensions_override.is_none());
707 assert!(config.additional_system_prompt.is_none());
708 assert!(!config.debug);
709 assert!(config.max_tool_repetitions.is_none());
710 assert!(config.max_turns.is_none());
711 assert!(config.scheduled_job_id.is_none());
712 assert!(!config.interactive);
713 assert!(!config.quiet);
714 assert!(config.final_output_response.is_none());
715 }
716
717 #[tokio::test]
718 async fn test_offer_extension_debugging_help_function_exists() {
719 let extension_name = "test-extension";
725 let error_message = "test error";
726
727 assert_eq!(extension_name, "test-extension");
729 assert_eq!(error_message, "test error");
730 }
731}