1use anyhow::{Context, Result};
4use tracing::{debug, info, warn};
5
6use crate::claude::token_budget::TokenBudget;
7use crate::claude::{ai::bedrock::BedrockAiClient, ai::claude::ClaudeAiClient};
8use crate::claude::{
9 ai::{AiClient, RequestOptions, ResponseFormat},
10 error::ClaudeError,
11 prompts, response_schema,
12};
13use crate::data::{
14 amendments::{Amendment, AmendmentFile},
15 context::CommitContext,
16 RepositoryView, RepositoryViewForAI,
17};
18
19struct BudgetExceeded {
24 available_input_tokens: usize,
26}
27
28const AMENDMENT_PARSE_MAX_RETRIES: u32 = 2;
30
31pub struct ClaudeClient {
33 ai_client: Box<dyn AiClient>,
35}
36
37impl ClaudeClient {
38 pub fn new(ai_client: Box<dyn AiClient>) -> Self {
40 Self { ai_client }
41 }
42
43 pub fn get_ai_client_metadata(&self) -> crate::claude::ai::AiClientMetadata {
45 self.ai_client.get_metadata()
46 }
47
48 #[must_use]
56 pub fn into_ai_client(self) -> Box<dyn AiClient> {
57 self.ai_client
58 }
59
60 fn adjusted_system_prompt(&self, system_prompt: String) -> String {
71 let format = ResponseFormat::from_capabilities(&self.ai_client.capabilities());
72 prompts::apply_response_format_to_system_prompt(system_prompt, format)
73 }
74
75 fn schema_if_supported<'a>(
82 &self,
83 schema: &'a serde_json::Value,
84 ) -> Option<&'a serde_json::Value> {
85 if self.ai_client.capabilities().supports_response_schema {
86 Some(schema)
87 } else {
88 None
89 }
90 }
91
92 async fn send_with_optional_schema(
103 &self,
104 system_prompt: &str,
105 user_prompt: &str,
106 schema: Option<&serde_json::Value>,
107 ) -> Result<String> {
108 match schema {
109 Some(s) => {
110 let opts = RequestOptions::default().with_response_schema(s.clone());
111 self.ai_client
112 .send_request_with_options(system_prompt, user_prompt, opts)
113 .await
114 }
115 None => {
116 self.ai_client
117 .send_request(system_prompt, user_prompt)
118 .await
119 }
120 }
121 }
122
123 fn validate_prompt_budget(&self, system_prompt: &str, user_prompt: &str) -> Result<()> {
128 let metadata = self.ai_client.get_metadata();
129 let budget = TokenBudget::from_metadata(&metadata);
130 let estimate = budget.validate_prompt(system_prompt, user_prompt)?;
131
132 debug!(
133 model = %metadata.model,
134 estimated_tokens = estimate.estimated_tokens,
135 available_tokens = estimate.available_tokens,
136 utilization_pct = format!("{:.1}%", estimate.utilization_pct),
137 "Token budget check passed"
138 );
139
140 Ok(())
141 }
142
143 fn build_prompt_fitting_budget(
149 &self,
150 ai_view: &RepositoryViewForAI,
151 system_prompt: &str,
152 build_user_prompt: &(impl Fn(&str) -> String + ?Sized),
153 ) -> Result<String> {
154 let metadata = self.ai_client.get_metadata();
155 let budget = TokenBudget::from_metadata(&metadata);
156
157 let yaml =
158 crate::data::to_yaml(ai_view).context("Failed to serialize repository view to YAML")?;
159 let user_prompt = build_user_prompt(&yaml);
160
161 let estimate = budget.validate_prompt(system_prompt, &user_prompt)?;
162 debug!(
163 model = %metadata.model,
164 estimated_tokens = estimate.estimated_tokens,
165 available_tokens = estimate.available_tokens,
166 utilization_pct = format!("{:.1}%", estimate.utilization_pct),
167 "Token budget check passed"
168 );
169
170 Ok(user_prompt)
171 }
172
173 fn try_full_diff_budget(
179 &self,
180 ai_view: &RepositoryViewForAI,
181 system_prompt: &str,
182 build_user_prompt: &(impl Fn(&str) -> String + ?Sized),
183 ) -> Result<std::result::Result<String, BudgetExceeded>> {
184 let metadata = self.ai_client.get_metadata();
185 let budget = TokenBudget::from_metadata(&metadata);
186
187 let yaml =
188 crate::data::to_yaml(ai_view).context("Failed to serialize repository view to YAML")?;
189 let user_prompt = build_user_prompt(&yaml);
190
191 if let Ok(estimate) = budget.validate_prompt(system_prompt, &user_prompt) {
192 debug!(
193 model = %metadata.model,
194 estimated_tokens = estimate.estimated_tokens,
195 available_tokens = estimate.available_tokens,
196 utilization_pct = format!("{:.1}%", estimate.utilization_pct),
197 "Token budget check passed"
198 );
199 return Ok(Ok(user_prompt));
200 }
201
202 Ok(Err(BudgetExceeded {
203 available_input_tokens: budget.available_input_tokens(),
204 }))
205 }
206
207 async fn generate_amendment_split(
214 &self,
215 commit: &crate::git::CommitInfo,
216 repo_view_for_ai: &RepositoryViewForAI,
217 system_prompt: &str,
218 build_user_prompt: &(dyn Fn(&str) -> String + Sync),
219 available_input_tokens: usize,
220 fresh: bool,
221 ) -> Result<Amendment> {
222 use crate::claude::batch::{
223 PER_COMMIT_METADATA_OVERHEAD_TOKENS, USER_PROMPT_TEMPLATE_OVERHEAD_TOKENS,
224 VIEW_ENVELOPE_OVERHEAD_TOKENS,
225 };
226 use crate::claude::diff_pack::pack_file_diffs;
227 use crate::claude::token_budget;
228 use crate::git::commit::CommitInfoForAI;
229
230 let system_prompt_tokens = token_budget::estimate_tokens(system_prompt);
238 let commit_text_tokens = token_budget::estimate_tokens(&commit.original_message)
239 + token_budget::estimate_tokens(&commit.analysis.diff_summary);
240 let chunk_capacity = available_input_tokens
241 .saturating_sub(system_prompt_tokens)
242 .saturating_sub(VIEW_ENVELOPE_OVERHEAD_TOKENS)
243 .saturating_sub(PER_COMMIT_METADATA_OVERHEAD_TOKENS)
244 .saturating_sub(USER_PROMPT_TEMPLATE_OVERHEAD_TOKENS)
245 .saturating_sub(commit_text_tokens);
246
247 debug!(
248 commit = %&commit.hash[..8],
249 available_input_tokens,
250 system_prompt_tokens,
251 envelope_overhead = VIEW_ENVELOPE_OVERHEAD_TOKENS,
252 metadata_overhead = PER_COMMIT_METADATA_OVERHEAD_TOKENS,
253 template_overhead = USER_PROMPT_TEMPLATE_OVERHEAD_TOKENS,
254 commit_text_tokens,
255 chunk_capacity,
256 "Split dispatch: computed chunk capacity"
257 );
258
259 let plan = pack_file_diffs(&commit.hash, &commit.analysis.file_diffs, chunk_capacity)
260 .with_context(|| {
261 format!(
262 "Failed to plan diff chunks for commit {}",
263 &commit.hash[..8]
264 )
265 })?;
266
267 let total_chunks = plan.chunks.len();
268 debug!(
269 commit = %&commit.hash[..8],
270 chunks = total_chunks,
271 chunk_capacity,
272 "Split dispatch: processing commit in chunks"
273 );
274
275 let mut chunk_amendments = Vec::with_capacity(total_chunks);
276 for (i, chunk) in plan.chunks.iter().enumerate() {
277 let mut partial = CommitInfoForAI::from_commit_info_partial_with_overrides(
278 commit.clone(),
279 &chunk.file_paths,
280 &chunk.diff_overrides,
281 )
282 .with_context(|| {
283 format!(
284 "Failed to build partial view for chunk {}/{} of commit {}",
285 i + 1,
286 total_chunks,
287 &commit.hash[..8]
288 )
289 })?;
290
291 if fresh {
292 partial.base.original_message =
293 "(Original message hidden - generate fresh message from diff)".to_string();
294 }
295
296 let partial_view = repo_view_for_ai.single_commit_view_for_ai(&partial);
297
298 let diff_content_len = partial.base.analysis.diff_content.len();
300 let diff_content_tokens =
301 token_budget::estimate_tokens_from_char_count(diff_content_len);
302 debug!(
303 commit = %&commit.hash[..8],
304 chunk_index = i,
305 diff_content_len,
306 diff_content_tokens,
307 "Split dispatch: chunk diff content size"
308 );
309
310 let user_prompt =
311 self.build_prompt_fitting_budget(&partial_view, system_prompt, build_user_prompt)?;
312
313 info!(
314 commit = %&commit.hash[..8],
315 chunk = i + 1,
316 total_chunks,
317 user_prompt_len = user_prompt.len(),
318 "Split dispatch: sending chunk to AI"
319 );
320
321 let content = match self
322 .send_with_optional_schema(
323 system_prompt,
324 &user_prompt,
325 self.schema_if_supported(response_schema::amendment_file_schema()),
326 )
327 .await
328 {
329 Ok(content) => content,
330 Err(e) => {
331 tracing::error!(
333 commit = %&commit.hash[..8],
334 chunk = i + 1,
335 error = %e,
336 error_debug = ?e,
337 "Split dispatch: AI request failed"
338 );
339 return Err(e).with_context(|| {
340 format!(
341 "Chunk {}/{} failed for commit {}",
342 i + 1,
343 total_chunks,
344 &commit.hash[..8]
345 )
346 });
347 }
348 };
349
350 info!(
351 commit = %&commit.hash[..8],
352 chunk = i + 1,
353 response_len = content.len(),
354 "Split dispatch: received chunk response"
355 );
356
357 let amendment_file = self.parse_amendment_response(&content).with_context(|| {
358 format!(
359 "Failed to parse chunk {}/{} response for commit {}",
360 i + 1,
361 total_chunks,
362 &commit.hash[..8]
363 )
364 })?;
365
366 if let Some(amendment) = amendment_file.amendments.into_iter().next() {
367 chunk_amendments.push(amendment);
368 }
369 }
370
371 self.merge_amendment_chunks(
372 &commit.hash,
373 &commit.original_message,
374 &commit.analysis.diff_summary,
375 &chunk_amendments,
376 )
377 .await
378 }
379
380 async fn merge_amendment_chunks(
386 &self,
387 commit_hash: &str,
388 original_message: &str,
389 diff_summary: &str,
390 chunk_amendments: &[Amendment],
391 ) -> Result<Amendment> {
392 let system_prompt =
393 self.adjusted_system_prompt(prompts::AMENDMENT_CHUNK_MERGE_SYSTEM_PROMPT.to_string());
394 let user_prompt = prompts::generate_chunk_merge_user_prompt(
395 commit_hash,
396 original_message,
397 diff_summary,
398 chunk_amendments,
399 );
400
401 self.validate_prompt_budget(&system_prompt, &user_prompt)?;
402
403 let content = self
404 .send_with_optional_schema(
405 &system_prompt,
406 &user_prompt,
407 self.schema_if_supported(response_schema::amendment_file_schema()),
408 )
409 .await
410 .context("Merge pass failed for chunk amendments")?;
411
412 let amendment_file = self
413 .parse_amendment_response(&content)
414 .context("Failed to parse merge pass response")?;
415
416 amendment_file
417 .amendments
418 .into_iter()
419 .next()
420 .context("Merge pass returned no amendments")
421 }
422
423 async fn generate_amendment_for_commit(
430 &self,
431 commit: &crate::git::CommitInfo,
432 repo_view_for_ai: &RepositoryViewForAI,
433 system_prompt: &str,
434 build_user_prompt: &(dyn Fn(&str) -> String + Sync),
435 fresh: bool,
436 ) -> Result<Amendment> {
437 let mut ai_commit = crate::git::commit::CommitInfoForAI::from_commit_info(commit.clone())?;
438 if fresh {
439 ai_commit.base.original_message =
440 "(Original message hidden - generate fresh message from diff)".to_string();
441 }
442 let single_view = repo_view_for_ai.single_commit_view_for_ai(&ai_commit);
443
444 match self.try_full_diff_budget(&single_view, system_prompt, build_user_prompt)? {
445 Ok(user_prompt) => {
446 let amendment_file = self
447 .send_and_parse_amendment_with_retry(system_prompt, &user_prompt)
448 .await?;
449 amendment_file
450 .amendments
451 .into_iter()
452 .next()
453 .context("AI returned no amendments for commit")
454 }
455 Err(exceeded) => {
456 if commit.analysis.file_diffs.is_empty() {
457 anyhow::bail!(
458 "Token budget exceeded for commit {} but no file-level diffs available for split dispatch",
459 &commit.hash[..8]
460 );
461 }
462 self.generate_amendment_split(
463 commit,
464 repo_view_for_ai,
465 system_prompt,
466 build_user_prompt,
467 exceeded.available_input_tokens,
468 fresh,
469 )
470 .await
471 }
472 }
473 }
474
475 async fn check_commit_split(
483 &self,
484 commit: &crate::git::CommitInfo,
485 repo_view: &RepositoryView,
486 system_prompt: &str,
487 valid_scopes: &[crate::data::context::ScopeDefinition],
488 include_suggestions: bool,
489 available_input_tokens: usize,
490 ) -> Result<crate::data::check::CheckReport> {
491 use crate::claude::batch::{
492 PER_COMMIT_METADATA_OVERHEAD_TOKENS, USER_PROMPT_TEMPLATE_OVERHEAD_TOKENS,
493 VIEW_ENVELOPE_OVERHEAD_TOKENS,
494 };
495 use crate::claude::diff_pack::pack_file_diffs;
496 use crate::claude::token_budget;
497 use crate::data::check::{CommitCheckResult, CommitIssue, IssueSeverity};
498 use crate::git::commit::CommitInfoForAI;
499
500 let system_prompt_tokens = token_budget::estimate_tokens(system_prompt);
508 let commit_text_tokens = token_budget::estimate_tokens(&commit.original_message)
509 + token_budget::estimate_tokens(&commit.analysis.diff_summary);
510 let chunk_capacity = available_input_tokens
511 .saturating_sub(system_prompt_tokens)
512 .saturating_sub(VIEW_ENVELOPE_OVERHEAD_TOKENS)
513 .saturating_sub(PER_COMMIT_METADATA_OVERHEAD_TOKENS)
514 .saturating_sub(USER_PROMPT_TEMPLATE_OVERHEAD_TOKENS)
515 .saturating_sub(commit_text_tokens);
516
517 debug!(
518 commit = %&commit.hash[..8],
519 available_input_tokens,
520 system_prompt_tokens,
521 envelope_overhead = VIEW_ENVELOPE_OVERHEAD_TOKENS,
522 metadata_overhead = PER_COMMIT_METADATA_OVERHEAD_TOKENS,
523 template_overhead = USER_PROMPT_TEMPLATE_OVERHEAD_TOKENS,
524 commit_text_tokens,
525 chunk_capacity,
526 "Check split dispatch: computed chunk capacity"
527 );
528
529 let plan = pack_file_diffs(&commit.hash, &commit.analysis.file_diffs, chunk_capacity)
530 .with_context(|| {
531 format!(
532 "Failed to plan diff chunks for commit {}",
533 &commit.hash[..8]
534 )
535 })?;
536
537 let total_chunks = plan.chunks.len();
538 debug!(
539 commit = %&commit.hash[..8],
540 chunks = total_chunks,
541 chunk_capacity,
542 "Check split dispatch: processing commit in chunks"
543 );
544
545 let build_user_prompt =
546 |yaml: &str| prompts::generate_check_user_prompt(yaml, include_suggestions);
547
548 let mut chunk_results = Vec::with_capacity(total_chunks);
549 for (i, chunk) in plan.chunks.iter().enumerate() {
550 let mut partial = CommitInfoForAI::from_commit_info_partial_with_overrides(
551 commit.clone(),
552 &chunk.file_paths,
553 &chunk.diff_overrides,
554 )
555 .with_context(|| {
556 format!(
557 "Failed to build partial view for chunk {}/{} of commit {}",
558 i + 1,
559 total_chunks,
560 &commit.hash[..8]
561 )
562 })?;
563
564 partial.run_pre_validation_checks(valid_scopes);
565
566 let partial_view = RepositoryViewForAI::from_repository_view(repo_view.clone())
567 .context("Failed to enhance repository view with diff content")?
568 .single_commit_view_for_ai(&partial);
569
570 let user_prompt =
571 self.build_prompt_fitting_budget(&partial_view, system_prompt, &build_user_prompt)?;
572
573 let content = self
574 .send_with_optional_schema(
575 system_prompt,
576 &user_prompt,
577 self.schema_if_supported(response_schema::check_response_schema()),
578 )
579 .await
580 .with_context(|| {
581 format!(
582 "Check chunk {}/{} failed for commit {}",
583 i + 1,
584 total_chunks,
585 &commit.hash[..8]
586 )
587 })?;
588
589 let report = self
590 .parse_check_response(&content, repo_view)
591 .with_context(|| {
592 format!(
593 "Failed to parse check chunk {}/{} response for commit {}",
594 i + 1,
595 total_chunks,
596 &commit.hash[..8]
597 )
598 })?;
599
600 if let Some(result) = report.commits.into_iter().next() {
601 chunk_results.push(result);
602 }
603 }
604
605 let mut seen = std::collections::HashSet::new();
607 let mut merged_issues: Vec<CommitIssue> = Vec::new();
608 for result in &chunk_results {
609 for issue in &result.issues {
610 let key: (String, IssueSeverity, String) =
611 (issue.rule.clone(), issue.severity, issue.section.clone());
612 if seen.insert(key) {
613 merged_issues.push(issue.clone());
614 }
615 }
616 }
617
618 let passes = chunk_results.iter().all(|r| r.passes);
619
620 let has_suggestions = chunk_results.iter().any(|r| r.suggestion.is_some());
622
623 let (merged_suggestion, merged_summary) = if has_suggestions {
624 self.merge_check_chunks(
625 &commit.hash,
626 &commit.original_message,
627 &commit.analysis.diff_summary,
628 passes,
629 &chunk_results,
630 repo_view,
631 )
632 .await?
633 } else {
634 let summary = chunk_results.iter().find_map(|r| r.summary.clone());
636 (None, summary)
637 };
638
639 let original_message = commit
640 .original_message
641 .lines()
642 .next()
643 .unwrap_or("")
644 .to_string();
645
646 let merged_result = CommitCheckResult {
647 hash: commit.hash.clone(),
648 message: original_message,
649 issues: merged_issues,
650 suggestion: merged_suggestion,
651 passes,
652 summary: merged_summary,
653 };
654
655 Ok(crate::data::check::CheckReport::new(vec![merged_result]))
656 }
657
658 async fn merge_check_chunks(
663 &self,
664 commit_hash: &str,
665 original_message: &str,
666 diff_summary: &str,
667 passes: bool,
668 chunk_results: &[crate::data::check::CommitCheckResult],
669 repo_view: &RepositoryView,
670 ) -> Result<(Option<crate::data::check::CommitSuggestion>, Option<String>)> {
671 let suggestions: Vec<&crate::data::check::CommitSuggestion> = chunk_results
672 .iter()
673 .filter_map(|r| r.suggestion.as_ref())
674 .collect();
675
676 let summaries: Vec<Option<&str>> =
677 chunk_results.iter().map(|r| r.summary.as_deref()).collect();
678
679 let system_prompt =
680 self.adjusted_system_prompt(prompts::CHECK_CHUNK_MERGE_SYSTEM_PROMPT.to_string());
681 let user_prompt = prompts::generate_check_chunk_merge_user_prompt(
682 commit_hash,
683 original_message,
684 diff_summary,
685 passes,
686 &suggestions,
687 &summaries,
688 );
689
690 self.validate_prompt_budget(&system_prompt, &user_prompt)?;
691
692 let content = self
693 .send_with_optional_schema(
694 &system_prompt,
695 &user_prompt,
696 self.schema_if_supported(response_schema::check_response_schema()),
697 )
698 .await
699 .context("Merge pass failed for check chunk suggestions")?;
700
701 let report = self
702 .parse_check_response(&content, repo_view)
703 .context("Failed to parse check merge pass response")?;
704
705 let result = report.commits.into_iter().next();
706 Ok(match result {
707 Some(r) => (r.suggestion, r.summary),
708 None => (None, None),
709 })
710 }
711
712 pub async fn send_message(&self, system_prompt: &str, user_prompt: &str) -> Result<String> {
714 self.validate_prompt_budget(system_prompt, user_prompt)?;
715 self.ai_client
716 .send_request(system_prompt, user_prompt)
717 .await
718 }
719
720 pub fn from_env(model: String) -> Result<Self> {
722 let api_key = std::env::var("CLAUDE_API_KEY")
724 .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
725 .map_err(|_| ClaudeError::ApiKeyNotFound)?;
726
727 let ai_client = ClaudeAiClient::new(model, api_key, None)?;
728 Ok(Self::new(Box::new(ai_client)))
729 }
730
731 pub async fn generate_amendments(&self, repo_view: &RepositoryView) -> Result<AmendmentFile> {
733 self.generate_amendments_with_options(repo_view, false)
734 .await
735 }
736
737 pub async fn generate_amendments_with_options(
748 &self,
749 repo_view: &RepositoryView,
750 fresh: bool,
751 ) -> Result<AmendmentFile> {
752 let ai_repo_view =
754 RepositoryViewForAI::from_repository_view_with_options(repo_view.clone(), fresh)
755 .context("Failed to enhance repository view with diff content")?;
756
757 let system_prompt = self.adjusted_system_prompt(prompts::SYSTEM_PROMPT.to_string());
758 let build_user_prompt = |yaml: &str| prompts::generate_user_prompt(yaml);
759
760 match self.try_full_diff_budget(&ai_repo_view, &system_prompt, &build_user_prompt)? {
762 Ok(user_prompt) => {
763 self.send_and_parse_amendment_with_retry(&system_prompt, &user_prompt)
764 .await
765 }
766 Err(_exceeded) => {
767 let mut amendments = Vec::new();
768 for commit in &repo_view.commits {
769 let amendment = self
770 .generate_amendment_for_commit(
771 commit,
772 &ai_repo_view,
773 &system_prompt,
774 &build_user_prompt,
775 fresh,
776 )
777 .await?;
778 amendments.push(amendment);
779 }
780 Ok(AmendmentFile { amendments })
781 }
782 }
783 }
784
785 pub async fn generate_contextual_amendments(
787 &self,
788 repo_view: &RepositoryView,
789 context: &CommitContext,
790 ) -> Result<AmendmentFile> {
791 self.generate_contextual_amendments_with_options(repo_view, context, false)
792 .await
793 }
794
795 pub async fn generate_contextual_amendments_with_options(
805 &self,
806 repo_view: &RepositoryView,
807 context: &CommitContext,
808 fresh: bool,
809 ) -> Result<AmendmentFile> {
810 let ai_repo_view =
812 RepositoryViewForAI::from_repository_view_with_options(repo_view.clone(), fresh)
813 .context("Failed to enhance repository view with diff content")?;
814
815 let prompt_style = self.ai_client.get_metadata().prompt_style();
817 let system_prompt = self.adjusted_system_prompt(
818 prompts::generate_contextual_system_prompt_for_provider(context, prompt_style),
819 );
820
821 match &context.project.commit_guidelines {
823 Some(guidelines) => {
824 debug!(length = guidelines.len(), "Project commit guidelines found");
825 debug!(guidelines = %guidelines, "Commit guidelines content");
826 }
827 None => {
828 debug!("No project commit guidelines found");
829 }
830 }
831
832 let build_user_prompt =
833 |yaml: &str| prompts::generate_contextual_user_prompt(yaml, context);
834
835 match self.try_full_diff_budget(&ai_repo_view, &system_prompt, &build_user_prompt)? {
837 Ok(user_prompt) => {
838 self.send_and_parse_amendment_with_retry(&system_prompt, &user_prompt)
839 .await
840 }
841 Err(_exceeded) => {
842 let mut amendments = Vec::new();
843 for commit in &repo_view.commits {
844 let amendment = self
845 .generate_amendment_for_commit(
846 commit,
847 &ai_repo_view,
848 &system_prompt,
849 &build_user_prompt,
850 fresh,
851 )
852 .await?;
853 amendments.push(amendment);
854 }
855 Ok(AmendmentFile { amendments })
856 }
857 }
858 }
859
860 fn parse_amendment_response(&self, content: &str) -> Result<AmendmentFile> {
862 let yaml_content = self.extract_yaml_from_response(content);
864
865 let amendment_file: AmendmentFile = crate::data::from_yaml(&yaml_content).map_err(|e| {
867 debug!(
868 error = %e,
869 content_length = content.len(),
870 yaml_length = yaml_content.len(),
871 "YAML parsing failed"
872 );
873 debug!(content = %content, "Raw Claude response");
874 debug!(yaml = %yaml_content, "Extracted YAML content");
875
876 if yaml_content.lines().any(|line| line.contains('\t')) {
878 ClaudeError::AmendmentParsingFailed("YAML parsing error: Found tab characters. YAML requires spaces for indentation.".to_string())
879 } else if yaml_content.lines().any(|line| line.trim().starts_with('-') && !line.trim().starts_with("- ")) {
880 ClaudeError::AmendmentParsingFailed("YAML parsing error: List items must have a space after the dash (- item).".to_string())
881 } else {
882 ClaudeError::AmendmentParsingFailed(format!("YAML parsing error: {e}"))
883 }
884 })?;
885
886 amendment_file
888 .validate()
889 .map_err(|e| ClaudeError::AmendmentParsingFailed(format!("Validation error: {e}")))?;
890
891 Ok(amendment_file)
892 }
893
894 async fn send_and_parse_amendment_with_retry(
902 &self,
903 system_prompt: &str,
904 user_prompt: &str,
905 ) -> Result<AmendmentFile> {
906 let mut last_error = None;
907 for attempt in 0..=AMENDMENT_PARSE_MAX_RETRIES {
908 match self
909 .send_with_optional_schema(
910 system_prompt,
911 user_prompt,
912 self.schema_if_supported(response_schema::amendment_file_schema()),
913 )
914 .await
915 {
916 Ok(content) => match self.parse_amendment_response(&content) {
917 Ok(amendment_file) => return Ok(amendment_file),
918 Err(e) => {
919 if attempt < AMENDMENT_PARSE_MAX_RETRIES {
920 eprintln!(
921 "warning: failed to parse amendment response (attempt {}), retrying...",
922 attempt + 1
923 );
924 debug!(error = %e, attempt = attempt + 1, "Amendment response parse failed, retrying");
925 }
926 last_error = Some(e);
927 }
928 },
929 Err(e) => {
930 if attempt < AMENDMENT_PARSE_MAX_RETRIES {
931 eprintln!(
932 "warning: AI request failed (attempt {}), retrying...",
933 attempt + 1
934 );
935 debug!(error = %e, attempt = attempt + 1, "AI request failed, retrying");
936 }
937 last_error = Some(e);
938 }
939 }
940 }
941 Err(last_error
942 .unwrap_or_else(|| anyhow::anyhow!("Amendment generation failed after retries")))
943 }
944
945 fn parse_pr_response(&self, content: &str) -> Result<crate::cli::git::PrContent> {
947 let yaml_content = content.trim();
948 crate::data::from_yaml(yaml_content)
949 .context("Failed to parse AI response as YAML. AI may have returned malformed output.")
950 }
951
952 async fn generate_pr_content_split(
959 &self,
960 commit: &crate::git::CommitInfo,
961 repo_view_for_ai: &RepositoryViewForAI,
962 system_prompt: &str,
963 build_user_prompt: &(dyn Fn(&str) -> String + Sync),
964 available_input_tokens: usize,
965 pr_template: &str,
966 ) -> Result<crate::cli::git::PrContent> {
967 use crate::claude::batch::{
968 PER_COMMIT_METADATA_OVERHEAD_TOKENS, USER_PROMPT_TEMPLATE_OVERHEAD_TOKENS,
969 VIEW_ENVELOPE_OVERHEAD_TOKENS,
970 };
971 use crate::claude::diff_pack::pack_file_diffs;
972 use crate::claude::token_budget;
973 use crate::git::commit::CommitInfoForAI;
974
975 let system_prompt_tokens = token_budget::estimate_tokens(system_prompt);
983 let commit_text_tokens = token_budget::estimate_tokens(&commit.original_message)
984 + token_budget::estimate_tokens(&commit.analysis.diff_summary);
985 let chunk_capacity = available_input_tokens
986 .saturating_sub(system_prompt_tokens)
987 .saturating_sub(VIEW_ENVELOPE_OVERHEAD_TOKENS)
988 .saturating_sub(PER_COMMIT_METADATA_OVERHEAD_TOKENS)
989 .saturating_sub(USER_PROMPT_TEMPLATE_OVERHEAD_TOKENS)
990 .saturating_sub(commit_text_tokens);
991
992 debug!(
993 commit = %&commit.hash[..8],
994 available_input_tokens,
995 system_prompt_tokens,
996 envelope_overhead = VIEW_ENVELOPE_OVERHEAD_TOKENS,
997 metadata_overhead = PER_COMMIT_METADATA_OVERHEAD_TOKENS,
998 template_overhead = USER_PROMPT_TEMPLATE_OVERHEAD_TOKENS,
999 commit_text_tokens,
1000 chunk_capacity,
1001 "PR split dispatch: computed chunk capacity"
1002 );
1003
1004 let plan = pack_file_diffs(&commit.hash, &commit.analysis.file_diffs, chunk_capacity)
1005 .with_context(|| {
1006 format!(
1007 "Failed to plan diff chunks for commit {}",
1008 &commit.hash[..8]
1009 )
1010 })?;
1011
1012 let total_chunks = plan.chunks.len();
1013 debug!(
1014 commit = %&commit.hash[..8],
1015 chunks = total_chunks,
1016 chunk_capacity,
1017 "PR split dispatch: processing commit in chunks"
1018 );
1019
1020 let mut chunk_contents = Vec::with_capacity(total_chunks);
1021 for (i, chunk) in plan.chunks.iter().enumerate() {
1022 let partial = CommitInfoForAI::from_commit_info_partial_with_overrides(
1023 commit.clone(),
1024 &chunk.file_paths,
1025 &chunk.diff_overrides,
1026 )
1027 .with_context(|| {
1028 format!(
1029 "Failed to build partial view for chunk {}/{} of commit {}",
1030 i + 1,
1031 total_chunks,
1032 &commit.hash[..8]
1033 )
1034 })?;
1035
1036 let partial_view = repo_view_for_ai.single_commit_view_for_ai(&partial);
1037
1038 let user_prompt =
1039 self.build_prompt_fitting_budget(&partial_view, system_prompt, build_user_prompt)?;
1040
1041 let content = self
1042 .send_with_optional_schema(
1043 system_prompt,
1044 &user_prompt,
1045 self.schema_if_supported(response_schema::pr_content_schema()),
1046 )
1047 .await
1048 .with_context(|| {
1049 format!(
1050 "PR chunk {}/{} failed for commit {}",
1051 i + 1,
1052 total_chunks,
1053 &commit.hash[..8]
1054 )
1055 })?;
1056
1057 let pr_content = self.parse_pr_response(&content).with_context(|| {
1058 format!(
1059 "Failed to parse PR chunk {}/{} response for commit {}",
1060 i + 1,
1061 total_chunks,
1062 &commit.hash[..8]
1063 )
1064 })?;
1065
1066 chunk_contents.push(pr_content);
1067 }
1068
1069 self.merge_pr_content_chunks(&chunk_contents, pr_template)
1070 .await
1071 }
1072
1073 async fn merge_pr_content_chunks(
1076 &self,
1077 partial_contents: &[crate::cli::git::PrContent],
1078 pr_template: &str,
1079 ) -> Result<crate::cli::git::PrContent> {
1080 let system_prompt =
1081 self.adjusted_system_prompt(prompts::PR_CONTENT_MERGE_SYSTEM_PROMPT.to_string());
1082 let user_prompt =
1083 prompts::generate_pr_content_merge_user_prompt(partial_contents, pr_template);
1084
1085 self.validate_prompt_budget(&system_prompt, &user_prompt)?;
1086
1087 let content = self
1088 .send_with_optional_schema(
1089 &system_prompt,
1090 &user_prompt,
1091 self.schema_if_supported(response_schema::pr_content_schema()),
1092 )
1093 .await
1094 .context("Merge pass failed for PR content chunks")?;
1095
1096 self.parse_pr_response(&content)
1097 .context("Failed to parse PR content merge pass response")
1098 }
1099
1100 async fn generate_pr_content_for_commit(
1102 &self,
1103 commit: &crate::git::CommitInfo,
1104 repo_view_for_ai: &RepositoryViewForAI,
1105 system_prompt: &str,
1106 build_user_prompt: &(dyn Fn(&str) -> String + Sync),
1107 pr_template: &str,
1108 ) -> Result<crate::cli::git::PrContent> {
1109 let ai_commit = crate::git::commit::CommitInfoForAI::from_commit_info(commit.clone())?;
1110 let single_view = repo_view_for_ai.single_commit_view_for_ai(&ai_commit);
1111
1112 match self.try_full_diff_budget(&single_view, system_prompt, build_user_prompt)? {
1113 Ok(user_prompt) => {
1114 let content = self
1115 .send_with_optional_schema(
1116 system_prompt,
1117 &user_prompt,
1118 self.schema_if_supported(response_schema::pr_content_schema()),
1119 )
1120 .await?;
1121 self.parse_pr_response(&content)
1122 }
1123 Err(exceeded) => {
1124 if commit.analysis.file_diffs.is_empty() {
1125 anyhow::bail!(
1126 "Token budget exceeded for commit {} but no file-level diffs available for split dispatch",
1127 &commit.hash[..8]
1128 );
1129 }
1130 self.generate_pr_content_split(
1131 commit,
1132 repo_view_for_ai,
1133 system_prompt,
1134 build_user_prompt,
1135 exceeded.available_input_tokens,
1136 pr_template,
1137 )
1138 .await
1139 }
1140 }
1141 }
1142
1143 pub async fn generate_pr_content(
1145 &self,
1146 repo_view: &RepositoryView,
1147 pr_template: &str,
1148 ) -> Result<crate::cli::git::PrContent> {
1149 let ai_repo_view = RepositoryViewForAI::from_repository_view(repo_view.clone())
1151 .context("Failed to enhance repository view with diff content")?;
1152
1153 let system_prompt =
1154 self.adjusted_system_prompt(prompts::PR_GENERATION_SYSTEM_PROMPT.to_string());
1155
1156 let build_user_prompt =
1157 |yaml: &str| prompts::generate_pr_description_prompt(yaml, pr_template);
1158
1159 match self.try_full_diff_budget(&ai_repo_view, &system_prompt, &build_user_prompt)? {
1161 Ok(user_prompt) => {
1162 let content = self
1163 .send_with_optional_schema(
1164 &system_prompt,
1165 &user_prompt,
1166 self.schema_if_supported(response_schema::pr_content_schema()),
1167 )
1168 .await?;
1169 self.parse_pr_response(&content)
1170 }
1171 Err(_exceeded) => {
1172 let mut per_commit_contents = Vec::new();
1173 for commit in &repo_view.commits {
1174 let pr = self
1175 .generate_pr_content_for_commit(
1176 commit,
1177 &ai_repo_view,
1178 &system_prompt,
1179 &build_user_prompt,
1180 pr_template,
1181 )
1182 .await?;
1183 per_commit_contents.push(pr);
1184 }
1185 if per_commit_contents.len() == 1 {
1186 return per_commit_contents
1187 .into_iter()
1188 .next()
1189 .context("Per-commit PR contents unexpectedly empty");
1190 }
1191 self.merge_pr_content_chunks(&per_commit_contents, pr_template)
1192 .await
1193 }
1194 }
1195 }
1196
1197 pub async fn generate_pr_content_with_context(
1199 &self,
1200 repo_view: &RepositoryView,
1201 pr_template: &str,
1202 context: &crate::data::context::CommitContext,
1203 ) -> Result<crate::cli::git::PrContent> {
1204 let ai_repo_view = RepositoryViewForAI::from_repository_view(repo_view.clone())
1206 .context("Failed to enhance repository view with diff content")?;
1207
1208 let prompt_style = self.ai_client.get_metadata().prompt_style();
1210 let system_prompt = self.adjusted_system_prompt(
1211 prompts::generate_pr_system_prompt_with_context_for_provider(context, prompt_style),
1212 );
1213
1214 let build_user_prompt = |yaml: &str| {
1215 prompts::generate_pr_description_prompt_with_context(yaml, pr_template, context)
1216 };
1217
1218 match self.try_full_diff_budget(&ai_repo_view, &system_prompt, &build_user_prompt)? {
1220 Ok(user_prompt) => {
1221 let content = self
1222 .send_with_optional_schema(
1223 &system_prompt,
1224 &user_prompt,
1225 self.schema_if_supported(response_schema::pr_content_schema()),
1226 )
1227 .await?;
1228
1229 debug!(
1230 content_length = content.len(),
1231 "Received AI response for PR content"
1232 );
1233
1234 let pr_content = self.parse_pr_response(&content)?;
1235
1236 debug!(
1237 parsed_title = %pr_content.title,
1238 parsed_description_length = pr_content.description.len(),
1239 parsed_description_preview = %pr_content.description.lines().take(3).collect::<Vec<_>>().join("\\n"),
1240 "Successfully parsed PR content from YAML"
1241 );
1242
1243 Ok(pr_content)
1244 }
1245 Err(_exceeded) => {
1246 let mut per_commit_contents = Vec::new();
1247 for commit in &repo_view.commits {
1248 let pr = self
1249 .generate_pr_content_for_commit(
1250 commit,
1251 &ai_repo_view,
1252 &system_prompt,
1253 &build_user_prompt,
1254 pr_template,
1255 )
1256 .await?;
1257 per_commit_contents.push(pr);
1258 }
1259 if per_commit_contents.len() == 1 {
1260 return per_commit_contents
1261 .into_iter()
1262 .next()
1263 .context("Per-commit PR contents unexpectedly empty");
1264 }
1265 self.merge_pr_content_chunks(&per_commit_contents, pr_template)
1266 .await
1267 }
1268 }
1269 }
1270
1271 pub async fn check_commits(
1276 &self,
1277 repo_view: &RepositoryView,
1278 guidelines: Option<&str>,
1279 include_suggestions: bool,
1280 ) -> Result<crate::data::check::CheckReport> {
1281 self.check_commits_with_scopes(repo_view, guidelines, &[], include_suggestions)
1282 .await
1283 }
1284
1285 pub async fn check_commits_with_scopes(
1290 &self,
1291 repo_view: &RepositoryView,
1292 guidelines: Option<&str>,
1293 valid_scopes: &[crate::data::context::ScopeDefinition],
1294 include_suggestions: bool,
1295 ) -> Result<crate::data::check::CheckReport> {
1296 self.check_commits_with_retry(repo_view, guidelines, valid_scopes, include_suggestions, 2)
1297 .await
1298 }
1299
1300 async fn check_commits_with_retry(
1308 &self,
1309 repo_view: &RepositoryView,
1310 guidelines: Option<&str>,
1311 valid_scopes: &[crate::data::context::ScopeDefinition],
1312 include_suggestions: bool,
1313 max_retries: u32,
1314 ) -> Result<crate::data::check::CheckReport> {
1315 let system_prompt = self.adjusted_system_prompt(
1317 prompts::generate_check_system_prompt_with_scopes(guidelines, valid_scopes),
1318 );
1319
1320 let build_user_prompt =
1321 |yaml: &str| prompts::generate_check_user_prompt(yaml, include_suggestions);
1322
1323 let mut ai_repo_view = RepositoryViewForAI::from_repository_view(repo_view.clone())
1324 .context("Failed to enhance repository view with diff content")?;
1325 for commit in &mut ai_repo_view.commits {
1326 commit.run_pre_validation_checks(valid_scopes);
1327 }
1328
1329 match self.try_full_diff_budget(&ai_repo_view, &system_prompt, &build_user_prompt)? {
1331 Ok(user_prompt) => {
1332 let mut last_error = None;
1334 for attempt in 0..=max_retries {
1335 match self
1336 .send_with_optional_schema(
1337 &system_prompt,
1338 &user_prompt,
1339 self.schema_if_supported(response_schema::check_response_schema()),
1340 )
1341 .await
1342 {
1343 Ok(content) => match self.parse_check_response(&content, repo_view) {
1344 Ok(report) => return Ok(report),
1345 Err(e) => {
1346 if attempt < max_retries {
1347 eprintln!(
1348 "warning: failed to parse AI response (attempt {}), retrying...",
1349 attempt + 1
1350 );
1351 debug!(error = %e, attempt = attempt + 1, "Check response parse failed, retrying");
1352 }
1353 last_error = Some(e);
1354 }
1355 },
1356 Err(e) => {
1357 if attempt < max_retries {
1358 eprintln!(
1359 "warning: AI request failed (attempt {}), retrying...",
1360 attempt + 1
1361 );
1362 debug!(error = %e, attempt = attempt + 1, "AI request failed, retrying");
1363 }
1364 last_error = Some(e);
1365 }
1366 }
1367 }
1368 Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Check failed after retries")))
1369 }
1370 Err(_exceeded) => {
1371 let mut all_results = Vec::new();
1373 for commit in &repo_view.commits {
1374 let single_view = repo_view.single_commit_view(commit);
1375 let mut single_ai_view =
1376 RepositoryViewForAI::from_repository_view(single_view.clone())
1377 .context("Failed to enhance single-commit view with diff content")?;
1378 for c in &mut single_ai_view.commits {
1379 c.run_pre_validation_checks(valid_scopes);
1380 }
1381
1382 match self.try_full_diff_budget(
1383 &single_ai_view,
1384 &system_prompt,
1385 &build_user_prompt,
1386 )? {
1387 Ok(user_prompt) => {
1388 let content = self
1389 .send_with_optional_schema(
1390 &system_prompt,
1391 &user_prompt,
1392 self.schema_if_supported(
1393 response_schema::check_response_schema(),
1394 ),
1395 )
1396 .await?;
1397 let report = self.parse_check_response(&content, &single_view)?;
1398 all_results.extend(report.commits);
1399 }
1400 Err(exceeded) => {
1401 if commit.analysis.file_diffs.is_empty() {
1402 anyhow::bail!(
1403 "Token budget exceeded for commit {} but no file-level diffs available for split dispatch",
1404 &commit.hash[..8]
1405 );
1406 }
1407 let report = self
1408 .check_commit_split(
1409 commit,
1410 &single_view,
1411 &system_prompt,
1412 valid_scopes,
1413 include_suggestions,
1414 exceeded.available_input_tokens,
1415 )
1416 .await?;
1417 all_results.extend(report.commits);
1418 }
1419 }
1420 }
1421 Ok(crate::data::check::CheckReport::new(all_results))
1422 }
1423 }
1424 }
1425
1426 fn parse_check_response(
1428 &self,
1429 content: &str,
1430 repo_view: &RepositoryView,
1431 ) -> Result<crate::data::check::CheckReport> {
1432 use crate::data::check::{
1433 AiCheckResponse, CheckReport, CommitCheckResult as CheckResultType,
1434 };
1435
1436 let yaml_content = self.extract_yaml_from_check_response(content);
1438
1439 let ai_response: AiCheckResponse = crate::data::from_yaml(&yaml_content).map_err(|e| {
1441 debug!(
1442 error = %e,
1443 content_length = content.len(),
1444 yaml_length = yaml_content.len(),
1445 "Check YAML parsing failed"
1446 );
1447 debug!(content = %content, "Raw AI response");
1448 debug!(yaml = %yaml_content, "Extracted YAML content");
1449 ClaudeError::AmendmentParsingFailed(format!("Check response parsing error: {e}"))
1450 })?;
1451
1452 let commit_messages: std::collections::HashMap<&str, &str> = repo_view
1454 .commits
1455 .iter()
1456 .map(|c| (c.hash.as_str(), c.original_message.as_str()))
1457 .collect();
1458
1459 let results: Vec<CheckResultType> = ai_response
1461 .checks
1462 .into_iter()
1463 .map(|check| {
1464 let mut result: CheckResultType = check.into();
1465 if let Some(msg) = commit_messages.get(result.hash.as_str()) {
1467 result.message = msg.lines().next().unwrap_or("").to_string();
1468 } else {
1469 for (hash, msg) in &commit_messages {
1471 if hash.starts_with(&result.hash) || result.hash.starts_with(*hash) {
1472 result.message = msg.lines().next().unwrap_or("").to_string();
1473 break;
1474 }
1475 }
1476 }
1477 result
1478 })
1479 .collect();
1480
1481 Ok(CheckReport::new(results))
1482 }
1483
1484 fn extract_yaml_from_check_response(&self, content: &str) -> String {
1486 let content = content.trim();
1487
1488 if content.starts_with("checks:") {
1490 return content.to_string();
1491 }
1492
1493 if let Some(yaml_start) = content.find("```yaml") {
1495 if let Some(yaml_content) = content[yaml_start + 7..].split("```").next() {
1496 return yaml_content.trim().to_string();
1497 }
1498 }
1499
1500 if let Some(code_start) = content.find("```") {
1502 if let Some(code_content) = content[code_start + 3..].split("```").next() {
1503 let potential_yaml = code_content.trim();
1504 if potential_yaml.starts_with("checks:") {
1506 return potential_yaml.to_string();
1507 }
1508 }
1509 }
1510
1511 content.to_string()
1513 }
1514
1515 pub async fn refine_amendments_coherence(
1520 &self,
1521 items: &[(crate::data::amendments::Amendment, String)],
1522 ) -> Result<AmendmentFile> {
1523 let system_prompt =
1524 self.adjusted_system_prompt(prompts::AMENDMENT_COHERENCE_SYSTEM_PROMPT.to_string());
1525 let user_prompt = prompts::generate_amendment_coherence_user_prompt(items);
1526
1527 self.validate_prompt_budget(&system_prompt, &user_prompt)?;
1528
1529 let content = self
1530 .send_with_optional_schema(
1531 &system_prompt,
1532 &user_prompt,
1533 self.schema_if_supported(response_schema::amendment_file_schema()),
1534 )
1535 .await?;
1536
1537 self.parse_amendment_response(&content)
1538 }
1539
1540 pub async fn refine_checks_coherence(
1546 &self,
1547 items: &[(crate::data::check::CommitCheckResult, String)],
1548 repo_view: &RepositoryView,
1549 ) -> Result<crate::data::check::CheckReport> {
1550 let system_prompt =
1551 self.adjusted_system_prompt(prompts::CHECK_COHERENCE_SYSTEM_PROMPT.to_string());
1552 let user_prompt = prompts::generate_check_coherence_user_prompt(items);
1553
1554 self.validate_prompt_budget(&system_prompt, &user_prompt)?;
1555
1556 let content = self
1557 .send_with_optional_schema(
1558 &system_prompt,
1559 &user_prompt,
1560 self.schema_if_supported(response_schema::check_response_schema()),
1561 )
1562 .await?;
1563
1564 self.parse_check_response(&content, repo_view)
1565 }
1566
1567 fn extract_yaml_from_response(&self, content: &str) -> String {
1569 let content = content.trim();
1570
1571 if content.starts_with("amendments:") {
1573 return content.to_string();
1574 }
1575
1576 if let Some(yaml_start) = content.find("```yaml") {
1578 if let Some(yaml_content) = content[yaml_start + 7..].split("```").next() {
1579 return yaml_content.trim().to_string();
1580 }
1581 }
1582
1583 if let Some(code_start) = content.find("```") {
1585 if let Some(code_content) = content[code_start + 3..].split("```").next() {
1586 let potential_yaml = code_content.trim();
1587 if potential_yaml.starts_with("amendments:") {
1589 return potential_yaml.to_string();
1590 }
1591 }
1592 }
1593
1594 content.to_string()
1596 }
1597}
1598
1599fn validate_beta_header(model: &str, beta_header: &Option<(String, String)>) -> Result<()> {
1601 if let Some((ref key, ref value)) = beta_header {
1602 let registry = crate::claude::model_config::get_model_registry();
1603 let supported = registry.get_beta_headers(model);
1604 if !supported
1605 .iter()
1606 .any(|bh| bh.key == *key && bh.value == *value)
1607 {
1608 let available: Vec<String> = supported
1609 .iter()
1610 .map(|bh| format!("{}:{}", bh.key, bh.value))
1611 .collect();
1612 if available.is_empty() {
1613 anyhow::bail!("Model '{model}' does not support any beta headers");
1614 }
1615 anyhow::bail!(
1616 "Beta header '{key}:{value}' is not supported for model '{model}'. Supported: {}",
1617 available.join(", ")
1618 );
1619 }
1620 }
1621 Ok(())
1622}
1623
1624pub async fn create_default_claude_client(
1631 model: Option<String>,
1632 beta_header: Option<(String, String)>,
1633) -> Result<ClaudeClient> {
1634 use crate::claude::ai::claude_cli::ClaudeCliAiClient;
1635 use crate::claude::ai::openai::OpenAiAiClient;
1636 use crate::utils::settings::{get_env_var, get_env_vars};
1637
1638 let ai_backend = get_env_var("OMNI_DEV_AI_BACKEND").ok();
1643 let use_claude_cli = ai_backend
1644 .as_deref()
1645 .is_some_and(|v| matches!(v, "claude-cli" | "claude_cli"));
1646
1647 if use_claude_cli {
1648 if beta_header.is_some() {
1649 warn!(
1650 "--beta-header is ignored when OMNI_DEV_AI_BACKEND=claude-cli \
1651 (the CLI's --betas flag has different semantics and is not forwarded)"
1652 );
1653 }
1654 let registry = crate::claude::model_config::get_model_registry();
1655 let cli_model = model
1656 .or_else(|| get_env_var("CLAUDE_MODEL").ok())
1657 .or_else(|| get_env_var("CLAUDE_CODE_MODEL").ok())
1658 .or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
1659 .unwrap_or_else(|| {
1660 registry
1661 .get_default_model("claude")
1662 .unwrap_or("claude-sonnet-4-6")
1663 .to_string()
1664 });
1665 debug!(model = %cli_model, "Creating claude -p subprocess client");
1666 let ai_client = ClaudeCliAiClient::new(cli_model);
1667 return Ok(ClaudeClient::new(Box::new(ai_client)));
1668 }
1669
1670 let use_openai = get_env_var("USE_OPENAI").is_ok_and(|val| val == "true");
1672
1673 let use_ollama = get_env_var("USE_OLLAMA").is_ok_and(|val| val == "true");
1674
1675 let use_bedrock = get_env_var("CLAUDE_CODE_USE_BEDROCK").is_ok_and(|val| val == "true");
1677
1678 debug!(
1679 use_openai = use_openai,
1680 use_ollama = use_ollama,
1681 use_bedrock = use_bedrock,
1682 "Client selection flags"
1683 );
1684
1685 let registry = crate::claude::model_config::get_model_registry();
1686
1687 if use_ollama {
1689 let ollama_model = model
1690 .or_else(|| get_env_var("OLLAMA_MODEL").ok())
1691 .unwrap_or_else(|| "llama2".to_string());
1692 validate_beta_header(&ollama_model, &beta_header)?;
1693 let base_url = get_env_var("OLLAMA_BASE_URL").ok();
1694 let mut ai_client = OpenAiAiClient::new_ollama(ollama_model, base_url, beta_header)?;
1695 match ai_client.probe_loaded_context_length().await {
1696 Some(source) => {
1697 info!(
1698 loaded_context_length = ai_client.loaded_context_length(),
1699 source = source.as_str(),
1700 model = %ai_client.get_metadata().model,
1701 "Probed loaded context length from local server"
1702 );
1703 }
1704 None => {
1705 debug!(
1706 "Loaded context length probe did not return a value; \
1707 falling back to registry/default for token budget"
1708 );
1709 }
1710 }
1711 return Ok(ClaudeClient::new(Box::new(ai_client)));
1712 }
1713
1714 if use_openai {
1716 debug!("Creating OpenAI client");
1717 let openai_model = model
1718 .or_else(|| get_env_var("OPENAI_MODEL").ok())
1719 .unwrap_or_else(|| {
1720 registry
1721 .get_default_model("openai")
1722 .unwrap_or("gpt-5")
1723 .to_string()
1724 });
1725 debug!(openai_model = %openai_model, "Selected OpenAI model");
1726 validate_beta_header(&openai_model, &beta_header)?;
1727
1728 let api_key = get_env_vars(&["OPENAI_API_KEY", "OPENAI_AUTH_TOKEN"]).map_err(|e| {
1729 debug!(error = ?e, "Failed to get OpenAI API key");
1730 ClaudeError::ApiKeyNotFound
1731 })?;
1732 debug!("OpenAI API key found");
1733
1734 let ai_client = OpenAiAiClient::new_openai(openai_model, api_key, beta_header)?;
1735 debug!("OpenAI client created successfully");
1736 return Ok(ClaudeClient::new(Box::new(ai_client)));
1737 }
1738
1739 let claude_model = model
1741 .or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
1742 .unwrap_or_else(|| {
1743 registry
1744 .get_default_model("claude")
1745 .unwrap_or("claude-sonnet-4-6")
1746 .to_string()
1747 });
1748 validate_beta_header(&claude_model, &beta_header)?;
1749
1750 if use_bedrock {
1751 let auth_token =
1753 get_env_var("ANTHROPIC_AUTH_TOKEN").map_err(|_| ClaudeError::ApiKeyNotFound)?;
1754
1755 let base_url =
1756 get_env_var("ANTHROPIC_BEDROCK_BASE_URL").map_err(|_| ClaudeError::ApiKeyNotFound)?;
1757
1758 let ai_client = BedrockAiClient::new(claude_model, auth_token, base_url, beta_header)?;
1759 return Ok(ClaudeClient::new(Box::new(ai_client)));
1760 }
1761
1762 debug!("Falling back to Claude client");
1764 let api_key = get_env_vars(&[
1765 "CLAUDE_API_KEY",
1766 "ANTHROPIC_API_KEY",
1767 "ANTHROPIC_AUTH_TOKEN",
1768 ])
1769 .map_err(|_| ClaudeError::ApiKeyNotFound)?;
1770
1771 let ai_client = ClaudeAiClient::new(claude_model, api_key, beta_header)?;
1772 debug!("Claude client created successfully");
1773 Ok(ClaudeClient::new(Box::new(ai_client)))
1774}
1775
1776#[cfg(test)]
1777#[allow(
1778 clippy::unwrap_used,
1779 clippy::expect_used,
1780 clippy::format_in_format_args
1781)]
1782mod tests {
1783 use super::*;
1784 use crate::claude::ai::{AiClient, AiClientCapabilities, AiClientMetadata};
1785 use std::future::Future;
1786 use std::pin::Pin;
1787 use std::sync::{Arc, Mutex};
1788
1789 struct MockAiClient;
1791
1792 impl AiClient for MockAiClient {
1793 fn send_request<'a>(
1794 &'a self,
1795 _system_prompt: &'a str,
1796 _user_prompt: &'a str,
1797 ) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>> {
1798 Box::pin(async { Ok(String::new()) })
1799 }
1800
1801 fn get_metadata(&self) -> AiClientMetadata {
1802 AiClientMetadata {
1803 provider: "Mock".to_string(),
1804 model: "mock-model".to_string(),
1805 max_context_length: 200_000,
1806 max_response_length: 8_192,
1807 active_beta: None,
1808 }
1809 }
1810 }
1811
1812 fn make_client() -> ClaudeClient {
1813 ClaudeClient::new(Box::new(MockAiClient))
1814 }
1815
1816 struct SchemaRecordingMockAiClient {
1827 capabilities: AiClientCapabilities,
1828 response: String,
1829 recorded_options: Arc<Mutex<Vec<RequestOptions>>>,
1830 recorded_plain: Arc<Mutex<Vec<(String, String)>>>,
1831 }
1832 impl SchemaRecordingMockAiClient {
1833 fn new(supports_response_schema: bool) -> Self {
1834 Self::with_response(supports_response_schema, String::new())
1835 }
1836
1837 fn with_response(supports_response_schema: bool, response: String) -> Self {
1838 Self {
1839 capabilities: AiClientCapabilities {
1840 supports_response_schema,
1841 },
1842 response,
1843 recorded_options: Arc::new(Mutex::new(Vec::new())),
1844 recorded_plain: Arc::new(Mutex::new(Vec::new())),
1845 }
1846 }
1847 }
1848
1849 impl AiClient for SchemaRecordingMockAiClient {
1850 fn send_request<'a>(
1851 &'a self,
1852 system_prompt: &'a str,
1853 user_prompt: &'a str,
1854 ) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>> {
1855 let plain = self.recorded_plain.clone();
1856 let sys = system_prompt.to_string();
1857 let usr = user_prompt.to_string();
1858 let response = self.response.clone();
1859 Box::pin(async move {
1860 plain.lock().unwrap().push((sys, usr));
1861 Ok(response)
1862 })
1863 }
1864
1865 fn capabilities(&self) -> AiClientCapabilities {
1866 self.capabilities
1867 }
1868
1869 fn send_request_with_options<'a>(
1870 &'a self,
1871 _system_prompt: &'a str,
1872 _user_prompt: &'a str,
1873 options: RequestOptions,
1874 ) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>> {
1875 let recorded = self.recorded_options.clone();
1876 let response = self.response.clone();
1877 Box::pin(async move {
1878 recorded.lock().unwrap().push(options);
1879 Ok(response)
1880 })
1881 }
1882
1883 fn get_metadata(&self) -> AiClientMetadata {
1884 AiClientMetadata {
1885 provider: "SchemaMock".to_string(),
1886 model: "schema-mock".to_string(),
1887 max_context_length: 200_000,
1888 max_response_length: 8_192,
1889 active_beta: None,
1890 }
1891 }
1892 }
1893
1894 #[tokio::test]
1900 async fn send_with_optional_schema_without_caps_uses_plain_send() {
1901 let inner = SchemaRecordingMockAiClient::new(false);
1902 let plain_log = inner.recorded_plain.clone();
1903 let opts_log = inner.recorded_options.clone();
1904 let client = ClaudeClient::new(Box::new(inner));
1905
1906 let schema = serde_json::json!({"type": "object"});
1907 client
1908 .send_with_optional_schema(
1909 "sys",
1910 "usr",
1911 client.schema_if_supported(&schema), )
1913 .await
1914 .unwrap();
1915
1916 assert_eq!(plain_log.lock().unwrap().len(), 1);
1917 assert!(opts_log.lock().unwrap().is_empty());
1918 }
1919
1920 #[tokio::test]
1924 async fn send_with_optional_schema_with_caps_uses_options_send() {
1925 let inner = SchemaRecordingMockAiClient::new(true);
1926 let plain_log = inner.recorded_plain.clone();
1927 let opts_log = inner.recorded_options.clone();
1928 let client = ClaudeClient::new(Box::new(inner));
1929
1930 let schema = serde_json::json!({"type": "object", "additionalProperties": false});
1931 client
1932 .send_with_optional_schema(
1933 "sys",
1934 "usr",
1935 client.schema_if_supported(&schema), )
1937 .await
1938 .unwrap();
1939
1940 let recorded = opts_log.lock().unwrap();
1941 assert_eq!(recorded.len(), 1);
1942 assert_eq!(recorded[0].response_schema.as_ref(), Some(&schema));
1943 assert!(plain_log.lock().unwrap().is_empty());
1944 }
1945
1946 #[test]
1949 fn adjusted_system_prompt_adds_suffix_when_supported() {
1950 let client = ClaudeClient::new(Box::new(SchemaRecordingMockAiClient::new(true)));
1951 let result = client.adjusted_system_prompt("body".to_string());
1952 assert!(result.starts_with("body"));
1953 assert!(result.contains("STRUCTURED OUTPUT OVERRIDE"));
1954 }
1955
1956 #[test]
1957 fn adjusted_system_prompt_passes_through_when_not_supported() {
1958 let client = ClaudeClient::new(Box::new(SchemaRecordingMockAiClient::new(false)));
1959 let result = client.adjusted_system_prompt("body".to_string());
1960 assert_eq!(result, "body");
1961 }
1962
1963 #[test]
1964 fn schema_if_supported_returns_some_when_supported() {
1965 let client = ClaudeClient::new(Box::new(SchemaRecordingMockAiClient::new(true)));
1966 let schema = serde_json::json!({"type": "object"});
1967 let returned = client.schema_if_supported(&schema);
1968 assert!(returned.is_some());
1969 assert!(std::ptr::eq(
1970 std::ptr::from_ref(returned.unwrap()),
1971 std::ptr::addr_of!(schema)
1972 ));
1973 }
1974
1975 #[test]
1976 fn schema_if_supported_returns_none_when_not_supported() {
1977 let client = ClaudeClient::new(Box::new(SchemaRecordingMockAiClient::new(false)));
1978 let schema = serde_json::json!({"type": "object"});
1979 assert!(client.schema_if_supported(&schema).is_none());
1980 }
1981
1982 #[tokio::test]
1989 async fn refine_amendments_coherence_round_trip() {
1990 let mock = SchemaRecordingMockAiClient::with_response(
1991 true, "amendments: []".to_string(),
1993 );
1994 let recorded_opts = mock.recorded_options.clone();
1995 let client = ClaudeClient::new(Box::new(mock));
1996
1997 let amendment = crate::data::amendments::Amendment {
1998 commit: "abc123".to_string(),
1999 message: "feat: do thing".to_string(),
2000 summary: "did the thing".to_string(),
2001 };
2002 let items = vec![(amendment, "summary text".to_string())];
2003
2004 let result = client
2005 .refine_amendments_coherence(&items)
2006 .await
2007 .expect("coherence refinement should succeed");
2008 assert!(result.amendments.is_empty());
2009
2010 let recorded = recorded_opts.lock().unwrap();
2013 assert_eq!(recorded.len(), 1);
2014 let attached = recorded[0]
2015 .response_schema
2016 .as_ref()
2017 .expect("schema must be attached when capability is true");
2018 assert_eq!(
2019 attached,
2020 response_schema::amendment_file_schema(),
2021 "refine_amendments_coherence should attach the AmendmentFile schema"
2022 );
2023 }
2024
2025 #[tokio::test]
2030 async fn refine_checks_coherence_round_trip() {
2031 let mock = SchemaRecordingMockAiClient::with_response(
2032 true, "checks: []".to_string(),
2034 );
2035 let recorded_opts = mock.recorded_options.clone();
2036 let client = ClaudeClient::new(Box::new(mock));
2037
2038 let check = crate::data::check::CommitCheckResult {
2039 hash: "abc123".to_string(),
2040 message: "feat: do thing".to_string(),
2041 issues: Vec::new(),
2042 suggestion: None,
2043 passes: true,
2044 summary: Some("summary".to_string()),
2045 };
2046 let items = vec![(check, "summary text".to_string())];
2047 let dir = tempfile::TempDir::new().unwrap();
2048 let repo_view = make_test_repo_view(&dir);
2049
2050 let result = client
2051 .refine_checks_coherence(&items, &repo_view)
2052 .await
2053 .expect("coherence refinement should succeed");
2054 assert_eq!(result.summary.total_commits, 0);
2055
2056 let recorded = recorded_opts.lock().unwrap();
2057 assert_eq!(recorded.len(), 1);
2058 let attached = recorded[0]
2059 .response_schema
2060 .as_ref()
2061 .expect("schema must be attached when capability is true");
2062 assert_eq!(
2063 attached,
2064 response_schema::check_response_schema(),
2065 "refine_checks_coherence should attach the AiCheckResponse schema"
2066 );
2067 }
2068
2069 #[tokio::test]
2073 async fn refine_amendments_coherence_without_schema_capability() {
2074 let mock = SchemaRecordingMockAiClient::with_response(
2075 false, "amendments: []".to_string(),
2077 );
2078 let recorded_plain = mock.recorded_plain.clone();
2079 let recorded_opts = mock.recorded_options.clone();
2080 let client = ClaudeClient::new(Box::new(mock));
2081
2082 let amendment = crate::data::amendments::Amendment {
2083 commit: "abc123".to_string(),
2084 message: "feat: do thing".to_string(),
2085 summary: String::new(),
2086 };
2087 let items = vec![(amendment, "summary".to_string())];
2088
2089 client
2090 .refine_amendments_coherence(&items)
2091 .await
2092 .expect("coherence refinement should succeed without schema support");
2093
2094 assert_eq!(recorded_plain.lock().unwrap().len(), 1);
2095 assert!(
2096 recorded_opts.lock().unwrap().is_empty(),
2097 "no-schema backend must not be reached via the options path"
2098 );
2099 }
2100
2101 #[test]
2104 fn extract_yaml_pure_amendments() {
2105 let client = make_client();
2106 let content = "amendments:\n - commit: abc123\n message: test";
2107 let result = client.extract_yaml_from_response(content);
2108 assert!(result.starts_with("amendments:"));
2109 }
2110
2111 #[test]
2112 fn extract_yaml_with_markdown_yaml_block() {
2113 let client = make_client();
2114 let content = "Here is the result:\n```yaml\namendments:\n - commit: abc\n```\n";
2115 let result = client.extract_yaml_from_response(content);
2116 assert!(result.starts_with("amendments:"));
2117 }
2118
2119 #[test]
2120 fn extract_yaml_with_generic_code_block() {
2121 let client = make_client();
2122 let content = "```\namendments:\n - commit: abc\n```";
2123 let result = client.extract_yaml_from_response(content);
2124 assert!(result.starts_with("amendments:"));
2125 }
2126
2127 #[test]
2128 fn extract_yaml_with_whitespace() {
2129 let client = make_client();
2130 let content = " \n amendments:\n - commit: abc\n ";
2131 let result = client.extract_yaml_from_response(content);
2132 assert!(result.starts_with("amendments:"));
2133 }
2134
2135 #[test]
2136 fn extract_yaml_fallback_returns_trimmed() {
2137 let client = make_client();
2138 let content = " some random text ";
2139 let result = client.extract_yaml_from_response(content);
2140 assert_eq!(result, "some random text");
2141 }
2142
2143 #[test]
2146 fn extract_check_yaml_pure() {
2147 let client = make_client();
2148 let content = "checks:\n - commit: abc123";
2149 let result = client.extract_yaml_from_check_response(content);
2150 assert!(result.starts_with("checks:"));
2151 }
2152
2153 #[test]
2154 fn extract_check_yaml_markdown_block() {
2155 let client = make_client();
2156 let content = "```yaml\nchecks:\n - commit: abc\n```";
2157 let result = client.extract_yaml_from_check_response(content);
2158 assert!(result.starts_with("checks:"));
2159 }
2160
2161 #[test]
2162 fn extract_check_yaml_generic_block() {
2163 let client = make_client();
2164 let content = "```\nchecks:\n - commit: abc\n```";
2165 let result = client.extract_yaml_from_check_response(content);
2166 assert!(result.starts_with("checks:"));
2167 }
2168
2169 #[test]
2170 fn extract_check_yaml_fallback() {
2171 let client = make_client();
2172 let content = " unexpected content ";
2173 let result = client.extract_yaml_from_check_response(content);
2174 assert_eq!(result, "unexpected content");
2175 }
2176
2177 #[test]
2180 fn parse_amendment_response_valid() {
2181 let client = make_client();
2182 let yaml = format!(
2183 "amendments:\n - commit: \"{}\"\n message: \"test message\"",
2184 "a".repeat(40)
2185 );
2186 let result = client.parse_amendment_response(&yaml);
2187 assert!(result.is_ok());
2188 assert_eq!(result.unwrap().amendments.len(), 1);
2189 }
2190
2191 #[test]
2192 fn parse_amendment_response_invalid_yaml() {
2193 let client = make_client();
2194 let result = client.parse_amendment_response("not: valid: yaml: [{{");
2195 assert!(result.is_err());
2196 }
2197
2198 #[test]
2199 fn parse_amendment_response_invalid_hash() {
2200 let client = make_client();
2201 let yaml = "amendments:\n - commit: \"short\"\n message: \"test\"";
2202 let result = client.parse_amendment_response(yaml);
2203 assert!(result.is_err());
2204 }
2205
2206 #[test]
2209 fn validate_beta_header_none_passes() {
2210 let result = validate_beta_header("claude-opus-4-1-20250805", &None);
2211 assert!(result.is_ok());
2212 }
2213
2214 #[test]
2215 fn validate_beta_header_unsupported_fails() {
2216 let header = Some(("fake-key".to_string(), "fake-value".to_string()));
2217 let result = validate_beta_header("claude-opus-4-1-20250805", &header);
2218 assert!(result.is_err());
2219 }
2220
2221 #[test]
2224 fn client_metadata() {
2225 let client = make_client();
2226 let metadata = client.get_ai_client_metadata();
2227 assert_eq!(metadata.provider, "Mock");
2228 assert_eq!(metadata.model, "mock-model");
2229 }
2230
2231 mod prop {
2234 use super::*;
2235 use proptest::prelude::*;
2236
2237 proptest! {
2238 #[test]
2239 fn yaml_response_output_trimmed(s in ".*") {
2240 let client = make_client();
2241 let result = client.extract_yaml_from_response(&s);
2242 prop_assert_eq!(&result, result.trim());
2243 }
2244
2245 #[test]
2246 fn yaml_response_amendments_prefix_preserved(tail in ".*") {
2247 let client = make_client();
2248 let input = format!("amendments:{tail}");
2249 let result = client.extract_yaml_from_response(&input);
2250 prop_assert!(result.starts_with("amendments:"));
2251 }
2252
2253 #[test]
2254 fn check_response_checks_prefix_preserved(tail in ".*") {
2255 let client = make_client();
2256 let input = format!("checks:{tail}");
2257 let result = client.extract_yaml_from_check_response(&input);
2258 prop_assert!(result.starts_with("checks:"));
2259 }
2260
2261 #[test]
2262 fn yaml_fenced_block_strips_fences(
2263 content in "[a-zA-Z0-9: _\\-\n]{1,100}",
2264 ) {
2265 let client = make_client();
2266 let input = format!("```yaml\n{content}\n```");
2267 let result = client.extract_yaml_from_response(&input);
2268 prop_assert!(!result.contains("```"));
2269 }
2270 }
2271 }
2272
2273 fn make_configurable_client(responses: Vec<Result<String>>) -> ClaudeClient {
2276 ClaudeClient::new(Box::new(
2277 crate::claude::test_utils::ConfigurableMockAiClient::new(responses),
2278 ))
2279 }
2280
2281 fn make_test_repo_view(dir: &tempfile::TempDir) -> crate::data::RepositoryView {
2282 use crate::data::{AiInfo, FieldExplanation, WorkingDirectoryInfo};
2283 use crate::git::commit::FileChanges;
2284 use crate::git::{CommitAnalysis, CommitInfo};
2285
2286 let diff_path = dir.path().join("0.diff");
2287 std::fs::write(&diff_path, "+added line\n").unwrap();
2288
2289 crate::data::RepositoryView {
2290 versions: None,
2291 explanation: FieldExplanation::default(),
2292 working_directory: WorkingDirectoryInfo {
2293 clean: true,
2294 untracked_changes: Vec::new(),
2295 },
2296 remotes: Vec::new(),
2297 ai: AiInfo {
2298 scratch: String::new(),
2299 },
2300 branch_info: None,
2301 pr_template: None,
2302 pr_template_location: None,
2303 branch_prs: None,
2304 commits: vec![CommitInfo {
2305 hash: format!("{:0>40}", 0),
2306 author: "Test <test@test.com>".to_string(),
2307 date: chrono::Utc::now().fixed_offset(),
2308 original_message: "feat(test): add something".to_string(),
2309 in_main_branches: Vec::new(),
2310 analysis: CommitAnalysis {
2311 detected_type: "feat".to_string(),
2312 detected_scope: "test".to_string(),
2313 proposed_message: "feat(test): add something".to_string(),
2314 file_changes: FileChanges {
2315 total_files: 1,
2316 files_added: 1,
2317 files_deleted: 0,
2318 file_list: Vec::new(),
2319 },
2320 diff_summary: "file.rs | 1 +".to_string(),
2321 diff_file: diff_path.to_string_lossy().to_string(),
2322 file_diffs: Vec::new(),
2323 },
2324 }],
2325 }
2326 }
2327
2328 fn valid_check_yaml() -> String {
2329 format!(
2330 "checks:\n - commit: \"{hash}\"\n passes: true\n issues: []\n",
2331 hash = format!("{:0>40}", 0)
2332 )
2333 }
2334
2335 #[tokio::test]
2336 async fn send_message_propagates_ai_error() {
2337 let client = make_configurable_client(vec![Err(anyhow::anyhow!("mock error"))]);
2338 let result = client.send_message("sys", "usr").await;
2339 assert!(result.is_err());
2340 assert!(result.unwrap_err().to_string().contains("mock error"));
2341 }
2342
2343 #[tokio::test]
2344 async fn check_commits_succeeds_after_request_error() {
2345 let dir = tempfile::tempdir().unwrap();
2346 let repo_view = make_test_repo_view(&dir);
2347 let client = make_configurable_client(vec![
2349 Err(anyhow::anyhow!("rate limit")),
2350 Ok(valid_check_yaml()),
2351 Ok(valid_check_yaml()),
2352 ]);
2353 let result = client
2354 .check_commits_with_scopes(&repo_view, None, &[], false)
2355 .await;
2356 assert!(result.is_ok());
2357 }
2358
2359 #[tokio::test]
2360 async fn check_commits_succeeds_after_parse_error() {
2361 let dir = tempfile::tempdir().unwrap();
2362 let repo_view = make_test_repo_view(&dir);
2363 let client = make_configurable_client(vec![
2365 Ok("not: valid: yaml: [[".to_string()),
2366 Ok(valid_check_yaml()),
2367 Ok(valid_check_yaml()),
2368 ]);
2369 let result = client
2370 .check_commits_with_scopes(&repo_view, None, &[], false)
2371 .await;
2372 assert!(result.is_ok());
2373 }
2374
2375 #[tokio::test]
2376 async fn check_commits_fails_after_all_retries_exhausted() {
2377 let dir = tempfile::tempdir().unwrap();
2378 let repo_view = make_test_repo_view(&dir);
2379 let client = make_configurable_client(vec![
2380 Err(anyhow::anyhow!("first failure")),
2381 Err(anyhow::anyhow!("second failure")),
2382 Err(anyhow::anyhow!("final failure")),
2383 ]);
2384 let result = client
2385 .check_commits_with_scopes(&repo_view, None, &[], false)
2386 .await;
2387 assert!(result.is_err());
2388 }
2389
2390 #[tokio::test]
2391 async fn check_commits_fails_when_all_parses_fail() {
2392 let dir = tempfile::tempdir().unwrap();
2393 let repo_view = make_test_repo_view(&dir);
2394 let client = make_configurable_client(vec![
2395 Ok("bad yaml [[".to_string()),
2396 Ok("bad yaml [[".to_string()),
2397 Ok("bad yaml [[".to_string()),
2398 ]);
2399 let result = client
2400 .check_commits_with_scopes(&repo_view, None, &[], false)
2401 .await;
2402 assert!(result.is_err());
2403 }
2404
2405 fn make_small_context_client(responses: Vec<Result<String>>) -> ClaudeClient {
2412 let mock = crate::claude::test_utils::ConfigurableMockAiClient::new(responses)
2416 .with_context_length(50_000);
2417 ClaudeClient::new(Box::new(mock))
2418 }
2419
2420 fn make_small_context_client_tracked(
2423 responses: Vec<Result<String>>,
2424 ) -> (ClaudeClient, crate::claude::test_utils::ResponseQueueHandle) {
2425 let mock = crate::claude::test_utils::ConfigurableMockAiClient::new(responses)
2426 .with_context_length(50_000);
2427 let handle = mock.response_handle();
2428 (ClaudeClient::new(Box::new(mock)), handle)
2429 }
2430
2431 fn make_large_diff_repo_view(dir: &tempfile::TempDir) -> crate::data::RepositoryView {
2434 use crate::data::{AiInfo, FieldExplanation, WorkingDirectoryInfo};
2435 use crate::git::commit::{FileChange, FileChanges, FileDiffRef};
2436 use crate::git::{CommitAnalysis, CommitInfo};
2437
2438 let hash = "a".repeat(40);
2439
2440 let full_diff = "x".repeat(120_000);
2444 let flat_diff_path = dir.path().join("full.diff");
2445 std::fs::write(&flat_diff_path, &full_diff).unwrap();
2446
2447 let diff_a = format!("diff --git a/src/a.rs b/src/a.rs\n{}\n", "a".repeat(30_000));
2450 let diff_b = format!("diff --git a/src/b.rs b/src/b.rs\n{}\n", "b".repeat(30_000));
2451
2452 let path_a = dir.path().join("0000.diff");
2453 let path_b = dir.path().join("0001.diff");
2454 std::fs::write(&path_a, &diff_a).unwrap();
2455 std::fs::write(&path_b, &diff_b).unwrap();
2456
2457 crate::data::RepositoryView {
2458 versions: None,
2459 explanation: FieldExplanation::default(),
2460 working_directory: WorkingDirectoryInfo {
2461 clean: true,
2462 untracked_changes: Vec::new(),
2463 },
2464 remotes: Vec::new(),
2465 ai: AiInfo {
2466 scratch: String::new(),
2467 },
2468 branch_info: None,
2469 pr_template: None,
2470 pr_template_location: None,
2471 branch_prs: None,
2472 commits: vec![CommitInfo {
2473 hash,
2474 author: "Test <test@test.com>".to_string(),
2475 date: chrono::Utc::now().fixed_offset(),
2476 original_message: "feat(test): large commit".to_string(),
2477 in_main_branches: Vec::new(),
2478 analysis: CommitAnalysis {
2479 detected_type: "feat".to_string(),
2480 detected_scope: "test".to_string(),
2481 proposed_message: "feat(test): large commit".to_string(),
2482 file_changes: FileChanges {
2483 total_files: 2,
2484 files_added: 2,
2485 files_deleted: 0,
2486 file_list: vec![
2487 FileChange {
2488 status: "A".to_string(),
2489 file: "src/a.rs".to_string(),
2490 },
2491 FileChange {
2492 status: "A".to_string(),
2493 file: "src/b.rs".to_string(),
2494 },
2495 ],
2496 },
2497 diff_summary: " src/a.rs | 100 ++++\n src/b.rs | 100 ++++\n".to_string(),
2498 diff_file: flat_diff_path.to_string_lossy().to_string(),
2499 file_diffs: vec![
2500 FileDiffRef {
2501 path: "src/a.rs".to_string(),
2502 diff_file: path_a.to_string_lossy().to_string(),
2503 byte_len: diff_a.len(),
2504 },
2505 FileDiffRef {
2506 path: "src/b.rs".to_string(),
2507 diff_file: path_b.to_string_lossy().to_string(),
2508 byte_len: diff_b.len(),
2509 },
2510 ],
2511 },
2512 }],
2513 }
2514 }
2515
2516 fn valid_amendment_yaml(hash: &str, message: &str) -> String {
2517 format!("amendments:\n - commit: \"{hash}\"\n message: \"{message}\"")
2518 }
2519
2520 #[tokio::test]
2521 async fn generate_amendments_split_dispatch() {
2522 let dir = tempfile::tempdir().unwrap();
2523 let repo_view = make_large_diff_repo_view(&dir);
2524 let hash = "a".repeat(40);
2525
2526 let client = make_small_context_client(vec![
2528 Ok(valid_amendment_yaml(&hash, "feat(a): add a.rs")),
2529 Ok(valid_amendment_yaml(&hash, "feat(b): add b.rs")),
2530 Ok(valid_amendment_yaml(&hash, "feat(test): add a.rs and b.rs")),
2531 ]);
2532
2533 let result = client
2534 .generate_amendments_with_options(&repo_view, false)
2535 .await;
2536
2537 assert!(result.is_ok(), "split dispatch failed: {:?}", result.err());
2538 let amendments = result.unwrap();
2539 assert_eq!(amendments.amendments.len(), 1);
2540 assert_eq!(amendments.amendments[0].commit, hash);
2541 assert!(amendments.amendments[0]
2542 .message
2543 .contains("add a.rs and b.rs"));
2544 }
2545
2546 #[tokio::test]
2547 async fn generate_amendments_split_chunk_failure() {
2548 let dir = tempfile::tempdir().unwrap();
2549 let repo_view = make_large_diff_repo_view(&dir);
2550 let hash = "a".repeat(40);
2551
2552 let client = make_small_context_client(vec![
2554 Ok(valid_amendment_yaml(&hash, "feat(a): add a.rs")),
2555 Err(anyhow::anyhow!("rate limit exceeded")),
2556 ]);
2557
2558 let result = client
2559 .generate_amendments_with_options(&repo_view, false)
2560 .await;
2561
2562 assert!(result.is_err());
2563 }
2564
2565 #[tokio::test]
2566 async fn generate_amendments_no_split_when_fits() {
2567 let dir = tempfile::tempdir().unwrap();
2568 let repo_view = make_test_repo_view(&dir); let hash = format!("{:0>40}", 0);
2570
2571 let client = make_configurable_client(vec![Ok(valid_amendment_yaml(
2573 &hash,
2574 "feat(test): improved message",
2575 ))]);
2576
2577 let result = client
2578 .generate_amendments_with_options(&repo_view, false)
2579 .await;
2580
2581 assert!(result.is_ok());
2582 assert_eq!(result.unwrap().amendments.len(), 1);
2583 }
2584
2585 fn valid_check_yaml_for(hash: &str, passes: bool) -> String {
2588 format!(
2589 "checks:\n - commit: \"{hash}\"\n passes: {passes}\n issues: []\n summary: \"test summary\"\n"
2590 )
2591 }
2592
2593 fn valid_check_yaml_with_issues(hash: &str) -> String {
2594 format!(
2595 concat!(
2596 "checks:\n",
2597 " - commit: \"{hash}\"\n",
2598 " passes: false\n",
2599 " issues:\n",
2600 " - severity: error\n",
2601 " section: \"Subject Line\"\n",
2602 " rule: \"imperative-mood\"\n",
2603 " explanation: \"Subject uses past tense\"\n",
2604 " suggestion:\n",
2605 " message: \"feat(test): shorter subject\"\n",
2606 " explanation: \"Shortened subject line\"\n",
2607 " summary: \"Large commit with issues\"\n",
2608 ),
2609 hash = hash,
2610 )
2611 }
2612
2613 fn valid_check_yaml_chunk_no_suggestion(hash: &str) -> String {
2614 format!(
2615 concat!(
2616 "checks:\n",
2617 " - commit: \"{hash}\"\n",
2618 " passes: true\n",
2619 " issues: []\n",
2620 " summary: \"chunk summary\"\n",
2621 ),
2622 hash = hash,
2623 )
2624 }
2625
2626 #[tokio::test]
2627 async fn check_commits_split_dispatch() {
2628 let dir = tempfile::tempdir().unwrap();
2629 let repo_view = make_large_diff_repo_view(&dir);
2630 let hash = "a".repeat(40);
2631
2632 let client = make_small_context_client(vec![
2634 Ok(valid_check_yaml_with_issues(&hash)),
2635 Ok(valid_check_yaml_with_issues(&hash)),
2636 Ok(valid_check_yaml_with_issues(&hash)), ]);
2638
2639 let result = client
2640 .check_commits_with_scopes(&repo_view, None, &[], true)
2641 .await;
2642
2643 assert!(result.is_ok(), "split dispatch failed: {:?}", result.err());
2644 let report = result.unwrap();
2645 assert_eq!(report.commits.len(), 1);
2646 assert!(!report.commits[0].passes);
2647 assert_eq!(report.commits[0].issues.len(), 1);
2649 assert_eq!(report.commits[0].issues[0].rule, "imperative-mood");
2650 }
2651
2652 #[tokio::test]
2653 async fn check_commits_split_dispatch_no_merge_when_no_suggestions() {
2654 let dir = tempfile::tempdir().unwrap();
2655 let repo_view = make_large_diff_repo_view(&dir);
2656 let hash = "a".repeat(40);
2657
2658 let client = make_small_context_client(vec![
2661 Ok(valid_check_yaml_chunk_no_suggestion(&hash)),
2662 Ok(valid_check_yaml_chunk_no_suggestion(&hash)),
2663 ]);
2664
2665 let result = client
2666 .check_commits_with_scopes(&repo_view, None, &[], false)
2667 .await;
2668
2669 assert!(result.is_ok(), "split dispatch failed: {:?}", result.err());
2670 let report = result.unwrap();
2671 assert_eq!(report.commits.len(), 1);
2672 assert!(report.commits[0].passes);
2673 assert!(report.commits[0].issues.is_empty());
2674 assert!(report.commits[0].suggestion.is_none());
2675 assert_eq!(report.commits[0].summary.as_deref(), Some("chunk summary"));
2677 }
2678
2679 #[tokio::test]
2680 async fn check_commits_split_chunk_failure() {
2681 let dir = tempfile::tempdir().unwrap();
2682 let repo_view = make_large_diff_repo_view(&dir);
2683 let hash = "a".repeat(40);
2684
2685 let client = make_small_context_client(vec![
2687 Ok(valid_check_yaml_for(&hash, true)),
2688 Err(anyhow::anyhow!("rate limit exceeded")),
2689 ]);
2690
2691 let result = client
2692 .check_commits_with_scopes(&repo_view, None, &[], false)
2693 .await;
2694
2695 assert!(result.is_err());
2696 }
2697
2698 #[tokio::test]
2699 async fn check_commits_no_split_when_fits() {
2700 let dir = tempfile::tempdir().unwrap();
2701 let repo_view = make_test_repo_view(&dir); let hash = format!("{:0>40}", 0);
2703
2704 let client = make_configurable_client(vec![Ok(valid_check_yaml_for(&hash, true))]);
2706
2707 let result = client
2708 .check_commits_with_scopes(&repo_view, None, &[], false)
2709 .await;
2710
2711 assert!(result.is_ok());
2712 assert_eq!(result.unwrap().commits.len(), 1);
2713 }
2714
2715 #[tokio::test]
2716 async fn check_commits_split_dedup_across_chunks() {
2717 let dir = tempfile::tempdir().unwrap();
2718 let repo_view = make_large_diff_repo_view(&dir);
2719 let hash = "a".repeat(40);
2720
2721 let chunk1 = format!(
2723 concat!(
2724 "checks:\n",
2725 " - commit: \"{hash}\"\n",
2726 " passes: false\n",
2727 " issues:\n",
2728 " - severity: error\n",
2729 " section: \"Subject Line\"\n",
2730 " rule: \"imperative-mood\"\n",
2731 " explanation: \"Subject uses past tense\"\n",
2732 " - severity: warning\n",
2733 " section: \"Content\"\n",
2734 " rule: \"body-required\"\n",
2735 " explanation: \"Large change needs body\"\n",
2736 ),
2737 hash = hash,
2738 );
2739
2740 let chunk2 = format!(
2742 concat!(
2743 "checks:\n",
2744 " - commit: \"{hash}\"\n",
2745 " passes: false\n",
2746 " issues:\n",
2747 " - severity: error\n",
2748 " section: \"Subject Line\"\n",
2749 " rule: \"imperative-mood\"\n",
2750 " explanation: \"Subject line is too long\"\n",
2751 " - severity: info\n",
2752 " section: \"Style\"\n",
2753 " rule: \"scope-suggestion\"\n",
2754 " explanation: \"Consider more specific scope\"\n",
2755 ),
2756 hash = hash,
2757 );
2758
2759 let client = make_small_context_client(vec![Ok(chunk1), Ok(chunk2)]);
2761
2762 let result = client
2763 .check_commits_with_scopes(&repo_view, None, &[], false)
2764 .await;
2765
2766 assert!(result.is_ok(), "split dispatch failed: {:?}", result.err());
2767 let report = result.unwrap();
2768 assert_eq!(report.commits.len(), 1);
2769 assert!(!report.commits[0].passes);
2770 assert_eq!(report.commits[0].issues.len(), 3);
2773 }
2774
2775 #[tokio::test]
2776 async fn check_commits_split_passes_only_when_all_chunks_pass() {
2777 let dir = tempfile::tempdir().unwrap();
2778 let repo_view = make_large_diff_repo_view(&dir);
2779 let hash = "a".repeat(40);
2780
2781 let client = make_small_context_client(vec![
2783 Ok(valid_check_yaml_for(&hash, true)),
2784 Ok(valid_check_yaml_for(&hash, false)),
2785 ]);
2786
2787 let result = client
2788 .check_commits_with_scopes(&repo_view, None, &[], false)
2789 .await;
2790
2791 assert!(result.is_ok(), "split dispatch failed: {:?}", result.err());
2792 let report = result.unwrap();
2793 assert!(
2794 !report.commits[0].passes,
2795 "should fail when any chunk fails"
2796 );
2797 }
2798
2799 fn make_multi_commit_repo_view(dir: &tempfile::TempDir) -> crate::data::RepositoryView {
2803 use crate::data::{AiInfo, FieldExplanation, WorkingDirectoryInfo};
2804 use crate::git::commit::FileChanges;
2805 use crate::git::{CommitAnalysis, CommitInfo};
2806
2807 let diff_a = dir.path().join("0.diff");
2808 let diff_b = dir.path().join("1.diff");
2809 std::fs::write(&diff_a, "+line a\n").unwrap();
2810 std::fs::write(&diff_b, "+line b\n").unwrap();
2811
2812 let hash_a = "a".repeat(40);
2813 let hash_b = "b".repeat(40);
2814
2815 crate::data::RepositoryView {
2816 versions: None,
2817 explanation: FieldExplanation::default(),
2818 working_directory: WorkingDirectoryInfo {
2819 clean: true,
2820 untracked_changes: Vec::new(),
2821 },
2822 remotes: Vec::new(),
2823 ai: AiInfo {
2824 scratch: String::new(),
2825 },
2826 branch_info: None,
2827 pr_template: None,
2828 pr_template_location: None,
2829 branch_prs: None,
2830 commits: vec![
2831 CommitInfo {
2832 hash: hash_a,
2833 author: "Test <test@test.com>".to_string(),
2834 date: chrono::Utc::now().fixed_offset(),
2835 original_message: "feat(a): add a".to_string(),
2836 in_main_branches: Vec::new(),
2837 analysis: CommitAnalysis {
2838 detected_type: "feat".to_string(),
2839 detected_scope: "a".to_string(),
2840 proposed_message: "feat(a): add a".to_string(),
2841 file_changes: FileChanges {
2842 total_files: 1,
2843 files_added: 1,
2844 files_deleted: 0,
2845 file_list: Vec::new(),
2846 },
2847 diff_summary: "a.rs | 1 +".to_string(),
2848 diff_file: diff_a.to_string_lossy().to_string(),
2849 file_diffs: Vec::new(),
2850 },
2851 },
2852 CommitInfo {
2853 hash: hash_b,
2854 author: "Test <test@test.com>".to_string(),
2855 date: chrono::Utc::now().fixed_offset(),
2856 original_message: "feat(b): add b".to_string(),
2857 in_main_branches: Vec::new(),
2858 analysis: CommitAnalysis {
2859 detected_type: "feat".to_string(),
2860 detected_scope: "b".to_string(),
2861 proposed_message: "feat(b): add b".to_string(),
2862 file_changes: FileChanges {
2863 total_files: 1,
2864 files_added: 1,
2865 files_deleted: 0,
2866 file_list: Vec::new(),
2867 },
2868 diff_summary: "b.rs | 1 +".to_string(),
2869 diff_file: diff_b.to_string_lossy().to_string(),
2870 file_diffs: Vec::new(),
2871 },
2872 },
2873 ],
2874 }
2875 }
2876
2877 #[tokio::test]
2878 async fn generate_amendments_multi_commit() {
2879 let dir = tempfile::tempdir().unwrap();
2880 let repo_view = make_multi_commit_repo_view(&dir);
2881 let hash_a = "a".repeat(40);
2882 let hash_b = "b".repeat(40);
2883
2884 let response = format!(
2885 concat!(
2886 "amendments:\n",
2887 " - commit: \"{hash_a}\"\n",
2888 " message: \"feat(a): improved a\"\n",
2889 " - commit: \"{hash_b}\"\n",
2890 " message: \"feat(b): improved b\"\n",
2891 ),
2892 hash_a = hash_a,
2893 hash_b = hash_b,
2894 );
2895 let client = make_configurable_client(vec![Ok(response)]);
2896
2897 let result = client
2898 .generate_amendments_with_options(&repo_view, false)
2899 .await;
2900
2901 assert!(
2902 result.is_ok(),
2903 "multi-commit amendment failed: {:?}",
2904 result.err()
2905 );
2906 let amendments = result.unwrap();
2907 assert_eq!(amendments.amendments.len(), 2);
2908 }
2909
2910 #[tokio::test]
2911 async fn generate_contextual_amendments_multi_commit() {
2912 let dir = tempfile::tempdir().unwrap();
2913 let repo_view = make_multi_commit_repo_view(&dir);
2914 let hash_a = "a".repeat(40);
2915 let hash_b = "b".repeat(40);
2916
2917 let response = format!(
2918 concat!(
2919 "amendments:\n",
2920 " - commit: \"{hash_a}\"\n",
2921 " message: \"feat(a): improved a\"\n",
2922 " - commit: \"{hash_b}\"\n",
2923 " message: \"feat(b): improved b\"\n",
2924 ),
2925 hash_a = hash_a,
2926 hash_b = hash_b,
2927 );
2928 let client = make_configurable_client(vec![Ok(response)]);
2929 let context = crate::data::context::CommitContext::default();
2930
2931 let result = client
2932 .generate_contextual_amendments_with_options(&repo_view, &context, false)
2933 .await;
2934
2935 assert!(
2936 result.is_ok(),
2937 "multi-commit contextual amendment failed: {:?}",
2938 result.err()
2939 );
2940 let amendments = result.unwrap();
2941 assert_eq!(amendments.amendments.len(), 2);
2942 }
2943
2944 #[tokio::test]
2945 async fn generate_pr_content_succeeds() {
2946 let dir = tempfile::tempdir().unwrap();
2947 let repo_view = make_test_repo_view(&dir);
2948
2949 let response = "title: \"feat: add something\"\ndescription: \"Adds a new feature.\"\n";
2950 let client = make_configurable_client(vec![Ok(response.to_string())]);
2951
2952 let result = client.generate_pr_content(&repo_view, "").await;
2953
2954 assert!(result.is_ok(), "PR generation failed: {:?}", result.err());
2955 let pr = result.unwrap();
2956 assert_eq!(pr.title, "feat: add something");
2957 assert_eq!(pr.description, "Adds a new feature.");
2958 }
2959
2960 #[tokio::test]
2961 async fn generate_pr_content_with_context_succeeds() {
2962 let dir = tempfile::tempdir().unwrap();
2963 let repo_view = make_test_repo_view(&dir);
2964 let context = crate::data::context::CommitContext::default();
2965
2966 let response = "title: \"feat: add something\"\ndescription: \"Adds a new feature.\"\n";
2967 let client = make_configurable_client(vec![Ok(response.to_string())]);
2968
2969 let result = client
2970 .generate_pr_content_with_context(&repo_view, "", &context)
2971 .await;
2972
2973 assert!(
2974 result.is_ok(),
2975 "PR generation with context failed: {:?}",
2976 result.err()
2977 );
2978 let pr = result.unwrap();
2979 assert_eq!(pr.title, "feat: add something");
2980 }
2981
2982 #[tokio::test]
2983 async fn check_commits_multi_commit() {
2984 let dir = tempfile::tempdir().unwrap();
2985 let repo_view = make_multi_commit_repo_view(&dir);
2986 let hash_a = "a".repeat(40);
2987 let hash_b = "b".repeat(40);
2988
2989 let response = format!(
2990 concat!(
2991 "checks:\n",
2992 " - commit: \"{hash_a}\"\n",
2993 " passes: true\n",
2994 " issues: []\n",
2995 " - commit: \"{hash_b}\"\n",
2996 " passes: true\n",
2997 " issues: []\n",
2998 ),
2999 hash_a = hash_a,
3000 hash_b = hash_b,
3001 );
3002 let client = make_configurable_client(vec![Ok(response)]);
3003
3004 let result = client
3005 .check_commits_with_scopes(&repo_view, None, &[], false)
3006 .await;
3007
3008 assert!(
3009 result.is_ok(),
3010 "multi-commit check failed: {:?}",
3011 result.err()
3012 );
3013 let report = result.unwrap();
3014 assert_eq!(report.commits.len(), 2);
3015 assert!(report.commits[0].passes);
3016 assert!(report.commits[1].passes);
3017 }
3018
3019 fn make_large_multi_commit_repo_view(dir: &tempfile::TempDir) -> crate::data::RepositoryView {
3024 use crate::data::{AiInfo, FieldExplanation, WorkingDirectoryInfo};
3025 use crate::git::commit::{FileChange, FileChanges, FileDiffRef};
3026 use crate::git::{CommitAnalysis, CommitInfo};
3027
3028 let hash_a = "a".repeat(40);
3029 let hash_b = "b".repeat(40);
3030
3031 let diff_content_a = "x".repeat(60_000);
3034 let diff_content_b = "y".repeat(60_000);
3035 let flat_a = dir.path().join("flat_a.diff");
3036 let flat_b = dir.path().join("flat_b.diff");
3037 std::fs::write(&flat_a, &diff_content_a).unwrap();
3038 std::fs::write(&flat_b, &diff_content_b).unwrap();
3039
3040 let file_diff_a = format!("diff --git a/src/a.rs b/src/a.rs\n{}\n", "a".repeat(30_000));
3042 let file_diff_b = format!("diff --git a/src/b.rs b/src/b.rs\n{}\n", "b".repeat(30_000));
3043 let per_file_a = dir.path().join("pf_a.diff");
3044 let per_file_b = dir.path().join("pf_b.diff");
3045 std::fs::write(&per_file_a, &file_diff_a).unwrap();
3046 std::fs::write(&per_file_b, &file_diff_b).unwrap();
3047
3048 crate::data::RepositoryView {
3049 versions: None,
3050 explanation: FieldExplanation::default(),
3051 working_directory: WorkingDirectoryInfo {
3052 clean: true,
3053 untracked_changes: Vec::new(),
3054 },
3055 remotes: Vec::new(),
3056 ai: AiInfo {
3057 scratch: String::new(),
3058 },
3059 branch_info: None,
3060 pr_template: None,
3061 pr_template_location: None,
3062 branch_prs: None,
3063 commits: vec![
3064 CommitInfo {
3065 hash: hash_a,
3066 author: "Test <test@test.com>".to_string(),
3067 date: chrono::Utc::now().fixed_offset(),
3068 original_message: "feat(a): add module a".to_string(),
3069 in_main_branches: Vec::new(),
3070 analysis: CommitAnalysis {
3071 detected_type: "feat".to_string(),
3072 detected_scope: "a".to_string(),
3073 proposed_message: "feat(a): add module a".to_string(),
3074 file_changes: FileChanges {
3075 total_files: 1,
3076 files_added: 1,
3077 files_deleted: 0,
3078 file_list: vec![FileChange {
3079 status: "A".to_string(),
3080 file: "src/a.rs".to_string(),
3081 }],
3082 },
3083 diff_summary: " src/a.rs | 100 ++++\n".to_string(),
3084 diff_file: flat_a.to_string_lossy().to_string(),
3085 file_diffs: vec![FileDiffRef {
3086 path: "src/a.rs".to_string(),
3087 diff_file: per_file_a.to_string_lossy().to_string(),
3088 byte_len: file_diff_a.len(),
3089 }],
3090 },
3091 },
3092 CommitInfo {
3093 hash: hash_b,
3094 author: "Test <test@test.com>".to_string(),
3095 date: chrono::Utc::now().fixed_offset(),
3096 original_message: "feat(b): add module b".to_string(),
3097 in_main_branches: Vec::new(),
3098 analysis: CommitAnalysis {
3099 detected_type: "feat".to_string(),
3100 detected_scope: "b".to_string(),
3101 proposed_message: "feat(b): add module b".to_string(),
3102 file_changes: FileChanges {
3103 total_files: 1,
3104 files_added: 1,
3105 files_deleted: 0,
3106 file_list: vec![FileChange {
3107 status: "A".to_string(),
3108 file: "src/b.rs".to_string(),
3109 }],
3110 },
3111 diff_summary: " src/b.rs | 100 ++++\n".to_string(),
3112 diff_file: flat_b.to_string_lossy().to_string(),
3113 file_diffs: vec![FileDiffRef {
3114 path: "src/b.rs".to_string(),
3115 diff_file: per_file_b.to_string_lossy().to_string(),
3116 byte_len: file_diff_b.len(),
3117 }],
3118 },
3119 },
3120 ],
3121 }
3122 }
3123
3124 fn valid_pr_yaml(title: &str, description: &str) -> String {
3125 format!("title: \"{title}\"\ndescription: \"{description}\"\n")
3126 }
3127
3128 #[tokio::test]
3131 async fn generate_amendments_multi_commit_split_dispatch() {
3132 let dir = tempfile::tempdir().unwrap();
3133 let repo_view = make_large_multi_commit_repo_view(&dir);
3134 let hash_a = "a".repeat(40);
3135 let hash_b = "b".repeat(40);
3136
3137 let (client, handle) = make_small_context_client_tracked(vec![
3140 Ok(valid_amendment_yaml(&hash_a, "feat(a): improved a")),
3141 Ok(valid_amendment_yaml(&hash_b, "feat(b): improved b")),
3142 ]);
3143
3144 let result = client
3145 .generate_amendments_with_options(&repo_view, false)
3146 .await;
3147
3148 assert!(
3149 result.is_ok(),
3150 "multi-commit split dispatch failed: {:?}",
3151 result.err()
3152 );
3153 let amendments = result.unwrap();
3154 assert_eq!(amendments.amendments.len(), 2);
3155 assert_eq!(amendments.amendments[0].commit, hash_a);
3156 assert_eq!(amendments.amendments[1].commit, hash_b);
3157 assert!(amendments.amendments[0].message.contains("improved a"));
3158 assert!(amendments.amendments[1].message.contains("improved b"));
3159 assert_eq!(handle.remaining(), 0, "expected all responses consumed");
3160 }
3161
3162 #[tokio::test]
3163 async fn generate_contextual_amendments_multi_commit_split_dispatch() {
3164 let dir = tempfile::tempdir().unwrap();
3165 let repo_view = make_large_multi_commit_repo_view(&dir);
3166 let hash_a = "a".repeat(40);
3167 let hash_b = "b".repeat(40);
3168 let context = crate::data::context::CommitContext::default();
3169
3170 let (client, handle) = make_small_context_client_tracked(vec![
3171 Ok(valid_amendment_yaml(&hash_a, "feat(a): improved a")),
3172 Ok(valid_amendment_yaml(&hash_b, "feat(b): improved b")),
3173 ]);
3174
3175 let result = client
3176 .generate_contextual_amendments_with_options(&repo_view, &context, false)
3177 .await;
3178
3179 assert!(
3180 result.is_ok(),
3181 "multi-commit contextual split dispatch failed: {:?}",
3182 result.err()
3183 );
3184 let amendments = result.unwrap();
3185 assert_eq!(amendments.amendments.len(), 2);
3186 assert_eq!(amendments.amendments[0].commit, hash_a);
3187 assert_eq!(amendments.amendments[1].commit, hash_b);
3188 assert_eq!(handle.remaining(), 0, "expected all responses consumed");
3189 }
3190
3191 #[tokio::test]
3194 async fn check_commits_multi_commit_split_dispatch() {
3195 let dir = tempfile::tempdir().unwrap();
3196 let repo_view = make_large_multi_commit_repo_view(&dir);
3197 let hash_a = "a".repeat(40);
3198 let hash_b = "b".repeat(40);
3199
3200 let (client, handle) = make_small_context_client_tracked(vec![
3202 Ok(valid_check_yaml_for(&hash_a, true)),
3203 Ok(valid_check_yaml_for(&hash_b, true)),
3204 ]);
3205
3206 let result = client
3207 .check_commits_with_scopes(&repo_view, None, &[], false)
3208 .await;
3209
3210 assert!(
3211 result.is_ok(),
3212 "multi-commit check split dispatch failed: {:?}",
3213 result.err()
3214 );
3215 let report = result.unwrap();
3216 assert_eq!(report.commits.len(), 2);
3217 assert!(report.commits[0].passes);
3218 assert!(report.commits[1].passes);
3219 assert_eq!(handle.remaining(), 0, "expected all responses consumed");
3220 }
3221
3222 #[tokio::test]
3225 async fn generate_pr_content_split_dispatch() {
3226 let dir = tempfile::tempdir().unwrap();
3227 let repo_view = make_large_diff_repo_view(&dir);
3228
3229 let (client, handle) = make_small_context_client_tracked(vec![
3233 Ok(valid_pr_yaml("feat(a): add a.rs", "Adds a.rs module")),
3234 Ok(valid_pr_yaml("feat(b): add b.rs", "Adds b.rs module")),
3235 Ok(valid_pr_yaml(
3236 "feat(test): add modules",
3237 "Adds a.rs and b.rs",
3238 )),
3239 ]);
3240
3241 let result = client.generate_pr_content(&repo_view, "").await;
3242
3243 assert!(
3244 result.is_ok(),
3245 "PR split dispatch failed: {:?}",
3246 result.err()
3247 );
3248 let pr = result.unwrap();
3249 assert!(pr.title.contains("add modules"));
3250 assert_eq!(handle.remaining(), 0, "expected all responses consumed");
3251 }
3252
3253 #[tokio::test]
3254 async fn generate_pr_content_multi_commit_split_dispatch() {
3255 let dir = tempfile::tempdir().unwrap();
3256 let repo_view = make_large_multi_commit_repo_view(&dir);
3257
3258 let (client, handle) = make_small_context_client_tracked(vec![
3261 Ok(valid_pr_yaml("feat(a): add module a", "Adds module a")),
3262 Ok(valid_pr_yaml("feat(b): add module b", "Adds module b")),
3263 Ok(valid_pr_yaml(
3264 "feat: add modules a and b",
3265 "Adds both modules",
3266 )),
3267 ]);
3268
3269 let result = client.generate_pr_content(&repo_view, "").await;
3270
3271 assert!(
3272 result.is_ok(),
3273 "PR multi-commit split dispatch failed: {:?}",
3274 result.err()
3275 );
3276 let pr = result.unwrap();
3277 assert!(pr.title.contains("modules"));
3278 assert_eq!(handle.remaining(), 0, "expected all responses consumed");
3279 }
3280
3281 #[tokio::test]
3282 async fn generate_pr_content_with_context_split_dispatch() {
3283 let dir = tempfile::tempdir().unwrap();
3284 let repo_view = make_large_multi_commit_repo_view(&dir);
3285 let context = crate::data::context::CommitContext::default();
3286
3287 let (client, handle) = make_small_context_client_tracked(vec![
3289 Ok(valid_pr_yaml("feat(a): add module a", "Adds module a")),
3290 Ok(valid_pr_yaml("feat(b): add module b", "Adds module b")),
3291 Ok(valid_pr_yaml(
3292 "feat: add modules a and b",
3293 "Adds both modules",
3294 )),
3295 ]);
3296
3297 let result = client
3298 .generate_pr_content_with_context(&repo_view, "", &context)
3299 .await;
3300
3301 assert!(
3302 result.is_ok(),
3303 "PR with context split dispatch failed: {:?}",
3304 result.err()
3305 );
3306 let pr = result.unwrap();
3307 assert!(pr.title.contains("modules"));
3308 assert_eq!(handle.remaining(), 0, "expected all responses consumed");
3309 }
3310
3311 fn make_small_context_client_with_prompts(
3316 responses: Vec<Result<String>>,
3317 ) -> (
3318 ClaudeClient,
3319 crate::claude::test_utils::ResponseQueueHandle,
3320 crate::claude::test_utils::PromptRecordHandle,
3321 ) {
3322 let mock = crate::claude::test_utils::ConfigurableMockAiClient::new(responses)
3323 .with_context_length(50_000);
3324 let response_handle = mock.response_handle();
3325 let prompt_handle = mock.prompt_handle();
3326 (
3327 ClaudeClient::new(Box::new(mock)),
3328 response_handle,
3329 prompt_handle,
3330 )
3331 }
3332
3333 fn make_configurable_client_with_prompts(
3335 responses: Vec<Result<String>>,
3336 ) -> (
3337 ClaudeClient,
3338 crate::claude::test_utils::ResponseQueueHandle,
3339 crate::claude::test_utils::PromptRecordHandle,
3340 ) {
3341 let mock = crate::claude::test_utils::ConfigurableMockAiClient::new(responses);
3342 let response_handle = mock.response_handle();
3343 let prompt_handle = mock.prompt_handle();
3344 (
3345 ClaudeClient::new(Box::new(mock)),
3346 response_handle,
3347 prompt_handle,
3348 )
3349 }
3350
3351 fn make_single_oversized_file_repo_view(
3358 dir: &tempfile::TempDir,
3359 ) -> crate::data::RepositoryView {
3360 use crate::data::{AiInfo, FieldExplanation, WorkingDirectoryInfo};
3361 use crate::git::commit::{FileChange, FileChanges, FileDiffRef};
3362 use crate::git::{CommitAnalysis, CommitInfo};
3363
3364 let hash = "c".repeat(40);
3365
3366 let diff_content = format!(
3369 "diff --git a/src/big.rs b/src/big.rs\n{}\n",
3370 "x".repeat(80_000)
3371 );
3372
3373 let flat_diff_path = dir.path().join("full.diff");
3374 std::fs::write(&flat_diff_path, &diff_content).unwrap();
3375
3376 let per_file_path = dir.path().join("0000.diff");
3377 std::fs::write(&per_file_path, &diff_content).unwrap();
3378
3379 crate::data::RepositoryView {
3380 versions: None,
3381 explanation: FieldExplanation::default(),
3382 working_directory: WorkingDirectoryInfo {
3383 clean: true,
3384 untracked_changes: Vec::new(),
3385 },
3386 remotes: Vec::new(),
3387 ai: AiInfo {
3388 scratch: String::new(),
3389 },
3390 branch_info: None,
3391 pr_template: None,
3392 pr_template_location: None,
3393 branch_prs: None,
3394 commits: vec![CommitInfo {
3395 hash,
3396 author: "Test <test@test.com>".to_string(),
3397 date: chrono::Utc::now().fixed_offset(),
3398 original_message: "feat(big): add large module".to_string(),
3399 in_main_branches: Vec::new(),
3400 analysis: CommitAnalysis {
3401 detected_type: "feat".to_string(),
3402 detected_scope: "big".to_string(),
3403 proposed_message: "feat(big): add large module".to_string(),
3404 file_changes: FileChanges {
3405 total_files: 1,
3406 files_added: 1,
3407 files_deleted: 0,
3408 file_list: vec![FileChange {
3409 status: "A".to_string(),
3410 file: "src/big.rs".to_string(),
3411 }],
3412 },
3413 diff_summary: " src/big.rs | 80 ++++\n".to_string(),
3414 diff_file: flat_diff_path.to_string_lossy().to_string(),
3415 file_diffs: vec![FileDiffRef {
3416 path: "src/big.rs".to_string(),
3417 diff_file: per_file_path.to_string_lossy().to_string(),
3418 byte_len: diff_content.len(),
3419 }],
3420 },
3421 }],
3422 }
3423 }
3424
3425 #[tokio::test]
3432 async fn amendment_single_file_under_budget_no_split() {
3433 let dir = tempfile::tempdir().unwrap();
3434 let repo_view = make_test_repo_view(&dir);
3435 let hash = format!("{:0>40}", 0);
3436
3437 let (client, response_handle, prompt_handle) =
3438 make_configurable_client_with_prompts(vec![Ok(valid_amendment_yaml(
3439 &hash,
3440 "feat(test): improved message",
3441 ))]);
3442
3443 let result = client
3444 .generate_amendments_with_options(&repo_view, false)
3445 .await;
3446
3447 assert!(result.is_ok());
3448 assert_eq!(result.unwrap().amendments.len(), 1);
3449 assert_eq!(response_handle.remaining(), 0);
3450
3451 let prompts = prompt_handle.prompts();
3452 assert_eq!(
3453 prompts.len(),
3454 1,
3455 "expected exactly one AI request, no split"
3456 );
3457
3458 let (_, user_prompt) = &prompts[0];
3459 assert!(
3460 user_prompt.contains("added line"),
3461 "user prompt should contain the diff content"
3462 );
3463 }
3464
3465 #[tokio::test]
3476 async fn amendment_two_chunks_prompt_content() {
3477 let dir = tempfile::tempdir().unwrap();
3478 let repo_view = make_large_diff_repo_view(&dir);
3479 let hash = "a".repeat(40);
3480
3481 let (client, response_handle, prompt_handle) =
3482 make_small_context_client_with_prompts(vec![
3483 Ok(valid_amendment_yaml(&hash, "feat(a): add a.rs")),
3484 Ok(valid_amendment_yaml(&hash, "feat(b): add b.rs")),
3485 Ok(valid_amendment_yaml(&hash, "feat(test): add a.rs and b.rs")),
3486 ]);
3487
3488 let result = client
3489 .generate_amendments_with_options(&repo_view, false)
3490 .await;
3491
3492 assert!(result.is_ok(), "split dispatch failed: {:?}", result.err());
3493 let amendments = result.unwrap();
3494 assert_eq!(amendments.amendments.len(), 1);
3495 assert!(amendments.amendments[0]
3496 .message
3497 .contains("add a.rs and b.rs"));
3498 assert_eq!(response_handle.remaining(), 0);
3499
3500 let prompts = prompt_handle.prompts();
3501 assert_eq!(prompts.len(), 3, "expected 2 chunks + 1 merge = 3 requests");
3502
3503 let (_, chunk1_user) = &prompts[0];
3505 assert!(
3506 chunk1_user.contains("aaa"),
3507 "chunk 1 prompt should contain file-a diff content"
3508 );
3509
3510 let (_, chunk2_user) = &prompts[1];
3512 assert!(
3513 chunk2_user.contains("bbb"),
3514 "chunk 2 prompt should contain file-b diff content"
3515 );
3516
3517 let (merge_sys, merge_user) = &prompts[2];
3519 assert!(
3520 merge_sys.contains("synthesiz"),
3521 "merge system prompt should contain synthesis instructions"
3522 );
3523 assert!(
3525 merge_user.contains("feat(a): add a.rs") && merge_user.contains("feat(b): add b.rs"),
3526 "merge user prompt should contain both partial amendment messages"
3527 );
3528 }
3529
3530 #[tokio::test]
3542 async fn amendment_single_oversized_file_gets_placeholder() {
3543 let dir = tempfile::tempdir().unwrap();
3544 let repo_view = make_single_oversized_file_repo_view(&dir);
3545 let hash = "c".repeat(40);
3546
3547 let (client, _, prompt_handle) = make_small_context_client_with_prompts(vec![
3552 Ok(valid_amendment_yaml(&hash, "feat(big): add large module")),
3553 Ok(valid_amendment_yaml(&hash, "feat(big): add large module")),
3554 ]);
3555
3556 let result = client
3557 .generate_amendments_with_options(&repo_view, false)
3558 .await;
3559
3560 assert!(
3562 result.is_ok(),
3563 "expected success with placeholder, got: {result:?}"
3564 );
3565
3566 assert!(
3568 prompt_handle.request_count() >= 1,
3569 "expected at least 1 request, got {}",
3570 prompt_handle.request_count()
3571 );
3572 }
3573
3574 #[tokio::test]
3583 async fn amendment_chunk_failure_stops_dispatch() {
3584 let dir = tempfile::tempdir().unwrap();
3585 let repo_view = make_large_diff_repo_view(&dir);
3586 let hash = "a".repeat(40);
3587
3588 let (client, _, prompt_handle) = make_small_context_client_with_prompts(vec![
3590 Ok(valid_amendment_yaml(&hash, "feat(a): add a.rs")),
3591 Err(anyhow::anyhow!("rate limit exceeded")),
3592 ]);
3593
3594 let result = client
3595 .generate_amendments_with_options(&repo_view, false)
3596 .await;
3597
3598 assert!(result.is_err());
3599
3600 let prompts = prompt_handle.prompts();
3602 assert_eq!(
3603 prompts.len(),
3604 2,
3605 "should stop after the failing chunk, got {} requests",
3606 prompts.len()
3607 );
3608
3609 let (_, first_user) = &prompts[0];
3611 assert!(
3612 first_user.contains("src/a.rs") || first_user.contains("src/b.rs"),
3613 "first chunk prompt should reference a file"
3614 );
3615 }
3616
3617 #[tokio::test]
3628 async fn amendment_reduce_pass_prompt_content() {
3629 let dir = tempfile::tempdir().unwrap();
3630 let repo_view = make_large_diff_repo_view(&dir);
3631 let hash = "a".repeat(40);
3632
3633 let (client, _, prompt_handle) = make_small_context_client_with_prompts(vec![
3634 Ok(valid_amendment_yaml(
3635 &hash,
3636 "feat(a): add module a implementation",
3637 )),
3638 Ok(valid_amendment_yaml(
3639 &hash,
3640 "feat(b): add module b implementation",
3641 )),
3642 Ok(valid_amendment_yaml(
3643 &hash,
3644 "feat(test): add modules a and b",
3645 )),
3646 ]);
3647
3648 let result = client
3649 .generate_amendments_with_options(&repo_view, false)
3650 .await;
3651
3652 assert!(result.is_ok());
3653
3654 let prompts = prompt_handle.prompts();
3655 assert_eq!(prompts.len(), 3);
3656
3657 let (merge_system, merge_user) = &prompts[2];
3659
3660 assert!(
3662 merge_system.contains("synthesiz"),
3663 "merge system prompt should contain synthesis instructions"
3664 );
3665
3666 assert!(
3668 merge_user.contains("feat(a): add module a implementation"),
3669 "merge user prompt should contain chunk 1's partial message"
3670 );
3671 assert!(
3672 merge_user.contains("feat(b): add module b implementation"),
3673 "merge user prompt should contain chunk 2's partial message"
3674 );
3675
3676 assert!(
3678 merge_user.contains("feat(test): large commit"),
3679 "merge user prompt should contain the original commit message"
3680 );
3681
3682 assert!(
3684 merge_user.contains("src/a.rs") && merge_user.contains("src/b.rs"),
3685 "merge user prompt should contain the diff_summary"
3686 );
3687
3688 assert!(
3690 merge_user.contains(&hash),
3691 "merge user prompt should reference the commit hash"
3692 );
3693 }
3694
3695 #[tokio::test]
3712 async fn check_split_dedup_and_merge_prompt() {
3713 let dir = tempfile::tempdir().unwrap();
3714 let repo_view = make_large_diff_repo_view(&dir);
3715 let hash = "a".repeat(40);
3716
3717 let chunk1_yaml = format!(
3719 concat!(
3720 "checks:\n",
3721 " - commit: \"{hash}\"\n",
3722 " passes: false\n",
3723 " issues:\n",
3724 " - severity: error\n",
3725 " section: \"Subject Line\"\n",
3726 " rule: \"imperative-mood\"\n",
3727 " explanation: \"Subject uses past tense\"\n",
3728 " - severity: warning\n",
3729 " section: \"Content\"\n",
3730 " rule: \"body-required\"\n",
3731 " explanation: \"Large change needs body\"\n",
3732 " suggestion:\n",
3733 " message: \"feat(a): shorter subject for a\"\n",
3734 " explanation: \"Shortened subject for file a\"\n",
3735 " summary: \"Adds module a\"\n",
3736 ),
3737 hash = hash,
3738 );
3739
3740 let chunk2_yaml = format!(
3742 concat!(
3743 "checks:\n",
3744 " - commit: \"{hash}\"\n",
3745 " passes: false\n",
3746 " issues:\n",
3747 " - severity: error\n",
3748 " section: \"Subject Line\"\n",
3749 " rule: \"imperative-mood\"\n",
3750 " explanation: \"Subject line is way too long\"\n",
3751 " - severity: info\n",
3752 " section: \"Style\"\n",
3753 " rule: \"scope-suggestion\"\n",
3754 " explanation: \"Consider more specific scope\"\n",
3755 " suggestion:\n",
3756 " message: \"feat(b): shorter subject for b\"\n",
3757 " explanation: \"Shortened subject for file b\"\n",
3758 " summary: \"Adds module b\"\n",
3759 ),
3760 hash = hash,
3761 );
3762
3763 let merge_yaml = format!(
3765 concat!(
3766 "checks:\n",
3767 " - commit: \"{hash}\"\n",
3768 " passes: false\n",
3769 " issues: []\n",
3770 " suggestion:\n",
3771 " message: \"feat(test): add modules a and b\"\n",
3772 " explanation: \"Combined suggestion\"\n",
3773 " summary: \"Adds modules a and b\"\n",
3774 ),
3775 hash = hash,
3776 );
3777
3778 let (client, response_handle, prompt_handle) =
3779 make_small_context_client_with_prompts(vec![
3780 Ok(chunk1_yaml),
3781 Ok(chunk2_yaml),
3782 Ok(merge_yaml),
3783 ]);
3784
3785 let result = client
3786 .check_commits_with_scopes(&repo_view, None, &[], true)
3787 .await;
3788
3789 assert!(result.is_ok(), "split dispatch failed: {:?}", result.err());
3790 let report = result.unwrap();
3791 assert_eq!(report.commits.len(), 1);
3792 assert!(!report.commits[0].passes);
3793 assert_eq!(response_handle.remaining(), 0);
3794
3795 assert_eq!(
3800 report.commits[0].issues.len(),
3801 3,
3802 "expected 3 unique issues after dedup, got {:?}",
3803 report.commits[0]
3804 .issues
3805 .iter()
3806 .map(|i| &i.rule)
3807 .collect::<Vec<_>>()
3808 );
3809
3810 assert!(report.commits[0].suggestion.is_some());
3812 assert!(
3813 report.commits[0]
3814 .suggestion
3815 .as_ref()
3816 .unwrap()
3817 .message
3818 .contains("add modules a and b"),
3819 "suggestion should come from the merge pass"
3820 );
3821
3822 let prompts = prompt_handle.prompts();
3824 assert_eq!(prompts.len(), 3, "expected 2 chunks + 1 merge");
3825
3826 let (_, chunk1_user) = &prompts[0];
3828 let (_, chunk2_user) = &prompts[1];
3829 let combined_chunk_prompts = format!("{chunk1_user}{chunk2_user}");
3830 assert!(
3831 combined_chunk_prompts.contains("src/a.rs")
3832 && combined_chunk_prompts.contains("src/b.rs"),
3833 "chunk prompts should collectively cover both files"
3834 );
3835
3836 let (merge_sys, merge_user) = &prompts[2];
3838 assert!(
3839 merge_sys.contains("synthesiz") || merge_sys.contains("reviewer"),
3840 "merge system prompt should be the check chunk merge prompt"
3841 );
3842 assert!(
3843 merge_user.contains("feat(a): shorter subject for a")
3844 && merge_user.contains("feat(b): shorter subject for b"),
3845 "merge user prompt should contain both partial suggestions"
3846 );
3847 assert!(
3849 merge_user.contains("src/a.rs") && merge_user.contains("src/b.rs"),
3850 "merge user prompt should contain the diff_summary"
3851 );
3852 }
3853
3854 #[tokio::test]
3857 async fn amendment_retry_parse_failure_then_success() {
3858 let dir = tempfile::tempdir().unwrap();
3859 let repo_view = make_test_repo_view(&dir);
3860 let hash = format!("{:0>40}", 0);
3861
3862 let (client, response_handle, prompt_handle) = make_configurable_client_with_prompts(vec![
3863 Ok("not valid yaml {{[".to_string()),
3864 Ok(valid_amendment_yaml(&hash, "feat(test): improved")),
3865 ]);
3866
3867 let result = client
3868 .generate_amendments_with_options(&repo_view, false)
3869 .await;
3870
3871 assert!(
3872 result.is_ok(),
3873 "should succeed after retry: {:?}",
3874 result.err()
3875 );
3876 assert_eq!(result.unwrap().amendments.len(), 1);
3877 assert_eq!(response_handle.remaining(), 0, "both responses consumed");
3878 assert_eq!(prompt_handle.request_count(), 2, "exactly 2 AI requests");
3879 }
3880
3881 #[tokio::test]
3882 async fn amendment_retry_request_failure_then_success() {
3883 let dir = tempfile::tempdir().unwrap();
3884 let repo_view = make_test_repo_view(&dir);
3885 let hash = format!("{:0>40}", 0);
3886
3887 let (client, response_handle, prompt_handle) = make_configurable_client_with_prompts(vec![
3888 Err(anyhow::anyhow!("rate limit")),
3889 Ok(valid_amendment_yaml(&hash, "feat(test): improved")),
3890 ]);
3891
3892 let result = client
3893 .generate_amendments_with_options(&repo_view, false)
3894 .await;
3895
3896 assert!(
3897 result.is_ok(),
3898 "should succeed after retry: {:?}",
3899 result.err()
3900 );
3901 assert_eq!(result.unwrap().amendments.len(), 1);
3902 assert_eq!(response_handle.remaining(), 0);
3903 assert_eq!(prompt_handle.request_count(), 2);
3904 }
3905
3906 #[tokio::test]
3907 async fn amendment_retry_all_attempts_exhausted() {
3908 let dir = tempfile::tempdir().unwrap();
3909 let repo_view = make_test_repo_view(&dir);
3910
3911 let (client, response_handle, prompt_handle) = make_configurable_client_with_prompts(vec![
3912 Ok("bad yaml 1".to_string()),
3913 Ok("bad yaml 2".to_string()),
3914 Ok("bad yaml 3".to_string()),
3915 ]);
3916
3917 let result = client
3918 .generate_amendments_with_options(&repo_view, false)
3919 .await;
3920
3921 assert!(result.is_err(), "should fail after all retries exhausted");
3922 assert_eq!(response_handle.remaining(), 0, "all 3 responses consumed");
3923 assert_eq!(
3924 prompt_handle.request_count(),
3925 3,
3926 "exactly 3 AI requests (1 + 2 retries)"
3927 );
3928 }
3929
3930 #[tokio::test]
3931 async fn amendment_retry_success_first_attempt() {
3932 let dir = tempfile::tempdir().unwrap();
3933 let repo_view = make_test_repo_view(&dir);
3934 let hash = format!("{:0>40}", 0);
3935
3936 let (client, response_handle, prompt_handle) =
3937 make_configurable_client_with_prompts(vec![Ok(valid_amendment_yaml(
3938 &hash,
3939 "feat(test): works first time",
3940 ))]);
3941
3942 let result = client
3943 .generate_amendments_with_options(&repo_view, false)
3944 .await;
3945
3946 assert!(result.is_ok());
3947 assert_eq!(response_handle.remaining(), 0);
3948 assert_eq!(prompt_handle.request_count(), 1, "only 1 request, no retry");
3949 }
3950
3951 #[tokio::test]
3952 async fn amendment_retry_mixed_request_and_parse_failures() {
3953 let dir = tempfile::tempdir().unwrap();
3954 let repo_view = make_test_repo_view(&dir);
3955 let hash = format!("{:0>40}", 0);
3956
3957 let (client, response_handle, prompt_handle) = make_configurable_client_with_prompts(vec![
3958 Err(anyhow::anyhow!("network error")),
3959 Ok("invalid yaml {{".to_string()),
3960 Ok(valid_amendment_yaml(&hash, "feat(test): third time")),
3961 ]);
3962
3963 let result = client
3964 .generate_amendments_with_options(&repo_view, false)
3965 .await;
3966
3967 assert!(
3968 result.is_ok(),
3969 "should succeed on third attempt: {:?}",
3970 result.err()
3971 );
3972 assert_eq!(result.unwrap().amendments.len(), 1);
3973 assert_eq!(response_handle.remaining(), 0);
3974 assert_eq!(prompt_handle.request_count(), 3, "all 3 attempts used");
3975 }
3976
3977 static FACTORY_ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
3981
3982 struct FactoryEnvGuard {
3983 _lock: std::sync::MutexGuard<'static, ()>,
3984 saved: Vec<(&'static str, Option<String>)>,
3985 }
3986
3987 impl FactoryEnvGuard {
3988 fn new(keys: &[&'static str]) -> Self {
3989 let lock = FACTORY_ENV_LOCK
3990 .lock()
3991 .unwrap_or_else(std::sync::PoisonError::into_inner);
3992 let saved = keys.iter().map(|k| (*k, std::env::var(k).ok())).collect();
3993 for k in keys {
3994 std::env::remove_var(k);
3995 }
3996 Self { _lock: lock, saved }
3997 }
3998
3999 fn set(&self, key: &str, value: &str) {
4000 std::env::set_var(key, value);
4001 }
4002 }
4003
4004 impl Drop for FactoryEnvGuard {
4005 fn drop(&mut self) {
4006 for (k, v) in self.saved.drain(..) {
4007 match v {
4008 Some(val) => std::env::set_var(k, val),
4009 None => std::env::remove_var(k),
4010 }
4011 }
4012 }
4013 }
4014
4015 #[tokio::test]
4016 async fn factory_claude_cli_backend_dispatches_to_claude_cli_client() {
4017 let guard = FactoryEnvGuard::new(&[
4018 "OMNI_DEV_AI_BACKEND",
4019 "USE_OPENAI",
4020 "USE_OLLAMA",
4021 "CLAUDE_CODE_USE_BEDROCK",
4022 "CLAUDE_MODEL",
4023 "CLAUDE_CODE_MODEL",
4024 "ANTHROPIC_MODEL",
4025 ]);
4026 guard.set("OMNI_DEV_AI_BACKEND", "claude-cli");
4027
4028 let client = create_default_claude_client(None, None)
4029 .await
4030 .expect("factory should succeed");
4031 let metadata = client.get_ai_client_metadata();
4032 assert_eq!(metadata.provider, "Claude CLI");
4033 assert_eq!(metadata.model, "claude-sonnet-4-6");
4035 }
4036
4037 #[tokio::test]
4038 async fn factory_claude_cli_backend_honours_model_precedence() {
4039 let guard = FactoryEnvGuard::new(&[
4040 "OMNI_DEV_AI_BACKEND",
4041 "USE_OPENAI",
4042 "USE_OLLAMA",
4043 "CLAUDE_CODE_USE_BEDROCK",
4044 "CLAUDE_MODEL",
4045 "CLAUDE_CODE_MODEL",
4046 "ANTHROPIC_MODEL",
4047 ]);
4048 guard.set("OMNI_DEV_AI_BACKEND", "claude-cli");
4049 guard.set("CLAUDE_CODE_MODEL", "opus");
4050 guard.set("CLAUDE_MODEL", "haiku");
4052
4053 let client = create_default_claude_client(None, None)
4054 .await
4055 .expect("factory should succeed");
4056 let metadata = client.get_ai_client_metadata();
4057 assert_eq!(metadata.provider, "Claude CLI");
4058 assert_eq!(metadata.model, "haiku");
4059 }
4060
4061 #[tokio::test]
4062 async fn factory_claude_cli_backend_explicit_model_wins_over_env() {
4063 let guard = FactoryEnvGuard::new(&[
4064 "OMNI_DEV_AI_BACKEND",
4065 "USE_OPENAI",
4066 "USE_OLLAMA",
4067 "CLAUDE_CODE_USE_BEDROCK",
4068 "CLAUDE_MODEL",
4069 "CLAUDE_CODE_MODEL",
4070 "ANTHROPIC_MODEL",
4071 ]);
4072 guard.set("OMNI_DEV_AI_BACKEND", "claude-cli");
4073 guard.set("CLAUDE_MODEL", "haiku");
4074
4075 let client = create_default_claude_client(Some("opus".to_string()), None)
4076 .await
4077 .expect("factory should succeed");
4078 let metadata = client.get_ai_client_metadata();
4079 assert_eq!(metadata.model, "opus");
4080 }
4081
4082 #[tokio::test]
4083 async fn factory_claude_cli_backend_accepts_underscore_alias() {
4084 let guard = FactoryEnvGuard::new(&[
4085 "OMNI_DEV_AI_BACKEND",
4086 "USE_OPENAI",
4087 "USE_OLLAMA",
4088 "CLAUDE_CODE_USE_BEDROCK",
4089 "CLAUDE_MODEL",
4090 "CLAUDE_CODE_MODEL",
4091 "ANTHROPIC_MODEL",
4092 ]);
4093 guard.set("OMNI_DEV_AI_BACKEND", "claude_cli");
4094
4095 let client = create_default_claude_client(None, None)
4096 .await
4097 .expect("factory should succeed");
4098 let metadata = client.get_ai_client_metadata();
4099 assert_eq!(metadata.provider, "Claude CLI");
4100 }
4101
4102 #[tokio::test]
4103 async fn factory_ollama_branch_probes_loaded_context_length() {
4104 use wiremock::matchers::{method, path};
4105 use wiremock::{Mock, MockServer, ResponseTemplate};
4106
4107 let server = MockServer::start().await;
4108 Mock::given(method("GET"))
4109 .and(path("/api/v0/models"))
4110 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
4111 "data": [
4112 { "id": "lm-loaded", "state": "loaded", "loaded_context_length": 6144_u64 }
4113 ]
4114 })))
4115 .mount(&server)
4116 .await;
4117
4118 let guard = FactoryEnvGuard::new(&[
4119 "OMNI_DEV_AI_BACKEND",
4120 "USE_OPENAI",
4121 "USE_OLLAMA",
4122 "CLAUDE_CODE_USE_BEDROCK",
4123 "OLLAMA_BASE_URL",
4124 "OLLAMA_MODEL",
4125 ]);
4126 guard.set("USE_OLLAMA", "true");
4127 guard.set("OLLAMA_BASE_URL", &server.uri());
4128 guard.set("OLLAMA_MODEL", "lm-loaded");
4129
4130 let client = create_default_claude_client(None, None)
4131 .await
4132 .expect("factory should succeed");
4133 let metadata = client.get_ai_client_metadata();
4134 assert_eq!(metadata.provider, "Ollama");
4135 assert_eq!(metadata.model, "lm-loaded");
4136 assert_eq!(metadata.max_context_length, 6144);
4138 }
4139
4140 #[tokio::test]
4141 async fn factory_ollama_branch_falls_back_when_probe_fails() {
4142 use wiremock::matchers::{method, path};
4143 use wiremock::{Mock, MockServer, ResponseTemplate};
4144
4145 let server = MockServer::start().await;
4146 Mock::given(method("GET"))
4147 .and(path("/api/v0/models"))
4148 .respond_with(ResponseTemplate::new(500))
4149 .mount(&server)
4150 .await;
4151 Mock::given(method("POST"))
4152 .and(path("/api/show"))
4153 .respond_with(ResponseTemplate::new(500))
4154 .mount(&server)
4155 .await;
4156
4157 let guard = FactoryEnvGuard::new(&[
4158 "OMNI_DEV_AI_BACKEND",
4159 "USE_OPENAI",
4160 "USE_OLLAMA",
4161 "CLAUDE_CODE_USE_BEDROCK",
4162 "OLLAMA_BASE_URL",
4163 "OLLAMA_MODEL",
4164 ]);
4165 guard.set("USE_OLLAMA", "true");
4166 guard.set("OLLAMA_BASE_URL", &server.uri());
4167 guard.set("OLLAMA_MODEL", "no-such-model");
4168
4169 let client = create_default_claude_client(None, None)
4170 .await
4171 .expect("factory should succeed");
4172 let metadata = client.get_ai_client_metadata();
4173 let registry_value =
4176 crate::claude::model_config::get_model_registry().get_input_context("no-such-model");
4177 assert_eq!(metadata.max_context_length, registry_value);
4178 }
4179
4180 #[tokio::test]
4184 async fn factory_ollama_branch_probes_via_ollama_native() {
4185 use wiremock::matchers::{method, path};
4186 use wiremock::{Mock, MockServer, ResponseTemplate};
4187
4188 let server = MockServer::start().await;
4189 Mock::given(method("GET"))
4190 .and(path("/api/v0/models"))
4191 .respond_with(ResponseTemplate::new(404))
4192 .mount(&server)
4193 .await;
4194 Mock::given(method("POST"))
4195 .and(path("/api/show"))
4196 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
4197 "model_info": { "llama.context_length": 12288_u64 }
4198 })))
4199 .mount(&server)
4200 .await;
4201
4202 let guard = FactoryEnvGuard::new(&[
4203 "OMNI_DEV_AI_BACKEND",
4204 "USE_OPENAI",
4205 "USE_OLLAMA",
4206 "CLAUDE_CODE_USE_BEDROCK",
4207 "OLLAMA_BASE_URL",
4208 "OLLAMA_MODEL",
4209 ]);
4210 guard.set("USE_OLLAMA", "true");
4211 guard.set("OLLAMA_BASE_URL", &server.uri());
4212 guard.set("OLLAMA_MODEL", "ollama-native-model");
4213
4214 let client = create_default_claude_client(None, None)
4215 .await
4216 .expect("factory should succeed");
4217 let metadata = client.get_ai_client_metadata();
4218 assert_eq!(metadata.max_context_length, 12288);
4219 }
4220}