git_revise/ai/
gemini.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use super::AI;
6use crate::ReviseResult;
7
8#[derive(Debug, Clone)]
9pub struct Gemini {
10    prompt: String,
11    url: String,
12}
13
14#[derive(Serialize, Deserialize, Debug)]
15pub struct Commit {
16    #[serde(rename = "type")]
17    pub kind: String,
18    pub message: String,
19    pub body: String,
20}
21
22impl Gemini {
23    pub fn new(key: &str) -> Self {
24        let prompt = r#"
25        # Character
26            You're a brilliant coding buddy with top-notch proficiency in Git. Your main duty is to assist users in crafting clear and precise Git commit messages.
27
28        ## Skills
29
30        ### Skill 1: Multilingual Translation
31        - Recognize translation requests in the format: "SourceLanguage:TargetLanguage; Content"
32        - Identify requests starting with "这是一个翻译commit" as translation tasks
33        - Translate the given content from the source language to the target language
34        - Preserve the original text alongside the translation in the output
35        - Adapt the translation to fit the context of Git commit messages
36        - Example input: "中文:English; 这是一个翻译commit, 优化用户界面布局"
37        - Example output: 
38          ```json
39          [
40            {
41              "type": "translation",
42              "message": "Optimize user interface layout",
43              "body": "A long body with details about the changes made"
44            },
45            {
46              "type": "translation",
47              "message": "Optimize user interface layout",
48              "body": "A long body with details about the changes made"
49            },
50            {
51              "type": "translation",
52              "message": "Optimize user interface layout",
53              "body": "A long body with details about the changes made"
54            }
55          ]
56          ```
57
58        ### Skill 2: The Commit Message Maverick
59        - Process the git diff or description given by the user
60        - Curate commit messages that confidently and tersely summarize the changes made
61        - Always provide exactly three alternative commit messages for each request
62        - Ensure diversity in style and content among the three alternatives
63
64        ## Output Format
65        The outcome should adhere to the following structure:
66        ```json
67        [
68          {"type": "<type>", "message": "<message>", "body": "<body>"},
69          {"type": "<type>", "message": "<message>", "body": "<body>"},
70          {"type": "<type>", "message": "<message>", "body": "<body>"}
71        ]
72        ```
73
74        ## Constraints
75        - Commit messages should be between 5-20 words
76        - If the message surpasses this limit, abbreviate it without shedding essential details while employing the 'body' part for detailed elaboration
77        - Do not include prefixes like "feat:", "fix:", etc. in the commit message, just put it in <type> part, and start the <message> with a verb
78        - Guarantee that all dialogues are carried out in the English language, except for translation requests
79        - Remain concentrated on tasks strictly linked with creating Git commit messages
80        - Remember to always provide three distinct commit message options.
81
82        ## Error Handling
83        If the user's submission doesn't correspond with the demanded parameters, generate this response:
84        ```json
85        [{"type": "error", "message": "Request processing failure", "body":"The submitted input isn't compatible with the required parameters"}]
86        ```
87
88        "#;
89        let url = format!(
90            "{}/models/{}:{}?key={}",
91            "https://generativelanguage.googleapis.com/v1beta",
92            "gemini-1.5-pro-latest",
93            "generateContent",
94            key
95        );
96        Self {
97            prompt: prompt.to_string(),
98            url,
99        }
100    }
101
102    pub async fn call(
103        &self,
104        input: &str,
105    ) -> ReviseResult<HashMap<String, Commit>> {
106        let txt_request = Request {
107            contents: vec![
108                Content {
109                    role: Role::User,
110                    parts: vec![Part {
111                        text: Some(self.prompt.clone()),
112                        inline_data: None,
113                        file_data: None,
114                        video_metadata: None,
115                    }],
116                },
117                Content {
118                    role: Role::User,
119                    parts: vec![Part {
120                        text: Some(input.to_string()),
121                        inline_data: None,
122                        file_data: None,
123                        video_metadata: None,
124                    }],
125                },
126            ],
127            tools: vec![],
128            safety_settings: vec![],
129            generation_config: Some(GenerationConfig {
130                temperature: None,
131                top_p: None,
132                top_k: None,
133                candidate_count: None,
134                max_output_tokens: None,
135                stop_sequences: None,
136                response_mime_type: Some("application/json".to_string()),
137            }),
138
139            system_instruction: None,
140        };
141
142        let client: reqwest::Client = reqwest::Client::builder()
143            .timeout(std::time::Duration::from_secs(30))
144            .build()
145            .map_err(|e| anyhow::anyhow!(e.to_string()))?;
146        let request_builder = client
147            .post(&self.url)
148            .header(reqwest::header::USER_AGENT, "crate/revise")
149            .header(reqwest::header::CONTENT_TYPE, "application/json");
150        let result = request_builder.json(&txt_request).send().await?;
151        match result.status() {
152            reqwest::StatusCode::OK => {
153                let response = result.json::<GeminiResponse>().await?;
154
155                let text = response
156                    .candidates
157                    .first()
158                    .ok_or_else(|| anyhow::anyhow!("No candidates found"))?
159                    .content
160                    .parts
161                    .first()
162                    .ok_or_else(|| anyhow::anyhow!("No parts found"))?
163                    .text
164                    .clone()
165                    .ok_or_else(|| anyhow::anyhow!("No text found"))?
166                    .clone();
167                let messages: Vec<Commit> = serde_json::from_str(&text)?;
168                let mut m = HashMap::new();
169                for message in messages {
170                    let msg = format!("Message: {}", message.message);
171                    let body = format!("Body: {}", message.body);
172                    m.insert(msg + "\n\r" + &body, message);
173                }
174
175                Ok(m)
176            }
177            _ => Err(anyhow::anyhow!(
178                "Failed to get response from Gemini API: {}, response: {}",
179                result.status(),
180                result.text().await?
181            )),
182        }
183    }
184}
185
186impl AI<HashMap<String, Commit>> for Gemini {
187    async fn generate_response(
188        &self,
189        input: &str,
190    ) -> ReviseResult<HashMap<String, Commit>> {
191        self.call(input).await
192    }
193}
194
195#[derive(Debug, Clone, Deserialize, Serialize)]
196pub struct Request {
197    pub contents: Vec<Content>,
198    #[serde(skip_serializing_if = "Vec::is_empty")]
199    pub tools: Vec<Tools>,
200    #[serde(skip_serializing_if = "Vec::is_empty")]
201    #[serde(default, rename = "safetySettings")]
202    pub safety_settings: Vec<SafetySettings>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    #[serde(default, rename = "generationConfig")]
205    pub generation_config: Option<GenerationConfig>,
206    #[serde(skip_serializing_if = "Option::is_none")]
207    #[serde(default, rename = "system_instruction")]
208    pub system_instruction: Option<SystemInstructionContent>,
209}
210
211#[derive(Debug, Clone, Deserialize)]
212#[serde(rename_all = "camelCase")]
213pub struct GeminiResponse {
214    pub candidates: Vec<Candidate>,
215}
216
217#[derive(Debug, Clone, Deserialize, Serialize)]
218pub struct Content {
219    pub role: Role,
220    #[serde(default)]
221    pub parts: Vec<Part>,
222}
223
224#[derive(Debug, Clone, Deserialize, Serialize)]
225#[serde(rename_all = "camelCase")]
226pub struct Part {
227    #[serde(skip_serializing_if = "Option::is_none")]
228    pub text: Option<String>,
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub inline_data: Option<InlineData>,
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub file_data: Option<FileData>,
233    #[serde(skip_serializing_if = "Option::is_none")]
234    pub video_metadata: Option<VideoMetadata>,
235}
236
237#[derive(Debug, Clone, Deserialize, Serialize)]
238#[serde(rename_all = "lowercase")]
239pub enum Role {
240    User,
241    Model,
242}
243
244#[derive(Debug, Clone, Deserialize, Serialize)]
245#[serde(rename_all = "camelCase")]
246pub struct InlineData {
247    pub mime_type: String,
248    pub data: String,
249}
250#[derive(Debug, Clone, Deserialize, Serialize)]
251#[serde(rename_all = "camelCase")]
252pub struct FileData {
253    pub mime_type: String,
254    pub file_uri: String,
255}
256#[derive(Debug, Clone, Deserialize, Serialize)]
257#[serde(rename_all = "camelCase")]
258pub struct VideoMetadata {
259    pub start_offset: StartOffset,
260    pub end_offset: EndOffset,
261}
262#[derive(Debug, Clone, Deserialize, Serialize)]
263pub struct StartOffset {
264    pub seconds: i32,
265    pub nanos: i32,
266}
267#[derive(Debug, Clone, Deserialize, Serialize)]
268pub struct EndOffset {
269    pub seconds: i32,
270    pub nanos: i32,
271}
272#[derive(Debug, Clone, Deserialize, Serialize)]
273pub struct Tools {
274    #[serde(rename = "functionDeclarations")]
275    pub function_declarations: Vec<FunctionDeclaration>,
276}
277
278#[derive(Debug, Clone, Deserialize, Serialize)]
279pub struct FunctionDeclaration {
280    pub name: String,
281    pub description: String,
282    pub parameters: serde_json::Value,
283}
284
285#[derive(Debug, Clone, Deserialize, Serialize)]
286pub struct SafetySettings {
287    pub category: HarmCategory,
288    pub threshold: HarmBlockThreshold,
289}
290#[derive(Debug, Clone, Deserialize, Serialize)]
291#[serde(rename_all = "camelCase")]
292pub struct GenerationConfig {
293    pub temperature: Option<f32>,
294    pub top_p: Option<f32>,
295    pub top_k: Option<i32>,
296    pub candidate_count: Option<i32>,
297    pub max_output_tokens: Option<i32>,
298    pub stop_sequences: Option<Vec<String>>,
299    pub response_mime_type: Option<String>,
300}
301
302#[derive(Debug, Clone, Deserialize, Serialize)]
303pub struct SystemInstructionContent {
304    #[serde(default)]
305    pub parts: Vec<SystemInstructionPart>,
306}
307
308#[derive(Debug, Clone, Deserialize, Serialize)]
309#[serde(rename_all = "camelCase")]
310pub struct SystemInstructionPart {
311    #[serde(skip_serializing_if = "Option::is_none")]
312    pub text: Option<String>,
313}
314
315#[derive(Debug, Clone, Deserialize)]
316#[serde(rename_all = "camelCase")]
317pub struct Candidate {
318    pub content: Content,
319}
320
321#[allow(clippy::enum_variant_names)]
322#[derive(Debug, Clone, Deserialize, Serialize)]
323#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
324pub enum HarmCategory {
325    HarmCategorySexuallyExplicit,
326    HarmCategoryHateSpeech,
327    HarmCategoryHarassment,
328    HarmCategoryDangerousContent,
329}
330
331#[allow(clippy::enum_variant_names)]
332#[derive(Debug, Clone, Deserialize, Serialize)]
333#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
334pub enum HarmProbability {
335    HarmProbabilityUnspecified,
336    Negligible,
337    Low,
338    Medium,
339    High,
340}
341
342#[allow(clippy::enum_variant_names)]
343#[derive(Debug, Clone, Deserialize, Serialize)]
344#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
345pub enum HarmBlockThreshold {
346    BlockNone,
347    BlockLowAndAbove,
348    BlockMedAndAbove,
349    BlockHighAndAbove,
350}
351
352#[cfg(test)]
353mod tests {
354    use tokio::sync::oneshot;
355
356    use super::*;
357
358    #[ignore]
359    #[tokio::test]
360    #[allow(clippy::needless_return)]
361    async fn test_gemini_call() {
362        dotenvy::dotenv().ok();
363        let key = std::env::var("REVISE_GEMINI_KEY").unwrap();
364        let gemini = Gemini::new(&key);
365
366        let (tx, mut rx) = oneshot::channel();
367
368        let task1 = tokio::spawn(async move {
369            let spinner = ['|', '/', '-', '\\'];
370            let mut idx = 0;
371            loop {
372                tokio::select! {
373                    () = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
374                        print!("\rGenerating... {}", spinner[idx]);
375                        std::io::Write::flush(&mut std::io::stdout()).unwrap(); // 确保立即打印字符
376                        std::thread::sleep(std::time::Duration::from_millis(300));
377                        idx = (idx + 1) % spinner.len();
378                    }
379                    _ = &mut rx => {
380                        break;
381                    }
382                }
383            }
384        });
385
386        let task2 = gemini.call("翻译: 这是一个测试");
387
388        let result = task2.await.unwrap();
389
390        let _ = tx.send(());
391
392        let _ = task1.await;
393
394        eprintln!("{result:#?}");
395    }
396}