gitai/features/commit/
service.rs

1use super::prompt::{create_system_prompt, create_user_prompt};
2use super::review::GeneratedReview;
3use super::types::GeneratedMessage;
4use crate::config::Config;
5use crate::core::context::CommitContext;
6use crate::core::llm;
7use crate::core::token_optimizer::TokenOptimizer;
8use crate::debug;
9use crate::git::{CommitResult, GitRepo};
10
11use anyhow::Result;
12use std::path::Path;
13use std::sync::Arc;
14use tokio::sync::{RwLock, mpsc};
15
16/// Service for handling Git commit operations with AI assistance
17pub struct CommitService {
18    config: Config,
19    repo: Arc<GitRepo>,
20    provider_name: String,
21    verify: bool,
22    cached_context: Arc<RwLock<Option<CommitContext>>>,
23}
24
25impl CommitService {
26    /// Create a new `CommitService` instance
27    ///
28    /// # Arguments
29    ///
30    /// * `config` - The configuration for the service
31    /// * `repo_path` - The path to the Git repository (unused but kept for API compatibility)
32    /// * `provider_name` - The name of the LLM provider to use
33    /// * `verify` - Whether to verify commits
34    /// * `git_repo` - An existing `GitRepo` instance
35    ///
36    /// # Returns
37    ///
38    /// A Result containing the new `CommitService` instance or an error
39    pub fn new(
40        config: Config,
41        _repo_path: &Path,
42        provider_name: &str,
43        verify: bool,
44        git_repo: GitRepo,
45    ) -> Result<Self> {
46        Ok(Self {
47            config,
48            repo: Arc::new(git_repo),
49            provider_name: provider_name.to_string(),
50            verify,
51            cached_context: Arc::new(RwLock::new(None)),
52        })
53    }
54
55    /// Check if the repository is remote
56    pub fn is_remote_repository(&self) -> bool {
57        self.repo.is_remote()
58    }
59
60    /// Check the environment for necessary prerequisites
61    pub fn check_environment(&self) -> Result<()> {
62        self.config.check_environment()
63    }
64
65    /// Get Git information for the current repository
66    pub async fn get_git_info(&self) -> Result<CommitContext> {
67        {
68            let cached_context = self.cached_context.read().await;
69            if let Some(context) = &*cached_context {
70                return Ok(context.clone());
71            }
72        }
73
74        let context = self.repo.get_git_info(&self.config).await?;
75
76        {
77            let mut cached_context = self.cached_context.write().await;
78            *cached_context = Some(context.clone());
79        }
80        Ok(context)
81    }
82
83    /// Get Git information including unstaged changes
84    pub async fn get_git_info_with_unstaged(
85        &self,
86        include_unstaged: bool,
87    ) -> Result<CommitContext> {
88        if !include_unstaged {
89            return self.get_git_info().await;
90        }
91
92        {
93            // Only use cached context if we're not including unstaged changes
94            // because unstaged changes might have changed since we last checked
95            let cached_context = self.cached_context.read().await;
96            if let Some(context) = &*cached_context
97                && !include_unstaged
98            {
99                return Ok(context.clone());
100            }
101        }
102
103        let context = self
104            .repo
105            .get_git_info_with_unstaged(&self.config, include_unstaged)
106            .await?;
107
108        // Don't cache the context with unstaged changes since they can be constantly changing
109        if !include_unstaged {
110            let mut cached_context = self.cached_context.write().await;
111            *cached_context = Some(context.clone());
112        }
113
114        Ok(context)
115    }
116
117    /// Get Git information for a specific commit
118    pub async fn get_git_info_for_commit(&self, commit_id: &str) -> Result<CommitContext> {
119        debug!("Getting git info for commit: {}", commit_id);
120
121        let context = self
122            .repo
123            .get_git_info_for_commit(&self.config, commit_id)
124            .await?;
125
126        // We don't cache commit-specific contexts
127        Ok(context)
128    }
129
130    /// Private helper method to handle common token optimization logic
131    ///
132    /// # Arguments
133    ///
134    /// * `config_clone` - Configuration with preset and instructions
135    /// * `system_prompt` - The system prompt to use
136    /// * `context` - The commit context
137    /// * `create_user_prompt_fn` - A function that creates a user prompt from a context
138    ///
139    /// # Returns
140    ///
141    /// A tuple containing the optimized context and final user prompt
142    fn optimize_prompt<F>(
143        &self,
144        config_clone: &Config,
145        system_prompt: &str,
146        mut context: CommitContext,
147        create_user_prompt_fn: F,
148    ) -> (CommitContext, String)
149    where
150        F: Fn(&CommitContext) -> String,
151    {
152        // Get the token limit for the provider from config or default value
153        let token_limit = config_clone
154            .providers
155            .get(&self.provider_name)
156            .and_then(|p| p.token_limit)
157            .unwrap_or({
158                match self.provider_name.as_str() {
159                    "openai" => 16_000,
160                    "anthropic" => 100_000,
161                    "groq" => 32_000,
162                    "openrouter" => 2_000_000,
163                    "google" => 1_000_000,
164                    _ => 8_000,
165                }
166            });
167
168        // Create a token optimizer to count tokens
169        let optimizer = TokenOptimizer::new(token_limit).expect("Failed to create TokenOptimizer");
170        let system_tokens = optimizer.count_tokens(system_prompt);
171
172        debug!("Token limit: {}", token_limit);
173        debug!("System prompt tokens: {}", system_tokens);
174
175        // Reserve tokens for system prompt and some buffer for formatting
176        // 1000 token buffer provides headroom for model responses and formatting
177        let context_token_limit = token_limit.saturating_sub(system_tokens + 1000);
178        debug!("Available tokens for context: {}", context_token_limit);
179
180        // Count tokens before optimization
181        let user_prompt_before = create_user_prompt_fn(&context);
182        let total_tokens_before = system_tokens + optimizer.count_tokens(&user_prompt_before);
183        debug!("Total tokens before optimization: {}", total_tokens_before);
184
185        // Optimize the context with remaining token budget
186        context.optimize(context_token_limit);
187
188        let user_prompt = create_user_prompt_fn(&context);
189        let user_tokens = optimizer.count_tokens(&user_prompt);
190        let total_tokens = system_tokens + user_tokens;
191
192        debug!("User prompt tokens after optimization: {}", user_tokens);
193        debug!("Total tokens after optimization: {}", total_tokens);
194
195        // If we're still over the limit, truncate the user prompt directly
196        // 100 token safety buffer ensures we stay under the limit
197        let final_user_prompt = if total_tokens > token_limit {
198            debug!(
199                "Total tokens {} still exceeds limit {}, truncating user prompt",
200                total_tokens, token_limit
201            );
202            let max_user_tokens = token_limit.saturating_sub(system_tokens + 100);
203            optimizer
204                .truncate_string(&user_prompt, max_user_tokens)
205                .expect("Failed to truncate user prompt")
206        } else {
207            user_prompt
208        };
209
210        let final_tokens = system_tokens + optimizer.count_tokens(&final_user_prompt);
211        debug!(
212            "Final total tokens after potential truncation: {}",
213            final_tokens
214        );
215
216        (context, final_user_prompt)
217    }
218
219    /// Generate a commit message using AI
220    ///
221    /// # Arguments
222    ///
223    /// * `preset` - The instruction preset to use
224    /// * `instructions` - Custom instructions for the AI
225    ///
226    /// # Returns
227    ///
228    /// A Result containing the generated commit message or an error
229    pub async fn generate_message(&self, instructions: &str) -> anyhow::Result<GeneratedMessage> {
230        let mut config_clone = self.config.clone();
231
232        config_clone.instructions = instructions.to_string();
233
234        let context = self.get_git_info().await?;
235
236        // Create system prompt
237        let system_prompt = create_system_prompt(&config_clone)?;
238
239        // Use the shared optimization logic
240        let (_, final_user_prompt) =
241            self.optimize_prompt(&config_clone, &system_prompt, context, create_user_prompt);
242
243        let generated_message = llm::get_message::<GeneratedMessage>(
244            &config_clone,
245            &self.provider_name,
246            &system_prompt,
247            &final_user_prompt,
248        )
249        .await?;
250
251        Ok(generated_message)
252    }
253
254    /// Generate a review for unstaged changes
255    ///
256    /// # Arguments
257    ///
258    /// * `preset` - The instruction preset to use
259    /// * `instructions` - Custom instructions for the AI
260    /// * `include_unstaged` - Whether to include unstaged changes in the review
261    ///
262    /// # Returns
263    ///
264    /// A Result containing the generated code review or an error
265    pub async fn generate_review_with_unstaged(
266        &self,
267        instructions: &str,
268        include_unstaged: bool,
269    ) -> anyhow::Result<GeneratedReview> {
270        let mut config_clone = self.config.clone();
271
272        config_clone.instructions = instructions.to_string();
273
274        // Get context including unstaged changes if requested
275        let context = self.get_git_info_with_unstaged(include_unstaged).await?;
276
277        // Create system prompt
278        let system_prompt = super::prompt::create_review_system_prompt(&config_clone)?;
279
280        // Use the shared optimization logic
281        let (_, final_user_prompt) = self.optimize_prompt(
282            &config_clone,
283            &system_prompt,
284            context,
285            super::prompt::create_review_user_prompt,
286        );
287
288        llm::get_message::<GeneratedReview>(
289            &config_clone,
290            &self.provider_name,
291            &system_prompt,
292            &final_user_prompt,
293        )
294        .await
295    }
296
297    /// Generate a review for a specific commit
298    ///
299    /// # Arguments
300    ///
301    /// * `preset` - The instruction preset to use
302    /// * `instructions` - Custom instructions for the AI
303    /// * `commit_id` - The ID of the commit to review
304    ///
305    /// # Returns
306    ///
307    /// A Result containing the generated code review or an error
308    pub async fn generate_review_for_commit(
309        &self,
310        instructions: &str,
311        commit_id: &str,
312    ) -> anyhow::Result<GeneratedReview> {
313        let mut config_clone = self.config.clone();
314
315        config_clone.instructions = instructions.to_string();
316
317        // Get context for the specific commit
318        let context = self.get_git_info_for_commit(commit_id).await?;
319
320        // Create system prompt
321        let system_prompt = super::prompt::create_review_system_prompt(&config_clone)?;
322
323        // Use the shared optimization logic
324        let (_, final_user_prompt) = self.optimize_prompt(
325            &config_clone,
326            &system_prompt,
327            context,
328            super::prompt::create_review_user_prompt,
329        );
330
331        llm::get_message::<GeneratedReview>(
332            &config_clone,
333            &self.provider_name,
334            &system_prompt,
335            &final_user_prompt,
336        )
337        .await
338    }
339
340    /// Generate a review for branch comparison
341    ///
342    /// # Arguments
343    ///
344    /// * `preset` - The instruction preset to use
345    /// * `instructions` - Custom instructions for the AI
346    /// * `base_branch` - The base branch (e.g., "main")
347    /// * `target_branch` - The target branch (e.g., "feature-branch")
348    ///
349    /// # Returns
350    ///
351    /// A Result containing the generated code review or an error
352    pub async fn generate_review_for_branch_diff(
353        &self,
354        instructions: &str,
355        base_branch: &str,
356        target_branch: &str,
357    ) -> anyhow::Result<GeneratedReview> {
358        let mut config_clone = self.config.clone();
359
360        config_clone.instructions = instructions.to_string();
361
362        // Get context for the branch comparison
363        let context = self
364            .repo
365            .get_git_info_for_branch_diff(&self.config, base_branch, target_branch)
366            .await?;
367
368        // Create system prompt
369        let system_prompt = super::prompt::create_review_system_prompt(&config_clone)?;
370
371        // Use the shared optimization logic
372        let (_, final_user_prompt) = self.optimize_prompt(
373            &config_clone,
374            &system_prompt,
375            context,
376            super::prompt::create_review_user_prompt,
377        );
378
379        llm::get_message::<GeneratedReview>(
380            &config_clone,
381            &self.provider_name,
382            &system_prompt,
383            &final_user_prompt,
384        )
385        .await
386    }
387
388    /// Generate a code review using AI
389    ///
390    /// # Arguments
391    ///
392    /// * `preset` - The instruction preset to use
393    /// * `instructions` - Custom instructions for the AI
394    ///
395    /// # Returns
396    ///
397    /// A Result containing the generated code review or an error
398    pub async fn generate_review(&self, instructions: &str) -> anyhow::Result<GeneratedReview> {
399        let mut config_clone = self.config.clone();
400
401        config_clone.instructions = instructions.to_string();
402
403        let context = self.get_git_info().await?;
404
405        // Create system prompt
406        let system_prompt = super::prompt::create_review_system_prompt(&config_clone)?;
407
408        // Use the shared optimization logic
409        let (_, final_user_prompt) = self.optimize_prompt(
410            &config_clone,
411            &system_prompt,
412            context,
413            super::prompt::create_review_user_prompt,
414        );
415
416        llm::get_message::<GeneratedReview>(
417            &config_clone,
418            &self.provider_name,
419            &system_prompt,
420            &final_user_prompt,
421        )
422        .await
423    }
424
425    /// Generate a PR description for a commit range
426    ///
427    /// # Arguments
428    ///
429    /// * `preset` - The instruction preset to use
430    /// * `instructions` - Custom instructions for the AI
431    /// * `from` - The starting Git reference (exclusive)
432    /// * `to` - The ending Git reference (inclusive)
433    ///
434    /// # Returns
435    ///
436    /// A Result containing the generated PR description or an error
437    pub async fn generate_pr_for_commit_range(
438        &self,
439        instructions: &str,
440        from: &str,
441        to: &str,
442    ) -> anyhow::Result<super::types::GeneratedPullRequest> {
443        let mut config_clone = self.config.clone();
444
445        config_clone.instructions = instructions.to_string();
446
447        // Get context for the commit range
448        let context = self
449            .repo
450            .get_git_info_for_commit_range(&self.config, from, to)
451            .await?;
452
453        // Get commit messages for the PR
454        let commit_messages = self.repo.get_commits_for_pr(from, to)?;
455
456        // Create system prompt
457        let system_prompt = super::prompt::create_pr_system_prompt(&config_clone)?;
458
459        // Use the shared optimization logic
460        let (_, final_user_prompt) =
461            self.optimize_prompt(&config_clone, &system_prompt, context, |ctx| {
462                super::prompt::create_pr_user_prompt(ctx, &commit_messages)
463            });
464
465        let generated_pr = llm::get_message::<super::types::GeneratedPullRequest>(
466            &config_clone,
467            &self.provider_name,
468            &system_prompt,
469            &final_user_prompt,
470        )
471        .await?;
472
473        Ok(generated_pr)
474    }
475
476    /// Generate a PR description for branch comparison
477    ///
478    /// # Arguments
479    ///
480    /// * `preset` - The instruction preset to use
481    /// * `instructions` - Custom instructions for the AI
482    /// * `base_branch` - The base branch (e.g., "main")
483    /// * `target_branch` - The target branch (e.g., "feature-branch")
484    ///
485    /// # Returns
486    ///
487    /// A Result containing the generated PR description or an error
488    pub async fn generate_pr_for_branch_diff(
489        &self,
490        instructions: &str,
491        base_branch: &str,
492        target_branch: &str,
493    ) -> anyhow::Result<super::types::GeneratedPullRequest> {
494        let mut config_clone = self.config.clone();
495
496        config_clone.instructions = instructions.to_string();
497
498        // Get context for the branch comparison
499        let context = self
500            .repo
501            .get_git_info_for_branch_diff(&self.config, base_branch, target_branch)
502            .await?;
503
504        // Get commit messages for the PR (commits in target_branch not in base_branch)
505        let commit_messages = self.repo.get_commits_for_pr(base_branch, target_branch)?;
506
507        // Create system prompt
508        let system_prompt = super::prompt::create_pr_system_prompt(&config_clone)?;
509
510        // Use the shared optimization logic
511        let (_, final_user_prompt) =
512            self.optimize_prompt(&config_clone, &system_prompt, context, |ctx| {
513                super::prompt::create_pr_user_prompt(ctx, &commit_messages)
514            });
515
516        let generated_pr = llm::get_message::<super::types::GeneratedPullRequest>(
517            &config_clone,
518            &self.provider_name,
519            &system_prompt,
520            &final_user_prompt,
521        )
522        .await?;
523
524        Ok(generated_pr)
525    }
526
527    /// Performs a commit with the given message.
528    ///
529    /// # Arguments
530    ///
531    /// * `message` - The commit message.
532    ///
533    /// # Returns
534    ///
535    /// A Result containing the `CommitResult` or an error.
536    pub fn perform_commit(&self, message: &str) -> Result<CommitResult> {
537        // Check if this is a remote repository
538        if self.is_remote_repository() {
539            return Err(anyhow::anyhow!("Cannot commit to a remote repository"));
540        }
541
542        debug!("Performing commit with message: {}", message);
543
544        if !self.verify {
545            debug!("Skipping pre-commit hook (verify=false)");
546            return self.repo.commit(message);
547        }
548
549        // Execute pre-commit hook
550        debug!("Executing pre-commit hook");
551        if let Err(e) = self.repo.execute_hook("pre-commit") {
552            debug!("Pre-commit hook failed: {}", e);
553            return Err(e);
554        }
555        debug!("Pre-commit hook executed successfully");
556
557        // Perform the commit
558        match self.repo.commit(message) {
559            Ok(result) => {
560                // Execute post-commit hook
561                debug!("Executing post-commit hook");
562                if let Err(e) = self.repo.execute_hook("post-commit") {
563                    debug!("Post-commit hook failed: {}", e);
564                    // We don't fail the commit if post-commit hook fails
565                }
566                debug!("Commit performed successfully");
567                Ok(result)
568            }
569            Err(e) => {
570                debug!("Commit failed: {}", e);
571                Err(e)
572            }
573        }
574    }
575
576    /// Execute the pre-commit hook if verification is enabled
577    pub fn pre_commit(&self) -> Result<()> {
578        // Skip pre-commit hook for remote repositories
579        if self.is_remote_repository() {
580            debug!("Skipping pre-commit hook for remote repository");
581            return Ok(());
582        }
583
584        if self.verify {
585            self.repo.execute_hook("pre-commit")
586        } else {
587            Ok(())
588        }
589    }
590
591    /// Create a channel for message generation
592    pub fn create_message_channel(
593        &self,
594    ) -> (
595        mpsc::Sender<Result<GeneratedMessage>>,
596        mpsc::Receiver<Result<GeneratedMessage>>,
597    ) {
598        mpsc::channel(1)
599    }
600}