Skip to main content

aptu_core/ai/
provider.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! AI provider trait and shared implementations.
4//!
5//! Defines the `AiProvider` trait that all AI providers must implement,
6//! along with default implementations for shared logic like prompt building,
7//! request sending, and response parsing.
8
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use regex::Regex;
12use reqwest::Client;
13use secrecy::SecretString;
14use std::sync::LazyLock;
15use tracing::{debug, instrument};
16
17use super::AiResponse;
18use super::types::{
19    ChatCompletionRequest, ChatCompletionResponse, ChatMessage, IssueDetails, ResponseFormat,
20    TriageResponse,
21};
22use crate::history::AiStats;
23
24use super::prompts::{
25    build_create_system_prompt, build_pr_label_system_prompt, build_pr_review_system_prompt,
26    build_triage_system_prompt,
27};
28
29/// Maximum number of characters retained from an AI provider error response body.
30const MAX_ERROR_BODY_LENGTH: usize = 200;
31
32/// Redacts error body to prevent leaking sensitive API details.
33/// Truncates to [`MAX_ERROR_BODY_LENGTH`] characters and appends "[truncated]" if longer.
34fn redact_api_error_body(body: &str) -> String {
35    if body.chars().count() <= MAX_ERROR_BODY_LENGTH {
36        body.to_owned()
37    } else {
38        let truncated: String = body.chars().take(MAX_ERROR_BODY_LENGTH).collect();
39        format!("{truncated} [truncated]")
40    }
41}
42
43/// Parses JSON response from AI provider, detecting truncated responses.
44///
45/// If the JSON parsing fails with an EOF error (indicating the response was cut off),
46/// returns a `TruncatedResponse` error that can be retried. Other JSON errors are
47/// wrapped as `InvalidAIResponse`.
48///
49/// # Arguments
50///
51/// * `text` - The JSON text to parse
52/// * `provider` - The name of the AI provider (for error context)
53///
54/// # Returns
55///
56/// Parsed value of type T, or an error if parsing fails
57fn parse_ai_json<T: serde::de::DeserializeOwned>(text: &str, provider: &str) -> Result<T> {
58    match serde_json::from_str::<T>(text) {
59        Ok(value) => Ok(value),
60        Err(e) => {
61            // Check if this is an EOF error (truncated response)
62            if e.is_eof() {
63                Err(anyhow::anyhow!(
64                    crate::error::AptuError::TruncatedResponse {
65                        provider: provider.to_string(),
66                    }
67                ))
68            } else {
69                Err(anyhow::anyhow!(crate::error::AptuError::InvalidAIResponse(
70                    e
71                )))
72            }
73        }
74    }
75}
76
77/// Maximum length for issue body to stay within token limits.
78pub const MAX_BODY_LENGTH: usize = 4000;
79
80/// Maximum number of comments to include in the prompt.
81pub const MAX_COMMENTS: usize = 5;
82
83/// Maximum number of files to include in PR review prompt.
84pub const MAX_FILES: usize = 20;
85
86/// Maximum total diff size (in characters) for PR review prompt.
87pub const MAX_TOTAL_DIFF_SIZE: usize = 50_000;
88
89/// Maximum number of labels to include in the prompt.
90pub const MAX_LABELS: usize = 30;
91
92/// Maximum number of milestones to include in the prompt.
93pub const MAX_MILESTONES: usize = 10;
94
95/// Maximum characters per file's full content included in the PR review prompt.
96/// Content pre-truncated by `fetch_file_contents` may already be within this limit,
97/// but the prompt builder applies it as a second safety cap.
98pub const MAX_FULL_CONTENT_CHARS: usize = 4_000;
99
100/// Estimated overhead for XML tags, section headers, and schema preamble added by
101/// `build_pr_review_user_prompt`. Used to ensure the prompt budget accounts for
102/// non-content characters when estimating total prompt size.
103const PROMPT_OVERHEAD_CHARS: usize = 1_000;
104
105/// Preamble appended to every user-turn prompt to request a JSON response matching the schema.
106const SCHEMA_PREAMBLE: &str = "\n\nRespond with valid JSON matching this schema:\n";
107
108/// Matches `<pull_request>`, `</pull_request>`, `<issue_content>`, and `</issue_content>` tags
109/// (case-insensitive) used as prompt delimiters. These must be stripped from user-controlled
110/// fields to prevent prompt injection.
111///
112/// The pattern uses a simple alternation with no quantifiers, so `ReDoS` is not a concern:
113/// regex engine complexity is O(n) in the input length regardless of content.
114static XML_DELIMITERS: LazyLock<Regex> =
115    LazyLock::new(|| Regex::new(r"(?i)</?(?:pull_request|issue_content)>").expect("valid regex"));
116
117/// Removes `<pull_request>` / `</pull_request>` and `<issue_content>` / `</issue_content>`
118/// XML delimiter tags from a user-supplied string, preventing prompt injection via XML tag
119/// smuggling.
120///
121/// Tags are removed entirely (replaced with empty string) rather than substituted with a
122/// placeholder. A visible placeholder such as `[sanitized]` could cause the LLM to reason
123/// about the substitution marker itself, which is unnecessary and potentially confusing.
124///
125/// Nested or malformed XML is not a concern: the only delimiters this code inserts into
126/// prompts are the exact strings `<pull_request>` / `</pull_request>` and
127/// `<issue_content>` / `</issue_content>` (no attributes, no nesting). Stripping those
128/// fixed forms is sufficient to prevent a user-supplied value from breaking out of the
129/// delimiter boundary.
130///
131/// Applied to all user-controlled fields inside prompt delimiter blocks:
132/// - Issue triage: `issue.title`, `issue.body`, comment author/body, related issue
133///   title/state, label name/description, milestone title/description.
134/// - PR review: `pr.title`, `pr.body`, `file.filename`, `file.status`, patch content.
135fn sanitize_prompt_field(s: &str) -> String {
136    XML_DELIMITERS.replace_all(s, "").into_owned()
137}
138
139/// AI provider trait for issue triage and creation.
140///
141/// Defines the interface that all AI providers must implement.
142/// Default implementations are provided for shared logic.
143#[async_trait]
144pub trait AiProvider: Send + Sync {
145    /// Returns the name of the provider (e.g., "gemini", "openrouter").
146    fn name(&self) -> &str;
147
148    /// Returns the API URL for this provider.
149    fn api_url(&self) -> &str;
150
151    /// Returns the environment variable name for the API key.
152    fn api_key_env(&self) -> &str;
153
154    /// Returns the HTTP client for making requests.
155    fn http_client(&self) -> &Client;
156
157    /// Returns the API key for authentication.
158    fn api_key(&self) -> &SecretString;
159
160    /// Returns the model name.
161    fn model(&self) -> &str;
162
163    /// Returns the maximum tokens for API responses.
164    fn max_tokens(&self) -> u32;
165
166    /// Returns the temperature for API requests.
167    fn temperature(&self) -> f32;
168
169    /// Returns the maximum retry attempts for rate-limited requests.
170    ///
171    /// Default implementation returns 3. Providers can override
172    /// to use a different retry limit.
173    fn max_attempts(&self) -> u32 {
174        3
175    }
176
177    /// Returns the circuit breaker for this provider (optional).
178    ///
179    /// Default implementation returns None. Providers can override
180    /// to provide circuit breaker functionality.
181    fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
182        None
183    }
184
185    /// Builds HTTP headers for API requests.
186    ///
187    /// Default implementation includes Authorization and Content-Type headers.
188    /// Providers can override to add custom headers.
189    fn build_headers(&self) -> reqwest::header::HeaderMap {
190        let mut headers = reqwest::header::HeaderMap::new();
191        if let Ok(val) = "application/json".parse() {
192            headers.insert("Content-Type", val);
193        }
194        headers
195    }
196
197    /// Validates the model configuration.
198    ///
199    /// Default implementation does nothing. Providers can override
200    /// to enforce constraints (e.g., free tier validation).
201    fn validate_model(&self) -> Result<()> {
202        Ok(())
203    }
204
205    /// Returns the custom guidance string for system prompt injection, if set.
206    ///
207    /// Default implementation returns `None`. Providers that store custom guidance
208    /// (e.g., from `AiConfig`) override this to supply it.
209    fn custom_guidance(&self) -> Option<&str> {
210        None
211    }
212
213    /// Sends a chat completion request to the provider's API (HTTP-only, no retry).
214    ///
215    /// Default implementation handles HTTP headers, error responses (401, 429).
216    /// Does not include retry logic - use `send_and_parse()` for retry behavior.
217    #[instrument(skip(self, request), fields(provider = self.name(), model = self.model()))]
218    async fn send_request_inner(
219        &self,
220        request: &ChatCompletionRequest,
221    ) -> Result<ChatCompletionResponse> {
222        use secrecy::ExposeSecret;
223        use tracing::warn;
224
225        use crate::error::AptuError;
226
227        let mut req = self.http_client().post(self.api_url());
228
229        // Add Authorization header
230        req = req.header(
231            "Authorization",
232            format!("Bearer {}", self.api_key().expose_secret()),
233        );
234
235        // Add custom headers from provider
236        for (key, value) in &self.build_headers() {
237            req = req.header(key.clone(), value.clone());
238        }
239
240        let response = req
241            .json(request)
242            .send()
243            .await
244            .context(format!("Failed to send request to {} API", self.name()))?;
245
246        // Check for HTTP errors
247        let status = response.status();
248        if !status.is_success() {
249            if status.as_u16() == 401 {
250                anyhow::bail!(
251                    "Invalid {} API key. Check your {} environment variable.",
252                    self.name(),
253                    self.api_key_env()
254                );
255            } else if status.as_u16() == 429 {
256                warn!("Rate limited by {} API", self.name());
257                // Parse Retry-After header (seconds), default to 0 if not present
258                let retry_after = response
259                    .headers()
260                    .get("Retry-After")
261                    .and_then(|h| h.to_str().ok())
262                    .and_then(|s| s.parse::<u64>().ok())
263                    .unwrap_or(0);
264                debug!(retry_after, "Parsed Retry-After header");
265                return Err(AptuError::RateLimited {
266                    provider: self.name().to_string(),
267                    retry_after,
268                }
269                .into());
270            }
271            let error_body = response.text().await.unwrap_or_default();
272            anyhow::bail!(
273                "{} API error (HTTP {}): {}",
274                self.name(),
275                status.as_u16(),
276                redact_api_error_body(&error_body)
277            );
278        }
279
280        // Parse response
281        let completion: ChatCompletionResponse = response
282            .json()
283            .await
284            .context(format!("Failed to parse {} API response", self.name()))?;
285
286        Ok(completion)
287    }
288
289    /// Sends a chat completion request and parses the response with retry logic.
290    ///
291    /// This method wraps both HTTP request and JSON parsing in a single retry loop,
292    /// allowing truncated responses to be retried. Includes circuit breaker handling.
293    ///
294    /// # Arguments
295    ///
296    /// * `request` - The chat completion request to send
297    ///
298    /// # Returns
299    ///
300    /// A tuple of (parsed response, stats) extracted from the API response
301    ///
302    /// # Errors
303    ///
304    /// Returns an error if:
305    /// - API request fails (network, timeout, rate limit)
306    /// - Response cannot be parsed as valid JSON (including truncated responses)
307    #[instrument(skip(self, request), fields(provider = self.name(), model = self.model()))]
308    async fn send_and_parse<T: serde::de::DeserializeOwned + Send>(
309        &self,
310        request: &ChatCompletionRequest,
311    ) -> Result<(T, AiStats)> {
312        use tracing::{info, warn};
313
314        use crate::error::AptuError;
315        use crate::retry::{extract_retry_after, is_retryable_anyhow};
316
317        // Check circuit breaker before attempting request
318        if let Some(cb) = self.circuit_breaker()
319            && cb.is_open()
320        {
321            return Err(AptuError::CircuitOpen.into());
322        }
323
324        // Start timing (outside retry loop to measure total time including retries)
325        let start = std::time::Instant::now();
326
327        // Custom retry loop that respects retry_after from RateLimited errors
328        let mut attempt: u32 = 0;
329        let max_attempts: u32 = self.max_attempts();
330
331        // Helper function to avoid closure-in-expression clippy warning
332        #[allow(clippy::items_after_statements)]
333        async fn try_request<T: serde::de::DeserializeOwned>(
334            provider: &(impl AiProvider + ?Sized),
335            request: &ChatCompletionRequest,
336        ) -> Result<(T, ChatCompletionResponse)> {
337            // Send HTTP request
338            let completion = provider.send_request_inner(request).await?;
339
340            // Extract message content
341            let content = completion
342                .choices
343                .first()
344                .and_then(|c| {
345                    c.message
346                        .content
347                        .clone()
348                        .or_else(|| c.message.reasoning.clone())
349                })
350                .context("No response from AI model")?;
351
352            debug!(response_length = content.len(), "Received AI response");
353
354            // Parse JSON response (inside retry loop, so truncated responses are retried)
355            let parsed: T = parse_ai_json(&content, provider.name())?;
356
357            Ok((parsed, completion))
358        }
359
360        let (parsed, completion): (T, ChatCompletionResponse) = loop {
361            attempt += 1;
362
363            let result = try_request(self, request).await;
364
365            match result {
366                Ok(success) => break success,
367                Err(err) => {
368                    // Check if error is retryable
369                    if !is_retryable_anyhow(&err) || attempt >= max_attempts {
370                        return Err(err);
371                    }
372
373                    // Extract retry_after if present, otherwise use exponential backoff
374                    let delay = if let Some(retry_after_duration) = extract_retry_after(&err) {
375                        debug!(
376                            retry_after_secs = retry_after_duration.as_secs(),
377                            "Using Retry-After value from rate limit error"
378                        );
379                        retry_after_duration
380                    } else {
381                        // Use exponential backoff with jitter: 1s, 2s, 4s + 0-500ms
382                        let backoff_secs = 2_u64.pow(attempt.saturating_sub(1));
383                        let jitter_ms = fastrand::u64(0..500);
384                        std::time::Duration::from_millis(backoff_secs * 1000 + jitter_ms)
385                    };
386
387                    let error_msg = err.to_string();
388                    warn!(
389                        error = %error_msg,
390                        delay_secs = delay.as_secs(),
391                        attempt,
392                        max_attempts,
393                        "Retrying after error"
394                    );
395
396                    // Drop err before await to avoid holding non-Send value across await
397                    drop(err);
398                    tokio::time::sleep(delay).await;
399                }
400            }
401        };
402
403        // Record success in circuit breaker
404        if let Some(cb) = self.circuit_breaker() {
405            cb.record_success();
406        }
407
408        // Calculate duration (total time including any retries)
409        #[allow(clippy::cast_possible_truncation)]
410        let duration_ms = start.elapsed().as_millis() as u64;
411
412        // Build AI stats from usage info (trust API's cost field)
413        let (input_tokens, output_tokens, cost_usd) = if let Some(usage) = completion.usage {
414            (usage.prompt_tokens, usage.completion_tokens, usage.cost)
415        } else {
416            // If no usage info, default to 0
417            debug!("No usage information in API response");
418            (0, 0, None)
419        };
420
421        let ai_stats = AiStats {
422            provider: self.name().to_string(),
423            model: self.model().to_string(),
424            input_tokens,
425            output_tokens,
426            duration_ms,
427            cost_usd,
428            fallback_provider: None,
429            prompt_chars: 0,
430        };
431
432        // Emit structured metrics
433        info!(
434            duration_ms,
435            input_tokens,
436            output_tokens,
437            cost_usd = ?cost_usd,
438            model = %self.model(),
439            "AI request completed"
440        );
441
442        Ok((parsed, ai_stats))
443    }
444
445    /// Analyzes a GitHub issue using the provider's API.
446    ///
447    /// Returns a structured triage response with summary, labels, questions, duplicates, and usage stats.
448    ///
449    /// # Arguments
450    ///
451    /// * `issue` - Issue details to analyze
452    ///
453    /// # Errors
454    ///
455    /// Returns an error if:
456    /// - API request fails (network, timeout, rate limit)
457    /// - Response cannot be parsed as valid JSON
458    #[instrument(skip(self, issue), fields(issue_number = issue.number, repo = %format!("{}/{}", issue.owner, issue.repo)))]
459    async fn analyze_issue(&self, issue: &IssueDetails) -> Result<AiResponse> {
460        debug!(model = %self.model(), "Calling {} API", self.name());
461
462        // Build request
463        let system_content = if let Some(override_prompt) =
464            super::context::load_system_prompt_override("triage_system").await
465        {
466            override_prompt
467        } else {
468            Self::build_system_prompt(self.custom_guidance())
469        };
470
471        let request = ChatCompletionRequest {
472            model: self.model().to_string(),
473            messages: vec![
474                ChatMessage {
475                    role: "system".to_string(),
476                    content: Some(system_content),
477                    reasoning: None,
478                },
479                ChatMessage {
480                    role: "user".to_string(),
481                    content: Some(Self::build_user_prompt(issue)),
482                    reasoning: None,
483                },
484            ],
485            response_format: Some(ResponseFormat {
486                format_type: "json_object".to_string(),
487                json_schema: None,
488            }),
489            max_tokens: Some(self.max_tokens()),
490            temperature: Some(self.temperature()),
491        };
492
493        // Send request and parse JSON with retry logic
494        let (triage, ai_stats) = self.send_and_parse::<TriageResponse>(&request).await?;
495
496        debug!(
497            input_tokens = ai_stats.input_tokens,
498            output_tokens = ai_stats.output_tokens,
499            duration_ms = ai_stats.duration_ms,
500            cost_usd = ?ai_stats.cost_usd,
501            "AI analysis complete"
502        );
503
504        Ok(AiResponse {
505            triage,
506            stats: ai_stats,
507        })
508    }
509
510    /// Creates a formatted GitHub issue using the provider's API.
511    ///
512    /// Takes raw issue title and body, formats them using AI (conventional commit style,
513    /// structured body), and returns the formatted content with suggested labels.
514    ///
515    /// # Arguments
516    ///
517    /// * `title` - Raw issue title from user
518    /// * `body` - Raw issue body/description from user
519    /// * `repo` - Repository name for context (owner/repo format)
520    ///
521    /// # Errors
522    ///
523    /// Returns an error if:
524    /// - API request fails (network, timeout, rate limit)
525    /// - Response cannot be parsed as valid JSON
526    #[instrument(skip(self), fields(repo = %repo))]
527    async fn create_issue(
528        &self,
529        title: &str,
530        body: &str,
531        repo: &str,
532    ) -> Result<(super::types::CreateIssueResponse, AiStats)> {
533        debug!(model = %self.model(), "Calling {} API for issue creation", self.name());
534
535        // Build request
536        let system_content = if let Some(override_prompt) =
537            super::context::load_system_prompt_override("create_system").await
538        {
539            override_prompt
540        } else {
541            Self::build_create_system_prompt(self.custom_guidance())
542        };
543
544        let request = ChatCompletionRequest {
545            model: self.model().to_string(),
546            messages: vec![
547                ChatMessage {
548                    role: "system".to_string(),
549                    content: Some(system_content),
550                    reasoning: None,
551                },
552                ChatMessage {
553                    role: "user".to_string(),
554                    content: Some(Self::build_create_user_prompt(title, body, repo)),
555                    reasoning: None,
556                },
557            ],
558            response_format: Some(ResponseFormat {
559                format_type: "json_object".to_string(),
560                json_schema: None,
561            }),
562            max_tokens: Some(self.max_tokens()),
563            temperature: Some(self.temperature()),
564        };
565
566        // Send request and parse JSON with retry logic
567        let (create_response, ai_stats) = self
568            .send_and_parse::<super::types::CreateIssueResponse>(&request)
569            .await?;
570
571        debug!(
572            title_len = create_response.formatted_title.len(),
573            body_len = create_response.formatted_body.len(),
574            labels = create_response.suggested_labels.len(),
575            input_tokens = ai_stats.input_tokens,
576            output_tokens = ai_stats.output_tokens,
577            duration_ms = ai_stats.duration_ms,
578            "Issue formatting complete with stats"
579        );
580
581        Ok((create_response, ai_stats))
582    }
583
584    /// Builds the system prompt for issue triage.
585    #[must_use]
586    fn build_system_prompt(custom_guidance: Option<&str>) -> String {
587        let context = super::context::load_custom_guidance(custom_guidance);
588        build_triage_system_prompt(&context)
589    }
590
591    /// Builds the user prompt containing the issue details.
592    #[must_use]
593    fn build_user_prompt(issue: &IssueDetails) -> String {
594        use std::fmt::Write;
595
596        let mut prompt = String::new();
597
598        prompt.push_str("<issue_content>\n");
599        let _ = writeln!(prompt, "Title: {}\n", sanitize_prompt_field(&issue.title));
600
601        // Sanitize body before truncation (injection tag could straddle the boundary)
602        let sanitized_body = sanitize_prompt_field(&issue.body);
603        let body = if sanitized_body.len() > MAX_BODY_LENGTH {
604            format!(
605                "{}...\n[Body truncated - original length: {} chars]",
606                &sanitized_body[..MAX_BODY_LENGTH],
607                sanitized_body.len()
608            )
609        } else if sanitized_body.is_empty() {
610            "[No description provided]".to_string()
611        } else {
612            sanitized_body
613        };
614        let _ = writeln!(prompt, "Body:\n{body}\n");
615
616        // Include existing labels
617        if !issue.labels.is_empty() {
618            let _ = writeln!(prompt, "Existing Labels: {}\n", issue.labels.join(", "));
619        }
620
621        // Include recent comments (limited)
622        if !issue.comments.is_empty() {
623            prompt.push_str("Recent Comments:\n");
624            for comment in issue.comments.iter().take(MAX_COMMENTS) {
625                let sanitized_comment_body = sanitize_prompt_field(&comment.body);
626                let comment_body = if sanitized_comment_body.len() > 500 {
627                    format!("{}...", &sanitized_comment_body[..500])
628                } else {
629                    sanitized_comment_body
630                };
631                let _ = writeln!(
632                    prompt,
633                    "- @{}: {}",
634                    sanitize_prompt_field(&comment.author),
635                    comment_body
636                );
637            }
638            prompt.push('\n');
639        }
640
641        // Include related issues from search (for context)
642        if !issue.repo_context.is_empty() {
643            prompt.push_str("Related Issues in Repository (for context):\n");
644            for related in issue.repo_context.iter().take(10) {
645                let _ = writeln!(
646                    prompt,
647                    "- #{} [{}] {}",
648                    related.number,
649                    sanitize_prompt_field(&related.state),
650                    sanitize_prompt_field(&related.title)
651                );
652            }
653            prompt.push('\n');
654        }
655
656        // Include repository structure (source files)
657        if !issue.repo_tree.is_empty() {
658            prompt.push_str("Repository Structure (source files):\n");
659            for path in issue.repo_tree.iter().take(20) {
660                let _ = writeln!(prompt, "- {path}");
661            }
662            prompt.push('\n');
663        }
664
665        // Include available labels
666        if !issue.available_labels.is_empty() {
667            prompt.push_str("Available Labels:\n");
668            for label in issue.available_labels.iter().take(MAX_LABELS) {
669                let description = if label.description.is_empty() {
670                    String::new()
671                } else {
672                    format!(" - {}", sanitize_prompt_field(&label.description))
673                };
674                let _ = writeln!(
675                    prompt,
676                    "- {} (color: #{}){}",
677                    sanitize_prompt_field(&label.name),
678                    label.color,
679                    description
680                );
681            }
682            prompt.push('\n');
683        }
684
685        // Include available milestones
686        if !issue.available_milestones.is_empty() {
687            prompt.push_str("Available Milestones:\n");
688            for milestone in issue.available_milestones.iter().take(MAX_MILESTONES) {
689                let description = if milestone.description.is_empty() {
690                    String::new()
691                } else {
692                    format!(" - {}", sanitize_prompt_field(&milestone.description))
693                };
694                let _ = writeln!(
695                    prompt,
696                    "- {}{}",
697                    sanitize_prompt_field(&milestone.title),
698                    description
699                );
700            }
701            prompt.push('\n');
702        }
703
704        prompt.push_str("</issue_content>");
705        prompt.push_str(SCHEMA_PREAMBLE);
706        prompt.push_str(crate::ai::prompts::TRIAGE_SCHEMA);
707
708        prompt
709    }
710
711    /// Builds the system prompt for issue creation/formatting.
712    #[must_use]
713    fn build_create_system_prompt(custom_guidance: Option<&str>) -> String {
714        let context = super::context::load_custom_guidance(custom_guidance);
715        build_create_system_prompt(&context)
716    }
717
718    /// Builds the user prompt for issue creation/formatting.
719    #[must_use]
720    fn build_create_user_prompt(title: &str, body: &str, _repo: &str) -> String {
721        let sanitized_title = sanitize_prompt_field(title);
722        let sanitized_body = sanitize_prompt_field(body);
723        format!(
724            "Please format this GitHub issue:\n\nTitle: {sanitized_title}\n\nBody:\n{sanitized_body}{}{}",
725            SCHEMA_PREAMBLE,
726            crate::ai::prompts::CREATE_SCHEMA
727        )
728    }
729
730    /// Reviews a pull request using the provider's API.
731    ///
732    /// Analyzes PR metadata and file diffs to provide structured review feedback.
733    ///
734    /// # Arguments
735    ///
736    /// * `pr` - Pull request details including files and diffs
737    ///
738    /// # Errors
739    ///
740    /// Returns an error if:
741    /// - API request fails (network, timeout, rate limit)
742    /// - Response cannot be parsed as valid JSON
743    #[instrument(skip(self, pr, ast_context, call_graph), fields(pr_number = pr.number, repo = %format!("{}/{}", pr.owner, pr.repo)))]
744    async fn review_pr(
745        &self,
746        pr: &super::types::PrDetails,
747        mut ast_context: String,
748        mut call_graph: String,
749        review_config: &crate::config::ReviewConfig,
750    ) -> Result<(super::types::PrReviewResponse, AiStats)> {
751        debug!(model = %self.model(), "Calling {} API for PR review", self.name());
752
753        // Estimate preliminary size; enforce drop order for budget control
754        let mut estimated_size = pr.title.len()
755            + pr.body.len()
756            + pr.files
757                .iter()
758                .map(|f| f.patch.as_ref().map_or(0, String::len))
759                .sum::<usize>()
760            + pr.files
761                .iter()
762                .map(|f| f.full_content.as_ref().map_or(0, String::len))
763                .sum::<usize>()
764            + ast_context.len()
765            + call_graph.len()
766            + PROMPT_OVERHEAD_CHARS;
767
768        let max_prompt_chars = review_config.max_prompt_chars;
769
770        // Drop call_graph if over budget
771        if estimated_size > max_prompt_chars {
772            tracing::warn!(
773                section = "call_graph",
774                chars = call_graph.len(),
775                "Dropping section: prompt budget exceeded"
776            );
777            let dropped_chars = call_graph.len();
778            call_graph.clear();
779            estimated_size -= dropped_chars;
780        }
781
782        // Drop ast_context if still over budget
783        if estimated_size > max_prompt_chars {
784            tracing::warn!(
785                section = "ast_context",
786                chars = ast_context.len(),
787                "Dropping section: prompt budget exceeded"
788            );
789            let dropped_chars = ast_context.len();
790            ast_context.clear();
791            estimated_size -= dropped_chars;
792        }
793
794        // Step 3: Drop largest file patches first if still over budget
795        let mut pr_mut = pr.clone();
796        if estimated_size > max_prompt_chars {
797            // Collect files with their patch sizes
798            let mut file_sizes: Vec<(usize, usize)> = pr_mut
799                .files
800                .iter()
801                .enumerate()
802                .map(|(idx, f)| (idx, f.patch.as_ref().map_or(0, String::len)))
803                .collect();
804            // Sort by patch size descending
805            file_sizes.sort_by(|a, b| b.1.cmp(&a.1));
806
807            for (file_idx, patch_size) in file_sizes {
808                if estimated_size <= max_prompt_chars {
809                    break;
810                }
811                if patch_size > 0 {
812                    tracing::warn!(
813                        file = %pr_mut.files[file_idx].filename,
814                        patch_chars = patch_size,
815                        "Dropping file patch: prompt budget exceeded"
816                    );
817                    pr_mut.files[file_idx].patch = None;
818                    estimated_size -= patch_size;
819                }
820            }
821        }
822
823        // Step 4: drop full_content on all files
824        if estimated_size > max_prompt_chars {
825            for file in &mut pr_mut.files {
826                if let Some(fc) = file.full_content.take() {
827                    estimated_size = estimated_size.saturating_sub(fc.len());
828                    tracing::warn!(
829                        bytes = fc.len(),
830                        filename = %file.filename,
831                        "prompt budget: dropping full_content"
832                    );
833                }
834            }
835        }
836
837        tracing::info!(
838            prompt_chars = estimated_size,
839            max_chars = max_prompt_chars,
840            "PR review prompt assembled"
841        );
842
843        // Build request
844        let system_content = if let Some(override_prompt) =
845            super::context::load_system_prompt_override("pr_review_system").await
846        {
847            override_prompt
848        } else {
849            Self::build_pr_review_system_prompt(self.custom_guidance())
850        };
851
852        // Assemble full prompt to measure actual size
853        let assembled_prompt =
854            Self::build_pr_review_user_prompt(&pr_mut, &ast_context, &call_graph);
855        let actual_prompt_chars = assembled_prompt.len();
856
857        tracing::info!(
858            actual_prompt_chars,
859            estimated_prompt_chars = estimated_size,
860            max_chars = max_prompt_chars,
861            "Actual assembled prompt size vs. estimate"
862        );
863
864        let request = ChatCompletionRequest {
865            model: self.model().to_string(),
866            messages: vec![
867                ChatMessage {
868                    role: "system".to_string(),
869                    content: Some(system_content),
870                    reasoning: None,
871                },
872                ChatMessage {
873                    role: "user".to_string(),
874                    content: Some(assembled_prompt),
875                    reasoning: None,
876                },
877            ],
878            response_format: Some(ResponseFormat {
879                format_type: "json_object".to_string(),
880                json_schema: None,
881            }),
882            max_tokens: Some(self.max_tokens()),
883            temperature: Some(self.temperature()),
884        };
885
886        // Send request and parse JSON with retry logic
887        let (review, mut ai_stats) = self
888            .send_and_parse::<super::types::PrReviewResponse>(&request)
889            .await?;
890
891        ai_stats.prompt_chars = actual_prompt_chars;
892
893        debug!(
894            verdict = %review.verdict,
895            input_tokens = ai_stats.input_tokens,
896            output_tokens = ai_stats.output_tokens,
897            duration_ms = ai_stats.duration_ms,
898            prompt_chars = ai_stats.prompt_chars,
899            "PR review complete with stats"
900        );
901
902        Ok((review, ai_stats))
903    }
904
905    /// Suggests labels for a pull request using the provider's API.
906    ///
907    /// Analyzes PR title, body, and file paths to suggest relevant labels.
908    ///
909    /// # Arguments
910    ///
911    /// * `title` - Pull request title
912    /// * `body` - Pull request description
913    /// * `file_paths` - List of file paths changed in the PR
914    ///
915    /// # Errors
916    ///
917    /// Returns an error if:
918    /// - API request fails (network, timeout, rate limit)
919    /// - Response cannot be parsed as valid JSON
920    #[instrument(skip(self), fields(title = %title))]
921    async fn suggest_pr_labels(
922        &self,
923        title: &str,
924        body: &str,
925        file_paths: &[String],
926    ) -> Result<(Vec<String>, AiStats)> {
927        debug!(model = %self.model(), "Calling {} API for PR label suggestion", self.name());
928
929        // Build request
930        let system_content = if let Some(override_prompt) =
931            super::context::load_system_prompt_override("pr_label_system").await
932        {
933            override_prompt
934        } else {
935            Self::build_pr_label_system_prompt(self.custom_guidance())
936        };
937
938        let request = ChatCompletionRequest {
939            model: self.model().to_string(),
940            messages: vec![
941                ChatMessage {
942                    role: "system".to_string(),
943                    content: Some(system_content),
944                    reasoning: None,
945                },
946                ChatMessage {
947                    role: "user".to_string(),
948                    content: Some(Self::build_pr_label_user_prompt(title, body, file_paths)),
949                    reasoning: None,
950                },
951            ],
952            response_format: Some(ResponseFormat {
953                format_type: "json_object".to_string(),
954                json_schema: None,
955            }),
956            max_tokens: Some(self.max_tokens()),
957            temperature: Some(self.temperature()),
958        };
959
960        // Send request and parse JSON with retry logic
961        let (response, ai_stats) = self
962            .send_and_parse::<super::types::PrLabelResponse>(&request)
963            .await?;
964
965        debug!(
966            label_count = response.suggested_labels.len(),
967            input_tokens = ai_stats.input_tokens,
968            output_tokens = ai_stats.output_tokens,
969            duration_ms = ai_stats.duration_ms,
970            "PR label suggestion complete with stats"
971        );
972
973        Ok((response.suggested_labels, ai_stats))
974    }
975
976    /// Builds the system prompt for PR review.
977    #[must_use]
978    fn build_pr_review_system_prompt(custom_guidance: Option<&str>) -> String {
979        let context = super::context::load_custom_guidance(custom_guidance);
980        build_pr_review_system_prompt(&context)
981    }
982
983    /// Builds the user prompt for PR review.
984    ///
985    /// All user-controlled fields (title, body, filename, status, patch) are sanitized via
986    /// [`sanitize_prompt_field`] before being written into the prompt to prevent prompt
987    /// injection via XML tag smuggling.
988    #[must_use]
989    fn build_pr_review_user_prompt(
990        pr: &super::types::PrDetails,
991        ast_context: &str,
992        call_graph: &str,
993    ) -> String {
994        use std::fmt::Write;
995
996        let mut prompt = String::new();
997
998        prompt.push_str("<pull_request>\n");
999        let _ = writeln!(prompt, "Title: {}\n", sanitize_prompt_field(&pr.title));
1000        let _ = writeln!(prompt, "Branch: {} -> {}\n", pr.head_branch, pr.base_branch);
1001
1002        // PR description - sanitize before truncation
1003        let sanitized_body = sanitize_prompt_field(&pr.body);
1004        let body = if sanitized_body.is_empty() {
1005            "[No description provided]".to_string()
1006        } else if sanitized_body.len() > MAX_BODY_LENGTH {
1007            format!(
1008                "{}...\n[Description truncated - original length: {} chars]",
1009                &sanitized_body[..MAX_BODY_LENGTH],
1010                sanitized_body.len()
1011            )
1012        } else {
1013            sanitized_body
1014        };
1015        let _ = writeln!(prompt, "Description:\n{body}\n");
1016
1017        // File changes with limits
1018        prompt.push_str("Files Changed:\n");
1019        let mut total_diff_size = 0;
1020        let mut files_included = 0;
1021        let mut files_skipped = 0;
1022
1023        for file in &pr.files {
1024            // Check file count limit
1025            if files_included >= MAX_FILES {
1026                files_skipped += 1;
1027                continue;
1028            }
1029
1030            let _ = writeln!(
1031                prompt,
1032                "- {} ({}) +{} -{}\n",
1033                sanitize_prompt_field(&file.filename),
1034                sanitize_prompt_field(&file.status),
1035                file.additions,
1036                file.deletions
1037            );
1038
1039            // Include patch if available (sanitize then truncate large patches)
1040            if let Some(patch) = &file.patch {
1041                const MAX_PATCH_LENGTH: usize = 2000;
1042                let sanitized_patch = sanitize_prompt_field(patch);
1043                let patch_content = if sanitized_patch.len() > MAX_PATCH_LENGTH {
1044                    format!(
1045                        "{}...\n[Patch truncated - original length: {} chars]",
1046                        &sanitized_patch[..MAX_PATCH_LENGTH],
1047                        sanitized_patch.len()
1048                    )
1049                } else {
1050                    sanitized_patch
1051                };
1052
1053                // Check if adding this patch would exceed total diff size limit
1054                let patch_size = patch_content.len();
1055                if total_diff_size + patch_size > MAX_TOTAL_DIFF_SIZE {
1056                    let _ = writeln!(
1057                        prompt,
1058                        "```diff\n[Patch omitted - total diff size limit reached]\n```\n"
1059                    );
1060                    files_skipped += 1;
1061                    continue;
1062                }
1063
1064                let _ = writeln!(prompt, "```diff\n{patch_content}\n```\n");
1065                total_diff_size += patch_size;
1066            }
1067
1068            // Include full file content if available (cap at MAX_FULL_CONTENT_CHARS)
1069            if let Some(content) = &file.full_content {
1070                let sanitized = sanitize_prompt_field(content);
1071                let displayed = if sanitized.len() > MAX_FULL_CONTENT_CHARS {
1072                    sanitized[..MAX_FULL_CONTENT_CHARS].to_string()
1073                } else {
1074                    sanitized
1075                };
1076                let _ = writeln!(
1077                    prompt,
1078                    "<file_content path=\"{}\">\n{}\n</file_content>\n",
1079                    sanitize_prompt_field(&file.filename),
1080                    displayed
1081                );
1082            }
1083
1084            files_included += 1;
1085        }
1086
1087        // Add truncation message if files were skipped
1088        if files_skipped > 0 {
1089            let _ = writeln!(
1090                prompt,
1091                "\n[{files_skipped} files omitted due to size limits (MAX_FILES={MAX_FILES}, MAX_TOTAL_DIFF_SIZE={MAX_TOTAL_DIFF_SIZE})]"
1092            );
1093        }
1094
1095        prompt.push_str("</pull_request>");
1096        if !ast_context.is_empty() {
1097            prompt.push_str(ast_context);
1098        }
1099        if !call_graph.is_empty() {
1100            prompt.push_str(call_graph);
1101        }
1102        prompt.push_str(SCHEMA_PREAMBLE);
1103        prompt.push_str(crate::ai::prompts::PR_REVIEW_SCHEMA);
1104
1105        prompt
1106    }
1107
1108    /// Builds the system prompt for PR label suggestion.
1109    #[must_use]
1110    fn build_pr_label_system_prompt(custom_guidance: Option<&str>) -> String {
1111        let context = super::context::load_custom_guidance(custom_guidance);
1112        build_pr_label_system_prompt(&context)
1113    }
1114
1115    /// Builds the user prompt for PR label suggestion.
1116    #[must_use]
1117    fn build_pr_label_user_prompt(title: &str, body: &str, file_paths: &[String]) -> String {
1118        use std::fmt::Write;
1119
1120        let mut prompt = String::new();
1121
1122        prompt.push_str("<pull_request>\n");
1123        let _ = writeln!(prompt, "Title: {title}\n");
1124
1125        // PR description
1126        let body_content = if body.is_empty() {
1127            "[No description provided]".to_string()
1128        } else if body.len() > MAX_BODY_LENGTH {
1129            format!(
1130                "{}...\n[Description truncated - original length: {} chars]",
1131                &body[..MAX_BODY_LENGTH],
1132                body.len()
1133            )
1134        } else {
1135            body.to_string()
1136        };
1137        let _ = writeln!(prompt, "Description:\n{body_content}\n");
1138
1139        // File paths
1140        if !file_paths.is_empty() {
1141            prompt.push_str("Files Changed:\n");
1142            for path in file_paths.iter().take(20) {
1143                let _ = writeln!(prompt, "- {path}");
1144            }
1145            if file_paths.len() > 20 {
1146                let _ = writeln!(prompt, "- ... and {} more files", file_paths.len() - 20);
1147            }
1148            prompt.push('\n');
1149        }
1150
1151        prompt.push_str("</pull_request>");
1152        prompt.push_str(SCHEMA_PREAMBLE);
1153        prompt.push_str(crate::ai::prompts::PR_LABEL_SCHEMA);
1154
1155        prompt
1156    }
1157}
1158
1159#[cfg(test)]
1160mod tests {
1161    use super::*;
1162
1163    /// Shared struct for parse_ai_json error-path tests.
1164    /// The field is only used via serde deserialization; `_message` silences dead_code.
1165    #[derive(Debug, serde::Deserialize)]
1166    struct ErrorTestResponse {
1167        _message: String,
1168    }
1169
1170    struct TestProvider;
1171
1172    impl AiProvider for TestProvider {
1173        fn name(&self) -> &'static str {
1174            "test"
1175        }
1176
1177        fn api_url(&self) -> &'static str {
1178            "https://test.example.com"
1179        }
1180
1181        fn api_key_env(&self) -> &'static str {
1182            "TEST_API_KEY"
1183        }
1184
1185        fn http_client(&self) -> &Client {
1186            unimplemented!()
1187        }
1188
1189        fn api_key(&self) -> &SecretString {
1190            unimplemented!()
1191        }
1192
1193        fn model(&self) -> &'static str {
1194            "test-model"
1195        }
1196
1197        fn max_tokens(&self) -> u32 {
1198            2048
1199        }
1200
1201        fn temperature(&self) -> f32 {
1202            0.3
1203        }
1204    }
1205
1206    #[test]
1207    fn test_build_system_prompt_contains_json_schema() {
1208        let system_prompt = TestProvider::build_system_prompt(None);
1209        // Schema description strings are unique to the schema file and must NOT appear in the
1210        // system prompt after moving schema injection to the user turn.
1211        assert!(
1212            !system_prompt
1213                .contains("A 2-3 sentence summary of what the issue is about and its impact")
1214        );
1215
1216        // Schema MUST appear in the user prompt
1217        let issue = IssueDetails::builder()
1218            .owner("test".to_string())
1219            .repo("repo".to_string())
1220            .number(1)
1221            .title("Test".to_string())
1222            .body("Body".to_string())
1223            .labels(vec![])
1224            .comments(vec![])
1225            .url("https://github.com/test/repo/issues/1".to_string())
1226            .build();
1227        let user_prompt = TestProvider::build_user_prompt(&issue);
1228        assert!(
1229            user_prompt
1230                .contains("A 2-3 sentence summary of what the issue is about and its impact")
1231        );
1232        assert!(user_prompt.contains("suggested_labels"));
1233    }
1234
1235    #[test]
1236    fn test_build_user_prompt_with_delimiters() {
1237        let issue = IssueDetails::builder()
1238            .owner("test".to_string())
1239            .repo("repo".to_string())
1240            .number(1)
1241            .title("Test issue".to_string())
1242            .body("This is the body".to_string())
1243            .labels(vec!["bug".to_string()])
1244            .comments(vec![])
1245            .url("https://github.com/test/repo/issues/1".to_string())
1246            .build();
1247
1248        let prompt = TestProvider::build_user_prompt(&issue);
1249        assert!(prompt.starts_with("<issue_content>"));
1250        assert!(prompt.contains("</issue_content>"));
1251        assert!(prompt.contains("Respond with valid JSON matching this schema"));
1252        assert!(prompt.contains("Title: Test issue"));
1253        assert!(prompt.contains("This is the body"));
1254        assert!(prompt.contains("Existing Labels: bug"));
1255    }
1256
1257    #[test]
1258    fn test_build_user_prompt_truncates_long_body() {
1259        let long_body = "x".repeat(5000);
1260        let issue = IssueDetails::builder()
1261            .owner("test".to_string())
1262            .repo("repo".to_string())
1263            .number(1)
1264            .title("Test".to_string())
1265            .body(long_body)
1266            .labels(vec![])
1267            .comments(vec![])
1268            .url("https://github.com/test/repo/issues/1".to_string())
1269            .build();
1270
1271        let prompt = TestProvider::build_user_prompt(&issue);
1272        assert!(prompt.contains("[Body truncated"));
1273        assert!(prompt.contains("5000 chars"));
1274    }
1275
1276    #[test]
1277    fn test_build_user_prompt_empty_body() {
1278        let issue = IssueDetails::builder()
1279            .owner("test".to_string())
1280            .repo("repo".to_string())
1281            .number(1)
1282            .title("Test".to_string())
1283            .body(String::new())
1284            .labels(vec![])
1285            .comments(vec![])
1286            .url("https://github.com/test/repo/issues/1".to_string())
1287            .build();
1288
1289        let prompt = TestProvider::build_user_prompt(&issue);
1290        assert!(prompt.contains("[No description provided]"));
1291    }
1292
1293    #[test]
1294    fn test_build_create_system_prompt_contains_json_schema() {
1295        let system_prompt = TestProvider::build_create_system_prompt(None);
1296        // Schema description strings are unique to the schema file and must NOT appear in system prompt.
1297        assert!(
1298            !system_prompt
1299                .contains("Well-formatted issue title following conventional commit style")
1300        );
1301
1302        // Schema MUST appear in the user prompt
1303        let user_prompt =
1304            TestProvider::build_create_user_prompt("My title", "My body", "test/repo");
1305        assert!(
1306            user_prompt.contains("Well-formatted issue title following conventional commit style")
1307        );
1308        assert!(user_prompt.contains("formatted_body"));
1309    }
1310
1311    #[test]
1312    fn test_build_pr_review_user_prompt_respects_file_limit() {
1313        use super::super::types::{PrDetails, PrFile};
1314
1315        let mut files = Vec::new();
1316        for i in 0..25 {
1317            files.push(PrFile {
1318                filename: format!("file{i}.rs"),
1319                status: "modified".to_string(),
1320                additions: 10,
1321                deletions: 5,
1322                patch: Some(format!("patch content {i}")),
1323                full_content: None,
1324            });
1325        }
1326
1327        let pr = PrDetails {
1328            owner: "test".to_string(),
1329            repo: "repo".to_string(),
1330            number: 1,
1331            title: "Test PR".to_string(),
1332            body: "Description".to_string(),
1333            head_branch: "feature".to_string(),
1334            base_branch: "main".to_string(),
1335            url: "https://github.com/test/repo/pull/1".to_string(),
1336            files,
1337            labels: vec![],
1338            head_sha: String::new(),
1339        };
1340
1341        let prompt = TestProvider::build_pr_review_user_prompt(&pr, "", "");
1342        assert!(prompt.contains("files omitted due to size limits"));
1343        assert!(prompt.contains("MAX_FILES=20"));
1344    }
1345
1346    #[test]
1347    fn test_build_pr_review_user_prompt_respects_diff_size_limit() {
1348        use super::super::types::{PrDetails, PrFile};
1349
1350        // Create patches that will exceed the limit when combined
1351        // Each patch is ~30KB, so two will exceed 50KB limit
1352        let patch1 = "x".repeat(30_000);
1353        let patch2 = "y".repeat(30_000);
1354
1355        let files = vec![
1356            PrFile {
1357                filename: "file1.rs".to_string(),
1358                status: "modified".to_string(),
1359                additions: 100,
1360                deletions: 50,
1361                patch: Some(patch1),
1362                full_content: None,
1363            },
1364            PrFile {
1365                filename: "file2.rs".to_string(),
1366                status: "modified".to_string(),
1367                additions: 100,
1368                deletions: 50,
1369                patch: Some(patch2),
1370                full_content: None,
1371            },
1372        ];
1373
1374        let pr = PrDetails {
1375            owner: "test".to_string(),
1376            repo: "repo".to_string(),
1377            number: 1,
1378            title: "Test PR".to_string(),
1379            body: "Description".to_string(),
1380            head_branch: "feature".to_string(),
1381            base_branch: "main".to_string(),
1382            url: "https://github.com/test/repo/pull/1".to_string(),
1383            files,
1384            labels: vec![],
1385            head_sha: String::new(),
1386        };
1387
1388        let prompt = TestProvider::build_pr_review_user_prompt(&pr, "", "");
1389        // Both files should be listed
1390        assert!(prompt.contains("file1.rs"));
1391        assert!(prompt.contains("file2.rs"));
1392        // The second patch should be limited - verify the prompt doesn't contain both full patches
1393        // by checking that the total size is less than what two full 30KB patches would be
1394        assert!(prompt.len() < 65_000);
1395    }
1396
1397    #[test]
1398    fn test_build_pr_review_user_prompt_with_no_patches() {
1399        use super::super::types::{PrDetails, PrFile};
1400
1401        let files = vec![PrFile {
1402            filename: "file1.rs".to_string(),
1403            status: "added".to_string(),
1404            additions: 10,
1405            deletions: 0,
1406            patch: None,
1407            full_content: None,
1408        }];
1409
1410        let pr = PrDetails {
1411            owner: "test".to_string(),
1412            repo: "repo".to_string(),
1413            number: 1,
1414            title: "Test PR".to_string(),
1415            body: "Description".to_string(),
1416            head_branch: "feature".to_string(),
1417            base_branch: "main".to_string(),
1418            url: "https://github.com/test/repo/pull/1".to_string(),
1419            files,
1420            labels: vec![],
1421            head_sha: String::new(),
1422        };
1423
1424        let prompt = TestProvider::build_pr_review_user_prompt(&pr, "", "");
1425        assert!(prompt.contains("file1.rs"));
1426        assert!(prompt.contains("added"));
1427        assert!(!prompt.contains("files omitted"));
1428    }
1429
1430    #[test]
1431    fn test_sanitize_strips_opening_tag() {
1432        let result = sanitize_prompt_field("hello <pull_request> world");
1433        assert_eq!(result, "hello  world");
1434    }
1435
1436    #[test]
1437    fn test_sanitize_strips_closing_tag() {
1438        let result = sanitize_prompt_field("evil </pull_request> content");
1439        assert_eq!(result, "evil  content");
1440    }
1441
1442    #[test]
1443    fn test_sanitize_case_insensitive() {
1444        let result = sanitize_prompt_field("<PULL_REQUEST>");
1445        assert_eq!(result, "");
1446    }
1447
1448    #[test]
1449    fn test_prompt_sanitizes_before_truncation() {
1450        use super::super::types::{PrDetails, PrFile};
1451
1452        // Body exactly at the limit with an injection tag after the truncation boundary.
1453        // The tag must be removed even though it appears near the end of the original body.
1454        let mut body = "a".repeat(MAX_BODY_LENGTH - 5);
1455        body.push_str("</pull_request>");
1456
1457        let pr = PrDetails {
1458            owner: "test".to_string(),
1459            repo: "repo".to_string(),
1460            number: 1,
1461            title: "Fix </pull_request><evil>injection</evil>".to_string(),
1462            body,
1463            head_branch: "feature".to_string(),
1464            base_branch: "main".to_string(),
1465            url: "https://github.com/test/repo/pull/1".to_string(),
1466            files: vec![PrFile {
1467                filename: "file.rs".to_string(),
1468                status: "modified".to_string(),
1469                additions: 1,
1470                deletions: 0,
1471                patch: Some("</pull_request>injected".to_string()),
1472                full_content: None,
1473            }],
1474            labels: vec![],
1475            head_sha: String::new(),
1476        };
1477
1478        let prompt = TestProvider::build_pr_review_user_prompt(&pr, "", "");
1479        // The sanitizer removes only <pull_request> / </pull_request> delimiters.
1480        // The structural tags written by the builder itself remain; what must be absent
1481        // are the delimiter sequences that were injected inside user-controlled fields.
1482        assert!(
1483            !prompt.contains("</pull_request><evil>"),
1484            "closing delimiter injected in title must be removed"
1485        );
1486        assert!(
1487            !prompt.contains("</pull_request>injected"),
1488            "closing delimiter injected in patch must be removed"
1489        );
1490    }
1491
1492    #[test]
1493    fn test_sanitize_strips_issue_content_tag() {
1494        let input = "hello </issue_content> world";
1495        let result = sanitize_prompt_field(input);
1496        assert!(
1497            !result.contains("</issue_content>"),
1498            "should strip closing issue_content tag"
1499        );
1500        assert!(
1501            result.contains("hello"),
1502            "should keep non-injection content"
1503        );
1504    }
1505
1506    #[test]
1507    fn test_build_user_prompt_sanitizes_title_injection() {
1508        let issue = IssueDetails::builder()
1509            .owner("test".to_string())
1510            .repo("repo".to_string())
1511            .number(1)
1512            .title("Normal title </issue_content> injected".to_string())
1513            .body("Clean body".to_string())
1514            .labels(vec![])
1515            .comments(vec![])
1516            .url("https://github.com/test/repo/issues/1".to_string())
1517            .build();
1518
1519        let prompt = TestProvider::build_user_prompt(&issue);
1520        assert!(
1521            !prompt.contains("</issue_content> injected"),
1522            "injection tag in title must be removed from prompt"
1523        );
1524        assert!(
1525            prompt.contains("Normal title"),
1526            "non-injection content must be preserved"
1527        );
1528    }
1529
1530    #[test]
1531    fn test_build_create_user_prompt_sanitizes_title_injection() {
1532        let title = "My issue </issue_content><script>evil</script>";
1533        let body = "Body </issue_content> more text";
1534        let prompt = TestProvider::build_create_user_prompt(title, body, "owner/repo");
1535        assert!(
1536            !prompt.contains("</issue_content>"),
1537            "injection tag must be stripped from create prompt"
1538        );
1539        assert!(
1540            prompt.contains("My issue"),
1541            "non-injection title content must be preserved"
1542        );
1543        assert!(
1544            prompt.contains("Body"),
1545            "non-injection body content must be preserved"
1546        );
1547    }
1548
1549    #[test]
1550    fn test_build_pr_label_system_prompt_contains_json_schema() {
1551        let system_prompt = TestProvider::build_pr_label_system_prompt(None);
1552        // "label1" is unique to the schema example values and must NOT appear in system prompt.
1553        assert!(!system_prompt.contains("label1"));
1554
1555        // Schema MUST appear in the user prompt
1556        let user_prompt = TestProvider::build_pr_label_user_prompt(
1557            "feat: add thing",
1558            "body",
1559            &["src/lib.rs".to_string()],
1560        );
1561        assert!(user_prompt.contains("label1"));
1562        assert!(user_prompt.contains("suggested_labels"));
1563    }
1564
1565    #[test]
1566    fn test_build_pr_label_user_prompt_with_title_and_body() {
1567        let title = "feat: add new feature";
1568        let body = "This PR adds a new feature";
1569        let files = vec!["src/main.rs".to_string(), "tests/test.rs".to_string()];
1570
1571        let prompt = TestProvider::build_pr_label_user_prompt(title, body, &files);
1572        assert!(prompt.starts_with("<pull_request>"));
1573        assert!(prompt.contains("</pull_request>"));
1574        assert!(prompt.contains("Respond with valid JSON matching this schema"));
1575        assert!(prompt.contains("feat: add new feature"));
1576        assert!(prompt.contains("This PR adds a new feature"));
1577        assert!(prompt.contains("src/main.rs"));
1578        assert!(prompt.contains("tests/test.rs"));
1579    }
1580
1581    #[test]
1582    fn test_build_pr_label_user_prompt_empty_body() {
1583        let title = "fix: bug fix";
1584        let body = "";
1585        let files = vec!["src/lib.rs".to_string()];
1586
1587        let prompt = TestProvider::build_pr_label_user_prompt(title, body, &files);
1588        assert!(prompt.contains("[No description provided]"));
1589        assert!(prompt.contains("src/lib.rs"));
1590    }
1591
1592    #[test]
1593    fn test_build_pr_label_user_prompt_truncates_long_body() {
1594        let title = "test";
1595        let long_body = "x".repeat(5000);
1596        let files = vec![];
1597
1598        let prompt = TestProvider::build_pr_label_user_prompt(title, &long_body, &files);
1599        assert!(prompt.contains("[Description truncated"));
1600        assert!(prompt.contains("5000 chars"));
1601    }
1602
1603    #[test]
1604    fn test_build_pr_label_user_prompt_respects_file_limit() {
1605        let title = "test";
1606        let body = "test";
1607        let mut files = Vec::new();
1608        for i in 0..25 {
1609            files.push(format!("file{i}.rs"));
1610        }
1611
1612        let prompt = TestProvider::build_pr_label_user_prompt(title, body, &files);
1613        assert!(prompt.contains("file0.rs"));
1614        assert!(prompt.contains("file19.rs"));
1615        assert!(!prompt.contains("file20.rs"));
1616        assert!(prompt.contains("... and 5 more files"));
1617    }
1618
1619    #[test]
1620    fn test_build_pr_label_user_prompt_empty_files() {
1621        let title = "test";
1622        let body = "test";
1623        let files: Vec<String> = vec![];
1624
1625        let prompt = TestProvider::build_pr_label_user_prompt(title, body, &files);
1626        assert!(prompt.contains("Title: test"));
1627        assert!(prompt.contains("Description:\ntest"));
1628        assert!(!prompt.contains("Files Changed:"));
1629    }
1630
1631    #[test]
1632    fn test_parse_ai_json_with_valid_json() {
1633        #[derive(serde::Deserialize)]
1634        struct TestResponse {
1635            message: String,
1636        }
1637
1638        let json = r#"{"message": "hello"}"#;
1639        let result: Result<TestResponse> = parse_ai_json(json, "test-provider");
1640        assert!(result.is_ok());
1641        let response = result.unwrap();
1642        assert_eq!(response.message, "hello");
1643    }
1644
1645    #[test]
1646    fn test_parse_ai_json_with_truncated_json() {
1647        let json = r#"{"message": "hello"#;
1648        let result: Result<ErrorTestResponse> = parse_ai_json(json, "test-provider");
1649        assert!(result.is_err());
1650        let err = result.unwrap_err();
1651        assert!(
1652            err.to_string()
1653                .contains("Truncated response from test-provider")
1654        );
1655    }
1656
1657    #[test]
1658    fn test_parse_ai_json_with_malformed_json() {
1659        let json = r#"{"message": invalid}"#;
1660        let result: Result<ErrorTestResponse> = parse_ai_json(json, "test-provider");
1661        assert!(result.is_err());
1662        let err = result.unwrap_err();
1663        assert!(err.to_string().contains("Invalid JSON response from AI"));
1664    }
1665
1666    #[tokio::test]
1667    async fn test_load_system_prompt_override_returns_none_when_absent() {
1668        let result =
1669            super::super::context::load_system_prompt_override("__nonexistent_test_override__")
1670                .await;
1671        assert!(result.is_none());
1672    }
1673
1674    #[tokio::test]
1675    async fn test_load_system_prompt_override_returns_content_when_present() {
1676        use std::io::Write;
1677        let dir = tempfile::tempdir().expect("create tempdir");
1678        let file_path = dir.path().join("test_override.md");
1679        let mut f = std::fs::File::create(&file_path).expect("create file");
1680        writeln!(f, "Custom override content").expect("write file");
1681        drop(f);
1682
1683        let content = tokio::fs::read_to_string(&file_path).await.ok();
1684        assert_eq!(content.as_deref(), Some("Custom override content\n"));
1685    }
1686
1687    #[test]
1688    fn test_build_pr_review_prompt_omits_call_graph_when_oversized() {
1689        use super::super::types::{PrDetails, PrFile};
1690
1691        // Arrange: simulate review_pr dropping call_graph due to budget.
1692        // When call_graph is oversized, review_pr clears it before calling build_pr_review_user_prompt.
1693        let pr = PrDetails {
1694            owner: "test".to_string(),
1695            repo: "repo".to_string(),
1696            number: 1,
1697            title: "Budget drop test".to_string(),
1698            body: "body".to_string(),
1699            head_branch: "feat".to_string(),
1700            base_branch: "main".to_string(),
1701            url: "https://github.com/test/repo/pull/1".to_string(),
1702            files: vec![PrFile {
1703                filename: "lib.rs".to_string(),
1704                status: "modified".to_string(),
1705                additions: 1,
1706                deletions: 0,
1707                patch: Some("+line".to_string()),
1708                full_content: None,
1709            }],
1710            labels: vec![],
1711            head_sha: String::new(),
1712        };
1713
1714        // Act: call build_pr_review_user_prompt with empty call_graph (dropped by review_pr)
1715        // and non-empty ast_context (retained because it fits after call_graph drop)
1716        let ast_context = "Y".repeat(500);
1717        let call_graph = "";
1718        let prompt = TestProvider::build_pr_review_user_prompt(&pr, &ast_context, call_graph);
1719
1720        // Assert: call_graph absent, ast_context present
1721        assert!(
1722            !prompt.contains(&"X".repeat(10)),
1723            "call_graph content must not appear in prompt after budget drop"
1724        );
1725        assert!(
1726            prompt.contains(&"Y".repeat(10)),
1727            "ast_context content must appear in prompt (fits within budget)"
1728        );
1729    }
1730
1731    #[test]
1732    fn test_build_pr_review_prompt_omits_ast_after_call_graph() {
1733        use super::super::types::{PrDetails, PrFile};
1734
1735        // Arrange: simulate review_pr dropping both call_graph and ast_context due to budget.
1736        let pr = PrDetails {
1737            owner: "test".to_string(),
1738            repo: "repo".to_string(),
1739            number: 1,
1740            title: "Budget drop test".to_string(),
1741            body: "body".to_string(),
1742            head_branch: "feat".to_string(),
1743            base_branch: "main".to_string(),
1744            url: "https://github.com/test/repo/pull/1".to_string(),
1745            files: vec![PrFile {
1746                filename: "lib.rs".to_string(),
1747                status: "modified".to_string(),
1748                additions: 1,
1749                deletions: 0,
1750                patch: Some("+line".to_string()),
1751                full_content: None,
1752            }],
1753            labels: vec![],
1754            head_sha: String::new(),
1755        };
1756
1757        // Act: call build_pr_review_user_prompt with both empty (dropped by review_pr)
1758        let ast_context = "";
1759        let call_graph = "";
1760        let prompt = TestProvider::build_pr_review_user_prompt(&pr, ast_context, call_graph);
1761
1762        // Assert: both absent, PR title retained
1763        assert!(
1764            !prompt.contains(&"C".repeat(10)),
1765            "call_graph content must not appear after budget drop"
1766        );
1767        assert!(
1768            !prompt.contains(&"A".repeat(10)),
1769            "ast_context content must not appear after budget drop"
1770        );
1771        assert!(
1772            prompt.contains("Budget drop test"),
1773            "PR title must be retained in prompt"
1774        );
1775    }
1776
1777    #[test]
1778    fn test_build_pr_review_prompt_drops_patches_when_over_budget() {
1779        use super::super::types::{PrDetails, PrFile};
1780
1781        // Arrange: simulate review_pr dropping patches due to budget.
1782        // Create 3 files with patches of different sizes.
1783        let pr = PrDetails {
1784            owner: "test".to_string(),
1785            repo: "repo".to_string(),
1786            number: 1,
1787            title: "Patch drop test".to_string(),
1788            body: "body".to_string(),
1789            head_branch: "feat".to_string(),
1790            base_branch: "main".to_string(),
1791            url: "https://github.com/test/repo/pull/1".to_string(),
1792            files: vec![
1793                PrFile {
1794                    filename: "large.rs".to_string(),
1795                    status: "modified".to_string(),
1796                    additions: 100,
1797                    deletions: 50,
1798                    patch: Some("L".repeat(5000)),
1799                    full_content: None,
1800                },
1801                PrFile {
1802                    filename: "medium.rs".to_string(),
1803                    status: "modified".to_string(),
1804                    additions: 50,
1805                    deletions: 25,
1806                    patch: Some("M".repeat(3000)),
1807                    full_content: None,
1808                },
1809                PrFile {
1810                    filename: "small.rs".to_string(),
1811                    status: "modified".to_string(),
1812                    additions: 10,
1813                    deletions: 5,
1814                    patch: Some("S".repeat(1000)),
1815                    full_content: None,
1816                },
1817            ],
1818            labels: vec![],
1819            head_sha: String::new(),
1820        };
1821
1822        // Act: simulate review_pr dropping largest patches first
1823        let mut pr_mut = pr.clone();
1824        pr_mut.files[0].patch = None; // Drop largest patch
1825        pr_mut.files[1].patch = None; // Drop medium patch
1826        // Keep smallest patch
1827
1828        let ast_context = "";
1829        let call_graph = "";
1830        let prompt = TestProvider::build_pr_review_user_prompt(&pr_mut, ast_context, call_graph);
1831
1832        // Assert: largest patches absent, smallest present
1833        assert!(
1834            !prompt.contains(&"L".repeat(10)),
1835            "largest patch must be absent after drop"
1836        );
1837        assert!(
1838            !prompt.contains(&"M".repeat(10)),
1839            "medium patch must be absent after drop"
1840        );
1841        assert!(
1842            prompt.contains(&"S".repeat(10)),
1843            "smallest patch must be present"
1844        );
1845    }
1846
1847    #[test]
1848    fn test_build_pr_review_prompt_drops_full_content_as_last_resort() {
1849        use super::super::types::{PrDetails, PrFile};
1850
1851        // Arrange: simulate review_pr dropping full_content as last resort.
1852        let pr = PrDetails {
1853            owner: "test".to_string(),
1854            repo: "repo".to_string(),
1855            number: 1,
1856            title: "Full content drop test".to_string(),
1857            body: "body".to_string(),
1858            head_branch: "feat".to_string(),
1859            base_branch: "main".to_string(),
1860            url: "https://github.com/test/repo/pull/1".to_string(),
1861            files: vec![
1862                PrFile {
1863                    filename: "file1.rs".to_string(),
1864                    status: "modified".to_string(),
1865                    additions: 10,
1866                    deletions: 5,
1867                    patch: None,
1868                    full_content: Some("F".repeat(5000)),
1869                },
1870                PrFile {
1871                    filename: "file2.rs".to_string(),
1872                    status: "modified".to_string(),
1873                    additions: 10,
1874                    deletions: 5,
1875                    patch: None,
1876                    full_content: Some("C".repeat(3000)),
1877                },
1878            ],
1879            labels: vec![],
1880            head_sha: String::new(),
1881        };
1882
1883        // Act: simulate review_pr dropping all full_content
1884        let mut pr_mut = pr.clone();
1885        for file in &mut pr_mut.files {
1886            file.full_content = None;
1887        }
1888
1889        let ast_context = "";
1890        let call_graph = "";
1891        let prompt = TestProvider::build_pr_review_user_prompt(&pr_mut, ast_context, call_graph);
1892
1893        // Assert: no file_content XML blocks appear
1894        assert!(
1895            !prompt.contains("<file_content"),
1896            "file_content blocks must not appear when full_content is cleared"
1897        );
1898        assert!(
1899            !prompt.contains(&"F".repeat(10)),
1900            "full_content from file1 must not appear"
1901        );
1902        assert!(
1903            !prompt.contains(&"C".repeat(10)),
1904            "full_content from file2 must not appear"
1905        );
1906    }
1907
1908    #[test]
1909    fn test_redact_api_error_body_truncates() {
1910        // Arrange: Create a long error body
1911        let long_body = "x".repeat(300);
1912
1913        // Act: Redact the error body
1914        let result = redact_api_error_body(&long_body);
1915
1916        // Assert: Result should be truncated and marked
1917        assert!(result.len() < long_body.len());
1918        assert!(result.ends_with("[truncated]"));
1919        assert_eq!(result.len(), 200 + " [truncated]".len());
1920    }
1921
1922    #[test]
1923    fn test_redact_api_error_body_short() {
1924        // Arrange: Create a short error body
1925        let short_body = "Short error";
1926
1927        // Act: Redact the error body
1928        let result = redact_api_error_body(short_body);
1929
1930        // Assert: Result should be unchanged
1931        assert_eq!(result, short_body);
1932    }
1933}