1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use super::AI;
6use crate::ReviseResult;
7
8#[derive(Debug, Clone)]
9pub struct Gemini {
10 prompt: String,
11 url: String,
12}
13
14#[derive(Serialize, Deserialize, Debug)]
15pub struct Commit {
16 #[serde(rename = "type")]
17 pub kind: String,
18 pub message: String,
19 pub body: String,
20}
21
22impl Gemini {
23 pub fn new(key: &str) -> Self {
24 let prompt = r#"
25 # Character
26 You're a brilliant coding buddy with top-notch proficiency in Git. Your main duty is to assist users in crafting clear and precise Git commit messages.
27
28 ## Skills
29
30 ### Skill 1: Multilingual Translation
31 - Recognize translation requests in the format: "SourceLanguage:TargetLanguage; Content"
32 - Identify requests starting with "这是一个翻译commit" as translation tasks
33 - Translate the given content from the source language to the target language
34 - Preserve the original text alongside the translation in the output
35 - Adapt the translation to fit the context of Git commit messages
36 - Example input: "中文:English; 这是一个翻译commit, 优化用户界面布局"
37 - Example output:
38 ```json
39 [
40 {
41 "type": "translation",
42 "message": "Optimize user interface layout",
43 "body": "A long body with details about the changes made"
44 },
45 {
46 "type": "translation",
47 "message": "Optimize user interface layout",
48 "body": "A long body with details about the changes made"
49 },
50 {
51 "type": "translation",
52 "message": "Optimize user interface layout",
53 "body": "A long body with details about the changes made"
54 }
55 ]
56 ```
57
58 ### Skill 2: The Commit Message Maverick
59 - Process the git diff or description given by the user
60 - Curate commit messages that confidently and tersely summarize the changes made
61 - Always provide exactly three alternative commit messages for each request
62 - Ensure diversity in style and content among the three alternatives
63
64 ## Output Format
65 The outcome should adhere to the following structure:
66 ```json
67 [
68 {"type": "<type>", "message": "<message>", "body": "<body>"},
69 {"type": "<type>", "message": "<message>", "body": "<body>"},
70 {"type": "<type>", "message": "<message>", "body": "<body>"}
71 ]
72 ```
73
74 ## Constraints
75 - Commit messages should be between 5-20 words
76 - If the message surpasses this limit, abbreviate it without shedding essential details while employing the 'body' part for detailed elaboration
77 - Do not include prefixes like "feat:", "fix:", etc. in the commit message, just put it in <type> part, and start the <message> with a verb
78 - Guarantee that all dialogues are carried out in the English language, except for translation requests
79 - Remain concentrated on tasks strictly linked with creating Git commit messages
80 - Remember to always provide three distinct commit message options.
81
82 ## Error Handling
83 If the user's submission doesn't correspond with the demanded parameters, generate this response:
84 ```json
85 [{"type": "error", "message": "Request processing failure", "body":"The submitted input isn't compatible with the required parameters"}]
86 ```
87
88 "#;
89 let url = format!(
90 "{}/models/{}:{}?key={}",
91 "https://generativelanguage.googleapis.com/v1beta",
92 "gemini-1.5-pro-latest",
93 "generateContent",
94 key
95 );
96 Self {
97 prompt: prompt.to_string(),
98 url,
99 }
100 }
101
102 pub async fn call(
103 &self,
104 input: &str,
105 ) -> ReviseResult<HashMap<String, Commit>> {
106 let txt_request = Request {
107 contents: vec![
108 Content {
109 role: Role::User,
110 parts: vec![Part {
111 text: Some(self.prompt.clone()),
112 inline_data: None,
113 file_data: None,
114 video_metadata: None,
115 }],
116 },
117 Content {
118 role: Role::User,
119 parts: vec![Part {
120 text: Some(input.to_string()),
121 inline_data: None,
122 file_data: None,
123 video_metadata: None,
124 }],
125 },
126 ],
127 tools: vec![],
128 safety_settings: vec![],
129 generation_config: Some(GenerationConfig {
130 temperature: None,
131 top_p: None,
132 top_k: None,
133 candidate_count: None,
134 max_output_tokens: None,
135 stop_sequences: None,
136 response_mime_type: Some("application/json".to_string()),
137 }),
138
139 system_instruction: None,
140 };
141
142 let client: reqwest::Client = reqwest::Client::builder()
143 .timeout(std::time::Duration::from_secs(30))
144 .build()
145 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
146 let request_builder = client
147 .post(&self.url)
148 .header(reqwest::header::USER_AGENT, "crate/revise")
149 .header(reqwest::header::CONTENT_TYPE, "application/json");
150 let result = request_builder.json(&txt_request).send().await?;
151 match result.status() {
152 reqwest::StatusCode::OK => {
153 let response = result.json::<GeminiResponse>().await?;
154
155 let text = response
156 .candidates
157 .first()
158 .ok_or_else(|| anyhow::anyhow!("No candidates found"))?
159 .content
160 .parts
161 .first()
162 .ok_or_else(|| anyhow::anyhow!("No parts found"))?
163 .text
164 .clone()
165 .ok_or_else(|| anyhow::anyhow!("No text found"))?
166 .clone();
167 let messages: Vec<Commit> = serde_json::from_str(&text)?;
168 let mut m = HashMap::new();
169 for message in messages {
170 let msg = format!("Message: {}", message.message);
171 let body = format!("Body: {}", message.body);
172 m.insert(msg + "\n\r" + &body, message);
173 }
174
175 Ok(m)
176 }
177 _ => Err(anyhow::anyhow!(
178 "Failed to get response from Gemini API: {}, response: {}",
179 result.status(),
180 result.text().await?
181 )),
182 }
183 }
184}
185
186impl AI<HashMap<String, Commit>> for Gemini {
187 async fn generate_response(
188 &self,
189 input: &str,
190 ) -> ReviseResult<HashMap<String, Commit>> {
191 self.call(input).await
192 }
193}
194
195#[derive(Debug, Clone, Deserialize, Serialize)]
196pub struct Request {
197 pub contents: Vec<Content>,
198 #[serde(skip_serializing_if = "Vec::is_empty")]
199 pub tools: Vec<Tools>,
200 #[serde(skip_serializing_if = "Vec::is_empty")]
201 #[serde(default, rename = "safetySettings")]
202 pub safety_settings: Vec<SafetySettings>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 #[serde(default, rename = "generationConfig")]
205 pub generation_config: Option<GenerationConfig>,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 #[serde(default, rename = "system_instruction")]
208 pub system_instruction: Option<SystemInstructionContent>,
209}
210
211#[derive(Debug, Clone, Deserialize)]
212#[serde(rename_all = "camelCase")]
213pub struct GeminiResponse {
214 pub candidates: Vec<Candidate>,
215}
216
217#[derive(Debug, Clone, Deserialize, Serialize)]
218pub struct Content {
219 pub role: Role,
220 #[serde(default)]
221 pub parts: Vec<Part>,
222}
223
224#[derive(Debug, Clone, Deserialize, Serialize)]
225#[serde(rename_all = "camelCase")]
226pub struct Part {
227 #[serde(skip_serializing_if = "Option::is_none")]
228 pub text: Option<String>,
229 #[serde(skip_serializing_if = "Option::is_none")]
230 pub inline_data: Option<InlineData>,
231 #[serde(skip_serializing_if = "Option::is_none")]
232 pub file_data: Option<FileData>,
233 #[serde(skip_serializing_if = "Option::is_none")]
234 pub video_metadata: Option<VideoMetadata>,
235}
236
237#[derive(Debug, Clone, Deserialize, Serialize)]
238#[serde(rename_all = "lowercase")]
239pub enum Role {
240 User,
241 Model,
242}
243
244#[derive(Debug, Clone, Deserialize, Serialize)]
245#[serde(rename_all = "camelCase")]
246pub struct InlineData {
247 pub mime_type: String,
248 pub data: String,
249}
250#[derive(Debug, Clone, Deserialize, Serialize)]
251#[serde(rename_all = "camelCase")]
252pub struct FileData {
253 pub mime_type: String,
254 pub file_uri: String,
255}
256#[derive(Debug, Clone, Deserialize, Serialize)]
257#[serde(rename_all = "camelCase")]
258pub struct VideoMetadata {
259 pub start_offset: StartOffset,
260 pub end_offset: EndOffset,
261}
262#[derive(Debug, Clone, Deserialize, Serialize)]
263pub struct StartOffset {
264 pub seconds: i32,
265 pub nanos: i32,
266}
267#[derive(Debug, Clone, Deserialize, Serialize)]
268pub struct EndOffset {
269 pub seconds: i32,
270 pub nanos: i32,
271}
272#[derive(Debug, Clone, Deserialize, Serialize)]
273pub struct Tools {
274 #[serde(rename = "functionDeclarations")]
275 pub function_declarations: Vec<FunctionDeclaration>,
276}
277
278#[derive(Debug, Clone, Deserialize, Serialize)]
279pub struct FunctionDeclaration {
280 pub name: String,
281 pub description: String,
282 pub parameters: serde_json::Value,
283}
284
285#[derive(Debug, Clone, Deserialize, Serialize)]
286pub struct SafetySettings {
287 pub category: HarmCategory,
288 pub threshold: HarmBlockThreshold,
289}
290#[derive(Debug, Clone, Deserialize, Serialize)]
291#[serde(rename_all = "camelCase")]
292pub struct GenerationConfig {
293 pub temperature: Option<f32>,
294 pub top_p: Option<f32>,
295 pub top_k: Option<i32>,
296 pub candidate_count: Option<i32>,
297 pub max_output_tokens: Option<i32>,
298 pub stop_sequences: Option<Vec<String>>,
299 pub response_mime_type: Option<String>,
300}
301
302#[derive(Debug, Clone, Deserialize, Serialize)]
303pub struct SystemInstructionContent {
304 #[serde(default)]
305 pub parts: Vec<SystemInstructionPart>,
306}
307
308#[derive(Debug, Clone, Deserialize, Serialize)]
309#[serde(rename_all = "camelCase")]
310pub struct SystemInstructionPart {
311 #[serde(skip_serializing_if = "Option::is_none")]
312 pub text: Option<String>,
313}
314
315#[derive(Debug, Clone, Deserialize)]
316#[serde(rename_all = "camelCase")]
317pub struct Candidate {
318 pub content: Content,
319}
320
321#[allow(clippy::enum_variant_names)]
322#[derive(Debug, Clone, Deserialize, Serialize)]
323#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
324pub enum HarmCategory {
325 HarmCategorySexuallyExplicit,
326 HarmCategoryHateSpeech,
327 HarmCategoryHarassment,
328 HarmCategoryDangerousContent,
329}
330
331#[allow(clippy::enum_variant_names)]
332#[derive(Debug, Clone, Deserialize, Serialize)]
333#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
334pub enum HarmProbability {
335 HarmProbabilityUnspecified,
336 Negligible,
337 Low,
338 Medium,
339 High,
340}
341
342#[allow(clippy::enum_variant_names)]
343#[derive(Debug, Clone, Deserialize, Serialize)]
344#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
345pub enum HarmBlockThreshold {
346 BlockNone,
347 BlockLowAndAbove,
348 BlockMedAndAbove,
349 BlockHighAndAbove,
350}
351
352#[cfg(test)]
353mod tests {
354 use tokio::sync::oneshot;
355
356 use super::*;
357
358 #[ignore]
359 #[tokio::test]
360 #[allow(clippy::needless_return)]
361 async fn test_gemini_call() {
362 dotenvy::dotenv().ok();
363 let key = std::env::var("REVISE_GEMINI_KEY").unwrap();
364 let gemini = Gemini::new(&key);
365
366 let (tx, mut rx) = oneshot::channel();
367
368 let task1 = tokio::spawn(async move {
369 let spinner = ['|', '/', '-', '\\'];
370 let mut idx = 0;
371 loop {
372 tokio::select! {
373 () = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
374 print!("\rGenerating... {}", spinner[idx]);
375 std::io::Write::flush(&mut std::io::stdout()).unwrap(); std::thread::sleep(std::time::Duration::from_millis(300));
377 idx = (idx + 1) % spinner.len();
378 }
379 _ = &mut rx => {
380 break;
381 }
382 }
383 }
384 });
385
386 let task2 = gemini.call("翻译: 这是一个测试");
387
388 let result = task2.await.unwrap();
389
390 let _ = tx.send(());
391
392 let _ = task1.await;
393
394 eprintln!("{result:#?}");
395 }
396}