Skip to main content

rusty_commit/providers/
gemini.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::{split_prompt, AIProvider};
7use crate::config::Config;
8
9pub struct GeminiProvider {
10    client: Client,
11    api_key: String,
12    model: String,
13}
14
15#[derive(Serialize)]
16struct GeminiRequest {
17    contents: Vec<Content>,
18    system_instruction: Option<SystemInstruction>,
19    generation_config: GenerationConfig,
20}
21
22#[derive(Serialize)]
23struct Content {
24    role: String,
25    parts: Vec<Part>,
26}
27
28#[derive(Serialize)]
29struct Part {
30    text: String,
31}
32
33#[derive(Serialize)]
34struct SystemInstruction {
35    role: String,
36    parts: Vec<Part>,
37}
38
39#[derive(Serialize)]
40struct GenerationConfig {
41    temperature: f32,
42    max_output_tokens: u32,
43}
44
45#[derive(Deserialize)]
46struct GeminiResponse {
47    candidates: Vec<Candidate>,
48}
49
50#[derive(Deserialize)]
51struct Candidate {
52    content: ResponseContent,
53}
54
55#[derive(Deserialize)]
56struct ResponseContent {
57    parts: Vec<ResponsePart>,
58}
59
60#[derive(Deserialize)]
61struct ResponsePart {
62    text: String,
63}
64
65impl GeminiProvider {
66    pub fn new(config: &Config) -> Result<Self> {
67        let api_key = config
68            .api_key
69            .as_ref()
70            .context("Gemini API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?
71            .clone();
72
73        let client = Client::new();
74        let model = config.model.as_deref().unwrap_or("gemini-pro").to_string();
75
76        Ok(Self {
77            client,
78            api_key,
79            model,
80        })
81    }
82
83    /// Create provider from account configuration
84    #[allow(dead_code)]
85    pub fn from_account(
86        _account: &crate::config::accounts::AccountConfig,
87        api_key: &str,
88        config: &Config,
89    ) -> Result<Self> {
90        let client = Client::new();
91        let model = _account
92            .model
93            .as_deref()
94            .or(config.model.as_deref())
95            .unwrap_or("gemini-pro")
96            .to_string();
97
98        Ok(Self {
99            client,
100            api_key: api_key.to_string(),
101            model,
102        })
103    }
104}
105
106#[async_trait]
107impl AIProvider for GeminiProvider {
108    async fn generate_commit_message(
109        &self,
110        diff: &str,
111        context: Option<&str>,
112        full_gitmoji: bool,
113        config: &Config,
114    ) -> Result<String> {
115        let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
116
117        let request = GeminiRequest {
118            contents: vec![Content {
119                role: "user".to_string(),
120                parts: vec![Part { text: user_prompt }],
121            }],
122            system_instruction: Some(SystemInstruction {
123                role: "system".to_string(),
124                parts: vec![Part {
125                    text: system_prompt,
126                }],
127            }),
128            generation_config: GenerationConfig {
129                temperature: 0.7,
130                max_output_tokens: config.tokens_max_output.unwrap_or(500),
131            },
132        };
133
134        let url = format!(
135            "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
136            self.model
137        );
138
139        let response = self
140            .client
141            .post(&url)
142            .header("X-Goog-Api-Key", &self.api_key)
143            .json(&request)
144            .send()
145            .await
146            .context("Failed to connect to Gemini")?;
147
148        if !response.status().is_success() {
149            let error_text = response.text().await?;
150            anyhow::bail!("Gemini API error: {}", error_text);
151        }
152
153        let gemini_response: GeminiResponse = response
154            .json()
155            .await
156            .context("Failed to parse Gemini response")?;
157
158        let message = gemini_response
159            .candidates
160            .first()
161            .and_then(|c| c.content.parts.first())
162            .map(|p| p.text.trim().to_string())
163            .context("No response from Gemini")?;
164
165        Ok(message)
166    }
167}
168
169/// ProviderBuilder for Gemini
170pub struct GeminiProviderBuilder;
171
172impl super::registry::ProviderBuilder for GeminiProviderBuilder {
173    fn name(&self) -> &'static str {
174        "gemini"
175    }
176
177    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
178        Ok(Box::new(GeminiProvider::new(config)?))
179    }
180
181    fn requires_api_key(&self) -> bool {
182        true
183    }
184
185    fn default_model(&self) -> Option<&'static str> {
186        Some("gemini-1.5-pro")
187    }
188}