git_iris/commit/
service.rs

1use anyhow::Result;
2use std::path::Path;
3use std::sync::Arc;
4use tokio::sync::{mpsc, RwLock};
5
6use super::prompt::{create_system_prompt, create_user_prompt, process_commit_message};
7use crate::config::Config;
8use crate::context::{CommitContext, GeneratedMessage};
9use crate::git::{CommitResult, GitRepo};
10use crate::llm;
11use crate::llm_providers::{get_provider_metadata, LLMProviderType};
12use crate::log_debug;
13use crate::token_optimizer::TokenOptimizer;
14
15/// Service for handling Git commit operations with AI assistance
16pub struct IrisCommitService {
17    config: Config,
18    repo: Arc<GitRepo>,
19    provider_type: LLMProviderType,
20    use_gitmoji: bool,
21    verify: bool,
22    cached_context: Arc<RwLock<Option<CommitContext>>>,
23}
24
25impl IrisCommitService {
26    /// Create a new `IrisCommitService` instance
27    ///
28    /// # Arguments
29    ///
30    /// * `config` - The configuration for the service
31    /// * `repo_path` - The path to the Git repository
32    /// * `provider_type` - The type of LLM provider to use
33    /// * `use_gitmoji` - Whether to use Gitmoji in commit messages
34    /// * `verify` - Whether to verify commits
35    ///
36    /// # Returns
37    ///
38    /// A Result containing the new `IrisCommitService` instance or an error
39    pub fn new(
40        config: Config,
41        repo_path: &Path,
42        provider_type: LLMProviderType,
43        use_gitmoji: bool,
44        verify: bool,
45    ) -> Result<Self> {
46        Ok(Self {
47            config,
48            repo: Arc::new(GitRepo::new(repo_path)?),
49            provider_type,
50            use_gitmoji,
51            verify,
52            cached_context: Arc::new(RwLock::new(None)),
53        })
54    }
55
56    /// Check the environment for necessary prerequisites
57    pub fn check_environment(&self) -> Result<()> {
58        self.config.check_environment()
59    }
60
61    /// Get Git information for the current repository
62    pub async fn get_git_info(&self) -> Result<CommitContext> {
63        {
64            let cached_context = self.cached_context.read().await;
65            if let Some(context) = &*cached_context {
66                return Ok(context.clone());
67            }
68        }
69
70        let context = self.repo.get_git_info(&self.config).await?;
71
72        {
73            let mut cached_context = self.cached_context.write().await;
74            *cached_context = Some(context.clone());
75        }
76        Ok(context)
77    }
78
79    /// Generate a commit message using AI
80    ///
81    /// # Arguments
82    ///
83    /// * `preset` - The instruction preset to use
84    /// * `instructions` - Custom instructions for the AI
85    ///
86    /// # Returns
87    ///
88    /// A Result containing the generated commit message or an error
89    pub async fn generate_message(
90        &self,
91        preset: &str,
92        instructions: &str,
93    ) -> anyhow::Result<GeneratedMessage> {
94        let mut config_clone = self.config.clone();
95        config_clone.instruction_preset = preset.to_string();
96        config_clone.instructions = instructions.to_string();
97
98        let mut context = self.get_git_info().await?;
99
100        // Get the token limit from the provider config
101        let token_limit = config_clone
102            .providers
103            .get(&self.provider_type.to_string())
104            .and_then(|p| p.token_limit)
105            .unwrap_or_else(|| get_provider_metadata(&self.provider_type).default_token_limit);
106
107        // Create system prompt first to know its token count
108        let system_prompt = create_system_prompt(&config_clone)?;
109
110        // Create a token optimizer to count tokens
111        let optimizer = TokenOptimizer::new(token_limit);
112        let system_tokens = optimizer.count_tokens(&system_prompt);
113
114        log_debug!("Token limit: {}", token_limit);
115        log_debug!("System prompt tokens: {}", system_tokens);
116
117        // Reserve tokens for system prompt and some buffer for formatting
118        let context_token_limit = token_limit.saturating_sub(system_tokens + 1000); // 1000 token buffer for safety
119        log_debug!("Available tokens for context: {}", context_token_limit);
120
121        // Count tokens before optimization
122        let user_prompt_before = create_user_prompt(&context);
123        let total_tokens_before = system_tokens + optimizer.count_tokens(&user_prompt_before);
124        log_debug!("Total tokens before optimization: {}", total_tokens_before);
125
126        // Optimize the context with remaining token budget
127        context.optimize(context_token_limit);
128
129        let user_prompt = create_user_prompt(&context);
130        let user_tokens = optimizer.count_tokens(&user_prompt);
131        let total_tokens = system_tokens + user_tokens;
132
133        log_debug!("User prompt tokens after optimization: {}", user_tokens);
134        log_debug!("Total tokens after optimization: {}", total_tokens);
135
136        // If we're still over the limit, truncate the user prompt directly
137        let final_user_prompt = if total_tokens > token_limit {
138            log_debug!(
139                "Total tokens {} still exceeds limit {}, truncating user prompt",
140                total_tokens,
141                token_limit
142            );
143            let max_user_tokens = token_limit.saturating_sub(system_tokens + 100); // 100 token safety buffer
144            optimizer.truncate_string(&user_prompt, max_user_tokens)
145        } else {
146            user_prompt
147        };
148
149        let final_tokens = system_tokens + optimizer.count_tokens(&final_user_prompt);
150        log_debug!(
151            "Final total tokens after potential truncation: {}",
152            final_tokens
153        );
154
155        let mut generated_message = llm::get_refined_message::<GeneratedMessage>(
156            &config_clone,
157            &self.provider_type,
158            &system_prompt,
159            &final_user_prompt,
160        )
161        .await?;
162
163        // Apply gitmoji setting
164        if !self.use_gitmoji {
165            generated_message.emoji = None;
166        }
167
168        Ok(generated_message)
169    }
170
171    /// Perform a commit with the given message
172    ///
173    /// # Arguments
174    ///
175    /// * `message` - The commit message to use
176    ///
177    /// # Returns
178    ///
179    /// A Result containing the `CommitResult` or an error
180    pub fn perform_commit(&self, message: &str) -> Result<CommitResult> {
181        let processed_message = process_commit_message(message.to_string(), self.use_gitmoji);
182        if self.verify {
183            self.repo.commit_and_verify(&processed_message)
184        } else {
185            self.repo.commit(&processed_message)
186        }
187    }
188
189    /// Execute the pre-commit hook if verification is enabled
190    pub fn pre_commit(&self) -> Result<()> {
191        if self.verify {
192            self.repo.execute_hook("pre-commit")
193        } else {
194            Ok(())
195        }
196    }
197
198    /// Create a channel for message generation
199    pub fn create_message_channel(
200        &self,
201    ) -> (
202        mpsc::Sender<Result<GeneratedMessage>>,
203        mpsc::Receiver<Result<GeneratedMessage>>,
204    ) {
205        mpsc::channel(1)
206    }
207}