agents_runtime/providers/
gemini.rs

1use agents_core::llm::{LanguageModel, LlmRequest, LlmResponse};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use agents_core::tools::ToolSchema;
4use async_trait::async_trait;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9#[derive(Clone)]
10pub struct GeminiConfig {
11    pub api_key: String,
12    pub model: String,
13    pub api_url: Option<String>,
14    pub custom_headers: Vec<(String, String)>,
15}
16
17impl GeminiConfig {
18    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
19        Self {
20            api_key: api_key.into(),
21            model: model.into(),
22            api_url: None,
23            custom_headers: Vec::new(),
24        }
25    }
26
27    pub fn with_custom_headers(mut self, headers: Vec<(String, String)>) -> Self {
28        self.custom_headers = headers;
29        self
30    }
31}
32
33pub struct GeminiChatModel {
34    client: Client,
35    config: GeminiConfig,
36}
37
38impl GeminiChatModel {
39    pub fn new(config: GeminiConfig) -> anyhow::Result<Self> {
40        Ok(Self {
41            client: Client::builder()
42                .user_agent("rust-deep-agents-sdk/0.1")
43                .build()?,
44            config,
45        })
46    }
47}
48
49#[derive(Serialize)]
50struct GeminiRequest {
51    contents: Vec<GeminiContent>,
52    #[serde(skip_serializing_if = "Option::is_none")]
53    system_instruction: Option<GeminiContent>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    tools: Option<Vec<GeminiToolDeclaration>>,
56}
57
58#[derive(Clone, Serialize)]
59struct GeminiToolDeclaration {
60    function_declarations: Vec<GeminiFunctionDeclaration>,
61}
62
63#[derive(Clone, Serialize)]
64struct GeminiFunctionDeclaration {
65    name: String,
66    description: String,
67    parameters: Value,
68}
69
70#[derive(Serialize)]
71struct GeminiContent {
72    role: String,
73    parts: Vec<GeminiPart>,
74}
75
76#[derive(Serialize)]
77struct GeminiPart {
78    text: String,
79}
80
81#[derive(Deserialize)]
82struct GeminiResponse {
83    candidates: Vec<GeminiCandidate>,
84}
85
86#[derive(Deserialize)]
87struct GeminiCandidate {
88    content: Option<GeminiContentResponse>,
89}
90
91#[derive(Deserialize)]
92struct GeminiContentResponse {
93    parts: Vec<GeminiPartResponse>,
94}
95
96#[derive(Deserialize)]
97struct GeminiPartResponse {
98    text: Option<String>,
99    #[serde(rename = "functionCall")]
100    function_call: Option<GeminiFunctionCall>,
101}
102
103#[derive(Deserialize)]
104struct GeminiFunctionCall {
105    name: String,
106    args: Value,
107}
108
109fn to_gemini_contents(request: &LlmRequest) -> (Vec<GeminiContent>, Option<GeminiContent>) {
110    let mut contents = Vec::new();
111    for message in &request.messages {
112        let role = match message.role {
113            MessageRole::User => "user",
114            MessageRole::Agent => "model",
115            MessageRole::Tool => "user",
116            MessageRole::System => "user",
117        };
118        let text = match &message.content {
119            MessageContent::Text(text) => text.clone(),
120            MessageContent::Json(value) => value.to_string(),
121        };
122        contents.push(GeminiContent {
123            role: role.into(),
124            parts: vec![GeminiPart { text }],
125        });
126    }
127
128    let system_instruction = if request.system_prompt.trim().is_empty() {
129        None
130    } else {
131        Some(GeminiContent {
132            role: "system".into(),
133            parts: vec![GeminiPart {
134                text: request.system_prompt.clone(),
135            }],
136        })
137    };
138
139    (contents, system_instruction)
140}
141
142/// Convert tool schemas to Gemini function declarations format
143fn to_gemini_tools(tools: &[ToolSchema]) -> Option<Vec<GeminiToolDeclaration>> {
144    if tools.is_empty() {
145        return None;
146    }
147
148    Some(vec![GeminiToolDeclaration {
149        function_declarations: tools
150            .iter()
151            .map(|tool| GeminiFunctionDeclaration {
152                name: tool.name.clone(),
153                description: tool.description.clone(),
154                parameters: serde_json::to_value(&tool.parameters)
155                    .unwrap_or_else(|_| serde_json::json!({})),
156            })
157            .collect(),
158    }])
159}
160
161#[async_trait]
162impl LanguageModel for GeminiChatModel {
163    async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
164        let (contents, system_instruction) = to_gemini_contents(&request);
165        let tools = to_gemini_tools(&request.tools);
166
167        // Debug logging (before moving contents)
168        tracing::debug!(
169            "Gemini request: model={}, contents={}, tools={}",
170            self.config.model,
171            contents.len(),
172            tools
173                .as_ref()
174                .map(|t| t
175                    .iter()
176                    .map(|td| td.function_declarations.len())
177                    .sum::<usize>())
178                .unwrap_or(0)
179        );
180
181        let body = GeminiRequest {
182            contents,
183            system_instruction,
184            tools,
185        };
186
187        let base_url = self
188            .config
189            .api_url
190            .clone()
191            .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".into());
192        let url = format!(
193            "{}/models/{}:generateContent?key={}",
194            base_url, self.config.model, self.config.api_key
195        );
196
197        let mut request = self.client.post(&url);
198
199        for (key, value) in &self.config.custom_headers {
200            request = request.header(key, value);
201        }
202
203        let response = request.json(&body).send().await?.error_for_status()?;
204
205        let data: GeminiResponse = response.json().await?;
206
207        // Check if response contains function calls
208        let function_calls: Vec<_> = data
209            .candidates
210            .iter()
211            .filter_map(|candidate| candidate.content.as_ref())
212            .flat_map(|content| &content.parts)
213            .filter_map(|part| part.function_call.as_ref())
214            .collect();
215
216        if !function_calls.is_empty() {
217            // Convert Gemini functionCall format to our JSON format
218            let tool_calls: Vec<_> = function_calls
219                .iter()
220                .map(|fc| {
221                    serde_json::json!({
222                        "name": fc.name,
223                        "args": fc.args
224                    })
225                })
226                .collect();
227
228            tracing::debug!(
229                "Gemini response contains {} function calls",
230                tool_calls.len()
231            );
232
233            return Ok(LlmResponse {
234                message: AgentMessage {
235                    role: MessageRole::Agent,
236                    content: MessageContent::Json(serde_json::json!({
237                        "tool_calls": tool_calls
238                    })),
239                    metadata: None,
240                },
241            });
242        }
243
244        // Regular text response
245        let text = data
246            .candidates
247            .into_iter()
248            .filter_map(|candidate| candidate.content)
249            .flat_map(|content| content.parts)
250            .find_map(|part| part.text)
251            .unwrap_or_default();
252
253        Ok(LlmResponse {
254            message: AgentMessage {
255                role: MessageRole::Agent,
256                content: MessageContent::Text(text),
257                metadata: None,
258            },
259        })
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn gemini_conversion_handles_system_prompt() {
269        let request = LlmRequest::new(
270            "You are concise",
271            vec![AgentMessage {
272                role: MessageRole::User,
273                content: MessageContent::Text("Hello".into()),
274                metadata: None,
275            }],
276        );
277        let (contents, system) = to_gemini_contents(&request);
278        assert_eq!(contents.len(), 1);
279        assert_eq!(contents[0].role, "user");
280        assert!(system.is_some());
281        assert_eq!(system.unwrap().parts[0].text, "You are concise");
282    }
283
284    #[test]
285    fn gemini_config_new_initializes_empty_custom_headers() {
286        let config = GeminiConfig::new("test-key", "gemini-pro");
287        assert_eq!(config.api_key, "test-key");
288        assert_eq!(config.model, "gemini-pro");
289        assert!(config.custom_headers.is_empty());
290        assert!(config.api_url.is_none());
291    }
292
293    #[test]
294    fn gemini_config_with_custom_headers_sets_headers() {
295        let headers = vec![
296            ("X-Custom-Header".to_string(), "value1".to_string()),
297            ("X-Another-Header".to_string(), "value2".to_string()),
298        ];
299        let config =
300            GeminiConfig::new("test-key", "gemini-pro").with_custom_headers(headers.clone());
301
302        assert_eq!(config.custom_headers.len(), 2);
303        assert_eq!(config.custom_headers[0].0, "X-Custom-Header");
304        assert_eq!(config.custom_headers[0].1, "value1");
305        assert_eq!(config.custom_headers[1].0, "X-Another-Header");
306        assert_eq!(config.custom_headers[1].1, "value2");
307    }
308}