1use anyhow::{Context, Result};
10use async_trait::async_trait;
11use reqwest::Client;
12use secrecy::SecretString;
13use tracing::{debug, instrument};
14
15use super::AiResponse;
16use super::types::{
17 ChatCompletionRequest, ChatCompletionResponse, ChatMessage, IssueDetails, ResponseFormat,
18 TriageResponse,
19};
20use crate::history::AiStats;
21
22use super::prompts::{
23 build_create_system_prompt, build_pr_label_system_prompt, build_pr_review_system_prompt,
24 build_release_notes_system_prompt, build_triage_system_prompt,
25};
26
27fn parse_ai_json<T: serde::de::DeserializeOwned>(text: &str, provider: &str) -> Result<T> {
42 match serde_json::from_str::<T>(text) {
43 Ok(value) => Ok(value),
44 Err(e) => {
45 if e.is_eof() {
47 Err(anyhow::anyhow!(
48 crate::error::AptuError::TruncatedResponse {
49 provider: provider.to_string(),
50 }
51 ))
52 } else {
53 Err(anyhow::anyhow!(crate::error::AptuError::InvalidAIResponse(
54 e
55 )))
56 }
57 }
58 }
59}
60
61pub const MAX_BODY_LENGTH: usize = 4000;
63
64pub const MAX_COMMENTS: usize = 5;
66
67pub const MAX_FILES: usize = 20;
69
70pub const MAX_TOTAL_DIFF_SIZE: usize = 50_000;
72
73pub const MAX_LABELS: usize = 30;
75
76pub const MAX_MILESTONES: usize = 10;
78
79#[async_trait]
84pub trait AiProvider: Send + Sync {
85 fn name(&self) -> &str;
87
88 fn api_url(&self) -> &str;
90
91 fn api_key_env(&self) -> &str;
93
94 fn http_client(&self) -> &Client;
96
97 fn api_key(&self) -> &SecretString;
99
100 fn model(&self) -> &str;
102
103 fn max_tokens(&self) -> u32;
105
106 fn temperature(&self) -> f32;
108
109 fn max_attempts(&self) -> u32 {
114 3
115 }
116
117 fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
122 None
123 }
124
125 fn build_headers(&self) -> reqwest::header::HeaderMap {
130 let mut headers = reqwest::header::HeaderMap::new();
131 if let Ok(val) = "application/json".parse() {
132 headers.insert("Content-Type", val);
133 }
134 headers
135 }
136
137 fn validate_model(&self) -> Result<()> {
142 Ok(())
143 }
144
145 fn custom_guidance(&self) -> Option<&str> {
150 None
151 }
152
153 #[instrument(skip(self, request), fields(provider = self.name(), model = self.model()))]
158 async fn send_request_inner(
159 &self,
160 request: &ChatCompletionRequest,
161 ) -> Result<ChatCompletionResponse> {
162 use secrecy::ExposeSecret;
163 use tracing::warn;
164
165 use crate::error::AptuError;
166
167 let mut req = self.http_client().post(self.api_url());
168
169 req = req.header(
171 "Authorization",
172 format!("Bearer {}", self.api_key().expose_secret()),
173 );
174
175 for (key, value) in &self.build_headers() {
177 req = req.header(key.clone(), value.clone());
178 }
179
180 let response = req
181 .json(request)
182 .send()
183 .await
184 .context(format!("Failed to send request to {} API", self.name()))?;
185
186 let status = response.status();
188 if !status.is_success() {
189 if status.as_u16() == 401 {
190 anyhow::bail!(
191 "Invalid {} API key. Check your {} environment variable.",
192 self.name(),
193 self.api_key_env()
194 );
195 } else if status.as_u16() == 429 {
196 warn!("Rate limited by {} API", self.name());
197 let retry_after = response
199 .headers()
200 .get("Retry-After")
201 .and_then(|h| h.to_str().ok())
202 .and_then(|s| s.parse::<u64>().ok())
203 .unwrap_or(0);
204 debug!(retry_after, "Parsed Retry-After header");
205 return Err(AptuError::RateLimited {
206 provider: self.name().to_string(),
207 retry_after,
208 }
209 .into());
210 }
211 let error_body = response.text().await.unwrap_or_default();
212 anyhow::bail!(
213 "{} API error (HTTP {}): {}",
214 self.name(),
215 status.as_u16(),
216 error_body
217 );
218 }
219
220 let completion: ChatCompletionResponse = response
222 .json()
223 .await
224 .context(format!("Failed to parse {} API response", self.name()))?;
225
226 Ok(completion)
227 }
228
229 #[instrument(skip(self, request), fields(provider = self.name(), model = self.model()))]
248 async fn send_and_parse<T: serde::de::DeserializeOwned + Send>(
249 &self,
250 request: &ChatCompletionRequest,
251 ) -> Result<(T, AiStats)> {
252 use tracing::{info, warn};
253
254 use crate::error::AptuError;
255 use crate::retry::{extract_retry_after, is_retryable_anyhow};
256
257 if let Some(cb) = self.circuit_breaker()
259 && cb.is_open()
260 {
261 return Err(AptuError::CircuitOpen.into());
262 }
263
264 let start = std::time::Instant::now();
266
267 let mut attempt: u32 = 0;
269 let max_attempts: u32 = self.max_attempts();
270
271 #[allow(clippy::items_after_statements)]
273 async fn try_request<T: serde::de::DeserializeOwned>(
274 provider: &(impl AiProvider + ?Sized),
275 request: &ChatCompletionRequest,
276 ) -> Result<(T, ChatCompletionResponse)> {
277 let completion = provider.send_request_inner(request).await?;
279
280 let content = completion
282 .choices
283 .first()
284 .map(|c| c.message.content.clone())
285 .context("No response from AI model")?;
286
287 debug!(response_length = content.len(), "Received AI response");
288
289 let parsed: T = parse_ai_json(&content, provider.name())?;
291
292 Ok((parsed, completion))
293 }
294
295 let (parsed, completion): (T, ChatCompletionResponse) = loop {
296 attempt += 1;
297
298 let result = try_request(self, request).await;
299
300 match result {
301 Ok(success) => break success,
302 Err(err) => {
303 if !is_retryable_anyhow(&err) || attempt >= max_attempts {
305 return Err(err);
306 }
307
308 let delay = if let Some(retry_after_duration) = extract_retry_after(&err) {
310 debug!(
311 retry_after_secs = retry_after_duration.as_secs(),
312 "Using Retry-After value from rate limit error"
313 );
314 retry_after_duration
315 } else {
316 let backoff_secs = 2_u64.pow(attempt.saturating_sub(1));
318 let jitter_ms = fastrand::u64(0..500);
319 std::time::Duration::from_millis(backoff_secs * 1000 + jitter_ms)
320 };
321
322 let error_msg = err.to_string();
323 warn!(
324 error = %error_msg,
325 delay_secs = delay.as_secs(),
326 attempt,
327 max_attempts,
328 "Retrying after error"
329 );
330
331 drop(err);
333 tokio::time::sleep(delay).await;
334 }
335 }
336 };
337
338 if let Some(cb) = self.circuit_breaker() {
340 cb.record_success();
341 }
342
343 #[allow(clippy::cast_possible_truncation)]
345 let duration_ms = start.elapsed().as_millis() as u64;
346
347 let (input_tokens, output_tokens, cost_usd) = if let Some(usage) = completion.usage {
349 (usage.prompt_tokens, usage.completion_tokens, usage.cost)
350 } else {
351 debug!("No usage information in API response");
353 (0, 0, None)
354 };
355
356 let ai_stats = AiStats {
357 provider: self.name().to_string(),
358 model: self.model().to_string(),
359 input_tokens,
360 output_tokens,
361 duration_ms,
362 cost_usd,
363 fallback_provider: None,
364 };
365
366 info!(
368 duration_ms,
369 input_tokens,
370 output_tokens,
371 cost_usd = ?cost_usd,
372 model = %self.model(),
373 "AI request completed"
374 );
375
376 Ok((parsed, ai_stats))
377 }
378
379 #[instrument(skip(self, issue), fields(issue_number = issue.number, repo = %format!("{}/{}", issue.owner, issue.repo)))]
393 async fn analyze_issue(&self, issue: &IssueDetails) -> Result<AiResponse> {
394 debug!(model = %self.model(), "Calling {} API", self.name());
395
396 let system_content = if let Some(override_prompt) =
398 super::context::load_system_prompt_override("triage_system").await
399 {
400 override_prompt
401 } else {
402 Self::build_system_prompt(self.custom_guidance())
403 };
404
405 let request = ChatCompletionRequest {
406 model: self.model().to_string(),
407 messages: vec![
408 ChatMessage {
409 role: "system".to_string(),
410 content: system_content,
411 },
412 ChatMessage {
413 role: "user".to_string(),
414 content: Self::build_user_prompt(issue),
415 },
416 ],
417 response_format: Some(ResponseFormat {
418 format_type: "json_object".to_string(),
419 json_schema: None,
420 }),
421 max_tokens: Some(self.max_tokens()),
422 temperature: Some(self.temperature()),
423 };
424
425 let (triage, ai_stats) = self.send_and_parse::<TriageResponse>(&request).await?;
427
428 debug!(
429 input_tokens = ai_stats.input_tokens,
430 output_tokens = ai_stats.output_tokens,
431 duration_ms = ai_stats.duration_ms,
432 cost_usd = ?ai_stats.cost_usd,
433 "AI analysis complete"
434 );
435
436 Ok(AiResponse {
437 triage,
438 stats: ai_stats,
439 })
440 }
441
442 #[instrument(skip(self), fields(repo = %repo))]
459 async fn create_issue(
460 &self,
461 title: &str,
462 body: &str,
463 repo: &str,
464 ) -> Result<(super::types::CreateIssueResponse, AiStats)> {
465 debug!(model = %self.model(), "Calling {} API for issue creation", self.name());
466
467 let system_content = if let Some(override_prompt) =
469 super::context::load_system_prompt_override("create_system").await
470 {
471 override_prompt
472 } else {
473 Self::build_create_system_prompt(self.custom_guidance())
474 };
475
476 let request = ChatCompletionRequest {
477 model: self.model().to_string(),
478 messages: vec![
479 ChatMessage {
480 role: "system".to_string(),
481 content: system_content,
482 },
483 ChatMessage {
484 role: "user".to_string(),
485 content: Self::build_create_user_prompt(title, body, repo),
486 },
487 ],
488 response_format: Some(ResponseFormat {
489 format_type: "json_object".to_string(),
490 json_schema: None,
491 }),
492 max_tokens: Some(self.max_tokens()),
493 temperature: Some(self.temperature()),
494 };
495
496 let (create_response, ai_stats) = self
498 .send_and_parse::<super::types::CreateIssueResponse>(&request)
499 .await?;
500
501 debug!(
502 title_len = create_response.formatted_title.len(),
503 body_len = create_response.formatted_body.len(),
504 labels = create_response.suggested_labels.len(),
505 input_tokens = ai_stats.input_tokens,
506 output_tokens = ai_stats.output_tokens,
507 duration_ms = ai_stats.duration_ms,
508 "Issue formatting complete with stats"
509 );
510
511 Ok((create_response, ai_stats))
512 }
513
514 #[must_use]
516 fn build_system_prompt(custom_guidance: Option<&str>) -> String {
517 let context = super::context::load_custom_guidance(custom_guidance);
518 build_triage_system_prompt(&context)
519 }
520
521 #[must_use]
523 fn build_user_prompt(issue: &IssueDetails) -> String {
524 use std::fmt::Write;
525
526 let mut prompt = String::new();
527
528 prompt.push_str("<issue_content>\n");
529 let _ = writeln!(prompt, "Title: {}\n", issue.title);
530
531 let body = if issue.body.len() > MAX_BODY_LENGTH {
533 format!(
534 "{}...\n[Body truncated - original length: {} chars]",
535 &issue.body[..MAX_BODY_LENGTH],
536 issue.body.len()
537 )
538 } else if issue.body.is_empty() {
539 "[No description provided]".to_string()
540 } else {
541 issue.body.clone()
542 };
543 let _ = writeln!(prompt, "Body:\n{body}\n");
544
545 if !issue.labels.is_empty() {
547 let _ = writeln!(prompt, "Existing Labels: {}\n", issue.labels.join(", "));
548 }
549
550 if !issue.comments.is_empty() {
552 prompt.push_str("Recent Comments:\n");
553 for comment in issue.comments.iter().take(MAX_COMMENTS) {
554 let comment_body = if comment.body.len() > 500 {
555 format!("{}...", &comment.body[..500])
556 } else {
557 comment.body.clone()
558 };
559 let _ = writeln!(prompt, "- @{}: {}", comment.author, comment_body);
560 }
561 prompt.push('\n');
562 }
563
564 if !issue.repo_context.is_empty() {
566 prompt.push_str("Related Issues in Repository (for context):\n");
567 for related in issue.repo_context.iter().take(10) {
568 let _ = writeln!(
569 prompt,
570 "- #{} [{}] {}",
571 related.number, related.state, related.title
572 );
573 }
574 prompt.push('\n');
575 }
576
577 if !issue.repo_tree.is_empty() {
579 prompt.push_str("Repository Structure (source files):\n");
580 for path in issue.repo_tree.iter().take(20) {
581 let _ = writeln!(prompt, "- {path}");
582 }
583 prompt.push('\n');
584 }
585
586 if !issue.available_labels.is_empty() {
588 prompt.push_str("Available Labels:\n");
589 for label in issue.available_labels.iter().take(MAX_LABELS) {
590 let description = if label.description.is_empty() {
591 String::new()
592 } else {
593 format!(" - {}", label.description)
594 };
595 let _ = writeln!(
596 prompt,
597 "- {} (color: #{}){}",
598 label.name, label.color, description
599 );
600 }
601 prompt.push('\n');
602 }
603
604 if !issue.available_milestones.is_empty() {
606 prompt.push_str("Available Milestones:\n");
607 for milestone in issue.available_milestones.iter().take(MAX_MILESTONES) {
608 let description = if milestone.description.is_empty() {
609 String::new()
610 } else {
611 format!(" - {}", milestone.description)
612 };
613 let _ = writeln!(prompt, "- {}{}", milestone.title, description);
614 }
615 prompt.push('\n');
616 }
617
618 prompt.push_str("</issue_content>");
619
620 prompt
621 }
622
623 #[must_use]
625 fn build_create_system_prompt(custom_guidance: Option<&str>) -> String {
626 let context = super::context::load_custom_guidance(custom_guidance);
627 build_create_system_prompt(&context)
628 }
629
630 #[must_use]
632 fn build_create_user_prompt(title: &str, body: &str, _repo: &str) -> String {
633 format!("Please format this GitHub issue:\n\nTitle: {title}\n\nBody:\n{body}")
634 }
635
636 #[instrument(skip(self, pr), fields(pr_number = pr.number, repo = %format!("{}/{}", pr.owner, pr.repo)))]
650 async fn review_pr(
651 &self,
652 pr: &super::types::PrDetails,
653 ) -> Result<(super::types::PrReviewResponse, AiStats)> {
654 debug!(model = %self.model(), "Calling {} API for PR review", self.name());
655
656 let system_content = if let Some(override_prompt) =
658 super::context::load_system_prompt_override("pr_review_system").await
659 {
660 override_prompt
661 } else {
662 Self::build_pr_review_system_prompt(self.custom_guidance())
663 };
664
665 let request = ChatCompletionRequest {
666 model: self.model().to_string(),
667 messages: vec![
668 ChatMessage {
669 role: "system".to_string(),
670 content: system_content,
671 },
672 ChatMessage {
673 role: "user".to_string(),
674 content: Self::build_pr_review_user_prompt(pr),
675 },
676 ],
677 response_format: Some(ResponseFormat {
678 format_type: "json_object".to_string(),
679 json_schema: None,
680 }),
681 max_tokens: Some(self.max_tokens()),
682 temperature: Some(self.temperature()),
683 };
684
685 let (review, ai_stats) = self
687 .send_and_parse::<super::types::PrReviewResponse>(&request)
688 .await?;
689
690 debug!(
691 verdict = %review.verdict,
692 input_tokens = ai_stats.input_tokens,
693 output_tokens = ai_stats.output_tokens,
694 duration_ms = ai_stats.duration_ms,
695 "PR review complete with stats"
696 );
697
698 Ok((review, ai_stats))
699 }
700
701 #[instrument(skip(self), fields(title = %title))]
717 async fn suggest_pr_labels(
718 &self,
719 title: &str,
720 body: &str,
721 file_paths: &[String],
722 ) -> Result<(Vec<String>, AiStats)> {
723 debug!(model = %self.model(), "Calling {} API for PR label suggestion", self.name());
724
725 let system_content = if let Some(override_prompt) =
727 super::context::load_system_prompt_override("pr_label_system").await
728 {
729 override_prompt
730 } else {
731 Self::build_pr_label_system_prompt(self.custom_guidance())
732 };
733
734 let request = ChatCompletionRequest {
735 model: self.model().to_string(),
736 messages: vec![
737 ChatMessage {
738 role: "system".to_string(),
739 content: system_content,
740 },
741 ChatMessage {
742 role: "user".to_string(),
743 content: Self::build_pr_label_user_prompt(title, body, file_paths),
744 },
745 ],
746 response_format: Some(ResponseFormat {
747 format_type: "json_object".to_string(),
748 json_schema: None,
749 }),
750 max_tokens: Some(self.max_tokens()),
751 temperature: Some(self.temperature()),
752 };
753
754 let (response, ai_stats) = self
756 .send_and_parse::<super::types::PrLabelResponse>(&request)
757 .await?;
758
759 debug!(
760 label_count = response.suggested_labels.len(),
761 input_tokens = ai_stats.input_tokens,
762 output_tokens = ai_stats.output_tokens,
763 duration_ms = ai_stats.duration_ms,
764 "PR label suggestion complete with stats"
765 );
766
767 Ok((response.suggested_labels, ai_stats))
768 }
769
770 #[must_use]
772 fn build_pr_review_system_prompt(custom_guidance: Option<&str>) -> String {
773 let context = super::context::load_custom_guidance(custom_guidance);
774 build_pr_review_system_prompt(&context)
775 }
776
777 #[must_use]
779 fn build_pr_review_user_prompt(pr: &super::types::PrDetails) -> String {
780 use std::fmt::Write;
781
782 let mut prompt = String::new();
783
784 prompt.push_str("<pull_request>\n");
785 let _ = writeln!(prompt, "Title: {}\n", pr.title);
786 let _ = writeln!(prompt, "Branch: {} -> {}\n", pr.head_branch, pr.base_branch);
787
788 let body = if pr.body.is_empty() {
790 "[No description provided]".to_string()
791 } else if pr.body.len() > MAX_BODY_LENGTH {
792 format!(
793 "{}...\n[Description truncated - original length: {} chars]",
794 &pr.body[..MAX_BODY_LENGTH],
795 pr.body.len()
796 )
797 } else {
798 pr.body.clone()
799 };
800 let _ = writeln!(prompt, "Description:\n{body}\n");
801
802 prompt.push_str("Files Changed:\n");
804 let mut total_diff_size = 0;
805 let mut files_included = 0;
806 let mut files_skipped = 0;
807
808 for file in &pr.files {
809 if files_included >= MAX_FILES {
811 files_skipped += 1;
812 continue;
813 }
814
815 let _ = writeln!(
816 prompt,
817 "- {} ({}) +{} -{}\n",
818 file.filename, file.status, file.additions, file.deletions
819 );
820
821 if let Some(patch) = &file.patch {
823 const MAX_PATCH_LENGTH: usize = 2000;
824 let patch_content = if patch.len() > MAX_PATCH_LENGTH {
825 format!(
826 "{}...\n[Patch truncated - original length: {} chars]",
827 &patch[..MAX_PATCH_LENGTH],
828 patch.len()
829 )
830 } else {
831 patch.clone()
832 };
833
834 let patch_size = patch_content.len();
836 if total_diff_size + patch_size > MAX_TOTAL_DIFF_SIZE {
837 let _ = writeln!(
838 prompt,
839 "```diff\n[Patch omitted - total diff size limit reached]\n```\n"
840 );
841 files_skipped += 1;
842 continue;
843 }
844
845 let _ = writeln!(prompt, "```diff\n{patch_content}\n```\n");
846 total_diff_size += patch_size;
847 }
848
849 files_included += 1;
850 }
851
852 if files_skipped > 0 {
854 let _ = writeln!(
855 prompt,
856 "\n[{files_skipped} files omitted due to size limits (MAX_FILES={MAX_FILES}, MAX_TOTAL_DIFF_SIZE={MAX_TOTAL_DIFF_SIZE})]"
857 );
858 }
859
860 prompt.push_str("</pull_request>");
861
862 prompt
863 }
864
865 #[must_use]
867 fn build_pr_label_system_prompt(custom_guidance: Option<&str>) -> String {
868 let context = super::context::load_custom_guidance(custom_guidance);
869 build_pr_label_system_prompt(&context)
870 }
871
872 #[must_use]
874 fn build_pr_label_user_prompt(title: &str, body: &str, file_paths: &[String]) -> String {
875 use std::fmt::Write;
876
877 let mut prompt = String::new();
878
879 prompt.push_str("<pull_request>\n");
880 let _ = writeln!(prompt, "Title: {title}\n");
881
882 let body_content = if body.is_empty() {
884 "[No description provided]".to_string()
885 } else if body.len() > MAX_BODY_LENGTH {
886 format!(
887 "{}...\n[Description truncated - original length: {} chars]",
888 &body[..MAX_BODY_LENGTH],
889 body.len()
890 )
891 } else {
892 body.to_string()
893 };
894 let _ = writeln!(prompt, "Description:\n{body_content}\n");
895
896 if !file_paths.is_empty() {
898 prompt.push_str("Files Changed:\n");
899 for path in file_paths.iter().take(20) {
900 let _ = writeln!(prompt, "- {path}");
901 }
902 if file_paths.len() > 20 {
903 let _ = writeln!(prompt, "- ... and {} more files", file_paths.len() - 20);
904 }
905 prompt.push('\n');
906 }
907
908 prompt.push_str("</pull_request>");
909
910 prompt
911 }
912
913 #[instrument(skip(self, prs))]
924 async fn generate_release_notes(
925 &self,
926 prs: Vec<super::types::PrSummary>,
927 version: &str,
928 ) -> Result<(super::types::ReleaseNotesResponse, AiStats)> {
929 let system_content = if let Some(override_prompt) =
930 super::context::load_system_prompt_override("release_notes_system").await
931 {
932 override_prompt
933 } else {
934 let context = super::context::load_custom_guidance(self.custom_guidance());
935 build_release_notes_system_prompt(&context)
936 };
937 let prompt = Self::build_release_notes_prompt(&prs, version);
938 let request = ChatCompletionRequest {
939 model: self.model().to_string(),
940 messages: vec![
941 ChatMessage {
942 role: "system".to_string(),
943 content: system_content,
944 },
945 ChatMessage {
946 role: "user".to_string(),
947 content: prompt,
948 },
949 ],
950 response_format: Some(ResponseFormat {
951 format_type: "json_object".to_string(),
952 json_schema: None,
953 }),
954 temperature: Some(0.7),
955 max_tokens: Some(self.max_tokens()),
956 };
957
958 let (parsed, ai_stats) = self
959 .send_and_parse::<super::types::ReleaseNotesResponse>(&request)
960 .await?;
961
962 debug!(
963 input_tokens = ai_stats.input_tokens,
964 output_tokens = ai_stats.output_tokens,
965 duration_ms = ai_stats.duration_ms,
966 "Release notes generation complete with stats"
967 );
968
969 Ok((parsed, ai_stats))
970 }
971
972 #[must_use]
974 fn build_release_notes_prompt(prs: &[super::types::PrSummary], version: &str) -> String {
975 let pr_list = prs
976 .iter()
977 .map(|pr| {
978 format!(
979 "- #{}: {} (by @{})\n {}",
980 pr.number,
981 pr.title,
982 pr.author,
983 pr.body.lines().next().unwrap_or("")
984 )
985 })
986 .collect::<Vec<_>>()
987 .join("\n");
988
989 format!(
990 "Generate release notes for version {version} based on these merged PRs:\n\n{pr_list}"
991 )
992 }
993}
994
995#[cfg(test)]
996mod tests {
997 use super::*;
998
999 #[derive(Debug, serde::Deserialize)]
1002 struct ErrorTestResponse {
1003 _message: String,
1004 }
1005
1006 struct TestProvider;
1007
1008 impl AiProvider for TestProvider {
1009 fn name(&self) -> &'static str {
1010 "test"
1011 }
1012
1013 fn api_url(&self) -> &'static str {
1014 "https://test.example.com"
1015 }
1016
1017 fn api_key_env(&self) -> &'static str {
1018 "TEST_API_KEY"
1019 }
1020
1021 fn http_client(&self) -> &Client {
1022 unimplemented!()
1023 }
1024
1025 fn api_key(&self) -> &SecretString {
1026 unimplemented!()
1027 }
1028
1029 fn model(&self) -> &'static str {
1030 "test-model"
1031 }
1032
1033 fn max_tokens(&self) -> u32 {
1034 2048
1035 }
1036
1037 fn temperature(&self) -> f32 {
1038 0.3
1039 }
1040 }
1041
1042 #[test]
1043 fn test_build_system_prompt_contains_json_schema() {
1044 let prompt = TestProvider::build_system_prompt(None);
1045 assert!(prompt.contains("summary"));
1046 assert!(prompt.contains("suggested_labels"));
1047 assert!(prompt.contains("clarifying_questions"));
1048 assert!(prompt.contains("potential_duplicates"));
1049 assert!(prompt.contains("status_note"));
1050 }
1051
1052 #[test]
1053 fn test_build_user_prompt_with_delimiters() {
1054 let issue = IssueDetails::builder()
1055 .owner("test".to_string())
1056 .repo("repo".to_string())
1057 .number(1)
1058 .title("Test issue".to_string())
1059 .body("This is the body".to_string())
1060 .labels(vec!["bug".to_string()])
1061 .comments(vec![])
1062 .url("https://github.com/test/repo/issues/1".to_string())
1063 .build();
1064
1065 let prompt = TestProvider::build_user_prompt(&issue);
1066 assert!(prompt.starts_with("<issue_content>"));
1067 assert!(prompt.ends_with("</issue_content>"));
1068 assert!(prompt.contains("Title: Test issue"));
1069 assert!(prompt.contains("This is the body"));
1070 assert!(prompt.contains("Existing Labels: bug"));
1071 }
1072
1073 #[test]
1074 fn test_build_user_prompt_truncates_long_body() {
1075 let long_body = "x".repeat(5000);
1076 let issue = IssueDetails::builder()
1077 .owner("test".to_string())
1078 .repo("repo".to_string())
1079 .number(1)
1080 .title("Test".to_string())
1081 .body(long_body)
1082 .labels(vec![])
1083 .comments(vec![])
1084 .url("https://github.com/test/repo/issues/1".to_string())
1085 .build();
1086
1087 let prompt = TestProvider::build_user_prompt(&issue);
1088 assert!(prompt.contains("[Body truncated"));
1089 assert!(prompt.contains("5000 chars"));
1090 }
1091
1092 #[test]
1093 fn test_build_user_prompt_empty_body() {
1094 let issue = IssueDetails::builder()
1095 .owner("test".to_string())
1096 .repo("repo".to_string())
1097 .number(1)
1098 .title("Test".to_string())
1099 .body(String::new())
1100 .labels(vec![])
1101 .comments(vec![])
1102 .url("https://github.com/test/repo/issues/1".to_string())
1103 .build();
1104
1105 let prompt = TestProvider::build_user_prompt(&issue);
1106 assert!(prompt.contains("[No description provided]"));
1107 }
1108
1109 #[test]
1110 fn test_build_create_system_prompt_contains_json_schema() {
1111 let prompt = TestProvider::build_create_system_prompt(None);
1112 assert!(prompt.contains("formatted_title"));
1113 assert!(prompt.contains("formatted_body"));
1114 assert!(prompt.contains("suggested_labels"));
1115 }
1116
1117 #[test]
1118 fn test_build_pr_review_user_prompt_respects_file_limit() {
1119 use super::super::types::{PrDetails, PrFile};
1120
1121 let mut files = Vec::new();
1122 for i in 0..25 {
1123 files.push(PrFile {
1124 filename: format!("file{i}.rs"),
1125 status: "modified".to_string(),
1126 additions: 10,
1127 deletions: 5,
1128 patch: Some(format!("patch content {i}")),
1129 });
1130 }
1131
1132 let pr = PrDetails {
1133 owner: "test".to_string(),
1134 repo: "repo".to_string(),
1135 number: 1,
1136 title: "Test PR".to_string(),
1137 body: "Description".to_string(),
1138 head_branch: "feature".to_string(),
1139 base_branch: "main".to_string(),
1140 url: "https://github.com/test/repo/pull/1".to_string(),
1141 files,
1142 labels: vec![],
1143 head_sha: String::new(),
1144 };
1145
1146 let prompt = TestProvider::build_pr_review_user_prompt(&pr);
1147 assert!(prompt.contains("files omitted due to size limits"));
1148 assert!(prompt.contains("MAX_FILES=20"));
1149 }
1150
1151 #[test]
1152 fn test_build_pr_review_user_prompt_respects_diff_size_limit() {
1153 use super::super::types::{PrDetails, PrFile};
1154
1155 let patch1 = "x".repeat(30_000);
1158 let patch2 = "y".repeat(30_000);
1159
1160 let files = vec![
1161 PrFile {
1162 filename: "file1.rs".to_string(),
1163 status: "modified".to_string(),
1164 additions: 100,
1165 deletions: 50,
1166 patch: Some(patch1),
1167 },
1168 PrFile {
1169 filename: "file2.rs".to_string(),
1170 status: "modified".to_string(),
1171 additions: 100,
1172 deletions: 50,
1173 patch: Some(patch2),
1174 },
1175 ];
1176
1177 let pr = PrDetails {
1178 owner: "test".to_string(),
1179 repo: "repo".to_string(),
1180 number: 1,
1181 title: "Test PR".to_string(),
1182 body: "Description".to_string(),
1183 head_branch: "feature".to_string(),
1184 base_branch: "main".to_string(),
1185 url: "https://github.com/test/repo/pull/1".to_string(),
1186 files,
1187 labels: vec![],
1188 head_sha: String::new(),
1189 };
1190
1191 let prompt = TestProvider::build_pr_review_user_prompt(&pr);
1192 assert!(prompt.contains("file1.rs"));
1194 assert!(prompt.contains("file2.rs"));
1195 assert!(prompt.len() < 65_000);
1198 }
1199
1200 #[test]
1201 fn test_build_pr_review_user_prompt_with_no_patches() {
1202 use super::super::types::{PrDetails, PrFile};
1203
1204 let files = vec![PrFile {
1205 filename: "file1.rs".to_string(),
1206 status: "added".to_string(),
1207 additions: 10,
1208 deletions: 0,
1209 patch: None,
1210 }];
1211
1212 let pr = PrDetails {
1213 owner: "test".to_string(),
1214 repo: "repo".to_string(),
1215 number: 1,
1216 title: "Test PR".to_string(),
1217 body: "Description".to_string(),
1218 head_branch: "feature".to_string(),
1219 base_branch: "main".to_string(),
1220 url: "https://github.com/test/repo/pull/1".to_string(),
1221 files,
1222 labels: vec![],
1223 head_sha: String::new(),
1224 };
1225
1226 let prompt = TestProvider::build_pr_review_user_prompt(&pr);
1227 assert!(prompt.contains("file1.rs"));
1228 assert!(prompt.contains("added"));
1229 assert!(!prompt.contains("files omitted"));
1230 }
1231
1232 #[test]
1233 fn test_build_pr_label_system_prompt_contains_json_schema() {
1234 let prompt = TestProvider::build_pr_label_system_prompt(None);
1235 assert!(prompt.contains("suggested_labels"));
1236 assert!(prompt.contains("json_object"));
1237 assert!(prompt.contains("bug"));
1238 assert!(prompt.contains("enhancement"));
1239 }
1240
1241 #[test]
1242 fn test_build_pr_label_user_prompt_with_title_and_body() {
1243 let title = "feat: add new feature";
1244 let body = "This PR adds a new feature";
1245 let files = vec!["src/main.rs".to_string(), "tests/test.rs".to_string()];
1246
1247 let prompt = TestProvider::build_pr_label_user_prompt(title, body, &files);
1248 assert!(prompt.starts_with("<pull_request>"));
1249 assert!(prompt.ends_with("</pull_request>"));
1250 assert!(prompt.contains("feat: add new feature"));
1251 assert!(prompt.contains("This PR adds a new feature"));
1252 assert!(prompt.contains("src/main.rs"));
1253 assert!(prompt.contains("tests/test.rs"));
1254 }
1255
1256 #[test]
1257 fn test_build_pr_label_user_prompt_empty_body() {
1258 let title = "fix: bug fix";
1259 let body = "";
1260 let files = vec!["src/lib.rs".to_string()];
1261
1262 let prompt = TestProvider::build_pr_label_user_prompt(title, body, &files);
1263 assert!(prompt.contains("[No description provided]"));
1264 assert!(prompt.contains("src/lib.rs"));
1265 }
1266
1267 #[test]
1268 fn test_build_pr_label_user_prompt_truncates_long_body() {
1269 let title = "test";
1270 let long_body = "x".repeat(5000);
1271 let files = vec![];
1272
1273 let prompt = TestProvider::build_pr_label_user_prompt(title, &long_body, &files);
1274 assert!(prompt.contains("[Description truncated"));
1275 assert!(prompt.contains("5000 chars"));
1276 }
1277
1278 #[test]
1279 fn test_build_pr_label_user_prompt_respects_file_limit() {
1280 let title = "test";
1281 let body = "test";
1282 let mut files = Vec::new();
1283 for i in 0..25 {
1284 files.push(format!("file{i}.rs"));
1285 }
1286
1287 let prompt = TestProvider::build_pr_label_user_prompt(title, body, &files);
1288 assert!(prompt.contains("file0.rs"));
1289 assert!(prompt.contains("file19.rs"));
1290 assert!(!prompt.contains("file20.rs"));
1291 assert!(prompt.contains("... and 5 more files"));
1292 }
1293
1294 #[test]
1295 fn test_build_pr_label_user_prompt_empty_files() {
1296 let title = "test";
1297 let body = "test";
1298 let files: Vec<String> = vec![];
1299
1300 let prompt = TestProvider::build_pr_label_user_prompt(title, body, &files);
1301 assert!(prompt.contains("Title: test"));
1302 assert!(prompt.contains("Description:\ntest"));
1303 assert!(!prompt.contains("Files Changed:"));
1304 }
1305
1306 #[test]
1307 fn test_parse_ai_json_with_valid_json() {
1308 #[derive(serde::Deserialize)]
1309 struct TestResponse {
1310 message: String,
1311 }
1312
1313 let json = r#"{"message": "hello"}"#;
1314 let result: Result<TestResponse> = parse_ai_json(json, "test-provider");
1315 assert!(result.is_ok());
1316 let response = result.unwrap();
1317 assert_eq!(response.message, "hello");
1318 }
1319
1320 #[test]
1321 fn test_parse_ai_json_with_truncated_json() {
1322 let json = r#"{"message": "hello"#;
1323 let result: Result<ErrorTestResponse> = parse_ai_json(json, "test-provider");
1324 assert!(result.is_err());
1325 let err = result.unwrap_err();
1326 assert!(
1327 err.to_string()
1328 .contains("Truncated response from test-provider")
1329 );
1330 }
1331
1332 #[test]
1333 fn test_parse_ai_json_with_malformed_json() {
1334 let json = r#"{"message": invalid}"#;
1335 let result: Result<ErrorTestResponse> = parse_ai_json(json, "test-provider");
1336 assert!(result.is_err());
1337 let err = result.unwrap_err();
1338 assert!(err.to_string().contains("Invalid JSON response from AI"));
1339 }
1340
1341 #[test]
1342 fn test_build_system_prompt_has_senior_persona() {
1343 let prompt = TestProvider::build_system_prompt(None);
1344 assert!(
1345 prompt.contains("You are a senior"),
1346 "prompt should have senior persona"
1347 );
1348 assert!(
1349 prompt.contains("Your mission is"),
1350 "prompt should have mission statement"
1351 );
1352 }
1353
1354 #[test]
1355 fn test_build_system_prompt_has_cot_directive() {
1356 let prompt = TestProvider::build_system_prompt(None);
1357 assert!(prompt.contains("Reason through each step before producing output."));
1358 }
1359
1360 #[test]
1361 fn test_build_system_prompt_has_examples_section() {
1362 let prompt = TestProvider::build_system_prompt(None);
1363 assert!(prompt.contains("## Examples"));
1364 }
1365
1366 #[test]
1367 fn test_build_create_system_prompt_has_senior_persona() {
1368 let prompt = TestProvider::build_create_system_prompt(None);
1369 assert!(
1370 prompt.contains("You are a senior"),
1371 "prompt should have senior persona"
1372 );
1373 assert!(
1374 prompt.contains("Your mission is"),
1375 "prompt should have mission statement"
1376 );
1377 }
1378
1379 #[test]
1380 fn test_build_pr_review_system_prompt_has_senior_persona() {
1381 let prompt = TestProvider::build_pr_review_system_prompt(None);
1382 assert!(
1383 prompt.contains("You are a senior"),
1384 "prompt should have senior persona"
1385 );
1386 assert!(
1387 prompt.contains("Your mission is"),
1388 "prompt should have mission statement"
1389 );
1390 }
1391
1392 #[test]
1393 fn test_build_pr_label_system_prompt_has_senior_persona() {
1394 let prompt = TestProvider::build_pr_label_system_prompt(None);
1395 assert!(
1396 prompt.contains("You are a senior"),
1397 "prompt should have senior persona"
1398 );
1399 assert!(
1400 prompt.contains("Your mission is"),
1401 "prompt should have mission statement"
1402 );
1403 }
1404
1405 #[tokio::test]
1406 async fn test_load_system_prompt_override_returns_none_when_absent() {
1407 let result =
1408 super::super::context::load_system_prompt_override("__nonexistent_test_override__")
1409 .await;
1410 assert!(result.is_none());
1411 }
1412
1413 #[tokio::test]
1414 async fn test_load_system_prompt_override_returns_content_when_present() {
1415 use std::io::Write;
1416 let dir = tempfile::tempdir().expect("create tempdir");
1417 let file_path = dir.path().join("test_override.md");
1418 let mut f = std::fs::File::create(&file_path).expect("create file");
1419 writeln!(f, "Custom override content").expect("write file");
1420 drop(f);
1421
1422 let content = tokio::fs::read_to_string(&file_path).await.ok();
1423 assert_eq!(content.as_deref(), Some("Custom override content\n"));
1424 }
1425}