Skip to main content

aster_cli/session/
builder.rs

1use super::output;
2use super::CliSession;
3use aster::agents::types::{RetryConfig, SessionConfig};
4use aster::agents::Agent;
5use aster::config::{
6    extensions::get_extension_by_name, get_all_extensions, get_enabled_extensions, Config,
7    ExtensionConfig,
8};
9use aster::providers::create;
10use aster::recipe::{Response, SubRecipe};
11use console::style;
12
13use aster::agents::extension::PlatformExtensionContext;
14use aster::session::session_manager::SessionType;
15use aster::session::SessionManager;
16use aster::session::{EnabledExtensionsState, ExtensionState};
17use rustyline::EditMode;
18use std::collections::HashSet;
19use std::process;
20use std::sync::Arc;
21use tokio::task::JoinSet;
22
23/// Configuration for building a new Aster session
24///
25/// This struct contains all the parameters needed to create a new session,
26/// including session identification, extension configuration, and debug settings.
27#[derive(Clone, Debug)]
28pub struct SessionBuilderConfig {
29    /// Session id, optional need to deduce from context
30    pub session_id: Option<String>,
31    /// Whether to resume an existing session
32    pub resume: bool,
33    /// Whether to run without a session file
34    pub no_session: bool,
35    /// List of stdio extension commands to add
36    pub extensions: Vec<String>,
37    /// List of streamable HTTP extension commands to add
38    pub streamable_http_extensions: Vec<String>,
39    /// List of builtin extension commands to add
40    pub builtins: Vec<String>,
41    /// List of extensions to enable, enable only this set and ignore configured ones
42    pub extensions_override: Option<Vec<ExtensionConfig>>,
43    /// Any additional system prompt to append to the default
44    pub additional_system_prompt: Option<String>,
45    /// Settings to override the global Aster settings
46    pub settings: Option<SessionSettings>,
47    /// Provider override from CLI arguments
48    pub provider: Option<String>,
49    /// Model override from CLI arguments
50    pub model: Option<String>,
51    /// Enable debug printing
52    pub debug: bool,
53    /// Maximum number of consecutive identical tool calls allowed
54    pub max_tool_repetitions: Option<u32>,
55    /// Maximum number of turns (iterations) allowed without user input
56    pub max_turns: Option<u32>,
57    /// ID of the scheduled job that triggered this session (if any)
58    pub scheduled_job_id: Option<String>,
59    /// Whether this session will be used interactively (affects debugging prompts)
60    pub interactive: bool,
61    /// Quiet mode - suppress non-response output
62    pub quiet: bool,
63    /// Sub-recipes to add to the session
64    pub sub_recipes: Option<Vec<SubRecipe>>,
65    /// Final output expected response
66    pub final_output_response: Option<Response>,
67    /// Retry configuration for automated validation and recovery
68    pub retry_config: Option<RetryConfig>,
69    /// Output format (text, json)
70    pub output_format: String,
71}
72
73/// Manual implementation of Default to ensure proper initialization of output_format
74/// This struct requires explicit default value for output_format field
75impl Default for SessionBuilderConfig {
76    fn default() -> Self {
77        SessionBuilderConfig {
78            session_id: None,
79            resume: false,
80            no_session: false,
81            extensions: Vec::new(),
82            streamable_http_extensions: Vec::new(),
83            builtins: Vec::new(),
84            extensions_override: None,
85            additional_system_prompt: None,
86            settings: None,
87            provider: None,
88            model: None,
89            debug: false,
90            max_tool_repetitions: None,
91            max_turns: None,
92            scheduled_job_id: None,
93            interactive: false,
94            quiet: false,
95            sub_recipes: None,
96            final_output_response: None,
97            retry_config: None,
98            output_format: "text".to_string(),
99        }
100    }
101}
102
103/// Offers to help debug an extension failure by creating a minimal debugging session
104async fn offer_extension_debugging_help(
105    extension_name: &str,
106    error_message: &str,
107    provider: Arc<dyn aster::providers::base::Provider>,
108    interactive: bool,
109) -> Result<(), anyhow::Error> {
110    // Only offer debugging help in interactive mode
111    if !interactive {
112        return Ok(());
113    }
114
115    let help_prompt = format!(
116        "Would you like me to help debug the '{}' extension failure?",
117        extension_name
118    );
119
120    let should_help = match cliclack::confirm(help_prompt)
121        .initial_value(false)
122        .interact()
123    {
124        Ok(choice) => choice,
125        Err(e) => {
126            if e.kind() == std::io::ErrorKind::Interrupted {
127                return Ok(());
128            } else {
129                return Err(e.into());
130            }
131        }
132    };
133
134    if !should_help {
135        return Ok(());
136    }
137
138    println!("{}", style("🔧 Starting debugging session...").cyan());
139
140    // Create a debugging prompt with context about the extension failure
141    let debug_prompt = format!(
142        "I'm having trouble starting an extension called '{}'. Here's the error I encountered:\n\n{}\n\nCan you help me diagnose what might be wrong and suggest how to fix it? Please consider common issues like:\n- Missing dependencies or tools\n- Configuration problems\n- Network connectivity (for remote extensions)\n- Permission issues\n- Path or environment variable problems",
143        extension_name,
144        error_message
145    );
146
147    // Create a minimal agent for debugging
148    let debug_agent = Agent::new();
149
150    let session = SessionManager::create_session(
151        std::env::current_dir()?,
152        "CLI Session".to_string(),
153        SessionType::Hidden,
154    )
155    .await?;
156
157    debug_agent.update_provider(provider, &session.id).await?;
158
159    // Add the developer extension if available to help with debugging
160    let extensions = get_all_extensions();
161    for ext_wrapper in extensions {
162        if ext_wrapper.enabled && ext_wrapper.config.name() == "developer" {
163            if let Err(e) = debug_agent.add_extension(ext_wrapper.config).await {
164                // If we can't add developer extension, continue without it
165                eprintln!(
166                    "Note: Could not load developer extension for debugging: {}",
167                    e
168                );
169            }
170            break;
171        }
172    }
173
174    let mut debug_session = CliSession::new(
175        debug_agent,
176        session.id,
177        false,
178        None,
179        None,
180        None,
181        None,
182        "text".to_string(),
183    )
184    .await;
185
186    // Process the debugging request
187    println!("{}", style("Analyzing the extension failure...").yellow());
188    match debug_session.headless(debug_prompt).await {
189        Ok(_) => {
190            println!(
191                "{}",
192                style("✅ Debugging session completed. Check the suggestions above.").green()
193            );
194        }
195        Err(e) => {
196            eprintln!(
197                "{}",
198                style(format!("❌ Debugging session failed: {}", e)).red()
199            );
200        }
201    }
202    Ok(())
203}
204
205fn check_missing_extensions_or_exit(saved_extensions: &[ExtensionConfig]) {
206    let missing: Vec<_> = saved_extensions
207        .iter()
208        .filter(|ext| get_extension_by_name(&ext.name()).is_none())
209        .cloned()
210        .collect();
211
212    if !missing.is_empty() {
213        let names = missing
214            .iter()
215            .map(|e| e.name())
216            .collect::<Vec<_>>()
217            .join(", ");
218
219        if !cliclack::confirm(format!(
220            "Extension(s) {} from previous session are no longer available. Restore for this session?",
221            names
222        ))
223        .initial_value(true)
224        .interact()
225        .unwrap_or(false)
226        {
227            println!("{}", style("Resume cancelled.").yellow());
228            process::exit(0);
229        }
230    }
231}
232
233#[derive(Clone, Debug, Default)]
234pub struct SessionSettings {
235    pub aster_model: Option<String>,
236    pub aster_provider: Option<String>,
237    pub temperature: Option<f32>,
238}
239
240pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession {
241    aster::posthog::set_session_context("cli", session_config.resume);
242
243    let config = Config::global();
244
245    let (saved_provider, saved_model_config) = if session_config.resume {
246        if let Some(ref session_id) = session_config.session_id {
247            match SessionManager::get_session(session_id, false).await {
248                Ok(session_data) => (session_data.provider_name, session_data.model_config),
249                Err(_) => (None, None),
250            }
251        } else {
252            (None, None)
253        }
254    } else {
255        (None, None)
256    };
257
258    let provider_name = session_config
259        .provider
260        .or(saved_provider)
261        .or_else(|| {
262            session_config
263                .settings
264                .as_ref()
265                .and_then(|s| s.aster_provider.clone())
266        })
267        .or_else(|| config.get_aster_provider().ok())
268        .expect("No provider configured. Run 'aster configure' first");
269
270    let model_name = session_config
271        .model
272        .or_else(|| saved_model_config.as_ref().map(|mc| mc.model_name.clone()))
273        .or_else(|| {
274            session_config
275                .settings
276                .as_ref()
277                .and_then(|s| s.aster_model.clone())
278        })
279        .or_else(|| config.get_aster_model().ok())
280        .expect("No model configured. Run 'aster configure' first");
281
282    let model_config = if session_config.resume
283        && saved_model_config
284            .as_ref()
285            .is_some_and(|mc| mc.model_name == model_name)
286    {
287        let mut config = saved_model_config.unwrap();
288        if let Some(temp) = session_config.settings.as_ref().and_then(|s| s.temperature) {
289            config = config.with_temperature(Some(temp));
290        }
291        config
292    } else {
293        let temperature = session_config.settings.as_ref().and_then(|s| s.temperature);
294        aster::model::ModelConfig::new(&model_name)
295            .unwrap_or_else(|e| {
296                output::render_error(&format!("Failed to create model configuration: {}", e));
297                process::exit(1);
298            })
299            .with_temperature(temperature)
300    };
301
302    let agent: Agent = Agent::new();
303
304    agent
305        .apply_recipe_components(
306            session_config.sub_recipes,
307            session_config.final_output_response,
308            true,
309        )
310        .await;
311
312    let new_provider = match create(&provider_name, model_config).await {
313        Ok(provider) => provider,
314        Err(e) => {
315            output::render_error(&format!(
316                "Error {}.\n\
317                Please check your system keychain and run 'aster configure' again.\n\
318                If your system is unable to use the keyring, please try setting secret key(s) via environment variables.\n\
319                For more info, see: https://astercloud.github.io/aster-rust/docs/troubleshooting/#keychainkeyring-errors",
320                e
321            ));
322            process::exit(1);
323        }
324    };
325    let provider_for_display = Arc::clone(&new_provider);
326
327    if let Some(lead_worker) = new_provider.as_lead_worker() {
328        let (lead_model, worker_model) = lead_worker.get_model_info();
329        tracing::info!(
330            "🤖 Lead/Worker Mode Enabled: Lead model (first 3 turns): {}, Worker model (turn 4+): {}, Auto-fallback on failures: Enabled",
331            lead_model,
332            worker_model
333        );
334    } else {
335        tracing::info!("🤖 Using model: {}", model_name);
336    }
337
338    let session_id: String = if session_config.no_session {
339        let working_dir = std::env::current_dir().expect("Could not get working directory");
340        let session = SessionManager::create_session(
341            working_dir,
342            "CLI Session".to_string(),
343            SessionType::Hidden,
344        )
345        .await
346        .expect("Could not create session");
347        session.id
348    } else if session_config.resume {
349        if let Some(session_id) = session_config.session_id {
350            match SessionManager::get_session(&session_id, false).await {
351                Ok(_) => session_id,
352                Err(_) => {
353                    output::render_error(&format!(
354                        "Cannot resume session {} - no such session exists",
355                        style(&session_id).cyan()
356                    ));
357                    process::exit(1);
358                }
359            }
360        } else {
361            match SessionManager::list_sessions().await {
362                Ok(sessions) if !sessions.is_empty() => sessions[0].id.clone(),
363                _ => {
364                    output::render_error("Cannot resume - no previous sessions found");
365                    process::exit(1);
366                }
367            }
368        }
369    } else {
370        session_config.session_id.unwrap()
371    };
372
373    agent
374        .update_provider(new_provider, &session_id)
375        .await
376        .unwrap_or_else(|e| {
377            output::render_error(&format!("Failed to initialize agent: {}", e));
378            process::exit(1);
379        });
380
381    agent
382        .extension_manager
383        .set_context(PlatformExtensionContext {
384            session_id: Some(session_id.clone()),
385            extension_manager: Some(Arc::downgrade(&agent.extension_manager)),
386        })
387        .await;
388
389    if session_config.resume {
390        let session = SessionManager::get_session(&session_id, false)
391            .await
392            .unwrap_or_else(|e| {
393                output::render_error(&format!("Failed to read session metadata: {}", e));
394                process::exit(1);
395            });
396
397        let current_workdir =
398            std::env::current_dir().expect("Failed to get current working directory");
399        if current_workdir != session.working_dir {
400            let change_workdir = cliclack::confirm(format!("{} The original working directory of this session was set to {}. Your current directory is {}. Do you want to switch back to the original working directory?", style("WARNING:").yellow(), style(session.working_dir.display()).cyan(), style(current_workdir.display()).cyan()))
401                    .initial_value(true)
402                    .interact().expect("Failed to get user input");
403
404            if change_workdir {
405                if !session.working_dir.exists() {
406                    output::render_error(&format!(
407                        "Cannot switch to original working directory - {} no longer exists",
408                        style(session.working_dir.display()).cyan()
409                    ));
410                } else if let Err(e) = std::env::set_current_dir(&session.working_dir) {
411                    output::render_error(&format!(
412                        "Failed to switch to original working directory: {}",
413                        e
414                    ));
415                }
416            }
417        }
418    }
419
420    // Setup extensions for the agent
421    // Extensions need to be added after the session is created because we change directory when resuming a session
422
423    for warning in aster::config::get_warnings() {
424        eprintln!("{}", style(format!("Warning: {}", warning)).yellow());
425    }
426
427    // If we get extensions_override, only run those extensions and none other
428    let extensions_to_run: Vec<_> = if let Some(extensions) = session_config.extensions_override {
429        extensions.into_iter().collect()
430    } else if session_config.resume {
431        match SessionManager::get_session(&session_id, false).await {
432            Ok(session_data) => {
433                if let Some(saved_state) =
434                    EnabledExtensionsState::from_extension_data(&session_data.extension_data)
435                {
436                    check_missing_extensions_or_exit(&saved_state.extensions);
437                    saved_state.extensions
438                } else {
439                    get_enabled_extensions()
440                }
441            }
442            _ => get_enabled_extensions(),
443        }
444    } else {
445        get_enabled_extensions()
446    };
447
448    let mut set = JoinSet::new();
449    let agent_ptr = Arc::new(agent);
450
451    let mut waiting_on = HashSet::new();
452    for extension in extensions_to_run {
453        waiting_on.insert(extension.name());
454        let agent_ptr = agent_ptr.clone();
455        set.spawn(async move {
456            (
457                extension.name(),
458                agent_ptr.add_extension(extension.clone()).await,
459            )
460        });
461    }
462
463    let get_message = |waiting_on: &HashSet<String>| {
464        let mut names: Vec<_> = waiting_on.iter().cloned().collect();
465        names.sort();
466        format!("starting {} extensions: {}", names.len(), names.join(", "))
467    };
468
469    let spinner = cliclack::spinner();
470    spinner.start(get_message(&waiting_on));
471
472    let mut offer_debug = Vec::new();
473    while let Some(result) = set.join_next().await {
474        match result {
475            Ok((name, Ok(_))) => {
476                waiting_on.remove(&name);
477                spinner.set_message(get_message(&waiting_on));
478            }
479            Ok((name, Err(e))) => offer_debug.push((name, e)),
480            Err(e) => tracing::error!("failed to add extension: {}", e),
481        }
482    }
483
484    spinner.clear();
485
486    for (name, err) in offer_debug {
487        if let Err(debug_err) = offer_extension_debugging_help(
488            &name,
489            &err.to_string(),
490            Arc::clone(&provider_for_display),
491            session_config.interactive,
492        )
493        .await
494        {
495            eprintln!("Note: Could not start debugging session: {}", debug_err);
496        }
497    }
498
499    // Determine editor mode
500    let edit_mode = config
501        .get_param::<String>("EDIT_MODE")
502        .ok()
503        .and_then(|edit_mode| match edit_mode.to_lowercase().as_str() {
504            "emacs" => Some(EditMode::Emacs),
505            "vi" => Some(EditMode::Vi),
506            _ => {
507                eprintln!("Invalid EDIT_MODE specified, defaulting to Emacs");
508                None
509            }
510        });
511
512    let debug_mode = session_config.debug || config.get_param("ASTER_DEBUG").unwrap_or(false);
513
514    // Create new session
515    let mut session = CliSession::new(
516        Arc::try_unwrap(agent_ptr).unwrap_or_else(|_| panic!("There should be no more references")),
517        session_id.clone(),
518        debug_mode,
519        session_config.scheduled_job_id.clone(),
520        session_config.max_turns,
521        edit_mode,
522        session_config.retry_config.clone(),
523        session_config.output_format.clone(),
524    )
525    .await;
526
527    // Add stdio extensions if provided
528    for extension_str in session_config.extensions {
529        if let Err(e) = session.add_extension(extension_str.clone()).await {
530            eprintln!(
531                "{}",
532                style(format!(
533                    "Warning: Failed to start stdio extension '{}' ({}), continuing without it",
534                    extension_str, e
535                ))
536                .yellow()
537            );
538
539            // Offer debugging help
540            if let Err(debug_err) = offer_extension_debugging_help(
541                &extension_str,
542                &e.to_string(),
543                Arc::clone(&provider_for_display),
544                session_config.interactive,
545            )
546            .await
547            {
548                eprintln!("Note: Could not start debugging session: {}", debug_err);
549            }
550        }
551    }
552
553    // Add streamable HTTP extensions if provided
554    for extension_str in session_config.streamable_http_extensions {
555        if let Err(e) = session
556            .add_streamable_http_extension(extension_str.clone())
557            .await
558        {
559            eprintln!(
560                "{}",
561                style(format!(
562                    "Warning: Failed to start streamable HTTP extension '{}' ({}), continuing without it",
563                    extension_str, e
564                ))
565                .yellow()
566            );
567
568            // Offer debugging help
569            if let Err(debug_err) = offer_extension_debugging_help(
570                &extension_str,
571                &e.to_string(),
572                Arc::clone(&provider_for_display),
573                session_config.interactive,
574            )
575            .await
576            {
577                eprintln!("Note: Could not start debugging session: {}", debug_err);
578            }
579        }
580    }
581
582    // Add builtin extensions
583    for builtin in session_config.builtins {
584        if let Err(e) = session.add_builtin(builtin.clone()).await {
585            eprintln!(
586                "{}",
587                style(format!(
588                    "Warning: Failed to start builtin extension '{}' ({}), continuing without it",
589                    builtin, e
590                ))
591                .yellow()
592            );
593
594            // Offer debugging help
595            if let Err(debug_err) = offer_extension_debugging_help(
596                &builtin,
597                &e.to_string(),
598                Arc::clone(&provider_for_display),
599                session_config.interactive,
600            )
601            .await
602            {
603                eprintln!("Note: Could not start debugging session: {}", debug_err);
604            }
605        }
606    }
607
608    let session_config_for_save = SessionConfig {
609        id: session_id.clone(),
610        schedule_id: None,
611        max_turns: None,
612        retry_config: None,
613        system_prompt: None,
614    };
615
616    if let Err(e) = session
617        .agent
618        .save_extension_state(&session_config_for_save)
619        .await
620    {
621        tracing::warn!("Failed to save initial extension state: {}", e);
622    }
623
624    // Add CLI-specific system prompt extension
625    session
626        .agent
627        .extend_system_prompt(super::prompt::get_cli_prompt())
628        .await;
629
630    if let Some(additional_prompt) = session_config.additional_system_prompt {
631        session.agent.extend_system_prompt(additional_prompt).await;
632    }
633
634    // Only override system prompt if a system override exists
635    let system_prompt_file: Option<String> = config.get_param("ASTER_SYSTEM_PROMPT_FILE_PATH").ok();
636    if let Some(ref path) = system_prompt_file {
637        let override_prompt =
638            std::fs::read_to_string(path).expect("Failed to read system prompt file");
639        session.agent.override_system_prompt(override_prompt).await;
640    }
641
642    // Display session information unless in quiet mode
643    if !session_config.quiet {
644        output::display_session_info(
645            session_config.resume,
646            &provider_name,
647            &model_name,
648            &Some(session_id),
649            Some(&provider_for_display),
650        );
651    }
652    session
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658
659    #[test]
660    fn test_session_builder_config_creation() {
661        let config = SessionBuilderConfig {
662            session_id: None,
663            resume: false,
664            no_session: false,
665            extensions: vec!["echo test".to_string()],
666            streamable_http_extensions: vec!["http://localhost:8080/mcp".to_string()],
667            builtins: vec!["developer".to_string()],
668            extensions_override: None,
669            additional_system_prompt: Some("Test prompt".to_string()),
670            settings: None,
671            provider: None,
672            model: None,
673            debug: true,
674            max_tool_repetitions: Some(5),
675            max_turns: None,
676            scheduled_job_id: None,
677            interactive: true,
678            quiet: false,
679            sub_recipes: None,
680            final_output_response: None,
681            retry_config: None,
682            output_format: "text".to_string(),
683        };
684
685        assert_eq!(config.extensions.len(), 1);
686        assert_eq!(config.streamable_http_extensions.len(), 1);
687        assert_eq!(config.builtins.len(), 1);
688        assert!(config.debug);
689        assert_eq!(config.max_tool_repetitions, Some(5));
690        assert!(config.max_turns.is_none());
691        assert!(config.scheduled_job_id.is_none());
692        assert!(config.interactive);
693        assert!(!config.quiet);
694    }
695
696    #[test]
697    fn test_session_builder_config_default() {
698        let config = SessionBuilderConfig::default();
699
700        assert!(config.session_id.is_none());
701        assert!(!config.resume);
702        assert!(!config.no_session);
703        assert!(config.extensions.is_empty());
704        assert!(config.streamable_http_extensions.is_empty());
705        assert!(config.builtins.is_empty());
706        assert!(config.extensions_override.is_none());
707        assert!(config.additional_system_prompt.is_none());
708        assert!(!config.debug);
709        assert!(config.max_tool_repetitions.is_none());
710        assert!(config.max_turns.is_none());
711        assert!(config.scheduled_job_id.is_none());
712        assert!(!config.interactive);
713        assert!(!config.quiet);
714        assert!(config.final_output_response.is_none());
715    }
716
717    #[tokio::test]
718    async fn test_offer_extension_debugging_help_function_exists() {
719        // This test just verifies the function compiles and can be called
720        // We can't easily test the interactive parts without mocking
721
722        // We can't actually test the full function without a real provider and user interaction
723        // But we can at least verify it compiles and the function signature is correct
724        let extension_name = "test-extension";
725        let error_message = "test error";
726
727        // This test mainly serves as a compilation check
728        assert_eq!(extension_name, "test-extension");
729        assert_eq!(error_message, "test error");
730    }
731}