use std::path::{Path, PathBuf};
use std::process::Command;
use std::time::Instant;
use anyhow::Result;
use sqry_nl::{DisambiguationOption, TranslationResponse, Translator, TranslatorConfig};
use crate::engine::{canonicalize_in_workspace, engine_for_workspace};
use crate::execution::utils::duration_to_ms;
use crate::execution::{NlDisambiguationOption, NlTranslationData, ToolExecution};
use crate::tools::SqryAskParams;
fn has_path_flag_outside_quotes(command: &str) -> bool {
let chars: Vec<char> = command.chars().collect();
let path_pattern: Vec<char> = "--path".chars().collect();
let mut state = QuoteScanState::default();
for i in 0..chars.len() {
let c = chars[i];
if state.advance(c) {
continue;
}
if !state.in_quotes()
&& matches_path_flag_at(&chars, i, &path_pattern)
&& has_path_flag_boundaries(&chars, i, path_pattern.len())
{
return true;
}
}
false
}
#[derive(Default)]
struct QuoteScanState {
in_double_quotes: bool,
in_single_quotes: bool,
prev_was_escape: bool,
}
impl QuoteScanState {
fn advance(&mut self, c: char) -> bool {
if self.prev_was_escape {
self.prev_was_escape = false;
return true;
}
if c == '\\' {
self.prev_was_escape = true;
return true;
}
if c == '"' && !self.in_single_quotes {
self.in_double_quotes = !self.in_double_quotes;
return true;
}
if c == '\'' && !self.in_double_quotes {
self.in_single_quotes = !self.in_single_quotes;
return true;
}
false
}
fn in_quotes(&self) -> bool {
self.in_double_quotes || self.in_single_quotes
}
}
fn matches_path_flag_at(chars: &[char], offset: usize, pattern: &[char]) -> bool {
if offset + pattern.len() > chars.len() {
return false;
}
chars[offset..offset + pattern.len()]
.iter()
.zip(pattern.iter())
.all(|(a, b)| a == b)
}
fn has_path_flag_boundaries(chars: &[char], offset: usize, pattern_len: usize) -> bool {
let before_ok = offset == 0 || chars[offset - 1].is_whitespace();
let after_pos = offset + pattern_len;
let after_ok =
after_pos == chars.len() || chars[after_pos].is_whitespace() || chars[after_pos] == '=';
before_ok && after_ok
}
fn augment_command_with_path(command: &str, scoped_path: &Path, workspace_root: &Path) -> String {
if scoped_path == workspace_root {
return command.to_string();
}
let relative_path = scoped_path
.strip_prefix(workspace_root)
.unwrap_or(scoped_path);
if has_path_flag_outside_quotes(command) {
return command.to_string();
}
format!(
"{} --path \"{}\"",
command,
crate::execution::symbol_utils::path_to_forward_slash(relative_path)
)
}
fn resolve_workspace_path(path: &str) -> Option<PathBuf> {
if path == "." {
None
} else {
Some(PathBuf::from(path))
}
}
pub fn execute_sqry_ask(args: &SqryAskParams) -> Result<ToolExecution<NlTranslationData>> {
let start = Instant::now();
let workspace_path = resolve_workspace_path(&args.path);
let engine = engine_for_workspace(workspace_path.as_ref())?;
let workspace_root = engine.workspace_root();
let scoped_path = canonicalize_in_workspace(&args.path, workspace_root)?;
tracing::debug!(
query = %args.query,
path = %args.path,
scoped_path = %scoped_path.display(),
workspace = %workspace_root.display(),
"Executing sqry_ask tool"
);
let mut translator = build_translator(&scoped_path)?;
let response = translator.translate(&args.query);
let mut data = build_translation_data(response, &scoped_path, workspace_root);
if args.execute
&& let Some(cmd_str) = &data.command
{
tracing::debug!(command = %cmd_str, "Executing translated sqry command");
let parts: Vec<&str> = cmd_str.split_whitespace().collect();
if !parts.is_empty() {
let bin = parts[0];
let cmd_args = &parts[1..];
let output = Command::new(bin)
.args(cmd_args)
.current_dir(workspace_root)
.output();
match output {
Ok(out) => {
let stdout = String::from_utf8_lossy(&out.stdout).to_string();
let stderr = String::from_utf8_lossy(&out.stderr).to_string();
if out.status.success() {
data.execution_output = Some(stdout);
} else {
data.execution_output = Some(format!("Error: {stderr}\n{stdout}"));
}
}
Err(e) => {
data.execution_output = Some(format!("Failed to execute command: {e}"));
}
}
}
}
tracing::debug!(
response_type = %data.response_type,
"sqry_ask tool completed"
);
Ok(ToolExecution {
data,
used_index: false,
used_graph: false,
graph_metadata: None,
execution_ms: duration_to_ms(start.elapsed()),
next_page_token: None,
total: Some(1),
truncated: Some(false),
candidates_scanned: None,
workspace_path: crate::execution::symbol_utils::path_to_forward_slash(workspace_root),
})
}
fn build_translator(scoped_path: &Path) -> Result<Translator> {
let config = TranslatorConfig {
working_directory: Some(crate::execution::symbol_utils::path_to_forward_slash(
scoped_path,
)),
..TranslatorConfig::default()
};
Ok(Translator::new(config)?)
}
fn build_translation_data(
response: TranslationResponse,
scoped_path: &Path,
workspace_root: &Path,
) -> NlTranslationData {
match response {
TranslationResponse::Execute {
command,
confidence,
intent,
..
} => build_execute_data(
&command,
confidence,
intent.as_str(),
scoped_path,
workspace_root,
),
TranslationResponse::Confirm {
command,
confidence,
prompt,
} => build_confirm_data(&command, confidence, &prompt, scoped_path, workspace_root),
TranslationResponse::Disambiguate { options, prompt } => {
build_disambiguate_data(options, prompt, scoped_path, workspace_root)
}
TranslationResponse::Reject {
reason,
suggestions,
} => build_reject_data(reason, suggestions),
}
}
fn build_execute_data(
command: &str,
confidence: f32,
intent: &str,
scoped_path: &Path,
workspace_root: &Path,
) -> NlTranslationData {
let scoped_command = augment_command_with_path(command, scoped_path, workspace_root);
NlTranslationData {
response_type: "execute".to_string(),
command: Some(scoped_command),
confidence: Some(confidence),
intent: Some(intent.to_string()),
prompt: None,
reason: None,
suggestions: Vec::new(),
options: Vec::new(),
execution_output: None,
}
}
fn build_confirm_data(
command: &str,
confidence: f32,
prompt: &str,
scoped_path: &Path,
workspace_root: &Path,
) -> NlTranslationData {
let scoped_command = augment_command_with_path(command, scoped_path, workspace_root);
let scoped_prompt = prompt.replace(command, &scoped_command);
NlTranslationData {
response_type: "confirm".to_string(),
command: Some(scoped_command),
confidence: Some(confidence),
intent: None,
prompt: Some(scoped_prompt),
reason: None,
suggestions: Vec::new(),
options: Vec::new(),
execution_output: None,
}
}
fn build_disambiguate_data(
options: Vec<DisambiguationOption>,
prompt: String,
scoped_path: &Path,
workspace_root: &Path,
) -> NlTranslationData {
let nl_options = build_disambiguation_options(options, scoped_path, workspace_root);
NlTranslationData {
response_type: "disambiguate".to_string(),
command: None,
confidence: None,
intent: None,
prompt: Some(prompt),
reason: None,
suggestions: Vec::new(),
options: nl_options,
execution_output: None,
}
}
fn build_disambiguation_options(
options: Vec<DisambiguationOption>,
scoped_path: &Path,
workspace_root: &Path,
) -> Vec<NlDisambiguationOption> {
options
.into_iter()
.map(|opt| {
let scoped_command =
augment_command_with_path(&opt.command, scoped_path, workspace_root);
NlDisambiguationOption {
command: scoped_command,
intent: opt.intent.as_str().to_string(),
description: opt.description,
confidence: opt.confidence,
}
})
.collect()
}
fn build_reject_data(reason: String, suggestions: Vec<String>) -> NlTranslationData {
NlTranslationData {
response_type: "reject".to_string(),
command: None,
confidence: None,
intent: None,
prompt: None,
reason: Some(reason),
suggestions,
options: Vec::new(),
execution_output: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_augment_command_same_as_workspace() {
let workspace = PathBuf::from("/workspace");
let scoped = PathBuf::from("/workspace");
let command = "sqry query \"kind:function\"";
let result = augment_command_with_path(command, &scoped, &workspace);
assert_eq!(
result, command,
"Should not modify command when path == workspace"
);
}
#[test]
fn test_augment_command_with_subdirectory() {
let workspace = PathBuf::from("/workspace");
let scoped = PathBuf::from("/workspace/src/lib");
let command = "sqry query \"kind:function\"";
let result = augment_command_with_path(command, &scoped, &workspace);
assert_eq!(
result, "sqry query \"kind:function\" --path \"src/lib\"",
"Should append relative --path"
);
}
#[test]
fn test_augment_command_already_has_path() {
let workspace = PathBuf::from("/workspace");
let scoped = PathBuf::from("/workspace/src");
let command = "sqry query \"kind:function\" --path \"other\"";
let result = augment_command_with_path(command, &scoped, &workspace);
assert_eq!(result, command, "Should not add --path if already present");
}
#[test]
fn test_augment_command_with_spaces_in_path() {
let workspace = PathBuf::from("/workspace");
let scoped = PathBuf::from("/workspace/my project/src");
let command = "sqry query \"kind:function\"";
let result = augment_command_with_path(command, &scoped, &workspace);
assert!(
result.contains("--path \"my project/src\""),
"Path with spaces should be quoted: {result}"
);
}
#[test]
fn test_augment_command_path_in_query_text() {
let workspace = PathBuf::from("/workspace");
let scoped = PathBuf::from("/workspace/src/lib");
let command = "sqry query \"find --path flag usage\"";
let result = augment_command_with_path(command, &scoped, &workspace);
assert!(
result.contains("--path \"src/lib\""),
"Should append --path when it only appears inside query: {result}"
);
}
#[test]
fn test_augment_command_path_in_single_quotes() {
let workspace = PathBuf::from("/workspace");
let scoped = PathBuf::from("/workspace/src");
let command = "sqry query 'find --path'";
let result = augment_command_with_path(command, &scoped, &workspace);
assert!(
result.contains("--path \"src\""),
"Should append --path when it only appears inside single quotes: {result}"
);
}
#[test]
fn test_has_path_flag_no_path() {
assert!(!has_path_flag_outside_quotes(
"sqry query \"kind:function\""
));
}
#[test]
fn test_has_path_flag_real_flag() {
assert!(has_path_flag_outside_quotes(
"sqry query \"kind:function\" --path \"src\""
));
}
#[test]
fn test_has_path_flag_with_equals() {
assert!(has_path_flag_outside_quotes(
"sqry query \"kind:function\" --path=\"src\""
));
}
#[test]
fn test_has_path_flag_inside_double_quotes() {
assert!(!has_path_flag_outside_quotes(
"sqry query \"find --path usage\""
));
}
#[test]
fn test_has_path_flag_inside_single_quotes() {
assert!(!has_path_flag_outside_quotes("sqry query 'find --path'"));
}
#[test]
fn test_has_path_flag_escaped_quote() {
assert!(!has_path_flag_outside_quotes(
"sqry query \"find \\\"--path\\\" usage\""
));
}
#[test]
fn test_has_path_flag_partial_match() {
assert!(!has_path_flag_outside_quotes(
"sqry query \"kind:function\" --pathlike \"src\""
));
}
#[test]
#[serial_test::serial(engine_cache)]
#[serial_test::serial(workspace_env)]
fn test_execute_sqry_ask_basic() {
crate::engine::init_engine_cache(std::num::NonZeroUsize::new(4).unwrap());
if engine_for_workspace(None).is_err() {
return;
}
let args = SqryAskParams {
query: "find public functions".to_string(),
path: ".".to_string(),
execute: false,
};
let result = execute_sqry_ask(&args);
assert!(result.is_ok());
let execution = result.unwrap();
let valid_types = ["execute", "confirm", "disambiguate", "reject"];
assert!(
valid_types.contains(&execution.data.response_type.as_str()),
"Unexpected response type: {}",
execution.data.response_type
);
}
#[test]
#[serial_test::serial(engine_cache)]
#[serial_test::serial(workspace_env)]
fn test_execute_sqry_ask_response_types() {
crate::engine::init_engine_cache(std::num::NonZeroUsize::new(4).unwrap());
if engine_for_workspace(None).is_err() {
return;
}
let test_cases = vec![
("find all public functions", "execute"),
("show me methods", "execute"),
("xyz123", "reject"), ];
for (query, _expected_type) in test_cases {
let args = SqryAskParams {
query: query.to_string(),
path: ".".to_string(),
execute: false,
};
let result = execute_sqry_ask(&args);
assert!(
result.is_ok(),
"Query '{}' should not error: {:?}",
query,
result.err()
);
let execution = result.unwrap();
let valid_types = ["execute", "confirm", "disambiguate", "reject"];
assert!(
valid_types.contains(&execution.data.response_type.as_str()),
"Query '{}' produced invalid response type: {}",
query,
execution.data.response_type
);
}
}
#[test]
#[serial_test::serial(engine_cache)]
#[serial_test::serial(workspace_env)]
fn test_execute_sqry_ask_path_validation() {
crate::engine::init_engine_cache(std::num::NonZeroUsize::new(4).unwrap());
if engine_for_workspace(None).is_err() {
return;
}
let args = SqryAskParams {
query: "find functions".to_string(),
path: ".".to_string(),
execute: false,
};
assert!(execute_sqry_ask(&args).is_ok());
let args_bad = SqryAskParams {
query: "find functions".to_string(),
path: "/etc/passwd".to_string(),
execute: false,
};
assert!(execute_sqry_ask(&args_bad).is_err());
}
}