use anyhow::Result;
use colored::Colorize;
use std::io::{self, Write};
use crate::args::Cli;
use crate::output::OutputStreams;
fn env_flag_truthy(name: &str) -> bool {
match std::env::var(name) {
Ok(v) => {
let v = v.trim();
v.eq_ignore_ascii_case("1")
|| v.eq_ignore_ascii_case("true")
|| v.eq_ignore_ascii_case("yes")
|| v.eq_ignore_ascii_case("on")
}
Err(_) => false,
}
}
struct ResponseConfig<'a> {
cli: &'a Cli,
path: &'a str,
auto_execute: bool,
dry_run: bool,
}
fn write_execute_json(
streams: &mut OutputStreams,
command: &str,
confidence: f32,
intent: &str,
dry_run: bool,
auto_execute: bool,
) -> Result<()> {
let output = if dry_run {
serde_json::json!({
"type": "execute",
"command": command,
"confidence": confidence,
"intent": intent,
"dry_run": true
})
} else if auto_execute {
serde_json::json!({
"type": "execute",
"command": command,
"confidence": confidence,
"intent": intent,
"auto_execute": true
})
} else {
serde_json::json!({
"type": "confirm",
"command": command,
"confidence": confidence,
"intent": intent
})
};
streams.write_result(&serde_json::to_string_pretty(&output)?)?;
Ok(())
}
fn write_execute_text(
streams: &mut OutputStreams,
command: &str,
confidence: f32,
intent: &str,
dry_run: bool,
auto_execute: bool,
) -> Result<()> {
if dry_run {
streams.write_result(&format!(
"{} {}\n{}: {:.0}%\n{}: {}\n",
"Command:".bold(),
command.green(),
"Confidence".dimmed(),
confidence * 100.0,
"Intent".dimmed(),
intent
))?;
} else if auto_execute {
streams.write_result(&format!(
"{} {} ({:.0}% confidence)\n",
"Executing:".green().bold(),
command,
confidence * 100.0
))?;
} else {
streams.write_result(&format!(
"{} {}\n{}: {:.0}%\n",
"Generated command:".bold(),
command.cyan(),
"Confidence".dimmed(),
confidence * 100.0
))?;
}
Ok(())
}
fn handle_execute_response(
streams: &mut OutputStreams,
config: &ResponseConfig,
command: &str,
confidence: f32,
intent: &str,
) -> Result<()> {
if config.cli.json {
write_execute_json(
streams,
command,
confidence,
intent,
config.dry_run,
config.auto_execute,
)?;
} else {
write_execute_text(
streams,
command,
confidence,
intent,
config.dry_run,
config.auto_execute,
)?;
}
if config.dry_run {
return Ok(());
}
if config.auto_execute {
execute_generated_command(command, config.path, config.cli)?;
} else if !config.cli.json {
if prompt_confirmation("Execute this command?")? {
execute_generated_command(command, config.path, config.cli)?;
} else {
streams.write_diagnostic("Cancelled.\n")?;
}
}
Ok(())
}
fn write_confirm_json(
streams: &mut OutputStreams,
command: &str,
confidence: f32,
prompt: &str,
dry_run: bool,
auto_execute: bool,
) -> Result<()> {
let output = serde_json::json!({
"type": "confirm",
"command": command,
"confidence": confidence,
"prompt": prompt,
"dry_run": dry_run,
"auto_execute": auto_execute
});
streams.write_result(&serde_json::to_string_pretty(&output)?)?;
Ok(())
}
fn write_confirm_text(
streams: &mut OutputStreams,
command: &str,
confidence: f32,
prompt: &str,
dry_run: bool,
) -> Result<()> {
if dry_run {
streams.write_result(&format!(
"{} {}\n{}: {:.0}%\n{}\n",
"Command:".bold(),
command.yellow(),
"Confidence".dimmed(),
confidence * 100.0,
"(Medium confidence - would require confirmation)".dimmed()
))?;
} else {
streams.write_result(&format!(
"{}\n{} {}\n",
prompt.yellow(),
"Command:".bold(),
command.cyan()
))?;
}
Ok(())
}
fn handle_confirm_response(
streams: &mut OutputStreams,
config: &ResponseConfig,
command: &str,
confidence: f32,
prompt: &str,
) -> Result<()> {
if config.cli.json {
write_confirm_json(
streams,
command,
confidence,
prompt,
config.dry_run,
config.auto_execute,
)?;
} else {
write_confirm_text(streams, command, confidence, prompt, config.dry_run)?;
}
if config.dry_run {
return Ok(());
}
let should_execute = if config.cli.json {
config.auto_execute
} else {
config.auto_execute || prompt_confirmation("")?
};
if should_execute {
execute_generated_command(command, config.path, config.cli)?;
} else if !config.cli.json {
streams.write_diagnostic("Cancelled.\n")?;
}
Ok(())
}
fn handle_disambiguate_response(
streams: &mut OutputStreams,
config: &ResponseConfig,
options: &[sqry_nl::DisambiguationOption],
prompt: &str,
) -> Result<()> {
let best_option = select_best_disambiguation(options);
if config.cli.json {
handle_disambiguate_json(streams, config, options, prompt, best_option)?;
} else {
handle_disambiguate_text(streams, config, options, prompt, best_option)?;
}
Ok(())
}
fn select_best_disambiguation(
options: &[sqry_nl::DisambiguationOption],
) -> Option<&sqry_nl::DisambiguationOption> {
options.iter().max_by(|a, b| {
a.confidence
.partial_cmp(&b.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
})
}
fn handle_disambiguate_json(
streams: &mut OutputStreams,
config: &ResponseConfig,
options: &[sqry_nl::DisambiguationOption],
prompt: &str,
best_option: Option<&sqry_nl::DisambiguationOption>,
) -> Result<()> {
let output = serde_json::json!({
"type": "disambiguate",
"prompt": prompt,
"options": options.iter().map(|opt| {
serde_json::json!({
"command": opt.command,
"intent": opt.intent.as_str(),
"description": opt.description,
"confidence": opt.confidence
})
}).collect::<Vec<_>>(),
"auto_execute": config.auto_execute,
"dry_run": config.dry_run
});
streams.write_result(&serde_json::to_string_pretty(&output)?)?;
if let Some(selected) = best_option.filter(|_| config.auto_execute && !config.dry_run) {
execute_generated_command(&selected.command, config.path, config.cli)?;
}
Ok(())
}
fn handle_disambiguate_text(
streams: &mut OutputStreams,
config: &ResponseConfig,
options: &[sqry_nl::DisambiguationOption],
prompt: &str,
best_option: Option<&sqry_nl::DisambiguationOption>,
) -> Result<()> {
streams.write_result(&format!("{}\n\n", prompt.yellow()))?;
for (i, opt) in options.iter().enumerate() {
streams.write_result(&format!(
" {}. {} - {}\n {}\n\n",
i + 1,
opt.description.bold(),
format!("{:.0}%", opt.confidence * 100.0).dimmed(),
opt.command.cyan()
))?;
}
if config.dry_run || options.is_empty() {
return Ok(());
}
if config.auto_execute {
if let Some(selected) = best_option {
streams.write_result(&format!(
"\n{} {}\n",
"Auto-executing highest confidence:".green().bold(),
selected.command
))?;
execute_generated_command(&selected.command, config.path, config.cli)?;
}
return Ok(());
}
execute_disambiguation_choice(streams, config, options)
}
fn execute_disambiguation_choice(
streams: &mut OutputStreams,
config: &ResponseConfig,
options: &[sqry_nl::DisambiguationOption],
) -> Result<()> {
let choice = prompt_choice(options.len())?;
if let Some(idx) = choice {
let selected = &options[idx];
streams.write_result(&format!(
"\n{} {}\n",
"Executing:".green().bold(),
selected.command
))?;
execute_generated_command(&selected.command, config.path, config.cli)?;
} else {
streams.write_diagnostic("Cancelled.\n")?;
}
Ok(())
}
fn handle_reject_response(
streams: &mut OutputStreams,
config: &ResponseConfig,
reason: &str,
suggestions: &[String],
) -> Result<String> {
if config.cli.json {
let output = serde_json::json!({
"type": "reject",
"reason": reason,
"suggestions": suggestions
});
streams.write_result(&serde_json::to_string_pretty(&output)?)?;
} else {
streams.write_diagnostic(&format!(
"{} {}\n",
"Cannot translate:".red().bold(),
reason
))?;
if !suggestions.is_empty() {
streams.write_diagnostic(&format!("\n{}:\n", "Suggestions".yellow()))?;
for suggestion in suggestions {
streams.write_diagnostic(&format!(" • {suggestion}\n"))?;
}
}
}
Ok(format!("Translation rejected: {reason}"))
}
#[allow(clippy::fn_params_excessive_bools, clippy::too_many_arguments)]
pub fn run_ask(
cli: &Cli,
query: &str,
path: &str,
auto_execute: bool,
dry_run: bool,
threshold: f32,
model_dir_override: Option<&std::path::Path>,
allow_unverified_model_flag: bool,
allow_model_download_flag: bool,
) -> Result<()> {
use sqry_nl::{TranslationResponse, Translator, TranslatorConfig};
let mut streams = OutputStreams::with_pager(cli.pager_config());
let allow_unverified_model =
allow_unverified_model_flag || env_flag_truthy("SQRY_NL_ALLOW_UNVERIFIED_MODEL");
let allow_model_download =
allow_model_download_flag || env_flag_truthy("SQRY_NL_ALLOW_DOWNLOAD");
let translator_config = TranslatorConfig {
execute_threshold: threshold,
confirm_threshold: threshold * 0.75, model_dir_override: model_dir_override.map(std::path::Path::to_path_buf),
allow_unverified_model,
allow_model_download,
..Default::default()
};
let mut translator = match Translator::new(translator_config) {
Ok(t) => t,
Err(sqry_nl::NlError::OnnxRuntimeMissing { hint }) => {
return Err(crate::error::CliError::OnnxRuntimeMissing { hint }.into());
}
Err(e) => {
return Err(
anyhow::Error::new(e).context("Failed to initialize natural language translator")
);
}
};
let response = translator.translate(query);
let config = ResponseConfig {
cli,
path,
auto_execute,
dry_run,
};
let reject_error = match response {
TranslationResponse::Execute {
command,
confidence,
intent,
..
} => {
handle_execute_response(&mut streams, &config, &command, confidence, intent.as_str())?;
None
}
TranslationResponse::Confirm {
command,
confidence,
prompt,
} => {
handle_confirm_response(&mut streams, &config, &command, confidence, &prompt)?;
None
}
TranslationResponse::Disambiguate { options, prompt } => {
handle_disambiguate_response(&mut streams, &config, &options, &prompt)?;
None
}
TranslationResponse::Reject {
reason,
suggestions,
} => {
let error_msg = handle_reject_response(&mut streams, &config, &reason, &suggestions)?;
Some(error_msg)
}
};
streams.finish_checked()?;
if let Some(error_msg) = reject_error {
anyhow::bail!("{error_msg}");
}
Ok(())
}
#[derive(Debug, Default)]
struct ParsedCommandArgs {
primary: String,
language: Option<String>,
kind: Option<String>,
limit: Option<u32>,
path_filter: Option<String>,
secondary: Option<String>,
max_depth: Option<u32>,
}
fn extract_flag_value(command: &str, flag: &str) -> Option<String> {
let flag_pos = command.find(flag)?;
let after_flag = &command[flag_pos + flag.len()..];
let trimmed = after_flag.trim_start();
if trimmed.is_empty() {
return None;
}
if let Some(stripped) = trimmed.strip_prefix('"') {
if let Some(end) = stripped.find('"') {
return Some(stripped[..end].to_string());
}
return Some(stripped.to_string());
}
let value = trimmed.split_whitespace().next()?;
Some(value.to_string())
}
fn parse_generated_command(command: &str) -> Result<ParsedCommandArgs> {
let mut args = ParsedCommandArgs::default();
let mut quoted_strings = Vec::new();
let mut in_quote = false;
let mut current_quoted = String::new();
for c in command.chars() {
if c == '"' {
if in_quote {
quoted_strings.push(current_quoted.clone());
current_quoted.clear();
}
in_quote = !in_quote;
} else if in_quote {
current_quoted.push(c);
}
}
if let Some(primary) = quoted_strings.first() {
args.primary.clone_from(primary);
}
if let Some(secondary) = quoted_strings.get(1) {
args.secondary = Some(secondary.clone());
}
args.path_filter = extract_flag_value(command, "--path");
let parts: Vec<&str> = command.split_whitespace().collect();
let mut i = 0;
while i < parts.len() {
match parts[i] {
"--language" if i + 1 < parts.len() => {
args.language = Some(parts[i + 1].to_string());
i += 2;
}
"--kind" if i + 1 < parts.len() => {
args.kind = Some(parts[i + 1].to_string());
i += 2;
}
"--limit" if i + 1 < parts.len() => {
args.limit = parts[i + 1].parse().ok();
i += 2;
}
"--path" => {
i += 2;
}
"--max-depth" if i + 1 < parts.len() => {
args.max_depth = parts[i + 1].parse().ok();
i += 2;
}
_ => {
i += 1;
}
}
}
if args.primary.is_empty() {
anyhow::bail!("Could not extract primary argument from command: {command}");
}
Ok(args)
}
fn build_query_expression(args: &ParsedCommandArgs) -> String {
let mut expr_parts = vec![args.primary.clone()];
if let Some(lang) = &args.language
&& !args.primary.contains("lang:")
&& !args.primary.contains("language:")
{
expr_parts.push(format!("language:{lang}"));
}
if let Some(path) = &args.path_filter
&& !args.primary.contains("path:")
{
if path.contains(' ') {
let escaped = path.replace('"', "\\\"");
expr_parts.push(format!("path:\"{escaped}\""));
} else {
expr_parts.push(format!("path:{path}"));
}
}
expr_parts.join(" ")
}
fn execute_generated_command(command: &str, path: &str, cli: &Cli) -> Result<()> {
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.is_empty() || parts[0] != "sqry" {
anyhow::bail!("Invalid generated command: {command}");
}
if parts.len() < 2 {
anyhow::bail!("Generated command missing subcommand: {command}");
}
let subcommand = parts[1];
match subcommand {
"query" => {
let parsed = parse_generated_command(command)?;
let query_expr = build_query_expression(&parsed);
let result_limit = parsed.limit.map(|l| l as usize);
super::run_query(
cli,
&query_expr,
path,
false,
false,
false,
false,
None,
result_limit,
&[],
)?;
}
"search" => {
let parsed = parse_generated_command(command)?;
super::run_search(cli, &parsed.primary, path, None, false, false)?;
}
"graph" => {
if parts.len() < 3 {
anyhow::bail!("Graph command missing operation: {command}");
}
eprintln!(
"{}",
format!("Graph commands not yet auto-executable: {command}").yellow()
);
}
"index" => {
if command.contains("--status") {
super::run_index_status(cli, path, crate::args::MetricsFormat::Json)?;
} else {
eprintln!(
"{}",
format!("Index build not auto-executable: {command}").yellow()
);
}
}
_ => {
anyhow::bail!("Unsupported generated command: {subcommand}");
}
}
Ok(())
}
#[cfg(test)]
fn extract_quoted_arg(command: &str, _position: usize) -> Result<String> {
if let Some(start) = command.find('"')
&& let Some(end) = command[start + 1..].find('"')
{
return Ok(command[start + 1..start + 1 + end].to_string());
}
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() > 2 {
let arg = parts[2].trim_matches('"');
return Ok(arg.to_string());
}
anyhow::bail!("Could not extract argument from: {command}")
}
fn prompt_confirmation(message: &str) -> Result<bool> {
if message.is_empty() {
eprint!("[y/N] ");
} else {
eprint!("{message} [y/N] ");
}
io::stderr().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
Ok(input.trim().eq_ignore_ascii_case("y") || input.trim().eq_ignore_ascii_case("yes"))
}
fn prompt_choice(max: usize) -> Result<Option<usize>> {
eprint!("Enter choice (1-{max}) or 'c' to cancel: ");
io::stderr().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
let trimmed = input.trim();
if trimmed.eq_ignore_ascii_case("c") || trimmed.is_empty() {
return Ok(None);
}
match trimmed.parse::<usize>() {
Ok(n) if n >= 1 && n <= max => Ok(Some(n - 1)),
_ => {
eprintln!("Invalid choice");
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_quoted_arg() {
let cmd = r#"sqry query "kind:function""#;
let arg = extract_quoted_arg(cmd, 2).unwrap();
assert_eq!(arg, "kind:function");
}
#[test]
fn test_extract_quoted_arg_with_spaces() {
let cmd = r#"sqry search "hello world""#;
let arg = extract_quoted_arg(cmd, 2).unwrap();
assert_eq!(arg, "hello world");
}
#[test]
fn test_parse_generated_command_basic() {
let cmd = r#"sqry query "authenticate" --limit 100"#;
let parsed = parse_generated_command(cmd).unwrap();
assert_eq!(parsed.primary, "authenticate");
assert_eq!(parsed.limit, Some(100));
assert!(parsed.language.is_none());
assert!(parsed.kind.is_none());
}
#[test]
fn test_parse_generated_command_with_all_flags() {
let cmd = r#"sqry query "login" --language rust --kind function --limit 50"#;
let parsed = parse_generated_command(cmd).unwrap();
assert_eq!(parsed.primary, "login");
assert_eq!(parsed.language.as_deref(), Some("rust"));
assert_eq!(parsed.kind.as_deref(), Some("function"));
assert_eq!(parsed.limit, Some(50));
}
#[test]
fn test_parse_generated_command_trace_path() {
let cmd = r#"sqry graph trace-path "source" "target" --max-depth 5"#;
let parsed = parse_generated_command(cmd).unwrap();
assert_eq!(parsed.primary, "source");
assert_eq!(parsed.secondary.as_deref(), Some("target"));
assert_eq!(parsed.max_depth, Some(5));
}
#[test]
fn test_build_query_expression_basic() {
let args = ParsedCommandArgs {
primary: "authenticate".to_string(),
..Default::default()
};
let expr = build_query_expression(&args);
assert_eq!(expr, "authenticate");
}
#[test]
fn test_build_query_expression_with_predicates() {
let args = ParsedCommandArgs {
primary: "kind:function login".to_string(), language: Some("rust".to_string()),
kind: Some("function".to_string()),
limit: Some(50), ..Default::default()
};
let expr = build_query_expression(&args);
assert!(expr.contains("login"));
assert!(expr.contains("kind:function"));
assert!(expr.contains("language:rust"));
assert!(!expr.contains("limit:"));
}
#[test]
fn test_build_query_expression_with_path() {
let args = ParsedCommandArgs {
primary: "test".to_string(),
path_filter: Some("src/lib.rs".to_string()),
..Default::default()
};
let expr = build_query_expression(&args);
assert!(expr.contains("path:src/lib.rs"));
}
#[test]
fn test_build_query_expression_with_path_spaces() {
let args = ParsedCommandArgs {
primary: "login".to_string(),
path_filter: Some("src/api services".to_string()),
language: Some("rust".to_string()),
..Default::default()
};
let expr = build_query_expression(&args);
assert!(expr.contains(r#"path:"src/api services""#));
assert!(expr.contains("language:rust"));
}
#[test]
fn test_extract_flag_value_unquoted() {
let cmd = r#"sqry query "test" --limit 50"#;
assert_eq!(extract_flag_value(cmd, "--limit"), Some("50".to_string()));
}
#[test]
fn test_extract_flag_value_quoted() {
let cmd = r#"sqry query "test" --path "src/api services""#;
assert_eq!(
extract_flag_value(cmd, "--path"),
Some("src/api services".to_string())
);
}
#[test]
fn test_extract_flag_value_not_present() {
let cmd = r#"sqry query "test""#;
assert_eq!(extract_flag_value(cmd, "--limit"), None);
}
#[test]
fn test_parse_generated_command_with_path_spaces() {
let cmd = r#"sqry query "login" --path "src/api services" --language rust"#;
let parsed = parse_generated_command(cmd).unwrap();
assert_eq!(parsed.primary, "login");
assert_eq!(parsed.path_filter.as_deref(), Some("src/api services"));
assert_eq!(parsed.language.as_deref(), Some("rust"));
}
}