1use std::{fmt, sync::Arc};
2
3use agent_sdk_core::{
4 AgentError, ProviderAdapter, ProviderCapabilities, ProviderMessageRole,
5 ProviderProjectionPolicy, ProviderRequest, ProviderResponse, ProviderStopReason,
6 ProviderToolCall, ProviderUsage, RetryClassification, ToolCallId,
7 tool_records::CanonicalToolName,
8};
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11
12use crate::{
13 ProviderApiKey, ProviderToolArgumentSink,
14 error::{provider_failure, unsupported_response},
15 http::{CurlJsonHttpTransport, JsonHttpRequest, JsonHttpTransport},
16};
17
18#[derive(Clone, Debug, Eq, PartialEq)]
19pub struct GeminiGenerateContentConfig {
21 pub provider_ref: String,
23 pub model: String,
25 pub endpoint_base: String,
27 pub max_input_tokens: Option<u32>,
29}
30
31impl GeminiGenerateContentConfig {
32 pub fn new(model: impl Into<String>) -> Self {
34 Self {
35 provider_ref: "provider.gemini.generate_content".to_string(),
36 model: model.into(),
37 endpoint_base: "https://generativelanguage.googleapis.com/v1beta".to_string(),
38 max_input_tokens: None,
39 }
40 }
41
42 pub fn provider_ref(mut self, provider_ref: impl Into<String>) -> Self {
44 self.provider_ref = provider_ref.into();
45 self
46 }
47
48 pub fn endpoint_base(mut self, endpoint_base: impl Into<String>) -> Self {
50 self.endpoint_base = endpoint_base.into();
51 self
52 }
53
54 pub fn max_input_tokens(mut self, max_input_tokens: u32) -> Self {
56 self.max_input_tokens = Some(max_input_tokens);
57 self
58 }
59
60 fn endpoint_url(&self) -> String {
61 let model = self.model.trim_start_matches("models/");
62 format!(
63 "{}/models/{model}:generateContent",
64 self.endpoint_base.trim_end_matches('/')
65 )
66 }
67}
68
69#[derive(Clone)]
70pub struct GeminiGenerateContentAdapter {
72 config: GeminiGenerateContentConfig,
73 api_key: ProviderApiKey,
74 http: Arc<dyn JsonHttpTransport>,
75 argument_sink: Option<Arc<dyn ProviderToolArgumentSink>>,
76}
77
78impl GeminiGenerateContentAdapter {
79 pub fn from_env(model: impl Into<String>) -> Result<Self, AgentError> {
81 Self::new(
82 GeminiGenerateContentConfig::new(model),
83 ProviderApiKey::from_env("GEMINI_API_KEY")?,
84 )
85 }
86
87 pub fn new(
89 config: GeminiGenerateContentConfig,
90 api_key: ProviderApiKey,
91 ) -> Result<Self, AgentError> {
92 Self::with_transport(config, api_key, Arc::new(CurlJsonHttpTransport::new()))
93 }
94
95 pub fn with_transport(
97 config: GeminiGenerateContentConfig,
98 api_key: ProviderApiKey,
99 http: Arc<dyn JsonHttpTransport>,
100 ) -> Result<Self, AgentError> {
101 Ok(Self {
102 config,
103 api_key,
104 http,
105 argument_sink: None,
106 })
107 }
108
109 pub fn with_argument_sink(mut self, sink: Arc<dyn ProviderToolArgumentSink>) -> Self {
111 self.argument_sink = Some(sink);
112 self
113 }
114
115 fn wire_request(&self, request: &ProviderRequest) -> Value {
116 let mut system = Vec::new();
117 let mut contents = Vec::new();
118 for message in &request.messages {
119 match message.role {
120 ProviderMessageRole::System | ProviderMessageRole::Developer => {
121 system.push(message.content.clone());
122 }
123 ProviderMessageRole::Assistant => {
124 contents.push(gemini_text_content("model", message.content.clone()));
125 }
126 ProviderMessageRole::Tool => {
127 contents.push(gemini_text_content(
128 "user",
129 format!("Tool result:\n{}", message.content),
130 ));
131 }
132 ProviderMessageRole::Context | ProviderMessageRole::User => {
133 contents.push(gemini_text_content("user", message.content.clone()));
134 }
135 }
136 }
137
138 let mut body = json!({ "contents": contents });
139 if !system.is_empty() {
140 body["systemInstruction"] = json!({
141 "parts": [{ "text": system.join("\n\n") }]
142 });
143 }
144 if let Some(generation_config) = gemini_generation_config(request) {
145 body["generationConfig"] = generation_config;
146 }
147 body
148 }
149
150 fn map_response(
151 &self,
152 response: GeminiGenerateContentResponse,
153 ) -> Result<ProviderResponse, AgentError> {
154 let tool_calls = self.tool_calls_from_response(&response)?;
155 let usage = response.usage_metadata.clone().map(ProviderUsage::from);
156 if !tool_calls.is_empty() {
157 let mut mapped = ProviderResponse::tool_use(tool_calls);
158 mapped.usage = usage;
159 return Ok(mapped);
160 }
161 Ok(ProviderResponse {
162 schema_version: ProviderResponse::SCHEMA_VERSION,
163 output_text: response.output_text(),
164 stop_reason: response.stop_reason(),
165 tool_calls: Vec::new(),
166 usage,
167 })
168 }
169
170 fn tool_calls_from_response(
171 &self,
172 response: &GeminiGenerateContentResponse,
173 ) -> Result<Vec<ProviderToolCall>, AgentError> {
174 let mut calls = Vec::new();
175 for candidate in &response.candidates {
176 if let Some(content) = &candidate.content {
177 for part in &content.parts {
178 let Some(function_call) = &part.function_call else {
179 continue;
180 };
181 let name = function_call.name.as_deref().ok_or_else(|| {
182 unsupported_response("Gemini generateContent", "functionCall missing name")
183 })?;
184 let call_id = function_call
185 .id
186 .clone()
187 .unwrap_or_else(|| format!("gemini_call_{}", calls.len()));
188 let canonical_tool_name = CanonicalToolName::new(name);
189 let mut call = ProviderToolCall::new(
190 ToolCallId::new(call_id.clone()),
191 canonical_tool_name.clone(),
192 format!(
193 "provider requested tool {name} with arguments stored as content refs"
194 ),
195 );
196 if let (Some(sink), Some(args)) =
197 (self.argument_sink.as_ref(), function_call.args.as_ref())
198 {
199 let raw_arguments = serde_json::to_string(args).map_err(|error| {
200 provider_failure(
201 RetryClassification::RepairNeeded,
202 format!(
203 "Gemini functionCall args could not be serialized: {error}"
204 ),
205 )
206 })?;
207 if let Some(args_ref) = sink.store_tool_arguments(
208 &self.config.provider_ref,
209 &call_id,
210 &canonical_tool_name,
211 &raw_arguments,
212 )? {
213 call = call.with_args_ref(args_ref);
214 }
215 }
216 calls.push(call);
217 }
218 }
219 }
220 Ok(calls)
221 }
222}
223
224impl ProviderAdapter for GeminiGenerateContentAdapter {
225 fn capabilities(&self) -> ProviderCapabilities {
226 let mut capabilities = ProviderCapabilities::text_only(self.config.provider_ref.clone());
227 capabilities.supports_usage = true;
228 capabilities.max_input_tokens = self.config.max_input_tokens;
229 capabilities
230 }
231
232 fn project_request(
233 &self,
234 projection: &agent_sdk_core::ContextProjection,
235 policy: &ProviderProjectionPolicy,
236 ) -> Result<ProviderRequest, AgentError> {
237 agent_sdk_core::projection::project_context_projection(projection, policy)
238 }
239
240 fn complete(&self, request: &ProviderRequest) -> Result<ProviderResponse, AgentError> {
241 let http_request =
242 JsonHttpRequest::new(self.config.endpoint_url(), self.wire_request(request))
243 .header("x-goog-api-key", self.api_key.expose_secret())
244 .header("Content-Type", "application/json");
245 let response = self.http.post_json(http_request)?;
246 let message = serde_json::from_value::<GeminiGenerateContentResponse>(response.body)
247 .map_err(|error| unsupported_response("Gemini generateContent", error.to_string()))?;
248 self.map_response(message)
249 }
250}
251
252fn gemini_text_content(role: &str, text: String) -> Value {
253 json!({
254 "role": role,
255 "parts": [{ "text": text }],
256 })
257}
258
259fn gemini_generation_config(request: &ProviderRequest) -> Option<Value> {
260 let hint = request.structured_output_hint.as_ref()?;
261 let schema = hint.redacted_schema.clone()?;
262 Some(json!({
263 "responseMimeType": "application/json",
264 "responseJsonSchema": schema,
265 }))
266}
267
268#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
269pub struct GeminiGenerateContentResponse {
271 #[serde(default)]
273 pub candidates: Vec<GeminiCandidate>,
274 #[serde(rename = "usageMetadata")]
276 pub usage_metadata: Option<GeminiUsage>,
277}
278
279impl GeminiGenerateContentResponse {
280 pub fn text(text: impl Into<String>) -> Self {
282 Self {
283 candidates: vec![GeminiCandidate {
284 content: Some(GeminiContent {
285 role: Some("model".to_string()),
286 parts: vec![GeminiPart::text(text)],
287 }),
288 finish_reason: Some("STOP".to_string()),
289 }],
290 usage_metadata: None,
291 }
292 }
293
294 pub fn function_call(id: impl Into<String>, name: impl Into<String>, args: Value) -> Self {
296 Self {
297 candidates: vec![GeminiCandidate {
298 content: Some(GeminiContent {
299 role: Some("model".to_string()),
300 parts: vec![GeminiPart::function_call(id, name, args)],
301 }),
302 finish_reason: Some("STOP".to_string()),
303 }],
304 usage_metadata: None,
305 }
306 }
307
308 fn output_text(&self) -> String {
309 self.candidates
310 .iter()
311 .filter_map(|candidate| candidate.content.as_ref())
312 .flat_map(|content| content.parts.iter())
313 .filter_map(|part| part.text.as_deref())
314 .collect::<Vec<_>>()
315 .join("")
316 }
317
318 fn stop_reason(&self) -> ProviderStopReason {
319 let reason = self
320 .candidates
321 .first()
322 .and_then(|candidate| candidate.finish_reason.as_deref())
323 .unwrap_or("STOP");
324 match reason {
325 "STOP" => ProviderStopReason::EndTurn,
326 "MAX_TOKENS" => ProviderStopReason::MaxTokens,
327 "SAFETY" | "RECITATION" | "MALFORMED_FUNCTION_CALL" => {
328 ProviderStopReason::ProviderError
329 }
330 _ => ProviderStopReason::Unknown,
331 }
332 }
333}
334
335impl fmt::Debug for GeminiGenerateContentResponse {
336 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
337 formatter
338 .debug_struct("GeminiGenerateContentResponse")
339 .field("candidate_count", &self.candidates.len())
340 .field("candidates", &self.candidates)
341 .field("usage_metadata", &self.usage_metadata)
342 .finish()
343 }
344}
345
346#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
347pub struct GeminiCandidate {
349 pub content: Option<GeminiContent>,
351 #[serde(rename = "finishReason")]
353 pub finish_reason: Option<String>,
354}
355
356impl fmt::Debug for GeminiCandidate {
357 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
358 formatter
359 .debug_struct("GeminiCandidate")
360 .field("content", &self.content)
361 .field("finish_reason", &self.finish_reason)
362 .finish()
363 }
364}
365
366#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
367pub struct GeminiContent {
369 pub role: Option<String>,
371 #[serde(default)]
373 pub parts: Vec<GeminiPart>,
374}
375
376impl fmt::Debug for GeminiContent {
377 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
378 formatter
379 .debug_struct("GeminiContent")
380 .field("role", &self.role)
381 .field("part_count", &self.parts.len())
382 .field("parts", &self.parts)
383 .finish()
384 }
385}
386
387#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
388pub struct GeminiPart {
390 pub text: Option<String>,
392 #[serde(rename = "functionCall")]
394 pub function_call: Option<GeminiFunctionCall>,
395}
396
397impl GeminiPart {
398 pub fn text(text: impl Into<String>) -> Self {
400 Self {
401 text: Some(text.into()),
402 function_call: None,
403 }
404 }
405
406 pub fn function_call(id: impl Into<String>, name: impl Into<String>, args: Value) -> Self {
408 Self {
409 text: None,
410 function_call: Some(GeminiFunctionCall {
411 id: Some(id.into()),
412 name: Some(name.into()),
413 args: Some(args),
414 }),
415 }
416 }
417}
418
419impl fmt::Debug for GeminiPart {
420 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
421 formatter
422 .debug_struct("GeminiPart")
423 .field(
424 "text_chars",
425 &self.text.as_ref().map(|value| value.chars().count()),
426 )
427 .field("function_call", &self.function_call)
428 .finish()
429 }
430}
431
432#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
433pub struct GeminiFunctionCall {
435 pub id: Option<String>,
437 pub name: Option<String>,
439 pub args: Option<Value>,
441}
442
443impl fmt::Debug for GeminiFunctionCall {
444 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
445 formatter
446 .debug_struct("GeminiFunctionCall")
447 .field("id", &self.id)
448 .field("name", &self.name)
449 .field("args", &"<redacted>")
450 .field("args_present", &self.args.is_some())
451 .finish()
452 }
453}
454
455#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
456pub struct GeminiUsage {
458 #[serde(rename = "promptTokenCount")]
460 pub prompt_token_count: Option<u32>,
461 #[serde(rename = "candidatesTokenCount")]
463 pub candidates_token_count: Option<u32>,
464 #[serde(rename = "totalTokenCount")]
466 pub total_token_count: Option<u32>,
467}
468
469impl From<GeminiUsage> for ProviderUsage {
470 fn from(value: GeminiUsage) -> Self {
471 Self {
472 input_tokens: value.prompt_token_count,
473 output_tokens: value.candidates_token_count,
474 total_tokens: value.total_token_count,
475 }
476 }
477}