agents_runtime/providers/
gemini.rs

1use agents_core::llm::{LanguageModel, LlmRequest, LlmResponse};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone)]
8pub struct GeminiConfig {
9    pub api_key: String,
10    pub model: String,
11    pub api_url: Option<String>,
12}
13
14pub struct GeminiChatModel {
15    client: Client,
16    config: GeminiConfig,
17}
18
19impl GeminiChatModel {
20    pub fn new(config: GeminiConfig) -> anyhow::Result<Self> {
21        Ok(Self {
22            client: Client::builder()
23                .user_agent("rust-deep-agents-sdk/0.1")
24                .build()?,
25            config,
26        })
27    }
28}
29
30#[derive(Serialize)]
31struct GeminiRequest {
32    contents: Vec<GeminiContent>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    system_instruction: Option<GeminiContent>,
35}
36
37#[derive(Serialize)]
38struct GeminiContent {
39    role: String,
40    parts: Vec<GeminiPart>,
41}
42
43#[derive(Serialize)]
44struct GeminiPart {
45    text: String,
46}
47
48#[derive(Deserialize)]
49struct GeminiResponse {
50    candidates: Vec<GeminiCandidate>,
51}
52
53#[derive(Deserialize)]
54struct GeminiCandidate {
55    content: Option<GeminiContentResponse>,
56}
57
58#[derive(Deserialize)]
59struct GeminiContentResponse {
60    parts: Vec<GeminiPartResponse>,
61}
62
63#[derive(Deserialize)]
64struct GeminiPartResponse {
65    text: Option<String>,
66}
67
68fn to_gemini_contents(request: &LlmRequest) -> (Vec<GeminiContent>, Option<GeminiContent>) {
69    let mut contents = Vec::new();
70    for message in &request.messages {
71        let role = match message.role {
72            MessageRole::User => "user",
73            MessageRole::Agent => "model",
74            MessageRole::Tool => "user",
75            MessageRole::System => "user",
76        };
77        let text = match &message.content {
78            MessageContent::Text(text) => text.clone(),
79            MessageContent::Json(value) => value.to_string(),
80        };
81        contents.push(GeminiContent {
82            role: role.into(),
83            parts: vec![GeminiPart { text }],
84        });
85    }
86
87    let system_instruction = if request.system_prompt.trim().is_empty() {
88        None
89    } else {
90        Some(GeminiContent {
91            role: "system".into(),
92            parts: vec![GeminiPart {
93                text: request.system_prompt.clone(),
94            }],
95        })
96    };
97
98    (contents, system_instruction)
99}
100
101#[async_trait]
102impl LanguageModel for GeminiChatModel {
103    async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
104        let (contents, system_instruction) = to_gemini_contents(&request);
105        let body = GeminiRequest {
106            contents,
107            system_instruction,
108        };
109
110        let base_url = self
111            .config
112            .api_url
113            .clone()
114            .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".into());
115        let url = format!(
116            "{}/models/{}:generateContent?key={}",
117            base_url, self.config.model, self.config.api_key
118        );
119
120        let response = self
121            .client
122            .post(&url)
123            .json(&body)
124            .send()
125            .await?
126            .error_for_status()?;
127
128        let data: GeminiResponse = response.json().await?;
129        let text = data
130            .candidates
131            .into_iter()
132            .filter_map(|candidate| candidate.content)
133            .flat_map(|content| content.parts)
134            .find_map(|part| part.text)
135            .unwrap_or_default();
136
137        Ok(LlmResponse {
138            message: AgentMessage {
139                role: MessageRole::Agent,
140                content: MessageContent::Text(text),
141                metadata: None,
142            },
143        })
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn gemini_conversion_handles_system_prompt() {
153        let request = LlmRequest::new(
154            "You are concise",
155            vec![AgentMessage {
156                role: MessageRole::User,
157                content: MessageContent::Text("Hello".into()),
158                metadata: None,
159            }],
160        );
161        let (contents, system) = to_gemini_contents(&request);
162        assert_eq!(contents.len(), 1);
163        assert_eq!(contents[0].role, "user");
164        assert!(system.is_some());
165        assert_eq!(system.unwrap().parts[0].text, "You are concise");
166    }
167}