1use crate::{
2 content_builder::ContentBuilder,
3 embed_builder::EmbedBuilder,
4 models::{
5 BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
6 EmbedContentRequest, GenerateContentRequest, GenerationResponse,
7 },
8 Error, Result,
9};
10use futures::stream::Stream;
11use reqwest::Client;
12use serde_json::Value;
13use std::pin::Pin;
14use std::sync::Arc;
15use url::Url;
16
17const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
18const DEFAULT_MODEL: &str = "models/gemini-2.5-flash";
19
20pub(crate) struct GeminiClient {
22 http_client: Client,
23 api_key: String,
24 pub model: String,
25 base_url: String,
26}
27
28impl GeminiClient {
29 #[allow(dead_code)]
31 fn new(api_key: impl Into<String>, model: String) -> Self {
32 Self::with_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
33 }
34
35 fn with_base_url(api_key: impl Into<String>, model: String, base_url: String) -> Self {
37 Self {
38 http_client: Client::new(),
39 api_key: api_key.into(),
40 model,
41 base_url,
42 }
43 }
44
45 pub(crate) async fn generate_content_raw(
47 &self,
48 request: GenerateContentRequest,
49 ) -> Result<GenerationResponse> {
50 let url = self.build_url("generateContent")?;
51
52 let response = self.http_client.post(url).json(&request).send().await?;
53
54 let status = response.status();
55 if !status.is_success() {
56 let error_text = response.text().await?;
57 return Err(Error::ApiError {
58 status_code: status.as_u16(),
59 message: error_text,
60 });
61 }
62
63 let response = response.json().await?;
64
65 Ok(response)
66 }
67
68 pub(crate) async fn generate_content_stream(
70 &self,
71 request: GenerateContentRequest,
72 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
73 let url = self.build_url("streamGenerateContent")?;
74
75 let response = self.http_client.post(url).json(&request).send().await?;
76
77 let status = response.status();
78 if !status.is_success() {
79 let error_text = response.text().await?;
80 return Err(Error::ApiError {
81 status_code: status.as_u16(),
82 message: error_text,
83 });
84 }
85
86 let bytes = response.bytes().await?;
88 let text = String::from_utf8_lossy(&bytes);
89
90 let responses: Vec<Result<GenerationResponse>> =
92 match serde_json::from_str::<Vec<GenerationResponse>>(&text) {
93 Ok(json_array) => json_array.into_iter().map(Ok).collect(),
94 Err(e) => {
95 vec![Err(Error::JsonError(e))]
96 }
97 };
98
99 let stream = futures::stream::iter(responses);
100 Ok(Box::pin(stream))
101 }
102
103 pub(crate) async fn embed_content(
105 &self,
106 request: EmbedContentRequest,
107 ) -> Result<ContentEmbeddingResponse> {
108 let value = self.embed(request, "embedContent").await?;
109 let response = serde_json::from_value::<ContentEmbeddingResponse>(value)?;
110
111 Ok(response)
112 }
113
114 pub(crate) async fn embed_content_batch(
116 &self,
117 request: BatchEmbedContentsRequest,
118 ) -> Result<BatchContentEmbeddingResponse> {
119 let value = self.embed(request, "batchEmbedContents").await?;
120 let response = serde_json::from_value::<BatchContentEmbeddingResponse>(value)?;
121
122 Ok(response)
123 }
124
125 async fn embed<T: serde::Serialize>(&self, request: T, endpoint: &str) -> Result<Value> {
127 let url = self.build_url(endpoint)?;
128
129 let response = self.http_client.post(url).json(&request).send().await?;
130
131 let status = response.status();
132 if !status.is_success() {
133 let error_text = response.text().await?;
134 return Err(Error::ApiError {
135 status_code: status.as_u16(),
136 message: error_text,
137 });
138 }
139
140 let response = response.json().await?;
141 Ok(response)
142 }
143
144 fn build_url(&self, endpoint: &str) -> Result<Url> {
146 let url_str = format!(
149 "{}{}:{}?key={}",
150 self.base_url, self.model, endpoint, self.api_key
151 );
152 Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
153 }
154}
155
156#[derive(Clone)]
158pub struct Gemini {
159 client: Arc<GeminiClient>,
160}
161
162impl Gemini {
163 pub fn new(api_key: impl Into<String>) -> Self {
165 Self::with_model(api_key, DEFAULT_MODEL.to_string())
166 }
167
168 pub fn pro(api_key: impl Into<String>) -> Self {
170 Self::with_model(api_key, "models/gemini-2.5-pro".to_string())
171 }
172
173 pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
175 Self::with_model_and_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
176 }
177
178 pub fn with_base_url(api_key: impl Into<String>, base_url: String) -> Self {
180 Self::with_model_and_base_url(api_key, DEFAULT_MODEL.to_string(), base_url)
181 }
182
183 pub fn with_model_and_base_url(
185 api_key: impl Into<String>,
186 model: String,
187 base_url: String,
188 ) -> Self {
189 let client = GeminiClient::with_base_url(api_key, model, base_url);
190 Self {
191 client: Arc::new(client),
192 }
193 }
194
195 pub fn generate_content(&self) -> ContentBuilder {
197 ContentBuilder::new(self.client.clone())
198 }
199
200 pub fn embed_content(&self) -> EmbedBuilder {
202 EmbedBuilder::new(self.client.clone())
203 }
204}