intelli_shell/ai/
gemini.rs

1use std::fmt::Debug;
2
3use color_eyre::eyre::Context;
4use reqwest::{Client, RequestBuilder, Response, header::HeaderName};
5use schemars::{JsonSchema, Schema};
6use serde::{Deserialize, de::DeserializeOwned};
7use serde_json::{Value as Json, json};
8
9use super::{AiProvider, AiProviderBase};
10use crate::{
11    config::GeminiModelConfig,
12    errors::{Result, UserFacingError},
13};
14
15impl AiProviderBase for GeminiModelConfig {
16    fn provider_name(&self) -> &'static str {
17        "Gemini"
18    }
19
20    fn auth_header(&self, api_key: String) -> (HeaderName, String) {
21        (HeaderName::from_static("x-goog-api-key"), api_key)
22    }
23
24    fn api_key_env_var_name(&self) -> &str {
25        &self.api_key_env
26    }
27
28    fn build_request(
29        &self,
30        client: &Client,
31        sys_prompt: &str,
32        user_prompt: &str,
33        json_schema: &Schema,
34    ) -> RequestBuilder {
35        // Request body
36        // https://ai.google.dev/api/rest/v1beta/models/generateContent
37        let request_body = json!({
38            "system_instruction": {
39                "parts": [{ "text": sys_prompt }]
40            },
41            "contents": [{
42                "role": "user",
43                "parts": [{ "text": user_prompt }]
44            }],
45            "generationConfig": {
46                "responseMimeType": "application/json",
47                "responseJsonSchema": json_schema,
48            }
49        });
50
51        tracing::trace!("Request:\n{request_body:#}");
52
53        // Generate content url
54        let url = format!("{}/models/{}:generateContent", self.url, self.model);
55
56        // Request
57        client.post(url).json(&request_body)
58    }
59}
60
61impl AiProvider for GeminiModelConfig {
62    async fn parse_response<T>(&self, res: Response) -> Result<T>
63    where
64        T: DeserializeOwned + JsonSchema + Debug,
65    {
66        // Parse successful response
67        let res: Json = res.json().await.wrap_err("Gemini response not a json")?;
68        tracing::trace!("Response:\n{res:#}");
69        let mut res: GeminiResponse = serde_json::from_value(res).wrap_err("Couldn't parse Gemini response")?;
70
71        // Validate the response content
72        if res.candidates.is_empty() {
73            tracing::error!("Response got no candidates: {res:?}");
74            return Err(UserFacingError::AiRequestFailed(String::from("received response with no candidates")).into());
75        } else if res.candidates.len() > 1 {
76            tracing::warn!("Response got {} candidates", res.candidates.len());
77        }
78
79        let mut candidate = res.candidates.remove(0);
80        if let Some(finish_reason) = candidate.finish_reason
81            && finish_reason != "STOP"
82        {
83            tracing::error!("Gemini response got an invalid finish reason: {finish_reason}");
84            return Err(UserFacingError::AiRequestFailed(format!(
85                "couldn't generate a valid response: {finish_reason}"
86            ))
87            .into());
88        }
89
90        if candidate.content.parts.is_empty() {
91            tracing::error!("Response candidate got no parts");
92            return Err(
93                UserFacingError::AiRequestFailed(String::from("received response candidate with no parts")).into(),
94            );
95        } else if candidate.content.parts.len() > 1 {
96            tracing::warn!("Response candidate got {} parts", candidate.content.parts.len());
97        }
98
99        let part = candidate.content.parts.remove(0);
100        let Some(text) = part.text.filter(|c| !c.trim().is_empty()) else {
101            tracing::error!("Gemini returned an empty candidate part");
102            return Err(UserFacingError::AiRequestFailed(String::from("received an empty response")).into());
103        };
104
105        // Parse the text
106        Ok(serde_json::from_str(&text).map_err(|err| {
107            tracing::error!("Couldn't parse API response into the expected format: {err}\nText:\n{text}");
108            UserFacingError::AiRequestFailed(String::from("couldn't parse api response into the expected format"))
109        })?)
110    }
111}
112
113#[derive(Debug, Deserialize)]
114struct GeminiResponse {
115    candidates: Vec<GeminiCandidate>,
116}
117
118#[derive(Debug, Deserialize)]
119#[serde(rename_all = "camelCase")]
120struct GeminiCandidate {
121    content: GeminiContent,
122    finish_reason: Option<String>,
123}
124
125#[derive(Debug, Deserialize)]
126struct GeminiContent {
127    #[serde(default)]
128    parts: Vec<GeminiPart>,
129}
130
131#[derive(Debug, Deserialize)]
132struct GeminiPart {
133    text: Option<String>,
134}
135
136#[cfg(test)]
137mod tests {
138    use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
139
140    use super::*;
141    use crate::{ai::AiClient, config::AiModelConfig};
142
143    #[tokio::test]
144    #[ignore] // Real API calls require valid api keys
145    async fn test_gemini_api() -> Result<()> {
146        tracing_subscriber::registry()
147            .with(tracing_subscriber::fmt::layer().compact())
148            .init();
149        let config = AiModelConfig::Gemini(GeminiModelConfig {
150            model: "gemini-2.5-flash-lite".into(),
151            url: "https://generativelanguage.googleapis.com/v1beta".into(),
152            api_key_env: "GEMINI_API_KEY".into(),
153        });
154        let client = AiClient::new("test", &config, "", None)?;
155        let res = client
156            .generate_command_suggestions(
157                "you're a cli expert, that will proide command suggestions based on what the user want to do",
158                "undo last n amount of commits",
159            )
160            .await?;
161        tracing::info!("Suggestions:");
162        for command in res.suggestions {
163            tracing::info!("  # {}", command.description);
164            tracing::info!("  {}", command.command);
165        }
166        Ok(())
167    }
168}