1use super::{GroupingResponse, SemanticGroup};
2use std::collections::HashSet;
3use std::time::Duration;
4use tokio::process::Command;
5
6const MAX_RESPONSE_BYTES: usize = 1_048_576;
8const MAX_JSON_SIZE: usize = 102_400;
10const MAX_GROUPS: usize = 20;
12const MAX_CHANGES_PER_GROUP: usize = 200;
14const MAX_LABEL_LEN: usize = 80;
16const MAX_DESC_LEN: usize = 500;
18
19#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum LlmBackend {
22 Claude,
23 Copilot,
24}
25
26pub async fn request_grouping_with_timeout(
28 backend: LlmBackend,
29 model: &str,
30 summaries: &str,
31) -> anyhow::Result<Vec<SemanticGroup>> {
32 let model = model.to_string();
33 tokio::time::timeout(
34 Duration::from_secs(60),
35 request_grouping(backend, &model, summaries),
36 )
37 .await
38 .map_err(|_| anyhow::anyhow!("LLM timed out after 60s"))?
39}
40
41pub async fn request_grouping(
47 backend: LlmBackend,
48 model: &str,
49 hunk_summaries: &str,
50) -> anyhow::Result<Vec<SemanticGroup>> {
51 let prompt = format!(
52 "Group these code changes by semantic intent at the HUNK level. \
53 Related hunks across different files should be in the same group.\n\
54 Return ONLY valid JSON.\n\
55 Schema: {{\"groups\": [{{\"label\": \"short name\", \"description\": \"one sentence\", \
56 \"changes\": [{{\"file\": \"path\", \"hunks\": [0, 1]}}]}}]}}\n\
57 Rules:\n\
58 - Every hunk of every file must appear in exactly one group\n\
59 - Use 2-5 groups (fewer for small changesets)\n\
60 - Labels should describe the PURPOSE (e.g. \"Auth refactor\", \"Test coverage\")\n\
61 - The \"hunks\" array contains 0-based hunk indices as shown in HUNK N: headers\n\
62 - A single file's hunks may be split across different groups if they serve different purposes\n\n\
63 Changed files and hunks:\n{hunk_summaries}",
64 );
65
66 let output = match backend {
67 LlmBackend::Claude => invoke_claude(&prompt, model).await?,
68 LlmBackend::Copilot => invoke_copilot(&prompt, model).await?,
69 };
70
71 let json_str = extract_json(&output)?;
73
74 if json_str.len() > MAX_JSON_SIZE {
76 anyhow::bail!(
77 "LLM JSON response too large ({} bytes, max {})",
78 json_str.len(),
79 MAX_JSON_SIZE
80 );
81 }
82
83 let response: GroupingResponse = serde_json::from_str(&json_str)?;
84
85 let known_files: HashSet<&str> = hunk_summaries
87 .lines()
88 .filter_map(|line| {
89 let line = line.trim();
90 if let Some(rest) = line.strip_prefix("FILE: ") {
91 let end = rest.find(" (")?;
92 Some(&rest[..end])
93 } else {
94 None
95 }
96 })
97 .collect();
98
99 let validated_groups: Vec<SemanticGroup> = response
101 .groups
102 .into_iter()
103 .take(MAX_GROUPS) .map(|group| {
105 let valid_changes: Vec<super::GroupedChange> = group
106 .changes()
107 .into_iter()
108 .filter(|change| {
109 let known = known_files.contains(change.file.as_str());
111 let safe = !change.file.contains("..") && !change.file.starts_with('/');
113 if !safe {
114 tracing::warn!("Rejected LLM file path with traversal: {}", change.file);
115 }
116 known && safe
117 })
118 .take(MAX_CHANGES_PER_GROUP) .collect();
120 SemanticGroup::new(
122 truncate_string(&group.label, MAX_LABEL_LEN),
123 truncate_string(&group.description, MAX_DESC_LEN),
124 valid_changes,
125 )
126 })
127 .filter(|group| !group.changes().is_empty())
128 .collect();
129
130 Ok(validated_groups)
131}
132
133pub async fn request_incremental_grouping(
138 backend: LlmBackend,
139 model: &str,
140 summaries: &str,
141) -> anyhow::Result<Vec<SemanticGroup>> {
142 let model = model.to_string();
143 tokio::time::timeout(
144 Duration::from_secs(60),
145 request_incremental(backend, &model, summaries),
146 )
147 .await
148 .map_err(|_| anyhow::anyhow!("LLM timed out after 60s"))?
149}
150
151async fn request_incremental(
152 backend: LlmBackend,
153 model: &str,
154 hunk_summaries: &str,
155) -> anyhow::Result<Vec<SemanticGroup>> {
156 let prompt = format!(
157 "You are updating an existing grouping of code changes. \
158 New or modified files have been added to the working tree.\n\
159 Assign the NEW/MODIFIED hunks to the EXISTING groups listed above, or create new groups if they don't fit.\n\
160 Return ONLY valid JSON with assignments for the NEW/MODIFIED files only.\n\
161 Schema: {{\"groups\": [{{\"label\": \"short name\", \"description\": \"one sentence\", \
162 \"changes\": [{{\"file\": \"path\", \"hunks\": [0, 1]}}]}}]}}\n\
163 Rules:\n\
164 - Every hunk of every NEW/MODIFIED file must appear in exactly one group\n\
165 - Reuse existing group labels when the change fits that group's purpose\n\
166 - Create new groups only when a change serves a genuinely different purpose\n\
167 - Use the same label string (case-sensitive) when assigning to an existing group\n\
168 - The \"hunks\" array contains 0-based hunk indices\n\
169 - Do NOT include unchanged files in your response\n\n\
170 {hunk_summaries}",
171 );
172
173 let output = match backend {
174 LlmBackend::Claude => invoke_claude(&prompt, model).await?,
175 LlmBackend::Copilot => invoke_copilot(&prompt, model).await?,
176 };
177
178 let json_str = extract_json(&output)?;
179
180 if json_str.len() > MAX_JSON_SIZE {
181 anyhow::bail!(
182 "LLM JSON response too large ({} bytes, max {})",
183 json_str.len(),
184 MAX_JSON_SIZE
185 );
186 }
187
188 let response: GroupingResponse = serde_json::from_str(&json_str)?;
189
190 let known_files: HashSet<&str> = hunk_summaries
192 .lines()
193 .filter_map(|line| {
194 let line = line.trim();
195 if let Some(rest) = line.strip_prefix("FILE: ") {
196 let end = rest.find(" (")?;
197 Some(&rest[..end])
198 } else {
199 None
200 }
201 })
202 .collect();
203
204 let validated_groups: Vec<SemanticGroup> = response
205 .groups
206 .into_iter()
207 .take(MAX_GROUPS)
208 .map(|group| {
209 let valid_changes: Vec<super::GroupedChange> = group
210 .changes()
211 .into_iter()
212 .filter(|change| {
213 let known = known_files.contains(change.file.as_str());
214 let safe = !change.file.contains("..") && !change.file.starts_with('/');
215 if !safe {
216 tracing::warn!("Rejected LLM file path with traversal: {}", change.file);
217 }
218 known && safe
219 })
220 .take(MAX_CHANGES_PER_GROUP)
221 .collect();
222 SemanticGroup::new(
223 truncate_string(&group.label, MAX_LABEL_LEN),
224 truncate_string(&group.description, MAX_DESC_LEN),
225 valid_changes,
226 )
227 })
228 .filter(|group| !group.changes().is_empty())
229 .collect();
230
231 Ok(validated_groups)
232}
233
234pub async fn invoke_llm_text(
237 backend: LlmBackend,
238 model: &str,
239 prompt: &str,
240) -> anyhow::Result<String> {
241 match backend {
242 LlmBackend::Claude => invoke_claude_text(prompt, model).await,
243 LlmBackend::Copilot => invoke_copilot(prompt, model).await,
244 }
245}
246
247async fn invoke_claude(prompt: &str, model: &str) -> anyhow::Result<String> {
252 use std::process::Stdio;
253 use tokio::io::{AsyncReadExt, AsyncWriteExt};
254
255 let mut child = Command::new("claude")
256 .args([
257 "-p",
258 "--output-format",
259 "json",
260 "--model",
261 model,
262 "--max-turns",
263 "1",
264 ])
265 .stdin(Stdio::piped())
266 .stdout(Stdio::piped())
267 .stderr(Stdio::piped())
268 .spawn()?;
269
270 if let Some(mut stdin) = child.stdin.take() {
272 stdin.write_all(prompt.as_bytes()).await?;
273 }
275
276 let stdout_pipe = child.stdout.take()
278 .ok_or_else(|| anyhow::anyhow!("failed to capture claude stdout"))?;
279 let mut limited = stdout_pipe.take(MAX_RESPONSE_BYTES as u64);
280 let mut buf = Vec::with_capacity(8192);
281 let bytes_read = limited.read_to_end(&mut buf).await?;
282
283 if bytes_read >= MAX_RESPONSE_BYTES {
284 child.kill().await.ok();
285 anyhow::bail!("LLM response exceeded {MAX_RESPONSE_BYTES} byte limit");
286 }
287
288 let status = child.wait().await?;
289 if !status.success() {
290 let mut stderr_buf = Vec::new();
292 if let Some(mut stderr) = child.stderr.take() {
293 stderr.read_to_end(&mut stderr_buf).await.ok();
294 }
295 let stderr_str = String::from_utf8_lossy(&stderr_buf);
296 anyhow::bail!("claude exited with status {status}: {stderr_str}");
297 }
298
299 let stdout_str = String::from_utf8(buf)?;
300 let wrapper: serde_json::Value = serde_json::from_str(&stdout_str)?;
301 let result_text = wrapper["result"]
302 .as_str()
303 .ok_or_else(|| anyhow::anyhow!("missing result field in claude JSON output"))?;
304
305 Ok(result_text.to_string())
306}
307
308async fn invoke_claude_text(prompt: &str, model: &str) -> anyhow::Result<String> {
311 use std::process::Stdio;
312 use tokio::io::{AsyncReadExt, AsyncWriteExt};
313
314 let mut child = Command::new("claude")
315 .args([
316 "-p",
317 "--output-format",
318 "text",
319 "--model",
320 model,
321 ])
322 .stdin(Stdio::piped())
323 .stdout(Stdio::piped())
324 .stderr(Stdio::piped())
325 .spawn()?;
326
327 if let Some(mut stdin) = child.stdin.take() {
328 stdin.write_all(prompt.as_bytes()).await?;
329 }
330
331 let stdout_pipe = child.stdout.take()
333 .ok_or_else(|| anyhow::anyhow!("failed to capture claude stdout"))?;
334 let stderr_pipe = child.stderr.take();
335
336 let stdout_fut = async {
337 let mut limited = stdout_pipe.take(MAX_RESPONSE_BYTES as u64);
338 let mut buf = Vec::with_capacity(8192);
339 let bytes_read = limited.read_to_end(&mut buf).await?;
340 Ok::<(Vec<u8>, usize), std::io::Error>((buf, bytes_read))
341 };
342 let stderr_fut = async {
343 let mut stderr_buf = Vec::new();
344 if let Some(mut stderr) = stderr_pipe {
345 stderr.read_to_end(&mut stderr_buf).await.ok();
346 }
347 stderr_buf
348 };
349
350 let (stdout_result, stderr_buf) = tokio::join!(stdout_fut, stderr_fut);
351 let (buf, bytes_read) = stdout_result?;
352
353 if bytes_read >= MAX_RESPONSE_BYTES {
354 child.kill().await.ok();
355 anyhow::bail!("LLM response exceeded {MAX_RESPONSE_BYTES} byte limit");
356 }
357
358 let status = child.wait().await?;
359 if !status.success() {
360 let stderr_str = String::from_utf8_lossy(&stderr_buf);
361 anyhow::bail!("claude exited with status {status}: {stderr_str}");
362 }
363
364 Ok(String::from_utf8(buf)?)
365}
366
367async fn invoke_copilot(prompt: &str, model: &str) -> anyhow::Result<String> {
372 use std::process::Stdio;
373 use tokio::io::{AsyncReadExt, AsyncWriteExt};
374
375 let mut child = Command::new("copilot")
376 .args(["--yolo", "--model", model])
377 .stdin(Stdio::piped())
378 .stdout(Stdio::piped())
379 .stderr(Stdio::piped())
380 .spawn()?;
381
382 if let Some(mut stdin) = child.stdin.take() {
384 stdin.write_all(prompt.as_bytes()).await?;
385 }
386
387 let stdout_pipe = child.stdout.take()
389 .ok_or_else(|| anyhow::anyhow!("failed to capture copilot stdout"))?;
390 let mut limited = stdout_pipe.take(MAX_RESPONSE_BYTES as u64);
391 let mut buf = Vec::with_capacity(8192);
392 let bytes_read = limited.read_to_end(&mut buf).await?;
393
394 if bytes_read >= MAX_RESPONSE_BYTES {
395 child.kill().await.ok();
396 anyhow::bail!("LLM response exceeded {MAX_RESPONSE_BYTES} byte limit");
397 }
398
399 let status = child.wait().await?;
400 if !status.success() {
401 let mut stderr_buf = Vec::new();
402 if let Some(mut stderr) = child.stderr.take() {
403 stderr.read_to_end(&mut stderr_buf).await.ok();
404 }
405 let stderr_str = String::from_utf8_lossy(&stderr_buf);
406 anyhow::bail!("copilot exited with status {status}: {stderr_str}");
407 }
408
409 Ok(String::from_utf8(buf)?)
410}
411
412fn extract_json(text: &str) -> anyhow::Result<String> {
414 let trimmed = text.trim();
415 if trimmed.starts_with('{') {
417 return Ok(trimmed.to_string());
418 }
419 if let Some(start) = trimmed.find('{') {
421 if let Some(end) = trimmed.rfind('}') {
422 return Ok(trimmed[start..=end].to_string());
423 }
424 }
425 anyhow::bail!("no JSON object found in response")
426}
427
428fn truncate_string(s: &str, max: usize) -> String {
430 if s.chars().count() <= max {
431 s.to_string()
432 } else {
433 s.chars().take(max).collect()
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440
441 #[test]
442 fn test_extract_json_direct() {
443 let input = r#"{"groups": []}"#;
444 assert_eq!(extract_json(input).unwrap(), input);
445 }
446
447 #[test]
448 fn test_extract_json_code_fences() {
449 let input = "```json\n{\"groups\": []}\n```";
450 assert_eq!(extract_json(input).unwrap(), r#"{"groups": []}"#);
451 }
452
453 #[test]
454 fn test_extract_json_no_json() {
455 assert!(extract_json("no json here").is_err());
456 }
457
458 #[test]
459 fn test_parse_hunk_level_response() {
460 let json = r#"{
461 "groups": [{
462 "label": "Auth refactor",
463 "description": "Refactored auth flow",
464 "changes": [
465 {"file": "src/auth.rs", "hunks": [0, 2]},
466 {"file": "src/middleware.rs", "hunks": [1]}
467 ]
468 }]
469 }"#;
470 let response: GroupingResponse = serde_json::from_str(json).unwrap();
471 assert_eq!(response.groups.len(), 1);
472 assert_eq!(response.groups[0].changes().len(), 2);
473 assert_eq!(response.groups[0].changes()[0].hunks, vec![0, 2]);
474 }
475
476 #[test]
477 fn test_parse_empty_hunks_means_all() {
478 let json = r#"{
479 "groups": [{
480 "label": "Config",
481 "description": "Config changes",
482 "changes": [{"file": "config.toml", "hunks": []}]
483 }]
484 }"#;
485 let response: GroupingResponse = serde_json::from_str(json).unwrap();
486 assert!(response.groups[0].changes()[0].hunks.is_empty());
487 }
488
489 #[test]
493 fn test_invoke_claude_uses_stdin_pipe() {
494 let src = include_str!("llm.rs");
495 let claude_start = src.find("async fn invoke_claude").expect("invoke_claude not found");
497 let claude_body = &src[claude_start..];
498 let end = claude_body[1..].find("\nasync fn").unwrap_or(claude_body.len());
500 let claude_fn = &claude_body[..end];
501
502 assert!(
503 claude_fn.contains("Stdio::piped()"),
504 "invoke_claude must use Stdio::piped() for stdin"
505 );
506 assert!(
507 claude_fn.contains("write_all"),
508 "invoke_claude must write prompt to stdin via write_all"
509 );
510 if let Some(args_start) = claude_fn.find(".args([") {
512 let args_section = &claude_fn[args_start..];
513 let args_end = args_section.find("])").expect("unclosed .args");
514 let args_content = &args_section[..args_end];
515 assert!(
516 !args_content.contains("prompt"),
517 "invoke_claude must not pass prompt in .args()"
518 );
519 }
520 }
521
522 #[test]
524 fn test_invoke_copilot_uses_stdin_pipe() {
525 let src = include_str!("llm.rs");
526 let copilot_start = src.find("async fn invoke_copilot").expect("invoke_copilot not found");
527 let copilot_body = &src[copilot_start..];
528 let end = copilot_body[1..].find("\n/// ").or_else(|| copilot_body[1..].find("\n#[cfg(test)]")).unwrap_or(copilot_body.len());
529 let copilot_fn = &copilot_body[..end];
530
531 assert!(
532 copilot_fn.contains("Stdio::piped()"),
533 "invoke_copilot must use Stdio::piped() for stdin"
534 );
535 assert!(
536 copilot_fn.contains("write_all"),
537 "invoke_copilot must write prompt to stdin via write_all"
538 );
539 }
540
541 #[test]
543 fn test_no_prompt_in_args() {
544 let src = include_str!("llm.rs");
545 let claude_start = src.find("async fn invoke_claude").expect("invoke_claude not found");
547 let claude_body = &src[claude_start..];
548 let end = claude_body[1..].find("\nasync fn").unwrap_or(claude_body.len());
549 let claude_fn = &claude_body[..end];
550
551 if let Some(args_start) = claude_fn.find(".args([") {
553 let args_section = &claude_fn[args_start..];
554 let args_end = args_section.find("])").expect("unclosed .args");
555 let args_content = &args_section[..args_end];
556 assert!(
557 !args_content.contains("prompt"),
558 "invoke_claude .args() must not contain prompt variable"
559 );
560 }
561
562 let copilot_start = src.find("async fn invoke_copilot").expect("invoke_copilot not found");
564 let copilot_body = &src[copilot_start..];
565 let end2 = copilot_body[1..].find("\n/// ").or_else(|| copilot_body[1..].find("\n#[cfg(test)]")).unwrap_or(copilot_body.len());
566 let copilot_fn = &copilot_body[..end2];
567
568 if let Some(args_start) = copilot_fn.find(".args([") {
569 let args_section = &copilot_fn[args_start..];
570 let args_end = args_section.find("])").expect("unclosed .args");
571 let args_content = &args_section[..args_end];
572 assert!(
573 !args_content.contains("prompt"),
574 "invoke_copilot .args() must not contain prompt variable"
575 );
576 }
577 }
578
579 #[test]
580 fn test_parse_files_fallback() {
581 let json = r#"{
583 "groups": [{
584 "label": "Refactor",
585 "description": "Code cleanup",
586 "files": ["src/app.rs", "src/main.rs"]
587 }]
588 }"#;
589 let response: GroupingResponse = serde_json::from_str(json).unwrap();
590 let changes = response.groups[0].changes();
591 assert_eq!(changes.len(), 2);
592 assert_eq!(changes[0].file, "src/app.rs");
593 assert!(changes[0].hunks.is_empty()); }
595
596 #[test]
599 fn test_read_bounded_under_limit() {
600 let data = "hello world";
602 assert!(data.len() < MAX_RESPONSE_BYTES);
603 assert_eq!(MAX_RESPONSE_BYTES, 1_048_576);
605 }
606
607 #[test]
608 fn test_read_bounded_over_limit_constant() {
609 assert_eq!(MAX_RESPONSE_BYTES, 1_048_576);
611 let oversized = vec![b'x'; MAX_RESPONSE_BYTES];
613 assert!(oversized.len() >= MAX_RESPONSE_BYTES);
614 }
615
616 #[test]
619 fn test_validate_rejects_oversized_json() {
620 let large_json = format!(r#"{{"groups": [{{"label": "x", "description": "{}", "changes": []}}]}}"#,
622 "a".repeat(MAX_JSON_SIZE + 1));
623 assert!(large_json.len() > MAX_JSON_SIZE);
624 }
626
627 #[test]
628 fn test_validate_caps_groups_at_max() {
629 let mut groups_json = Vec::new();
631 for i in 0..30 {
632 groups_json.push(format!(
633 r#"{{"label": "Group {}", "description": "desc", "changes": [{{"file": "src/f{}.rs", "hunks": [0]}}]}}"#,
634 i, i
635 ));
636 }
637 let json = format!(r#"{{"groups": [{}]}}"#, groups_json.join(","));
638 let response: GroupingResponse = serde_json::from_str(&json).unwrap();
639 assert_eq!(response.groups.len(), 30);
640 let capped: Vec<_> = response.groups.into_iter().take(MAX_GROUPS).collect();
642 assert_eq!(capped.len(), 20);
643 }
644
645 #[test]
646 fn test_validate_rejects_path_traversal() {
647 let json = r#"{
648 "groups": [{
649 "label": "Evil",
650 "description": "traversal",
651 "changes": [{"file": "../../../etc/passwd", "hunks": [0]}]
652 }]
653 }"#;
654 let response: GroupingResponse = serde_json::from_str(json).unwrap();
655 let change = &response.groups[0].changes()[0];
656 assert!(change.file.contains(".."), "path should contain traversal");
657 }
659
660 #[test]
661 fn test_validate_rejects_absolute_paths() {
662 let json = r#"{
663 "groups": [{
664 "label": "Evil",
665 "description": "absolute",
666 "changes": [{"file": "/etc/passwd", "hunks": [0]}]
667 }]
668 }"#;
669 let response: GroupingResponse = serde_json::from_str(json).unwrap();
670 let change = &response.groups[0].changes()[0];
671 assert!(change.file.starts_with('/'), "path should be absolute");
672 }
674
675 #[test]
676 fn test_truncate_string_label() {
677 let long_label = "a".repeat(100);
678 let truncated = truncate_string(&long_label, MAX_LABEL_LEN);
679 assert_eq!(truncated.chars().count(), MAX_LABEL_LEN);
680 }
681
682 #[test]
683 fn test_truncate_string_description() {
684 let long_desc = "b".repeat(600);
685 let truncated = truncate_string(&long_desc, MAX_DESC_LEN);
686 assert_eq!(truncated.chars().count(), MAX_DESC_LEN);
687 }
688
689 #[test]
690 fn test_validate_caps_changes_per_group() {
691 let mut changes = Vec::new();
693 for i in 0..250 {
694 changes.push(format!(r#"{{"file": "src/f{}.rs", "hunks": [0]}}"#, i));
695 }
696 let json = format!(
697 r#"{{"groups": [{{"label": "Big", "description": "lots", "changes": [{}]}}]}}"#,
698 changes.join(",")
699 );
700 let response: GroupingResponse = serde_json::from_str(&json).unwrap();
701 assert_eq!(response.groups[0].changes().len(), 250);
702 let capped: Vec<_> = response.groups[0].changes().into_iter().take(MAX_CHANGES_PER_GROUP).collect();
704 assert_eq!(capped.len(), 200);
705 }
706}