Skip to main content

rusty_commit/providers/
gemini.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::prompt::split_prompt;
7use super::AIProvider;
8use crate::config::Config;
9
10pub struct GeminiProvider {
11    client: Client,
12    api_key: String,
13    model: String,
14}
15
16#[derive(Serialize)]
17struct GeminiRequest {
18    contents: Vec<Content>,
19    system_instruction: Option<SystemInstruction>,
20    generation_config: GenerationConfig,
21}
22
23#[derive(Serialize)]
24struct Content {
25    role: String,
26    parts: Vec<Part>,
27}
28
29#[derive(Serialize)]
30struct Part {
31    text: String,
32}
33
34#[derive(Serialize)]
35struct SystemInstruction {
36    role: String,
37    parts: Vec<Part>,
38}
39
40#[derive(Serialize)]
41struct GenerationConfig {
42    temperature: f32,
43    max_output_tokens: u32,
44}
45
46#[derive(Deserialize)]
47struct GeminiResponse {
48    candidates: Vec<Candidate>,
49}
50
51#[derive(Deserialize)]
52struct Candidate {
53    content: ResponseContent,
54}
55
56#[derive(Deserialize)]
57struct ResponseContent {
58    parts: Vec<ResponsePart>,
59}
60
61#[derive(Deserialize)]
62struct ResponsePart {
63    text: String,
64}
65
66impl GeminiProvider {
67    pub fn new(config: &Config) -> Result<Self> {
68        let api_key = config
69            .api_key
70            .as_ref()
71            .context("Gemini API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?
72            .clone();
73
74        let client = Client::new();
75        let model = config.model.as_deref().unwrap_or("gemini-pro").to_string();
76
77        Ok(Self {
78            client,
79            api_key,
80            model,
81        })
82    }
83
84    /// Create provider from account configuration
85    #[allow(dead_code)]
86    pub fn from_account(
87        _account: &crate::config::accounts::AccountConfig,
88        api_key: &str,
89        config: &Config,
90    ) -> Result<Self> {
91        let client = Client::new();
92        let model = _account
93            .model
94            .as_deref()
95            .or(config.model.as_deref())
96            .unwrap_or("gemini-pro")
97            .to_string();
98
99        Ok(Self {
100            client,
101            api_key: api_key.to_string(),
102            model,
103        })
104    }
105}
106
107#[async_trait]
108impl AIProvider for GeminiProvider {
109    async fn generate_commit_message(
110        &self,
111        diff: &str,
112        context: Option<&str>,
113        full_gitmoji: bool,
114        config: &Config,
115    ) -> Result<String> {
116        let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
117
118        let request = GeminiRequest {
119            contents: vec![Content {
120                role: "user".to_string(),
121                parts: vec![Part { text: user_prompt }],
122            }],
123            system_instruction: Some(SystemInstruction {
124                role: "system".to_string(),
125                parts: vec![Part {
126                    text: system_prompt,
127                }],
128            }),
129            generation_config: GenerationConfig {
130                temperature: 0.7,
131                max_output_tokens: config.tokens_max_output.unwrap_or(500),
132            },
133        };
134
135        let url = format!(
136            "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
137            self.model
138        );
139
140        let response = self
141            .client
142            .post(&url)
143            .header("X-Goog-Api-Key", &self.api_key)
144            .json(&request)
145            .send()
146            .await
147            .context("Failed to connect to Gemini")?;
148
149        if !response.status().is_success() {
150            let error_text = response.text().await?;
151            anyhow::bail!("Gemini API error: {}", error_text);
152        }
153
154        let gemini_response: GeminiResponse = response
155            .json()
156            .await
157            .context("Failed to parse Gemini response")?;
158
159        let message = gemini_response
160            .candidates
161            .first()
162            .and_then(|c| c.content.parts.first())
163            .map(|p| p.text.trim().to_string())
164            .context("No response from Gemini")?;
165
166        Ok(message)
167    }
168}
169
170/// ProviderBuilder for Gemini
171pub struct GeminiProviderBuilder;
172
173impl super::registry::ProviderBuilder for GeminiProviderBuilder {
174    fn name(&self) -> &'static str {
175        "gemini"
176    }
177
178    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
179        Ok(Box::new(GeminiProvider::new(config)?))
180    }
181
182    fn requires_api_key(&self) -> bool {
183        true
184    }
185
186    fn default_model(&self) -> Option<&'static str> {
187        Some("gemini-1.5-pro")
188    }
189}