agents_runtime/providers/
gemini.rs1use agents_core::llm::{LanguageModel, LlmRequest, LlmResponse};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use agents_core::tools::ToolSchema;
4use async_trait::async_trait;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9#[derive(Clone)]
10pub struct GeminiConfig {
11 pub api_key: String,
12 pub model: String,
13 pub api_url: Option<String>,
14 pub custom_headers: Vec<(String, String)>,
15}
16
17impl GeminiConfig {
18 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
19 Self {
20 api_key: api_key.into(),
21 model: model.into(),
22 api_url: None,
23 custom_headers: Vec::new(),
24 }
25 }
26
27 pub fn with_custom_headers(mut self, headers: Vec<(String, String)>) -> Self {
28 self.custom_headers = headers;
29 self
30 }
31}
32
33pub struct GeminiChatModel {
34 client: Client,
35 config: GeminiConfig,
36}
37
38impl GeminiChatModel {
39 pub fn new(config: GeminiConfig) -> anyhow::Result<Self> {
40 Ok(Self {
41 client: Client::builder()
42 .user_agent("rust-deep-agents-sdk/0.1")
43 .build()?,
44 config,
45 })
46 }
47}
48
49#[derive(Serialize)]
50struct GeminiRequest {
51 contents: Vec<GeminiContent>,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 system_instruction: Option<GeminiContent>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 tools: Option<Vec<GeminiToolDeclaration>>,
56}
57
58#[derive(Clone, Serialize)]
59struct GeminiToolDeclaration {
60 function_declarations: Vec<GeminiFunctionDeclaration>,
61}
62
63#[derive(Clone, Serialize)]
64struct GeminiFunctionDeclaration {
65 name: String,
66 description: String,
67 parameters: Value,
68}
69
70#[derive(Serialize)]
71struct GeminiContent {
72 role: String,
73 parts: Vec<GeminiPart>,
74}
75
76#[derive(Serialize)]
77struct GeminiPart {
78 text: String,
79}
80
81#[derive(Deserialize)]
82struct GeminiResponse {
83 candidates: Vec<GeminiCandidate>,
84}
85
86#[derive(Deserialize)]
87struct GeminiCandidate {
88 content: Option<GeminiContentResponse>,
89}
90
91#[derive(Deserialize)]
92struct GeminiContentResponse {
93 parts: Vec<GeminiPartResponse>,
94}
95
96#[derive(Deserialize)]
97struct GeminiPartResponse {
98 text: Option<String>,
99 #[serde(rename = "functionCall")]
100 function_call: Option<GeminiFunctionCall>,
101}
102
103#[derive(Deserialize)]
104struct GeminiFunctionCall {
105 name: String,
106 args: Value,
107}
108
109fn to_gemini_contents(request: &LlmRequest) -> (Vec<GeminiContent>, Option<GeminiContent>) {
110 let mut contents = Vec::new();
111 for message in &request.messages {
112 let role = match message.role {
113 MessageRole::User => "user",
114 MessageRole::Agent => "model",
115 MessageRole::Tool => "user",
116 MessageRole::System => "user",
117 };
118 let text = match &message.content {
119 MessageContent::Text(text) => text.clone(),
120 MessageContent::Json(value) => value.to_string(),
121 };
122 contents.push(GeminiContent {
123 role: role.into(),
124 parts: vec![GeminiPart { text }],
125 });
126 }
127
128 let system_instruction = if request.system_prompt.trim().is_empty() {
129 None
130 } else {
131 Some(GeminiContent {
132 role: "system".into(),
133 parts: vec![GeminiPart {
134 text: request.system_prompt.clone(),
135 }],
136 })
137 };
138
139 (contents, system_instruction)
140}
141
142fn to_gemini_tools(tools: &[ToolSchema]) -> Option<Vec<GeminiToolDeclaration>> {
144 if tools.is_empty() {
145 return None;
146 }
147
148 Some(vec![GeminiToolDeclaration {
149 function_declarations: tools
150 .iter()
151 .map(|tool| GeminiFunctionDeclaration {
152 name: tool.name.clone(),
153 description: tool.description.clone(),
154 parameters: serde_json::to_value(&tool.parameters)
155 .unwrap_or_else(|_| serde_json::json!({})),
156 })
157 .collect(),
158 }])
159}
160
161#[async_trait]
162impl LanguageModel for GeminiChatModel {
163 async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
164 let (contents, system_instruction) = to_gemini_contents(&request);
165 let tools = to_gemini_tools(&request.tools);
166
167 tracing::debug!(
169 "Gemini request: model={}, contents={}, tools={}",
170 self.config.model,
171 contents.len(),
172 tools
173 .as_ref()
174 .map(|t| t
175 .iter()
176 .map(|td| td.function_declarations.len())
177 .sum::<usize>())
178 .unwrap_or(0)
179 );
180
181 let body = GeminiRequest {
182 contents,
183 system_instruction,
184 tools,
185 };
186
187 let base_url = self
188 .config
189 .api_url
190 .clone()
191 .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".into());
192 let url = format!(
193 "{}/models/{}:generateContent?key={}",
194 base_url, self.config.model, self.config.api_key
195 );
196
197 let mut request = self.client.post(&url);
198
199 for (key, value) in &self.config.custom_headers {
200 request = request.header(key, value);
201 }
202
203 let response = request.json(&body).send().await?.error_for_status()?;
204
205 let data: GeminiResponse = response.json().await?;
206
207 let function_calls: Vec<_> = data
209 .candidates
210 .iter()
211 .filter_map(|candidate| candidate.content.as_ref())
212 .flat_map(|content| &content.parts)
213 .filter_map(|part| part.function_call.as_ref())
214 .collect();
215
216 if !function_calls.is_empty() {
217 let tool_calls: Vec<_> = function_calls
219 .iter()
220 .map(|fc| {
221 serde_json::json!({
222 "name": fc.name,
223 "args": fc.args
224 })
225 })
226 .collect();
227
228 tracing::debug!(
229 "Gemini response contains {} function calls",
230 tool_calls.len()
231 );
232
233 return Ok(LlmResponse {
234 message: AgentMessage {
235 role: MessageRole::Agent,
236 content: MessageContent::Json(serde_json::json!({
237 "tool_calls": tool_calls
238 })),
239 metadata: None,
240 },
241 });
242 }
243
244 let text = data
246 .candidates
247 .into_iter()
248 .filter_map(|candidate| candidate.content)
249 .flat_map(|content| content.parts)
250 .find_map(|part| part.text)
251 .unwrap_or_default();
252
253 Ok(LlmResponse {
254 message: AgentMessage {
255 role: MessageRole::Agent,
256 content: MessageContent::Text(text),
257 metadata: None,
258 },
259 })
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn gemini_conversion_handles_system_prompt() {
269 let request = LlmRequest::new(
270 "You are concise",
271 vec![AgentMessage {
272 role: MessageRole::User,
273 content: MessageContent::Text("Hello".into()),
274 metadata: None,
275 }],
276 );
277 let (contents, system) = to_gemini_contents(&request);
278 assert_eq!(contents.len(), 1);
279 assert_eq!(contents[0].role, "user");
280 assert!(system.is_some());
281 assert_eq!(system.unwrap().parts[0].text, "You are concise");
282 }
283
284 #[test]
285 fn gemini_config_new_initializes_empty_custom_headers() {
286 let config = GeminiConfig::new("test-key", "gemini-pro");
287 assert_eq!(config.api_key, "test-key");
288 assert_eq!(config.model, "gemini-pro");
289 assert!(config.custom_headers.is_empty());
290 assert!(config.api_url.is_none());
291 }
292
293 #[test]
294 fn gemini_config_with_custom_headers_sets_headers() {
295 let headers = vec![
296 ("X-Custom-Header".to_string(), "value1".to_string()),
297 ("X-Another-Header".to_string(), "value2".to_string()),
298 ];
299 let config =
300 GeminiConfig::new("test-key", "gemini-pro").with_custom_headers(headers.clone());
301
302 assert_eq!(config.custom_headers.len(), 2);
303 assert_eq!(config.custom_headers[0].0, "X-Custom-Header");
304 assert_eq!(config.custom_headers[0].1, "value1");
305 assert_eq!(config.custom_headers[1].0, "X-Another-Header");
306 assert_eq!(config.custom_headers[1].1, "value2");
307 }
308}