1mod formatting;
6pub(crate) mod templates;
7
8use crate::error::{AssemblerError, NlResult};
9use crate::types::{ExtractedEntities, Intent, PredicateType, TemplateType, Visibility};
10
11pub fn assemble_command(
19 intent: &Intent,
20 entities: &ExtractedEntities,
21) -> Result<String, AssemblerError> {
22 let assembler = TemplateAssembler::default();
23 assembler
24 .assemble(*intent, entities)
25 .map(|cmd| cmd.command)
26 .map_err(|e| match e {
27 crate::error::NlError::Assembler(ae) => ae,
28 _ => AssemblerError::AmbiguousIntent, })
30}
31
32#[derive(Debug, Clone)]
34pub struct AssembledCommand {
35 pub command: String,
37 pub template_type: TemplateType,
39}
40
41#[derive(Debug, Clone)]
43pub struct AssemblerConfig {
44 pub default_limit: u32,
46 pub max_command_length: usize,
48 pub default_max_depth: u32,
50}
51
52impl Default for AssemblerConfig {
53 fn default() -> Self {
54 Self {
55 default_limit: 100,
56 max_command_length: 512,
57 default_max_depth: 10,
58 }
59 }
60}
61
62pub struct TemplateAssembler {
64 config: AssemblerConfig,
65}
66
67impl Default for TemplateAssembler {
68 fn default() -> Self {
69 Self::new(AssemblerConfig::default())
70 }
71}
72
73impl TemplateAssembler {
74 #[must_use]
76 pub fn new(config: AssemblerConfig) -> Self {
77 Self { config }
78 }
79
80 pub fn assemble(
89 &self,
90 intent: Intent,
91 entities: &ExtractedEntities,
92 ) -> NlResult<AssembledCommand> {
93 match intent {
94 Intent::SymbolQuery => self.build_query_command(entities),
95 Intent::TextSearch => self.build_search_command(entities),
96 Intent::TracePath => self.build_trace_path_command(entities),
97 Intent::FindCallers => self.build_callers_command(entities),
98 Intent::FindCallees => self.build_callees_command(entities),
99 Intent::Visualize => self.build_visualize_command(entities),
100 Intent::IndexStatus => self.build_index_status_command(entities),
101 Intent::Ambiguous => Err(AssemblerError::AmbiguousIntent.into()),
102 }
103 }
104
105 fn build_command(
106 &self,
107 parts: &[String],
108 template_type: TemplateType,
109 ) -> NlResult<AssembledCommand> {
110 let command = parts.join(" ");
111 self.validate_length(&command)?;
112
113 Ok(AssembledCommand {
114 command,
115 template_type,
116 })
117 }
118
119 fn push_languages(parts: &mut Vec<String>, languages: &[String]) {
120 for lang in languages {
121 parts.push(format!("--language {lang}"));
122 }
123 }
124
125 fn push_path(parts: &mut Vec<String>, paths: &[String]) {
126 if let Some(path) = paths.first() {
127 parts.push(format!("--path \"{}\"", formatting::escape_quotes(path)));
128 }
129 }
130
131 fn require_primary_symbol(
132 entities: &ExtractedEntities,
133 error: AssemblerError,
134 ) -> NlResult<&str> {
135 entities.primary_symbol().ok_or_else(|| error.into())
136 }
137
138 fn build_query_command(&self, entities: &ExtractedEntities) -> NlResult<AssembledCommand> {
139 let query_expr = Self::build_query_expression(entities)?;
141
142 let mut parts = vec![
143 "sqry".to_string(),
144 "query".to_string(),
145 format!("\"{}\"", formatting::escape_quotes(&query_expr)),
146 ];
147
148 Self::push_languages(&mut parts, &entities.languages);
150
151 Self::push_path(&mut parts, &entities.paths);
153
154 let limit = entities.limit.unwrap_or(self.config.default_limit);
159 parts.push(format!("--limit {limit}"));
160
161 self.build_command(&parts, TemplateType::Query)
162 }
163
164 fn build_query_expression(entities: &ExtractedEntities) -> NlResult<String> {
168 let mut expr_parts = Self::collect_predicates(entities);
169
170 if expr_parts.is_empty() {
171 return Self::build_symbol_only_query(entities);
172 }
173
174 if let Some(symbol) = entities.primary_symbol()
175 && Self::should_include_symbol(entities, symbol)
176 {
177 expr_parts.push(symbol.to_string());
178 }
179
180 Ok(expr_parts.join(" AND "))
183 }
184
185 fn build_search_command(&self, entities: &ExtractedEntities) -> NlResult<AssembledCommand> {
186 let pattern = Self::require_primary_symbol(entities, AssemblerError::MissingSymbol)?;
187
188 let mut parts = vec![
189 "sqry".to_string(),
190 "search".to_string(),
191 format!("\"{}\"", formatting::escape_quotes(pattern)),
192 ];
193
194 Self::push_languages(&mut parts, &entities.languages);
196
197 Self::push_path(&mut parts, &entities.paths);
199
200 self.build_command(&parts, TemplateType::Search)
201 }
202
203 fn build_trace_path_command(&self, entities: &ExtractedEntities) -> NlResult<AssembledCommand> {
204 let from = entities
205 .from_symbol
206 .as_deref()
207 .or_else(|| entities.symbols.first().map(String::as_str))
208 .ok_or(AssemblerError::MissingTracePath)?;
209
210 let to = entities
211 .to_symbol
212 .as_deref()
213 .or_else(|| entities.symbols.get(1).map(String::as_str))
214 .ok_or(AssemblerError::MissingTracePath)?;
215
216 let mut parts = vec![
217 "sqry".to_string(),
218 "graph".to_string(),
219 "trace-path".to_string(),
220 format!("\"{}\"", formatting::escape_quotes(from)),
221 format!("\"{}\"", formatting::escape_quotes(to)),
222 ];
223
224 let depth = entities.depth.unwrap_or(self.config.default_max_depth);
226 parts.push(format!("--max-depth {depth}"));
227
228 let command = parts.join(" ");
229 self.validate_length(&command)?;
230
231 Ok(AssembledCommand {
232 command,
233 template_type: TemplateType::TracePath,
234 })
235 }
236
237 fn build_callers_command(&self, entities: &ExtractedEntities) -> NlResult<AssembledCommand> {
238 let symbol = Self::require_primary_symbol(entities, AssemblerError::MissingSymbol)?;
239
240 let mut parts = vec![
241 "sqry".to_string(),
242 "graph".to_string(),
243 "direct-callers".to_string(),
244 format!("\"{}\"", formatting::escape_quotes(symbol)),
245 ];
246
247 if let Some(lang) = entities.languages.first() {
249 parts.push(format!("--language {lang}"));
250 }
251
252 self.build_command(&parts, TemplateType::GraphCallers)
253 }
254
255 fn build_callees_command(&self, entities: &ExtractedEntities) -> NlResult<AssembledCommand> {
256 let symbol = Self::require_primary_symbol(entities, AssemblerError::MissingSymbol)?;
257
258 let mut parts = vec![
259 "sqry".to_string(),
260 "graph".to_string(),
261 "direct-callees".to_string(),
262 format!("\"{}\"", formatting::escape_quotes(symbol)),
263 ];
264
265 if let Some(lang) = entities.languages.first() {
267 parts.push(format!("--language {lang}"));
268 }
269
270 self.build_command(&parts, TemplateType::GraphCallees)
271 }
272
273 fn build_visualize_command(&self, entities: &ExtractedEntities) -> NlResult<AssembledCommand> {
274 let symbol = Self::require_primary_symbol(entities, AssemblerError::MissingSymbol)?;
275
276 let relation = entities.relation.as_deref().unwrap_or("call");
277
278 let mut parts = vec![
279 "sqry".to_string(),
280 "visualize".to_string(),
281 format!("--relation {}", relation),
282 format!("--symbol \"{}\"", formatting::escape_quotes(symbol)),
283 ];
284
285 if let Some(format) = entities.format {
287 parts.push(format!("--format {}", format.as_str()));
288 }
289
290 self.build_command(&parts, TemplateType::Visualize)
291 }
292
293 fn build_index_status_command(
294 &self,
295 entities: &ExtractedEntities,
296 ) -> NlResult<AssembledCommand> {
297 let mut parts = vec![
298 "sqry".to_string(),
299 "index".to_string(),
300 "--status".to_string(),
301 ];
302
303 if let Some(path) = entities.paths.first() {
305 parts.push(format!("--path \"{}\"", formatting::escape_quotes(path)));
306 }
307
308 if entities.format == Some(crate::types::OutputFormat::Json) {
310 parts.push("--json".to_string());
311 }
312
313 self.build_command(&parts, TemplateType::IndexStatus)
314 }
315
316 fn collect_predicates(entities: &ExtractedEntities) -> Vec<String> {
317 let mut expr_parts = Vec::new();
318
319 if let Some(trait_name) = &entities.impl_trait {
321 expr_parts.push(format!("impl:{trait_name}"));
322 }
323
324 if entities.predicate_type == Some(PredicateType::Duplicates) {
326 let arg = entities.predicate_arg.as_deref().unwrap_or("body");
327 expr_parts.push(format!("duplicates:{arg}"));
328 }
329
330 if entities.predicate_type == Some(PredicateType::Circular) {
332 let arg = entities.predicate_arg.as_deref().unwrap_or("calls");
333 expr_parts.push(format!("circular:{arg}"));
334 }
335
336 if entities.predicate_type == Some(PredicateType::Unused) {
338 expr_parts.push("unused:".to_string());
339 }
340
341 if let Some(visibility) = entities.visibility {
343 match visibility {
344 Visibility::Public => expr_parts.push("visibility:public".to_string()),
345 Visibility::Private => expr_parts.push("visibility:private".to_string()),
346 }
347 }
348
349 if entities.is_async == Some(true) {
351 expr_parts.push("async:true".to_string());
352 }
353
354 if entities.is_unsafe == Some(true) {
356 expr_parts.push("unsafe:true".to_string());
357 }
358
359 if let Some(kind) = entities.kind {
361 expr_parts.push(format!("kind:{}", kind.as_str()));
362 }
363
364 expr_parts
365 }
366
367 fn build_symbol_only_query(entities: &ExtractedEntities) -> NlResult<String> {
368 match entities.primary_symbol() {
369 Some(symbol) => Ok(symbol.to_string()),
370 None if entities.kind.is_some() => Ok("*".to_string()),
371 None => Err(AssemblerError::MissingSymbol.into()),
372 }
373 }
374
375 fn should_include_symbol(entities: &ExtractedEntities, symbol: &str) -> bool {
376 entities.impl_trait.is_none() || symbol != entities.impl_trait.as_deref().unwrap_or("")
377 }
378
379 fn validate_length(&self, command: &str) -> NlResult<()> {
380 if command.len() > self.config.max_command_length {
381 return Err(AssemblerError::CommandTooLong {
382 len: command.len(),
383 max: self.config.max_command_length,
384 }
385 .into());
386 }
387 Ok(())
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::types::{PredicateType, SymbolKind, Visibility};
395
396 #[test]
397 fn test_build_query_basic() {
398 let assembler = TemplateAssembler::default();
399 let mut entities = ExtractedEntities::new();
400 entities.symbols.push("authenticate".to_string());
401
402 let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
403 assert!(result.command.starts_with("sqry query"));
404 assert!(result.command.contains("\"authenticate\""));
405 }
406
407 #[test]
408 fn test_build_query_with_options() {
409 let assembler = TemplateAssembler::default();
410 let mut entities = ExtractedEntities::new();
411 entities.symbols.push("foo".to_string());
412 entities.languages.push("rust".to_string());
413 entities.kind = Some(SymbolKind::Function);
414 entities.limit = Some(10);
415
416 let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
417 assert!(result.command.contains("--language rust"));
418 assert!(result.command.contains("kind:function"));
420 assert!(result.command.contains("--limit 10"));
421 }
422
423 #[test]
424 fn test_build_callers() {
425 let assembler = TemplateAssembler::default();
426 let mut entities = ExtractedEntities::new();
427 entities.symbols.push("login".to_string());
428
429 let result = assembler.assemble(Intent::FindCallers, &entities).unwrap();
430 assert!(result.command.contains("sqry graph direct-callers"));
431 assert!(result.command.contains("\"login\""));
432 }
433
434 #[test]
435 fn test_build_trace_path() {
436 let assembler = TemplateAssembler::default();
437 let mut entities = ExtractedEntities::new();
438 entities.from_symbol = Some("login".to_string());
439 entities.to_symbol = Some("database".to_string());
440
441 let result = assembler.assemble(Intent::TracePath, &entities).unwrap();
442 assert!(result.command.contains("sqry graph trace-path"));
443 assert!(result.command.contains("\"login\""));
444 assert!(result.command.contains("\"database\""));
445 }
446
447 #[test]
448 fn test_missing_symbol_error() {
449 let assembler = TemplateAssembler::default();
450 let entities = ExtractedEntities::new();
451
452 let result = assembler.assemble(Intent::SymbolQuery, &entities);
453 assert!(matches!(
454 result,
455 Err(crate::error::NlError::Assembler(
456 AssemblerError::MissingSymbol
457 ))
458 ));
459 }
460
461 #[test]
462 fn test_ambiguous_intent_error() {
463 let assembler = TemplateAssembler::default();
464 let entities = ExtractedEntities::new();
465
466 let result = assembler.assemble(Intent::Ambiguous, &entities);
467 assert!(matches!(
468 result,
469 Err(crate::error::NlError::Assembler(
470 AssemblerError::AmbiguousIntent
471 ))
472 ));
473 }
474
475 #[test]
478 fn test_build_query_impl_predicate() {
479 let assembler = TemplateAssembler::default();
480 let mut entities = ExtractedEntities::new();
481 entities.impl_trait = Some("Future".to_string());
482 entities.predicate_type = Some(PredicateType::Impl);
483
484 let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
485 assert!(result.command.contains("\"impl:Future\""));
486 assert!(result.command.starts_with("sqry query"));
487 }
488
489 #[test]
490 fn test_build_query_duplicates_predicate() {
491 let assembler = TemplateAssembler::default();
492 let mut entities = ExtractedEntities::new();
493 entities.predicate_type = Some(PredicateType::Duplicates);
494
495 let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
496 assert!(result.command.contains("\"duplicates:body\""));
498 }
499
500 #[test]
501 fn test_build_query_duplicates_signature() {
502 let assembler = TemplateAssembler::default();
503 let mut entities = ExtractedEntities::new();
504 entities.predicate_type = Some(PredicateType::Duplicates);
505 entities.predicate_arg = Some("signature".to_string());
506
507 let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
508 assert!(result.command.contains("\"duplicates:signature\""));
509 }
510
511 #[test]
512 fn test_build_query_circular_predicate() {
513 let assembler = TemplateAssembler::default();
514 let mut entities = ExtractedEntities::new();
515 entities.predicate_type = Some(PredicateType::Circular);
516
517 let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
518 assert!(result.command.contains("\"circular:calls\""));
520 }
521
522 #[test]
523 fn test_build_query_unused_predicate() {
524 let assembler = TemplateAssembler::default();
525 let mut entities = ExtractedEntities::new();
526 entities.predicate_type = Some(PredicateType::Unused);
527
528 let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
529 assert!(result.command.contains("\"unused:\""));
530 }
531
532 #[test]
533 fn test_build_query_visibility_public() {
534 let assembler = TemplateAssembler::default();
535 let mut entities = ExtractedEntities::new();
536 entities.visibility = Some(Visibility::Public);
537
538 let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
539 assert!(result.command.contains("visibility:public"));
540 }
541
542 #[test]
543 fn test_build_query_async_predicate() {
544 let assembler = TemplateAssembler::default();
545 let mut entities = ExtractedEntities::new();
546 entities.is_async = Some(true);
547
548 let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
549 assert!(result.command.contains("async:true"));
550 }
551
552 #[test]
553 fn test_build_query_unsafe_predicate() {
554 let assembler = TemplateAssembler::default();
555 let mut entities = ExtractedEntities::new();
556 entities.is_unsafe = Some(true);
557
558 let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
559 assert!(result.command.contains("unsafe:true"));
560 }
561
562 #[test]
563 fn test_build_query_combined_predicates() {
564 let assembler = TemplateAssembler::default();
565 let mut entities = ExtractedEntities::new();
566 entities.visibility = Some(Visibility::Public);
567 entities.is_async = Some(true);
568
569 let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
570 assert!(result.command.contains("visibility:public"));
571 assert!(result.command.contains("async:true"));
572 }
573
574 #[test]
575 fn test_build_query_impl_with_symbol_no_duplicate() {
576 let assembler = TemplateAssembler::default();
577 let mut entities = ExtractedEntities::new();
578 entities.impl_trait = Some("Iterator".to_string());
579 entities.predicate_type = Some(PredicateType::Impl);
580 entities.symbols.push("Iterator".to_string());
582
583 let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
584 assert!(result.command.contains("\"impl:Iterator\""));
586 let count = result.command.matches("Iterator").count();
588 assert_eq!(
589 count, 1,
590 "Iterator should only appear once in: {}",
591 result.command
592 );
593 }
594}
595
596#[cfg(test)]
598mod predicate_assembly_tests {
599 use super::*;
600 use crate::extractor::extract_entities;
601
602 #[test]
603 fn test_async_functions_assembly() {
604 let assembler = TemplateAssembler::default();
605 let entities = extract_entities("find async functions");
606 let result = assembler.assemble(Intent::SymbolQuery, &entities);
607
608 assert!(result.is_ok());
609 let cmd = result.unwrap();
610 assert!(cmd.command.contains("async:true"));
611 }
612
613 #[test]
614 fn test_unsafe_functions_assembly() {
615 let assembler = TemplateAssembler::default();
616 let entities = extract_entities("find unsafe functions");
617 let result = assembler.assemble(Intent::SymbolQuery, &entities);
618
619 assert!(result.is_ok());
620 let cmd = result.unwrap();
621 assert!(cmd.command.contains("unsafe:true"));
622 }
623
624 #[test]
625 fn test_public_async_functions_assembly() {
626 let assembler = TemplateAssembler::default();
627 let entities = extract_entities("find public async functions");
628 let result = assembler.assemble(Intent::SymbolQuery, &entities);
629
630 assert!(result.is_ok());
631 let cmd = result.unwrap();
632 assert!(cmd.command.contains("visibility:public"));
633 assert!(cmd.command.contains("async:true"));
634 }
635}