Skip to main content

aptu_core/ai/
provider.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! AI provider trait and shared implementations.
4//!
5//! Defines the `AiProvider` trait that all AI providers must implement,
6//! along with default implementations for shared logic like prompt building,
7//! request sending, and response parsing.
8
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use reqwest::Client;
12use secrecy::SecretString;
13use tracing::{debug, instrument};
14
15use super::AiResponse;
16use super::types::{
17    ChatCompletionRequest, ChatCompletionResponse, ChatMessage, IssueDetails, ResponseFormat,
18    TriageResponse,
19};
20use crate::history::AiStats;
21
22use super::prompts::{
23    build_create_system_prompt, build_pr_label_system_prompt, build_pr_review_system_prompt,
24    build_release_notes_system_prompt, build_triage_system_prompt,
25};
26
27/// Parses JSON response from AI provider, detecting truncated responses.
28///
29/// If the JSON parsing fails with an EOF error (indicating the response was cut off),
30/// returns a `TruncatedResponse` error that can be retried. Other JSON errors are
31/// wrapped as `InvalidAIResponse`.
32///
33/// # Arguments
34///
35/// * `text` - The JSON text to parse
36/// * `provider` - The name of the AI provider (for error context)
37///
38/// # Returns
39///
40/// Parsed value of type T, or an error if parsing fails
41fn parse_ai_json<T: serde::de::DeserializeOwned>(text: &str, provider: &str) -> Result<T> {
42    match serde_json::from_str::<T>(text) {
43        Ok(value) => Ok(value),
44        Err(e) => {
45            // Check if this is an EOF error (truncated response)
46            if e.is_eof() {
47                Err(anyhow::anyhow!(
48                    crate::error::AptuError::TruncatedResponse {
49                        provider: provider.to_string(),
50                    }
51                ))
52            } else {
53                Err(anyhow::anyhow!(crate::error::AptuError::InvalidAIResponse(
54                    e
55                )))
56            }
57        }
58    }
59}
60
61/// Maximum length for issue body to stay within token limits.
62pub const MAX_BODY_LENGTH: usize = 4000;
63
64/// Maximum number of comments to include in the prompt.
65pub const MAX_COMMENTS: usize = 5;
66
67/// Maximum number of files to include in PR review prompt.
68pub const MAX_FILES: usize = 20;
69
70/// Maximum total diff size (in characters) for PR review prompt.
71pub const MAX_TOTAL_DIFF_SIZE: usize = 50_000;
72
73/// Maximum number of labels to include in the prompt.
74pub const MAX_LABELS: usize = 30;
75
76/// Maximum number of milestones to include in the prompt.
77pub const MAX_MILESTONES: usize = 10;
78
79/// AI provider trait for issue triage and creation.
80///
81/// Defines the interface that all AI providers must implement.
82/// Default implementations are provided for shared logic.
83#[async_trait]
84pub trait AiProvider: Send + Sync {
85    /// Returns the name of the provider (e.g., "gemini", "openrouter").
86    fn name(&self) -> &str;
87
88    /// Returns the API URL for this provider.
89    fn api_url(&self) -> &str;
90
91    /// Returns the environment variable name for the API key.
92    fn api_key_env(&self) -> &str;
93
94    /// Returns the HTTP client for making requests.
95    fn http_client(&self) -> &Client;
96
97    /// Returns the API key for authentication.
98    fn api_key(&self) -> &SecretString;
99
100    /// Returns the model name.
101    fn model(&self) -> &str;
102
103    /// Returns the maximum tokens for API responses.
104    fn max_tokens(&self) -> u32;
105
106    /// Returns the temperature for API requests.
107    fn temperature(&self) -> f32;
108
109    /// Returns the maximum retry attempts for rate-limited requests.
110    ///
111    /// Default implementation returns 3. Providers can override
112    /// to use a different retry limit.
113    fn max_attempts(&self) -> u32 {
114        3
115    }
116
117    /// Returns the circuit breaker for this provider (optional).
118    ///
119    /// Default implementation returns None. Providers can override
120    /// to provide circuit breaker functionality.
121    fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
122        None
123    }
124
125    /// Builds HTTP headers for API requests.
126    ///
127    /// Default implementation includes Authorization and Content-Type headers.
128    /// Providers can override to add custom headers.
129    fn build_headers(&self) -> reqwest::header::HeaderMap {
130        let mut headers = reqwest::header::HeaderMap::new();
131        if let Ok(val) = "application/json".parse() {
132            headers.insert("Content-Type", val);
133        }
134        headers
135    }
136
137    /// Validates the model configuration.
138    ///
139    /// Default implementation does nothing. Providers can override
140    /// to enforce constraints (e.g., free tier validation).
141    fn validate_model(&self) -> Result<()> {
142        Ok(())
143    }
144
145    /// Returns the custom guidance string for system prompt injection, if set.
146    ///
147    /// Default implementation returns `None`. Providers that store custom guidance
148    /// (e.g., from `AiConfig`) override this to supply it.
149    fn custom_guidance(&self) -> Option<&str> {
150        None
151    }
152
153    /// Sends a chat completion request to the provider's API (HTTP-only, no retry).
154    ///
155    /// Default implementation handles HTTP headers, error responses (401, 429).
156    /// Does not include retry logic - use `send_and_parse()` for retry behavior.
157    #[instrument(skip(self, request), fields(provider = self.name(), model = self.model()))]
158    async fn send_request_inner(
159        &self,
160        request: &ChatCompletionRequest,
161    ) -> Result<ChatCompletionResponse> {
162        use secrecy::ExposeSecret;
163        use tracing::warn;
164
165        use crate::error::AptuError;
166
167        let mut req = self.http_client().post(self.api_url());
168
169        // Add Authorization header
170        req = req.header(
171            "Authorization",
172            format!("Bearer {}", self.api_key().expose_secret()),
173        );
174
175        // Add custom headers from provider
176        for (key, value) in &self.build_headers() {
177            req = req.header(key.clone(), value.clone());
178        }
179
180        let response = req
181            .json(request)
182            .send()
183            .await
184            .context(format!("Failed to send request to {} API", self.name()))?;
185
186        // Check for HTTP errors
187        let status = response.status();
188        if !status.is_success() {
189            if status.as_u16() == 401 {
190                anyhow::bail!(
191                    "Invalid {} API key. Check your {} environment variable.",
192                    self.name(),
193                    self.api_key_env()
194                );
195            } else if status.as_u16() == 429 {
196                warn!("Rate limited by {} API", self.name());
197                // Parse Retry-After header (seconds), default to 0 if not present
198                let retry_after = response
199                    .headers()
200                    .get("Retry-After")
201                    .and_then(|h| h.to_str().ok())
202                    .and_then(|s| s.parse::<u64>().ok())
203                    .unwrap_or(0);
204                debug!(retry_after, "Parsed Retry-After header");
205                return Err(AptuError::RateLimited {
206                    provider: self.name().to_string(),
207                    retry_after,
208                }
209                .into());
210            }
211            let error_body = response.text().await.unwrap_or_default();
212            anyhow::bail!(
213                "{} API error (HTTP {}): {}",
214                self.name(),
215                status.as_u16(),
216                error_body
217            );
218        }
219
220        // Parse response
221        let completion: ChatCompletionResponse = response
222            .json()
223            .await
224            .context(format!("Failed to parse {} API response", self.name()))?;
225
226        Ok(completion)
227    }
228
229    /// Sends a chat completion request and parses the response with retry logic.
230    ///
231    /// This method wraps both HTTP request and JSON parsing in a single retry loop,
232    /// allowing truncated responses to be retried. Includes circuit breaker handling.
233    ///
234    /// # Arguments
235    ///
236    /// * `request` - The chat completion request to send
237    ///
238    /// # Returns
239    ///
240    /// A tuple of (parsed response, stats) extracted from the API response
241    ///
242    /// # Errors
243    ///
244    /// Returns an error if:
245    /// - API request fails (network, timeout, rate limit)
246    /// - Response cannot be parsed as valid JSON (including truncated responses)
247    #[instrument(skip(self, request), fields(provider = self.name(), model = self.model()))]
248    async fn send_and_parse<T: serde::de::DeserializeOwned + Send>(
249        &self,
250        request: &ChatCompletionRequest,
251    ) -> Result<(T, AiStats)> {
252        use tracing::{info, warn};
253
254        use crate::error::AptuError;
255        use crate::retry::{extract_retry_after, is_retryable_anyhow};
256
257        // Check circuit breaker before attempting request
258        if let Some(cb) = self.circuit_breaker()
259            && cb.is_open()
260        {
261            return Err(AptuError::CircuitOpen.into());
262        }
263
264        // Start timing (outside retry loop to measure total time including retries)
265        let start = std::time::Instant::now();
266
267        // Custom retry loop that respects retry_after from RateLimited errors
268        let mut attempt: u32 = 0;
269        let max_attempts: u32 = self.max_attempts();
270
271        // Helper function to avoid closure-in-expression clippy warning
272        #[allow(clippy::items_after_statements)]
273        async fn try_request<T: serde::de::DeserializeOwned>(
274            provider: &(impl AiProvider + ?Sized),
275            request: &ChatCompletionRequest,
276        ) -> Result<(T, ChatCompletionResponse)> {
277            // Send HTTP request
278            let completion = provider.send_request_inner(request).await?;
279
280            // Extract message content
281            let content = completion
282                .choices
283                .first()
284                .map(|c| c.message.content.clone())
285                .context("No response from AI model")?;
286
287            debug!(response_length = content.len(), "Received AI response");
288
289            // Parse JSON response (inside retry loop, so truncated responses are retried)
290            let parsed: T = parse_ai_json(&content, provider.name())?;
291
292            Ok((parsed, completion))
293        }
294
295        let (parsed, completion): (T, ChatCompletionResponse) = loop {
296            attempt += 1;
297
298            let result = try_request(self, request).await;
299
300            match result {
301                Ok(success) => break success,
302                Err(err) => {
303                    // Check if error is retryable
304                    if !is_retryable_anyhow(&err) || attempt >= max_attempts {
305                        return Err(err);
306                    }
307
308                    // Extract retry_after if present, otherwise use exponential backoff
309                    let delay = if let Some(retry_after_duration) = extract_retry_after(&err) {
310                        debug!(
311                            retry_after_secs = retry_after_duration.as_secs(),
312                            "Using Retry-After value from rate limit error"
313                        );
314                        retry_after_duration
315                    } else {
316                        // Use exponential backoff with jitter: 1s, 2s, 4s + 0-500ms
317                        let backoff_secs = 2_u64.pow(attempt.saturating_sub(1));
318                        let jitter_ms = fastrand::u64(0..500);
319                        std::time::Duration::from_millis(backoff_secs * 1000 + jitter_ms)
320                    };
321
322                    let error_msg = err.to_string();
323                    warn!(
324                        error = %error_msg,
325                        delay_secs = delay.as_secs(),
326                        attempt,
327                        max_attempts,
328                        "Retrying after error"
329                    );
330
331                    // Drop err before await to avoid holding non-Send value across await
332                    drop(err);
333                    tokio::time::sleep(delay).await;
334                }
335            }
336        };
337
338        // Record success in circuit breaker
339        if let Some(cb) = self.circuit_breaker() {
340            cb.record_success();
341        }
342
343        // Calculate duration (total time including any retries)
344        #[allow(clippy::cast_possible_truncation)]
345        let duration_ms = start.elapsed().as_millis() as u64;
346
347        // Build AI stats from usage info (trust API's cost field)
348        let (input_tokens, output_tokens, cost_usd) = if let Some(usage) = completion.usage {
349            (usage.prompt_tokens, usage.completion_tokens, usage.cost)
350        } else {
351            // If no usage info, default to 0
352            debug!("No usage information in API response");
353            (0, 0, None)
354        };
355
356        let ai_stats = AiStats {
357            provider: self.name().to_string(),
358            model: self.model().to_string(),
359            input_tokens,
360            output_tokens,
361            duration_ms,
362            cost_usd,
363            fallback_provider: None,
364        };
365
366        // Emit structured metrics
367        info!(
368            duration_ms,
369            input_tokens,
370            output_tokens,
371            cost_usd = ?cost_usd,
372            model = %self.model(),
373            "AI request completed"
374        );
375
376        Ok((parsed, ai_stats))
377    }
378
379    /// Analyzes a GitHub issue using the provider's API.
380    ///
381    /// Returns a structured triage response with summary, labels, questions, duplicates, and usage stats.
382    ///
383    /// # Arguments
384    ///
385    /// * `issue` - Issue details to analyze
386    ///
387    /// # Errors
388    ///
389    /// Returns an error if:
390    /// - API request fails (network, timeout, rate limit)
391    /// - Response cannot be parsed as valid JSON
392    #[instrument(skip(self, issue), fields(issue_number = issue.number, repo = %format!("{}/{}", issue.owner, issue.repo)))]
393    async fn analyze_issue(&self, issue: &IssueDetails) -> Result<AiResponse> {
394        debug!(model = %self.model(), "Calling {} API", self.name());
395
396        // Build request
397        let system_content = if let Some(override_prompt) =
398            super::context::load_system_prompt_override("triage_system").await
399        {
400            override_prompt
401        } else {
402            Self::build_system_prompt(self.custom_guidance())
403        };
404
405        let request = ChatCompletionRequest {
406            model: self.model().to_string(),
407            messages: vec![
408                ChatMessage {
409                    role: "system".to_string(),
410                    content: system_content,
411                },
412                ChatMessage {
413                    role: "user".to_string(),
414                    content: Self::build_user_prompt(issue),
415                },
416            ],
417            response_format: Some(ResponseFormat {
418                format_type: "json_object".to_string(),
419                json_schema: None,
420            }),
421            max_tokens: Some(self.max_tokens()),
422            temperature: Some(self.temperature()),
423        };
424
425        // Send request and parse JSON with retry logic
426        let (triage, ai_stats) = self.send_and_parse::<TriageResponse>(&request).await?;
427
428        debug!(
429            input_tokens = ai_stats.input_tokens,
430            output_tokens = ai_stats.output_tokens,
431            duration_ms = ai_stats.duration_ms,
432            cost_usd = ?ai_stats.cost_usd,
433            "AI analysis complete"
434        );
435
436        Ok(AiResponse {
437            triage,
438            stats: ai_stats,
439        })
440    }
441
442    /// Creates a formatted GitHub issue using the provider's API.
443    ///
444    /// Takes raw issue title and body, formats them using AI (conventional commit style,
445    /// structured body), and returns the formatted content with suggested labels.
446    ///
447    /// # Arguments
448    ///
449    /// * `title` - Raw issue title from user
450    /// * `body` - Raw issue body/description from user
451    /// * `repo` - Repository name for context (owner/repo format)
452    ///
453    /// # Errors
454    ///
455    /// Returns an error if:
456    /// - API request fails (network, timeout, rate limit)
457    /// - Response cannot be parsed as valid JSON
458    #[instrument(skip(self), fields(repo = %repo))]
459    async fn create_issue(
460        &self,
461        title: &str,
462        body: &str,
463        repo: &str,
464    ) -> Result<(super::types::CreateIssueResponse, AiStats)> {
465        debug!(model = %self.model(), "Calling {} API for issue creation", self.name());
466
467        // Build request
468        let system_content = if let Some(override_prompt) =
469            super::context::load_system_prompt_override("create_system").await
470        {
471            override_prompt
472        } else {
473            Self::build_create_system_prompt(self.custom_guidance())
474        };
475
476        let request = ChatCompletionRequest {
477            model: self.model().to_string(),
478            messages: vec![
479                ChatMessage {
480                    role: "system".to_string(),
481                    content: system_content,
482                },
483                ChatMessage {
484                    role: "user".to_string(),
485                    content: Self::build_create_user_prompt(title, body, repo),
486                },
487            ],
488            response_format: Some(ResponseFormat {
489                format_type: "json_object".to_string(),
490                json_schema: None,
491            }),
492            max_tokens: Some(self.max_tokens()),
493            temperature: Some(self.temperature()),
494        };
495
496        // Send request and parse JSON with retry logic
497        let (create_response, ai_stats) = self
498            .send_and_parse::<super::types::CreateIssueResponse>(&request)
499            .await?;
500
501        debug!(
502            title_len = create_response.formatted_title.len(),
503            body_len = create_response.formatted_body.len(),
504            labels = create_response.suggested_labels.len(),
505            input_tokens = ai_stats.input_tokens,
506            output_tokens = ai_stats.output_tokens,
507            duration_ms = ai_stats.duration_ms,
508            "Issue formatting complete with stats"
509        );
510
511        Ok((create_response, ai_stats))
512    }
513
514    /// Builds the system prompt for issue triage.
515    #[must_use]
516    fn build_system_prompt(custom_guidance: Option<&str>) -> String {
517        let context = super::context::load_custom_guidance(custom_guidance);
518        build_triage_system_prompt(&context)
519    }
520
521    /// Builds the user prompt containing the issue details.
522    #[must_use]
523    fn build_user_prompt(issue: &IssueDetails) -> String {
524        use std::fmt::Write;
525
526        let mut prompt = String::new();
527
528        prompt.push_str("<issue_content>\n");
529        let _ = writeln!(prompt, "Title: {}\n", issue.title);
530
531        // Truncate body if too long
532        let body = if issue.body.len() > MAX_BODY_LENGTH {
533            format!(
534                "{}...\n[Body truncated - original length: {} chars]",
535                &issue.body[..MAX_BODY_LENGTH],
536                issue.body.len()
537            )
538        } else if issue.body.is_empty() {
539            "[No description provided]".to_string()
540        } else {
541            issue.body.clone()
542        };
543        let _ = writeln!(prompt, "Body:\n{body}\n");
544
545        // Include existing labels
546        if !issue.labels.is_empty() {
547            let _ = writeln!(prompt, "Existing Labels: {}\n", issue.labels.join(", "));
548        }
549
550        // Include recent comments (limited)
551        if !issue.comments.is_empty() {
552            prompt.push_str("Recent Comments:\n");
553            for comment in issue.comments.iter().take(MAX_COMMENTS) {
554                let comment_body = if comment.body.len() > 500 {
555                    format!("{}...", &comment.body[..500])
556                } else {
557                    comment.body.clone()
558                };
559                let _ = writeln!(prompt, "- @{}: {}", comment.author, comment_body);
560            }
561            prompt.push('\n');
562        }
563
564        // Include related issues from search (for context)
565        if !issue.repo_context.is_empty() {
566            prompt.push_str("Related Issues in Repository (for context):\n");
567            for related in issue.repo_context.iter().take(10) {
568                let _ = writeln!(
569                    prompt,
570                    "- #{} [{}] {}",
571                    related.number, related.state, related.title
572                );
573            }
574            prompt.push('\n');
575        }
576
577        // Include repository structure (source files)
578        if !issue.repo_tree.is_empty() {
579            prompt.push_str("Repository Structure (source files):\n");
580            for path in issue.repo_tree.iter().take(20) {
581                let _ = writeln!(prompt, "- {path}");
582            }
583            prompt.push('\n');
584        }
585
586        // Include available labels
587        if !issue.available_labels.is_empty() {
588            prompt.push_str("Available Labels:\n");
589            for label in issue.available_labels.iter().take(MAX_LABELS) {
590                let description = if label.description.is_empty() {
591                    String::new()
592                } else {
593                    format!(" - {}", label.description)
594                };
595                let _ = writeln!(
596                    prompt,
597                    "- {} (color: #{}){}",
598                    label.name, label.color, description
599                );
600            }
601            prompt.push('\n');
602        }
603
604        // Include available milestones
605        if !issue.available_milestones.is_empty() {
606            prompt.push_str("Available Milestones:\n");
607            for milestone in issue.available_milestones.iter().take(MAX_MILESTONES) {
608                let description = if milestone.description.is_empty() {
609                    String::new()
610                } else {
611                    format!(" - {}", milestone.description)
612                };
613                let _ = writeln!(prompt, "- {}{}", milestone.title, description);
614            }
615            prompt.push('\n');
616        }
617
618        prompt.push_str("</issue_content>");
619
620        prompt
621    }
622
623    /// Builds the system prompt for issue creation/formatting.
624    #[must_use]
625    fn build_create_system_prompt(custom_guidance: Option<&str>) -> String {
626        let context = super::context::load_custom_guidance(custom_guidance);
627        build_create_system_prompt(&context)
628    }
629
630    /// Builds the user prompt for issue creation/formatting.
631    #[must_use]
632    fn build_create_user_prompt(title: &str, body: &str, _repo: &str) -> String {
633        format!("Please format this GitHub issue:\n\nTitle: {title}\n\nBody:\n{body}")
634    }
635
636    /// Reviews a pull request using the provider's API.
637    ///
638    /// Analyzes PR metadata and file diffs to provide structured review feedback.
639    ///
640    /// # Arguments
641    ///
642    /// * `pr` - Pull request details including files and diffs
643    ///
644    /// # Errors
645    ///
646    /// Returns an error if:
647    /// - API request fails (network, timeout, rate limit)
648    /// - Response cannot be parsed as valid JSON
649    #[instrument(skip(self, pr), fields(pr_number = pr.number, repo = %format!("{}/{}", pr.owner, pr.repo)))]
650    async fn review_pr(
651        &self,
652        pr: &super::types::PrDetails,
653    ) -> Result<(super::types::PrReviewResponse, AiStats)> {
654        debug!(model = %self.model(), "Calling {} API for PR review", self.name());
655
656        // Build request
657        let system_content = if let Some(override_prompt) =
658            super::context::load_system_prompt_override("pr_review_system").await
659        {
660            override_prompt
661        } else {
662            Self::build_pr_review_system_prompt(self.custom_guidance())
663        };
664
665        let request = ChatCompletionRequest {
666            model: self.model().to_string(),
667            messages: vec![
668                ChatMessage {
669                    role: "system".to_string(),
670                    content: system_content,
671                },
672                ChatMessage {
673                    role: "user".to_string(),
674                    content: Self::build_pr_review_user_prompt(pr),
675                },
676            ],
677            response_format: Some(ResponseFormat {
678                format_type: "json_object".to_string(),
679                json_schema: None,
680            }),
681            max_tokens: Some(self.max_tokens()),
682            temperature: Some(self.temperature()),
683        };
684
685        // Send request and parse JSON with retry logic
686        let (review, ai_stats) = self
687            .send_and_parse::<super::types::PrReviewResponse>(&request)
688            .await?;
689
690        debug!(
691            verdict = %review.verdict,
692            input_tokens = ai_stats.input_tokens,
693            output_tokens = ai_stats.output_tokens,
694            duration_ms = ai_stats.duration_ms,
695            "PR review complete with stats"
696        );
697
698        Ok((review, ai_stats))
699    }
700
701    /// Suggests labels for a pull request using the provider's API.
702    ///
703    /// Analyzes PR title, body, and file paths to suggest relevant labels.
704    ///
705    /// # Arguments
706    ///
707    /// * `title` - Pull request title
708    /// * `body` - Pull request description
709    /// * `file_paths` - List of file paths changed in the PR
710    ///
711    /// # Errors
712    ///
713    /// Returns an error if:
714    /// - API request fails (network, timeout, rate limit)
715    /// - Response cannot be parsed as valid JSON
716    #[instrument(skip(self), fields(title = %title))]
717    async fn suggest_pr_labels(
718        &self,
719        title: &str,
720        body: &str,
721        file_paths: &[String],
722    ) -> Result<(Vec<String>, AiStats)> {
723        debug!(model = %self.model(), "Calling {} API for PR label suggestion", self.name());
724
725        // Build request
726        let system_content = if let Some(override_prompt) =
727            super::context::load_system_prompt_override("pr_label_system").await
728        {
729            override_prompt
730        } else {
731            Self::build_pr_label_system_prompt(self.custom_guidance())
732        };
733
734        let request = ChatCompletionRequest {
735            model: self.model().to_string(),
736            messages: vec![
737                ChatMessage {
738                    role: "system".to_string(),
739                    content: system_content,
740                },
741                ChatMessage {
742                    role: "user".to_string(),
743                    content: Self::build_pr_label_user_prompt(title, body, file_paths),
744                },
745            ],
746            response_format: Some(ResponseFormat {
747                format_type: "json_object".to_string(),
748                json_schema: None,
749            }),
750            max_tokens: Some(self.max_tokens()),
751            temperature: Some(self.temperature()),
752        };
753
754        // Send request and parse JSON with retry logic
755        let (response, ai_stats) = self
756            .send_and_parse::<super::types::PrLabelResponse>(&request)
757            .await?;
758
759        debug!(
760            label_count = response.suggested_labels.len(),
761            input_tokens = ai_stats.input_tokens,
762            output_tokens = ai_stats.output_tokens,
763            duration_ms = ai_stats.duration_ms,
764            "PR label suggestion complete with stats"
765        );
766
767        Ok((response.suggested_labels, ai_stats))
768    }
769
770    /// Builds the system prompt for PR review.
771    #[must_use]
772    fn build_pr_review_system_prompt(custom_guidance: Option<&str>) -> String {
773        let context = super::context::load_custom_guidance(custom_guidance);
774        build_pr_review_system_prompt(&context)
775    }
776
777    /// Builds the user prompt for PR review.
778    #[must_use]
779    fn build_pr_review_user_prompt(pr: &super::types::PrDetails) -> String {
780        use std::fmt::Write;
781
782        let mut prompt = String::new();
783
784        prompt.push_str("<pull_request>\n");
785        let _ = writeln!(prompt, "Title: {}\n", pr.title);
786        let _ = writeln!(prompt, "Branch: {} -> {}\n", pr.head_branch, pr.base_branch);
787
788        // PR description
789        let body = if pr.body.is_empty() {
790            "[No description provided]".to_string()
791        } else if pr.body.len() > MAX_BODY_LENGTH {
792            format!(
793                "{}...\n[Description truncated - original length: {} chars]",
794                &pr.body[..MAX_BODY_LENGTH],
795                pr.body.len()
796            )
797        } else {
798            pr.body.clone()
799        };
800        let _ = writeln!(prompt, "Description:\n{body}\n");
801
802        // File changes with limits
803        prompt.push_str("Files Changed:\n");
804        let mut total_diff_size = 0;
805        let mut files_included = 0;
806        let mut files_skipped = 0;
807
808        for file in &pr.files {
809            // Check file count limit
810            if files_included >= MAX_FILES {
811                files_skipped += 1;
812                continue;
813            }
814
815            let _ = writeln!(
816                prompt,
817                "- {} ({}) +{} -{}\n",
818                file.filename, file.status, file.additions, file.deletions
819            );
820
821            // Include patch if available (truncate large patches)
822            if let Some(patch) = &file.patch {
823                const MAX_PATCH_LENGTH: usize = 2000;
824                let patch_content = if patch.len() > MAX_PATCH_LENGTH {
825                    format!(
826                        "{}...\n[Patch truncated - original length: {} chars]",
827                        &patch[..MAX_PATCH_LENGTH],
828                        patch.len()
829                    )
830                } else {
831                    patch.clone()
832                };
833
834                // Check if adding this patch would exceed total diff size limit
835                let patch_size = patch_content.len();
836                if total_diff_size + patch_size > MAX_TOTAL_DIFF_SIZE {
837                    let _ = writeln!(
838                        prompt,
839                        "```diff\n[Patch omitted - total diff size limit reached]\n```\n"
840                    );
841                    files_skipped += 1;
842                    continue;
843                }
844
845                let _ = writeln!(prompt, "```diff\n{patch_content}\n```\n");
846                total_diff_size += patch_size;
847            }
848
849            files_included += 1;
850        }
851
852        // Add truncation message if files were skipped
853        if files_skipped > 0 {
854            let _ = writeln!(
855                prompt,
856                "\n[{files_skipped} files omitted due to size limits (MAX_FILES={MAX_FILES}, MAX_TOTAL_DIFF_SIZE={MAX_TOTAL_DIFF_SIZE})]"
857            );
858        }
859
860        prompt.push_str("</pull_request>");
861
862        prompt
863    }
864
865    /// Builds the system prompt for PR label suggestion.
866    #[must_use]
867    fn build_pr_label_system_prompt(custom_guidance: Option<&str>) -> String {
868        let context = super::context::load_custom_guidance(custom_guidance);
869        build_pr_label_system_prompt(&context)
870    }
871
872    /// Builds the user prompt for PR label suggestion.
873    #[must_use]
874    fn build_pr_label_user_prompt(title: &str, body: &str, file_paths: &[String]) -> String {
875        use std::fmt::Write;
876
877        let mut prompt = String::new();
878
879        prompt.push_str("<pull_request>\n");
880        let _ = writeln!(prompt, "Title: {title}\n");
881
882        // PR description
883        let body_content = if body.is_empty() {
884            "[No description provided]".to_string()
885        } else if body.len() > MAX_BODY_LENGTH {
886            format!(
887                "{}...\n[Description truncated - original length: {} chars]",
888                &body[..MAX_BODY_LENGTH],
889                body.len()
890            )
891        } else {
892            body.to_string()
893        };
894        let _ = writeln!(prompt, "Description:\n{body_content}\n");
895
896        // File paths
897        if !file_paths.is_empty() {
898            prompt.push_str("Files Changed:\n");
899            for path in file_paths.iter().take(20) {
900                let _ = writeln!(prompt, "- {path}");
901            }
902            if file_paths.len() > 20 {
903                let _ = writeln!(prompt, "- ... and {} more files", file_paths.len() - 20);
904            }
905            prompt.push('\n');
906        }
907
908        prompt.push_str("</pull_request>");
909
910        prompt
911    }
912
913    /// Generate release notes from PR summaries.
914    ///
915    /// # Arguments
916    ///
917    /// * `prs` - List of PR summaries to synthesize
918    /// * `version` - Version being released
919    ///
920    /// # Returns
921    ///
922    /// Structured release notes with theme, highlights, and categorized changes.
923    #[instrument(skip(self, prs))]
924    async fn generate_release_notes(
925        &self,
926        prs: Vec<super::types::PrSummary>,
927        version: &str,
928    ) -> Result<(super::types::ReleaseNotesResponse, AiStats)> {
929        let system_content = if let Some(override_prompt) =
930            super::context::load_system_prompt_override("release_notes_system").await
931        {
932            override_prompt
933        } else {
934            let context = super::context::load_custom_guidance(self.custom_guidance());
935            build_release_notes_system_prompt(&context)
936        };
937        let prompt = Self::build_release_notes_prompt(&prs, version);
938        let request = ChatCompletionRequest {
939            model: self.model().to_string(),
940            messages: vec![
941                ChatMessage {
942                    role: "system".to_string(),
943                    content: system_content,
944                },
945                ChatMessage {
946                    role: "user".to_string(),
947                    content: prompt,
948                },
949            ],
950            response_format: Some(ResponseFormat {
951                format_type: "json_object".to_string(),
952                json_schema: None,
953            }),
954            temperature: Some(0.7),
955            max_tokens: Some(self.max_tokens()),
956        };
957
958        let (parsed, ai_stats) = self
959            .send_and_parse::<super::types::ReleaseNotesResponse>(&request)
960            .await?;
961
962        debug!(
963            input_tokens = ai_stats.input_tokens,
964            output_tokens = ai_stats.output_tokens,
965            duration_ms = ai_stats.duration_ms,
966            "Release notes generation complete with stats"
967        );
968
969        Ok((parsed, ai_stats))
970    }
971
972    /// Build the user prompt for release notes generation.
973    #[must_use]
974    fn build_release_notes_prompt(prs: &[super::types::PrSummary], version: &str) -> String {
975        let pr_list = prs
976            .iter()
977            .map(|pr| {
978                format!(
979                    "- #{}: {} (by @{})\n  {}",
980                    pr.number,
981                    pr.title,
982                    pr.author,
983                    pr.body.lines().next().unwrap_or("")
984                )
985            })
986            .collect::<Vec<_>>()
987            .join("\n");
988
989        format!(
990            "Generate release notes for version {version} based on these merged PRs:\n\n{pr_list}"
991        )
992    }
993}
994
995#[cfg(test)]
996mod tests {
997    use super::*;
998
999    /// Shared struct for parse_ai_json error-path tests.
1000    /// The field is only used via serde deserialization; `_message` silences dead_code.
1001    #[derive(Debug, serde::Deserialize)]
1002    struct ErrorTestResponse {
1003        _message: String,
1004    }
1005
1006    struct TestProvider;
1007
1008    impl AiProvider for TestProvider {
1009        fn name(&self) -> &'static str {
1010            "test"
1011        }
1012
1013        fn api_url(&self) -> &'static str {
1014            "https://test.example.com"
1015        }
1016
1017        fn api_key_env(&self) -> &'static str {
1018            "TEST_API_KEY"
1019        }
1020
1021        fn http_client(&self) -> &Client {
1022            unimplemented!()
1023        }
1024
1025        fn api_key(&self) -> &SecretString {
1026            unimplemented!()
1027        }
1028
1029        fn model(&self) -> &'static str {
1030            "test-model"
1031        }
1032
1033        fn max_tokens(&self) -> u32 {
1034            2048
1035        }
1036
1037        fn temperature(&self) -> f32 {
1038            0.3
1039        }
1040    }
1041
1042    #[test]
1043    fn test_build_system_prompt_contains_json_schema() {
1044        let prompt = TestProvider::build_system_prompt(None);
1045        assert!(prompt.contains("summary"));
1046        assert!(prompt.contains("suggested_labels"));
1047        assert!(prompt.contains("clarifying_questions"));
1048        assert!(prompt.contains("potential_duplicates"));
1049        assert!(prompt.contains("status_note"));
1050    }
1051
1052    #[test]
1053    fn test_build_user_prompt_with_delimiters() {
1054        let issue = IssueDetails::builder()
1055            .owner("test".to_string())
1056            .repo("repo".to_string())
1057            .number(1)
1058            .title("Test issue".to_string())
1059            .body("This is the body".to_string())
1060            .labels(vec!["bug".to_string()])
1061            .comments(vec![])
1062            .url("https://github.com/test/repo/issues/1".to_string())
1063            .build();
1064
1065        let prompt = TestProvider::build_user_prompt(&issue);
1066        assert!(prompt.starts_with("<issue_content>"));
1067        assert!(prompt.ends_with("</issue_content>"));
1068        assert!(prompt.contains("Title: Test issue"));
1069        assert!(prompt.contains("This is the body"));
1070        assert!(prompt.contains("Existing Labels: bug"));
1071    }
1072
1073    #[test]
1074    fn test_build_user_prompt_truncates_long_body() {
1075        let long_body = "x".repeat(5000);
1076        let issue = IssueDetails::builder()
1077            .owner("test".to_string())
1078            .repo("repo".to_string())
1079            .number(1)
1080            .title("Test".to_string())
1081            .body(long_body)
1082            .labels(vec![])
1083            .comments(vec![])
1084            .url("https://github.com/test/repo/issues/1".to_string())
1085            .build();
1086
1087        let prompt = TestProvider::build_user_prompt(&issue);
1088        assert!(prompt.contains("[Body truncated"));
1089        assert!(prompt.contains("5000 chars"));
1090    }
1091
1092    #[test]
1093    fn test_build_user_prompt_empty_body() {
1094        let issue = IssueDetails::builder()
1095            .owner("test".to_string())
1096            .repo("repo".to_string())
1097            .number(1)
1098            .title("Test".to_string())
1099            .body(String::new())
1100            .labels(vec![])
1101            .comments(vec![])
1102            .url("https://github.com/test/repo/issues/1".to_string())
1103            .build();
1104
1105        let prompt = TestProvider::build_user_prompt(&issue);
1106        assert!(prompt.contains("[No description provided]"));
1107    }
1108
1109    #[test]
1110    fn test_build_create_system_prompt_contains_json_schema() {
1111        let prompt = TestProvider::build_create_system_prompt(None);
1112        assert!(prompt.contains("formatted_title"));
1113        assert!(prompt.contains("formatted_body"));
1114        assert!(prompt.contains("suggested_labels"));
1115    }
1116
1117    #[test]
1118    fn test_build_pr_review_user_prompt_respects_file_limit() {
1119        use super::super::types::{PrDetails, PrFile};
1120
1121        let mut files = Vec::new();
1122        for i in 0..25 {
1123            files.push(PrFile {
1124                filename: format!("file{i}.rs"),
1125                status: "modified".to_string(),
1126                additions: 10,
1127                deletions: 5,
1128                patch: Some(format!("patch content {i}")),
1129            });
1130        }
1131
1132        let pr = PrDetails {
1133            owner: "test".to_string(),
1134            repo: "repo".to_string(),
1135            number: 1,
1136            title: "Test PR".to_string(),
1137            body: "Description".to_string(),
1138            head_branch: "feature".to_string(),
1139            base_branch: "main".to_string(),
1140            url: "https://github.com/test/repo/pull/1".to_string(),
1141            files,
1142            labels: vec![],
1143            head_sha: String::new(),
1144        };
1145
1146        let prompt = TestProvider::build_pr_review_user_prompt(&pr);
1147        assert!(prompt.contains("files omitted due to size limits"));
1148        assert!(prompt.contains("MAX_FILES=20"));
1149    }
1150
1151    #[test]
1152    fn test_build_pr_review_user_prompt_respects_diff_size_limit() {
1153        use super::super::types::{PrDetails, PrFile};
1154
1155        // Create patches that will exceed the limit when combined
1156        // Each patch is ~30KB, so two will exceed 50KB limit
1157        let patch1 = "x".repeat(30_000);
1158        let patch2 = "y".repeat(30_000);
1159
1160        let files = vec![
1161            PrFile {
1162                filename: "file1.rs".to_string(),
1163                status: "modified".to_string(),
1164                additions: 100,
1165                deletions: 50,
1166                patch: Some(patch1),
1167            },
1168            PrFile {
1169                filename: "file2.rs".to_string(),
1170                status: "modified".to_string(),
1171                additions: 100,
1172                deletions: 50,
1173                patch: Some(patch2),
1174            },
1175        ];
1176
1177        let pr = PrDetails {
1178            owner: "test".to_string(),
1179            repo: "repo".to_string(),
1180            number: 1,
1181            title: "Test PR".to_string(),
1182            body: "Description".to_string(),
1183            head_branch: "feature".to_string(),
1184            base_branch: "main".to_string(),
1185            url: "https://github.com/test/repo/pull/1".to_string(),
1186            files,
1187            labels: vec![],
1188            head_sha: String::new(),
1189        };
1190
1191        let prompt = TestProvider::build_pr_review_user_prompt(&pr);
1192        // Both files should be listed
1193        assert!(prompt.contains("file1.rs"));
1194        assert!(prompt.contains("file2.rs"));
1195        // The second patch should be limited - verify the prompt doesn't contain both full patches
1196        // by checking that the total size is less than what two full 30KB patches would be
1197        assert!(prompt.len() < 65_000);
1198    }
1199
1200    #[test]
1201    fn test_build_pr_review_user_prompt_with_no_patches() {
1202        use super::super::types::{PrDetails, PrFile};
1203
1204        let files = vec![PrFile {
1205            filename: "file1.rs".to_string(),
1206            status: "added".to_string(),
1207            additions: 10,
1208            deletions: 0,
1209            patch: None,
1210        }];
1211
1212        let pr = PrDetails {
1213            owner: "test".to_string(),
1214            repo: "repo".to_string(),
1215            number: 1,
1216            title: "Test PR".to_string(),
1217            body: "Description".to_string(),
1218            head_branch: "feature".to_string(),
1219            base_branch: "main".to_string(),
1220            url: "https://github.com/test/repo/pull/1".to_string(),
1221            files,
1222            labels: vec![],
1223            head_sha: String::new(),
1224        };
1225
1226        let prompt = TestProvider::build_pr_review_user_prompt(&pr);
1227        assert!(prompt.contains("file1.rs"));
1228        assert!(prompt.contains("added"));
1229        assert!(!prompt.contains("files omitted"));
1230    }
1231
1232    #[test]
1233    fn test_build_pr_label_system_prompt_contains_json_schema() {
1234        let prompt = TestProvider::build_pr_label_system_prompt(None);
1235        assert!(prompt.contains("suggested_labels"));
1236        assert!(prompt.contains("json_object"));
1237        assert!(prompt.contains("bug"));
1238        assert!(prompt.contains("enhancement"));
1239    }
1240
1241    #[test]
1242    fn test_build_pr_label_user_prompt_with_title_and_body() {
1243        let title = "feat: add new feature";
1244        let body = "This PR adds a new feature";
1245        let files = vec!["src/main.rs".to_string(), "tests/test.rs".to_string()];
1246
1247        let prompt = TestProvider::build_pr_label_user_prompt(title, body, &files);
1248        assert!(prompt.starts_with("<pull_request>"));
1249        assert!(prompt.ends_with("</pull_request>"));
1250        assert!(prompt.contains("feat: add new feature"));
1251        assert!(prompt.contains("This PR adds a new feature"));
1252        assert!(prompt.contains("src/main.rs"));
1253        assert!(prompt.contains("tests/test.rs"));
1254    }
1255
1256    #[test]
1257    fn test_build_pr_label_user_prompt_empty_body() {
1258        let title = "fix: bug fix";
1259        let body = "";
1260        let files = vec!["src/lib.rs".to_string()];
1261
1262        let prompt = TestProvider::build_pr_label_user_prompt(title, body, &files);
1263        assert!(prompt.contains("[No description provided]"));
1264        assert!(prompt.contains("src/lib.rs"));
1265    }
1266
1267    #[test]
1268    fn test_build_pr_label_user_prompt_truncates_long_body() {
1269        let title = "test";
1270        let long_body = "x".repeat(5000);
1271        let files = vec![];
1272
1273        let prompt = TestProvider::build_pr_label_user_prompt(title, &long_body, &files);
1274        assert!(prompt.contains("[Description truncated"));
1275        assert!(prompt.contains("5000 chars"));
1276    }
1277
1278    #[test]
1279    fn test_build_pr_label_user_prompt_respects_file_limit() {
1280        let title = "test";
1281        let body = "test";
1282        let mut files = Vec::new();
1283        for i in 0..25 {
1284            files.push(format!("file{i}.rs"));
1285        }
1286
1287        let prompt = TestProvider::build_pr_label_user_prompt(title, body, &files);
1288        assert!(prompt.contains("file0.rs"));
1289        assert!(prompt.contains("file19.rs"));
1290        assert!(!prompt.contains("file20.rs"));
1291        assert!(prompt.contains("... and 5 more files"));
1292    }
1293
1294    #[test]
1295    fn test_build_pr_label_user_prompt_empty_files() {
1296        let title = "test";
1297        let body = "test";
1298        let files: Vec<String> = vec![];
1299
1300        let prompt = TestProvider::build_pr_label_user_prompt(title, body, &files);
1301        assert!(prompt.contains("Title: test"));
1302        assert!(prompt.contains("Description:\ntest"));
1303        assert!(!prompt.contains("Files Changed:"));
1304    }
1305
1306    #[test]
1307    fn test_parse_ai_json_with_valid_json() {
1308        #[derive(serde::Deserialize)]
1309        struct TestResponse {
1310            message: String,
1311        }
1312
1313        let json = r#"{"message": "hello"}"#;
1314        let result: Result<TestResponse> = parse_ai_json(json, "test-provider");
1315        assert!(result.is_ok());
1316        let response = result.unwrap();
1317        assert_eq!(response.message, "hello");
1318    }
1319
1320    #[test]
1321    fn test_parse_ai_json_with_truncated_json() {
1322        let json = r#"{"message": "hello"#;
1323        let result: Result<ErrorTestResponse> = parse_ai_json(json, "test-provider");
1324        assert!(result.is_err());
1325        let err = result.unwrap_err();
1326        assert!(
1327            err.to_string()
1328                .contains("Truncated response from test-provider")
1329        );
1330    }
1331
1332    #[test]
1333    fn test_parse_ai_json_with_malformed_json() {
1334        let json = r#"{"message": invalid}"#;
1335        let result: Result<ErrorTestResponse> = parse_ai_json(json, "test-provider");
1336        assert!(result.is_err());
1337        let err = result.unwrap_err();
1338        assert!(err.to_string().contains("Invalid JSON response from AI"));
1339    }
1340
1341    #[test]
1342    fn test_build_system_prompt_has_senior_persona() {
1343        let prompt = TestProvider::build_system_prompt(None);
1344        assert!(
1345            prompt.contains("You are a senior"),
1346            "prompt should have senior persona"
1347        );
1348        assert!(
1349            prompt.contains("Your mission is"),
1350            "prompt should have mission statement"
1351        );
1352    }
1353
1354    #[test]
1355    fn test_build_system_prompt_has_cot_directive() {
1356        let prompt = TestProvider::build_system_prompt(None);
1357        assert!(prompt.contains("Reason through each step before producing output."));
1358    }
1359
1360    #[test]
1361    fn test_build_system_prompt_has_examples_section() {
1362        let prompt = TestProvider::build_system_prompt(None);
1363        assert!(prompt.contains("## Examples"));
1364    }
1365
1366    #[test]
1367    fn test_build_create_system_prompt_has_senior_persona() {
1368        let prompt = TestProvider::build_create_system_prompt(None);
1369        assert!(
1370            prompt.contains("You are a senior"),
1371            "prompt should have senior persona"
1372        );
1373        assert!(
1374            prompt.contains("Your mission is"),
1375            "prompt should have mission statement"
1376        );
1377    }
1378
1379    #[test]
1380    fn test_build_pr_review_system_prompt_has_senior_persona() {
1381        let prompt = TestProvider::build_pr_review_system_prompt(None);
1382        assert!(
1383            prompt.contains("You are a senior"),
1384            "prompt should have senior persona"
1385        );
1386        assert!(
1387            prompt.contains("Your mission is"),
1388            "prompt should have mission statement"
1389        );
1390    }
1391
1392    #[test]
1393    fn test_build_pr_label_system_prompt_has_senior_persona() {
1394        let prompt = TestProvider::build_pr_label_system_prompt(None);
1395        assert!(
1396            prompt.contains("You are a senior"),
1397            "prompt should have senior persona"
1398        );
1399        assert!(
1400            prompt.contains("Your mission is"),
1401            "prompt should have mission statement"
1402        );
1403    }
1404
1405    #[tokio::test]
1406    async fn test_load_system_prompt_override_returns_none_when_absent() {
1407        let result =
1408            super::super::context::load_system_prompt_override("__nonexistent_test_override__")
1409                .await;
1410        assert!(result.is_none());
1411    }
1412
1413    #[tokio::test]
1414    async fn test_load_system_prompt_override_returns_content_when_present() {
1415        use std::io::Write;
1416        let dir = tempfile::tempdir().expect("create tempdir");
1417        let file_path = dir.path().join("test_override.md");
1418        let mut f = std::fs::File::create(&file_path).expect("create file");
1419        writeln!(f, "Custom override content").expect("write file");
1420        drop(f);
1421
1422        let content = tokio::fs::read_to_string(&file_path).await.ok();
1423        assert_eq!(content.as_deref(), Some("Custom override content\n"));
1424    }
1425}