1use anyhow::{Context, Result};
7use colored::Colorize;
8use std::io::{self, Write};
9
10use crate::args::Cli;
11use crate::output::OutputStreams;
12
13struct ResponseConfig<'a> {
15 cli: &'a Cli,
16 path: &'a str,
17 auto_execute: bool,
18 dry_run: bool,
19}
20
21fn write_execute_json(
23 streams: &mut OutputStreams,
24 command: &str,
25 confidence: f32,
26 intent: &str,
27 dry_run: bool,
28 auto_execute: bool,
29) -> Result<()> {
30 let output = if dry_run {
31 serde_json::json!({
32 "type": "execute",
33 "command": command,
34 "confidence": confidence,
35 "intent": intent,
36 "dry_run": true
37 })
38 } else if auto_execute {
39 serde_json::json!({
40 "type": "execute",
41 "command": command,
42 "confidence": confidence,
43 "intent": intent,
44 "auto_execute": true
45 })
46 } else {
47 serde_json::json!({
48 "type": "confirm",
49 "command": command,
50 "confidence": confidence,
51 "intent": intent
52 })
53 };
54 streams.write_result(&serde_json::to_string_pretty(&output)?)?;
55 Ok(())
56}
57
58fn write_execute_text(
60 streams: &mut OutputStreams,
61 command: &str,
62 confidence: f32,
63 intent: &str,
64 dry_run: bool,
65 auto_execute: bool,
66) -> Result<()> {
67 if dry_run {
68 streams.write_result(&format!(
69 "{} {}\n{}: {:.0}%\n{}: {}\n",
70 "Command:".bold(),
71 command.green(),
72 "Confidence".dimmed(),
73 confidence * 100.0,
74 "Intent".dimmed(),
75 intent
76 ))?;
77 } else if auto_execute {
78 streams.write_result(&format!(
79 "{} {} ({:.0}% confidence)\n",
80 "Executing:".green().bold(),
81 command,
82 confidence * 100.0
83 ))?;
84 } else {
85 streams.write_result(&format!(
86 "{} {}\n{}: {:.0}%\n",
87 "Generated command:".bold(),
88 command.cyan(),
89 "Confidence".dimmed(),
90 confidence * 100.0
91 ))?;
92 }
93 Ok(())
94}
95
96fn handle_execute_response(
98 streams: &mut OutputStreams,
99 config: &ResponseConfig,
100 command: &str,
101 confidence: f32,
102 intent: &str,
103) -> Result<()> {
104 if config.cli.json {
105 write_execute_json(
106 streams,
107 command,
108 confidence,
109 intent,
110 config.dry_run,
111 config.auto_execute,
112 )?;
113 } else {
114 write_execute_text(
115 streams,
116 command,
117 confidence,
118 intent,
119 config.dry_run,
120 config.auto_execute,
121 )?;
122 }
123
124 if config.dry_run {
125 return Ok(());
126 }
127
128 if config.auto_execute {
129 execute_generated_command(command, config.path, config.cli)?;
130 } else if !config.cli.json {
131 if prompt_confirmation("Execute this command?")? {
133 execute_generated_command(command, config.path, config.cli)?;
134 } else {
135 streams.write_diagnostic("Cancelled.\n")?;
136 }
137 }
138
139 Ok(())
140}
141
142fn write_confirm_json(
144 streams: &mut OutputStreams,
145 command: &str,
146 confidence: f32,
147 prompt: &str,
148 dry_run: bool,
149 auto_execute: bool,
150) -> Result<()> {
151 let output = serde_json::json!({
152 "type": "confirm",
153 "command": command,
154 "confidence": confidence,
155 "prompt": prompt,
156 "dry_run": dry_run,
157 "auto_execute": auto_execute
158 });
159 streams.write_result(&serde_json::to_string_pretty(&output)?)?;
160 Ok(())
161}
162
163fn write_confirm_text(
165 streams: &mut OutputStreams,
166 command: &str,
167 confidence: f32,
168 prompt: &str,
169 dry_run: bool,
170) -> Result<()> {
171 if dry_run {
172 streams.write_result(&format!(
173 "{} {}\n{}: {:.0}%\n{}\n",
174 "Command:".bold(),
175 command.yellow(),
176 "Confidence".dimmed(),
177 confidence * 100.0,
178 "(Medium confidence - would require confirmation)".dimmed()
179 ))?;
180 } else {
181 streams.write_result(&format!(
182 "{}\n{} {}\n",
183 prompt.yellow(),
184 "Command:".bold(),
185 command.cyan()
186 ))?;
187 }
188 Ok(())
189}
190
191fn handle_confirm_response(
193 streams: &mut OutputStreams,
194 config: &ResponseConfig,
195 command: &str,
196 confidence: f32,
197 prompt: &str,
198) -> Result<()> {
199 if config.cli.json {
200 write_confirm_json(
201 streams,
202 command,
203 confidence,
204 prompt,
205 config.dry_run,
206 config.auto_execute,
207 )?;
208 } else {
209 write_confirm_text(streams, command, confidence, prompt, config.dry_run)?;
210 }
211
212 if config.dry_run {
213 return Ok(());
214 }
215
216 let should_execute = if config.cli.json {
218 config.auto_execute
219 } else {
220 config.auto_execute || prompt_confirmation("")?
221 };
222
223 if should_execute {
224 execute_generated_command(command, config.path, config.cli)?;
225 } else if !config.cli.json {
226 streams.write_diagnostic("Cancelled.\n")?;
227 }
228
229 Ok(())
230}
231
232fn handle_disambiguate_response(
234 streams: &mut OutputStreams,
235 config: &ResponseConfig,
236 options: &[sqry_nl::DisambiguationOption],
237 prompt: &str,
238) -> Result<()> {
239 let best_option = select_best_disambiguation(options);
240
241 if config.cli.json {
242 handle_disambiguate_json(streams, config, options, prompt, best_option)?;
243 } else {
244 handle_disambiguate_text(streams, config, options, prompt, best_option)?;
245 }
246
247 Ok(())
248}
249
250fn select_best_disambiguation(
251 options: &[sqry_nl::DisambiguationOption],
252) -> Option<&sqry_nl::DisambiguationOption> {
253 options.iter().max_by(|a, b| {
254 a.confidence
255 .partial_cmp(&b.confidence)
256 .unwrap_or(std::cmp::Ordering::Equal)
257 })
258}
259
260fn handle_disambiguate_json(
261 streams: &mut OutputStreams,
262 config: &ResponseConfig,
263 options: &[sqry_nl::DisambiguationOption],
264 prompt: &str,
265 best_option: Option<&sqry_nl::DisambiguationOption>,
266) -> Result<()> {
267 let output = serde_json::json!({
268 "type": "disambiguate",
269 "prompt": prompt,
270 "options": options.iter().map(|opt| {
271 serde_json::json!({
272 "command": opt.command,
273 "intent": opt.intent.as_str(),
274 "description": opt.description,
275 "confidence": opt.confidence
276 })
277 }).collect::<Vec<_>>(),
278 "auto_execute": config.auto_execute,
279 "dry_run": config.dry_run
280 });
281 streams.write_result(&serde_json::to_string_pretty(&output)?)?;
282
283 if let Some(selected) = best_option.filter(|_| config.auto_execute && !config.dry_run) {
284 execute_generated_command(&selected.command, config.path, config.cli)?;
285 }
286
287 Ok(())
288}
289
290fn handle_disambiguate_text(
291 streams: &mut OutputStreams,
292 config: &ResponseConfig,
293 options: &[sqry_nl::DisambiguationOption],
294 prompt: &str,
295 best_option: Option<&sqry_nl::DisambiguationOption>,
296) -> Result<()> {
297 streams.write_result(&format!("{}\n\n", prompt.yellow()))?;
298
299 for (i, opt) in options.iter().enumerate() {
300 streams.write_result(&format!(
301 " {}. {} - {}\n {}\n\n",
302 i + 1,
303 opt.description.bold(),
304 format!("{:.0}%", opt.confidence * 100.0).dimmed(),
305 opt.command.cyan()
306 ))?;
307 }
308
309 if config.dry_run || options.is_empty() {
310 return Ok(());
311 }
312
313 if config.auto_execute {
314 if let Some(selected) = best_option {
315 streams.write_result(&format!(
316 "\n{} {}\n",
317 "Auto-executing highest confidence:".green().bold(),
318 selected.command
319 ))?;
320 execute_generated_command(&selected.command, config.path, config.cli)?;
321 }
322 return Ok(());
323 }
324
325 execute_disambiguation_choice(streams, config, options)
326}
327
328fn execute_disambiguation_choice(
329 streams: &mut OutputStreams,
330 config: &ResponseConfig,
331 options: &[sqry_nl::DisambiguationOption],
332) -> Result<()> {
333 let choice = prompt_choice(options.len())?;
334 if let Some(idx) = choice {
335 let selected = &options[idx];
336 streams.write_result(&format!(
337 "\n{} {}\n",
338 "Executing:".green().bold(),
339 selected.command
340 ))?;
341 execute_generated_command(&selected.command, config.path, config.cli)?;
342 } else {
343 streams.write_diagnostic("Cancelled.\n")?;
344 }
345 Ok(())
346}
347
348fn handle_reject_response(
351 streams: &mut OutputStreams,
352 config: &ResponseConfig,
353 reason: &str,
354 suggestions: &[String],
355) -> Result<String> {
356 if config.cli.json {
357 let output = serde_json::json!({
358 "type": "reject",
359 "reason": reason,
360 "suggestions": suggestions
361 });
362 streams.write_result(&serde_json::to_string_pretty(&output)?)?;
363 } else {
364 streams.write_diagnostic(&format!(
365 "{} {}\n",
366 "Cannot translate:".red().bold(),
367 reason
368 ))?;
369
370 if !suggestions.is_empty() {
371 streams.write_diagnostic(&format!("\n{}:\n", "Suggestions".yellow()))?;
372 for suggestion in suggestions {
373 streams.write_diagnostic(&format!(" • {suggestion}\n"))?;
374 }
375 }
376 }
377 Ok(format!("Translation rejected: {reason}"))
378}
379
380pub fn run_ask(
388 cli: &Cli,
389 query: &str,
390 path: &str,
391 auto_execute: bool,
392 dry_run: bool,
393 threshold: f32,
394) -> Result<()> {
395 use sqry_nl::{TranslationResponse, Translator, TranslatorConfig};
396
397 let mut streams = OutputStreams::with_pager(cli.pager_config());
398
399 let translator_config = TranslatorConfig {
401 execute_threshold: threshold,
402 confirm_threshold: threshold * 0.75, ..Default::default()
404 };
405
406 let mut translator = Translator::new(translator_config)
407 .context("Failed to initialize natural language translator")?;
408
409 let response = translator.translate(query);
411
412 let config = ResponseConfig {
414 cli,
415 path,
416 auto_execute,
417 dry_run,
418 };
419
420 let reject_error = match response {
422 TranslationResponse::Execute {
423 command,
424 confidence,
425 intent,
426 ..
427 } => {
428 handle_execute_response(&mut streams, &config, &command, confidence, intent.as_str())?;
429 None
430 }
431
432 TranslationResponse::Confirm {
433 command,
434 confidence,
435 prompt,
436 } => {
437 handle_confirm_response(&mut streams, &config, &command, confidence, &prompt)?;
438 None
439 }
440
441 TranslationResponse::Disambiguate { options, prompt } => {
442 handle_disambiguate_response(&mut streams, &config, &options, &prompt)?;
443 None
444 }
445
446 TranslationResponse::Reject {
447 reason,
448 suggestions,
449 } => {
450 let error_msg = handle_reject_response(&mut streams, &config, &reason, &suggestions)?;
451 Some(error_msg)
452 }
453 };
454
455 streams.finish_checked()?;
456
457 if let Some(error_msg) = reject_error {
459 anyhow::bail!("{error_msg}");
460 }
461
462 Ok(())
463}
464
465#[derive(Debug, Default)]
467struct ParsedCommandArgs {
468 primary: String,
470 language: Option<String>,
472 kind: Option<String>,
474 limit: Option<u32>,
476 path_filter: Option<String>,
478 secondary: Option<String>,
480 max_depth: Option<u32>,
482}
483
484fn extract_flag_value(command: &str, flag: &str) -> Option<String> {
489 let flag_pos = command.find(flag)?;
491 let after_flag = &command[flag_pos + flag.len()..];
492
493 let trimmed = after_flag.trim_start();
495 if trimmed.is_empty() {
496 return None;
497 }
498
499 if let Some(stripped) = trimmed.strip_prefix('"') {
501 if let Some(end) = stripped.find('"') {
503 return Some(stripped[..end].to_string());
504 }
505 return Some(stripped.to_string());
507 }
508
509 let value = trimmed.split_whitespace().next()?;
511 Some(value.to_string())
512}
513
514fn parse_generated_command(command: &str) -> Result<ParsedCommandArgs> {
516 let mut args = ParsedCommandArgs::default();
517
518 let mut quoted_strings = Vec::new();
520 let mut in_quote = false;
521 let mut current_quoted = String::new();
522
523 for c in command.chars() {
524 if c == '"' {
525 if in_quote {
526 quoted_strings.push(current_quoted.clone());
527 current_quoted.clear();
528 }
529 in_quote = !in_quote;
530 } else if in_quote {
531 current_quoted.push(c);
532 }
533 }
534
535 if let Some(primary) = quoted_strings.first() {
537 args.primary.clone_from(primary);
538 }
539
540 if let Some(secondary) = quoted_strings.get(1) {
542 args.secondary = Some(secondary.clone());
543 }
544
545 args.path_filter = extract_flag_value(command, "--path");
548
549 let parts: Vec<&str> = command.split_whitespace().collect();
551 let mut i = 0;
552 while i < parts.len() {
553 match parts[i] {
554 "--language" if i + 1 < parts.len() => {
555 args.language = Some(parts[i + 1].to_string());
556 i += 2;
557 }
558 "--kind" if i + 1 < parts.len() => {
559 args.kind = Some(parts[i + 1].to_string());
560 i += 2;
561 }
562 "--limit" if i + 1 < parts.len() => {
563 args.limit = parts[i + 1].parse().ok();
564 i += 2;
565 }
566 "--path" => {
567 i += 2;
570 }
571 "--max-depth" if i + 1 < parts.len() => {
572 args.max_depth = parts[i + 1].parse().ok();
573 i += 2;
574 }
575 _ => {
576 i += 1;
577 }
578 }
579 }
580
581 if args.primary.is_empty() {
582 anyhow::bail!("Could not extract primary argument from command: {command}");
583 }
584
585 Ok(args)
586}
587
588fn build_query_expression(args: &ParsedCommandArgs) -> String {
594 let mut expr_parts = vec![args.primary.clone()];
595
596 if let Some(lang) = &args.language
601 && !args.primary.contains("lang:")
602 && !args.primary.contains("language:")
603 {
604 expr_parts.push(format!("language:{lang}"));
605 }
606
607 if let Some(path) = &args.path_filter
609 && !args.primary.contains("path:")
610 {
611 if path.contains(' ') {
612 let escaped = path.replace('"', "\\\"");
614 expr_parts.push(format!("path:\"{escaped}\""));
615 } else {
616 expr_parts.push(format!("path:{path}"));
617 }
618 }
619
620 expr_parts.join(" ")
623}
624
625fn execute_generated_command(command: &str, path: &str, cli: &Cli) -> Result<()> {
627 let parts: Vec<&str> = command.split_whitespace().collect();
629
630 if parts.is_empty() || parts[0] != "sqry" {
631 anyhow::bail!("Invalid generated command: {command}");
632 }
633
634 if parts.len() < 2 {
635 anyhow::bail!("Generated command missing subcommand: {command}");
636 }
637
638 let subcommand = parts[1];
639
640 match subcommand {
641 "query" => {
642 let parsed = parse_generated_command(command)?;
644 let query_expr = build_query_expression(&parsed);
646 let result_limit = parsed.limit.map(|l| l as usize);
648 super::run_query(
649 cli,
650 &query_expr,
651 path,
652 false,
653 false,
654 false,
655 false,
656 None,
657 result_limit,
658 &[],
659 )?;
660 }
661 "search" => {
662 let parsed = parse_generated_command(command)?;
663 super::run_search(cli, &parsed.primary, path)?;
665 }
666 "graph" => {
667 if parts.len() < 3 {
669 anyhow::bail!("Graph command missing operation: {command}");
670 }
671 eprintln!(
673 "{}",
674 format!("Graph commands not yet auto-executable: {command}").yellow()
675 );
676 }
677 "index" => {
678 if command.contains("--status") {
679 super::run_index_status(cli, path, crate::args::MetricsFormat::Json)?;
680 } else {
681 eprintln!(
682 "{}",
683 format!("Index build not auto-executable: {command}").yellow()
684 );
685 }
686 }
687 _ => {
688 anyhow::bail!("Unsupported generated command: {subcommand}");
689 }
690 }
691
692 Ok(())
693}
694
695#[cfg(test)]
697fn extract_quoted_arg(command: &str, _position: usize) -> Result<String> {
698 if let Some(start) = command.find('"')
700 && let Some(end) = command[start + 1..].find('"')
701 {
702 return Ok(command[start + 1..start + 1 + end].to_string());
703 }
704 let parts: Vec<&str> = command.split_whitespace().collect();
706 if parts.len() > 2 {
707 let arg = parts[2].trim_matches('"');
709 return Ok(arg.to_string());
710 }
711 anyhow::bail!("Could not extract argument from: {command}")
712}
713
714fn prompt_confirmation(message: &str) -> Result<bool> {
716 if message.is_empty() {
717 eprint!("[y/N] ");
718 } else {
719 eprint!("{message} [y/N] ");
720 }
721 io::stderr().flush()?;
722
723 let mut input = String::new();
724 io::stdin().read_line(&mut input)?;
725
726 Ok(input.trim().eq_ignore_ascii_case("y") || input.trim().eq_ignore_ascii_case("yes"))
727}
728
729fn prompt_choice(max: usize) -> Result<Option<usize>> {
731 eprint!("Enter choice (1-{max}) or 'c' to cancel: ");
732 io::stderr().flush()?;
733
734 let mut input = String::new();
735 io::stdin().read_line(&mut input)?;
736
737 let trimmed = input.trim();
738 if trimmed.eq_ignore_ascii_case("c") || trimmed.is_empty() {
739 return Ok(None);
740 }
741
742 match trimmed.parse::<usize>() {
743 Ok(n) if n >= 1 && n <= max => Ok(Some(n - 1)),
744 _ => {
745 eprintln!("Invalid choice");
746 Ok(None)
747 }
748 }
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754
755 #[test]
756 fn test_extract_quoted_arg() {
757 let cmd = r#"sqry query "kind:function""#;
758 let arg = extract_quoted_arg(cmd, 2).unwrap();
759 assert_eq!(arg, "kind:function");
760 }
761
762 #[test]
763 fn test_extract_quoted_arg_with_spaces() {
764 let cmd = r#"sqry search "hello world""#;
765 let arg = extract_quoted_arg(cmd, 2).unwrap();
766 assert_eq!(arg, "hello world");
767 }
768
769 #[test]
770 fn test_parse_generated_command_basic() {
771 let cmd = r#"sqry query "authenticate" --limit 100"#;
772 let parsed = parse_generated_command(cmd).unwrap();
773 assert_eq!(parsed.primary, "authenticate");
774 assert_eq!(parsed.limit, Some(100));
775 assert!(parsed.language.is_none());
776 assert!(parsed.kind.is_none());
777 }
778
779 #[test]
780 fn test_parse_generated_command_with_all_flags() {
781 let cmd = r#"sqry query "login" --language rust --kind function --limit 50"#;
782 let parsed = parse_generated_command(cmd).unwrap();
783 assert_eq!(parsed.primary, "login");
784 assert_eq!(parsed.language.as_deref(), Some("rust"));
785 assert_eq!(parsed.kind.as_deref(), Some("function"));
786 assert_eq!(parsed.limit, Some(50));
787 }
788
789 #[test]
790 fn test_parse_generated_command_trace_path() {
791 let cmd = r#"sqry graph trace-path "source" "target" --max-depth 5"#;
792 let parsed = parse_generated_command(cmd).unwrap();
793 assert_eq!(parsed.primary, "source");
794 assert_eq!(parsed.secondary.as_deref(), Some("target"));
795 assert_eq!(parsed.max_depth, Some(5));
796 }
797
798 #[test]
799 fn test_build_query_expression_basic() {
800 let args = ParsedCommandArgs {
801 primary: "authenticate".to_string(),
802 ..Default::default()
803 };
804 let expr = build_query_expression(&args);
805 assert_eq!(expr, "authenticate");
806 }
807
808 #[test]
809 fn test_build_query_expression_with_predicates() {
810 let args = ParsedCommandArgs {
813 primary: "kind:function login".to_string(), language: Some("rust".to_string()),
815 kind: Some("function".to_string()),
816 limit: Some(50), ..Default::default()
818 };
819 let expr = build_query_expression(&args);
820 assert!(expr.contains("login"));
821 assert!(expr.contains("kind:function"));
822 assert!(expr.contains("language:rust"));
823 assert!(!expr.contains("limit:"));
825 }
826
827 #[test]
828 fn test_build_query_expression_with_path() {
829 let args = ParsedCommandArgs {
830 primary: "test".to_string(),
831 path_filter: Some("src/lib.rs".to_string()),
832 ..Default::default()
833 };
834 let expr = build_query_expression(&args);
835 assert!(expr.contains("path:src/lib.rs"));
836 }
837
838 #[test]
839 fn test_build_query_expression_with_path_spaces() {
840 let args = ParsedCommandArgs {
841 primary: "login".to_string(),
842 path_filter: Some("src/api services".to_string()),
843 language: Some("rust".to_string()),
844 ..Default::default()
845 };
846 let expr = build_query_expression(&args);
847 assert!(expr.contains(r#"path:"src/api services""#));
849 assert!(expr.contains("language:rust"));
850 }
851
852 #[test]
853 fn test_extract_flag_value_unquoted() {
854 let cmd = r#"sqry query "test" --limit 50"#;
855 assert_eq!(extract_flag_value(cmd, "--limit"), Some("50".to_string()));
856 }
857
858 #[test]
859 fn test_extract_flag_value_quoted() {
860 let cmd = r#"sqry query "test" --path "src/api services""#;
861 assert_eq!(
862 extract_flag_value(cmd, "--path"),
863 Some("src/api services".to_string())
864 );
865 }
866
867 #[test]
868 fn test_extract_flag_value_not_present() {
869 let cmd = r#"sqry query "test""#;
870 assert_eq!(extract_flag_value(cmd, "--limit"), None);
871 }
872
873 #[test]
874 fn test_parse_generated_command_with_path_spaces() {
875 let cmd = r#"sqry query "login" --path "src/api services" --language rust"#;
876 let parsed = parse_generated_command(cmd).unwrap();
877 assert_eq!(parsed.primary, "login");
878 assert_eq!(parsed.path_filter.as_deref(), Some("src/api services"));
879 assert_eq!(parsed.language.as_deref(), Some("rust"));
880 }
881}