1use crate::{
2 batch_builder::BatchBuilder,
3 content_builder::ContentBuilder,
4 embed_builder::EmbedBuilder,
5 models::{
6 BatchContentEmbeddingResponse, BatchEmbedContentsRequest, BatchGenerateContentRequest,
7 BatchGenerateContentResponse, ContentEmbeddingResponse, EmbedContentRequest,
8 GenerateContentRequest, GenerationResponse,
9 },
10 Error, Result,
11};
12use futures::stream::Stream;
13use reqwest::Client;
14use serde_json::Value;
15use std::{pin::Pin, sync::Arc};
16use url::Url;
17
18const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
19const DEFAULT_MODEL: &str = "models/gemini-2.5-flash";
20
21pub(crate) struct GeminiClient {
23 http_client: Client,
24 api_key: String,
25 pub model: String,
26 base_url: String,
27}
28
29impl GeminiClient {
30 #[allow(dead_code)]
32 fn new(api_key: impl Into<String>, model: String) -> Self {
33 Self::with_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
34 }
35
36 fn with_base_url(api_key: impl Into<String>, model: String, base_url: String) -> Self {
38 Self {
39 http_client: Client::new(),
40 api_key: api_key.into(),
41 model,
42 base_url,
43 }
44 }
45
46 pub(crate) async fn generate_content_raw(
48 &self,
49 request: GenerateContentRequest,
50 ) -> Result<GenerationResponse> {
51 let url = self.build_url("generateContent")?;
52
53 let response = self.http_client.post(url).json(&request).send().await?;
54
55 let status = response.status();
56 if !status.is_success() {
57 let error_text = response.text().await?;
58 return Err(Error::ApiError {
59 status_code: status.as_u16(),
60 message: error_text,
61 });
62 }
63
64 let response = response.json().await?;
65
66 Ok(response)
67 }
68
69 pub(crate) async fn generate_content_stream(
71 &self,
72 request: GenerateContentRequest,
73 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
74 let url = self.build_url("streamGenerateContent")?;
75
76 let response = self.http_client.post(url).json(&request).send().await?;
77
78 let status = response.status();
79 if !status.is_success() {
80 let error_text = response.text().await?;
81 return Err(Error::ApiError {
82 status_code: status.as_u16(),
83 message: error_text,
84 });
85 }
86
87 let bytes = response.bytes().await?;
89 let text = String::from_utf8_lossy(&bytes);
90
91 let responses: Vec<Result<GenerationResponse>> =
93 match serde_json::from_str::<Vec<GenerationResponse>>(&text) {
94 Ok(json_array) => json_array.into_iter().map(Ok).collect(),
95 Err(e) => {
96 vec![Err(Error::JsonError(e))]
97 }
98 };
99
100 let stream = futures::stream::iter(responses);
101 Ok(Box::pin(stream))
102 }
103
104 pub(crate) async fn embed_content(
106 &self,
107 request: EmbedContentRequest,
108 ) -> Result<ContentEmbeddingResponse> {
109 let value = self.post_json(request, "embedContent").await?;
110 let response = serde_json::from_value::<ContentEmbeddingResponse>(value)?;
111
112 Ok(response)
113 }
114
115 pub(crate) async fn embed_content_batch(
117 &self,
118 request: BatchEmbedContentsRequest,
119 ) -> Result<BatchContentEmbeddingResponse> {
120 let value = self.post_json(request, "batchEmbedContents").await?;
121 let response = serde_json::from_value::<BatchContentEmbeddingResponse>(value)?;
122
123 Ok(response)
124 }
125
126 pub(crate) async fn batch_generate_content_sync(
128 &self,
129 request: BatchGenerateContentRequest,
130 ) -> Result<BatchGenerateContentResponse> {
131 let value = self.post_json(request, "batchGenerateContent").await?;
132 let response = serde_json::from_value::<BatchGenerateContentResponse>(value)?;
133 Ok(response)
134 }
135
136 async fn post_json<T: serde::Serialize>(&self, request: T, endpoint: &str) -> Result<Value> {
138 let url = self.build_url(endpoint)?;
139
140 let response = self.http_client.post(url).json(&request).send().await?;
141
142 let status = response.status();
143 if !status.is_success() {
144 let error_text = response.text().await?;
145 return Err(Error::ApiError {
146 status_code: status.as_u16(),
147 message: error_text,
148 });
149 }
150
151 let response = response.json().await?;
152 Ok(response)
153 }
154
155 fn build_url(&self, endpoint: &str) -> Result<Url> {
157 let url_str = format!(
158 "{}{}:{}?key={}",
159 self.base_url, self.model, endpoint, self.api_key
160 );
161 Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
162 }
163}
164
165#[derive(Clone)]
167pub struct Gemini {
168 client: Arc<GeminiClient>,
169}
170
171impl Gemini {
172 pub fn new(api_key: impl Into<String>) -> Self {
174 Self::with_model(api_key, DEFAULT_MODEL.to_string())
175 }
176
177 pub fn pro(api_key: impl Into<String>) -> Self {
179 Self::with_model(api_key, "models/gemini-2.5-pro".to_string())
180 }
181
182 pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
184 Self::with_model_and_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
185 }
186
187 pub fn with_base_url(api_key: impl Into<String>, base_url: String) -> Self {
189 Self::with_model_and_base_url(api_key, DEFAULT_MODEL.to_string(), base_url)
190 }
191
192 pub fn with_model_and_base_url(
194 api_key: impl Into<String>,
195 model: String,
196 base_url: String,
197 ) -> Self {
198 let client = GeminiClient::with_base_url(api_key, model, base_url);
199 Self {
200 client: Arc::new(client),
201 }
202 }
203
204 pub fn generate_content(&self) -> ContentBuilder {
206 ContentBuilder::new(self.client.clone())
207 }
208
209 pub fn embed_content(&self) -> EmbedBuilder {
211 EmbedBuilder::new(self.client.clone())
212 }
213
214 pub fn batch_generate_content_sync(&self) -> BatchBuilder {
216 BatchBuilder::new(self.client.clone())
217 }
218}