1use crate::{
2 content_builder::ContentBuilder,
3 embed_builder::EmbedBuilder,
4 models::{
5 BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
6 EmbedContentRequest, GenerateContentRequest, GenerationResponse,
7 },
8 Error, Result,
9};
10use futures::stream::Stream;
11use futures_util::StreamExt;
12use reqwest::Client;
13use serde_json::Value;
14use std::pin::Pin;
15use std::sync::Arc;
16use url::Url;
17
18const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
19const DEFAULT_MODEL: &str = "models/gemini-2.0-flash";
20
21pub(crate) struct GeminiClient {
23 http_client: Client,
24 api_key: String,
25 pub model: String,
26 base_url: String,
27}
28
29impl GeminiClient {
30 #[allow(dead_code)]
32 fn new(api_key: impl Into<String>, model: String) -> Self {
33 Self::with_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
34 }
35
36 fn with_base_url(api_key: impl Into<String>, model: String, base_url: String) -> Self {
38 Self {
39 http_client: Client::new(),
40 api_key: api_key.into(),
41 model,
42 base_url,
43 }
44 }
45
46 pub(crate) async fn generate_content_raw(
48 &self,
49 request: GenerateContentRequest,
50 ) -> Result<GenerationResponse> {
51 let url = self.build_url("generateContent")?;
52
53 let response = self.http_client.post(url).json(&request).send().await?;
54
55 let status = response.status();
56 if !status.is_success() {
57 let error_text = response.text().await?;
58 return Err(Error::ApiError {
59 status_code: status.as_u16(),
60 message: error_text,
61 });
62 }
63
64 let response = response.json().await?;
65
66 Ok(response)
67 }
68
69 pub(crate) async fn generate_content_stream(
71 &self,
72 request: GenerateContentRequest,
73 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
74 let url = self.build_url("streamGenerateContent")?;
75
76 let response = self.http_client.post(url).json(&request).send().await?;
77
78 let status = response.status();
79 if !status.is_success() {
80 let error_text = response.text().await?;
81 return Err(Error::ApiError {
82 status_code: status.as_u16(),
83 message: error_text,
84 });
85 }
86
87 let stream = response
88 .bytes_stream()
89 .map(|result| {
90 match result {
91 Ok(bytes) => {
92 let text = String::from_utf8_lossy(&bytes);
93 let mut responses = Vec::new();
96 for line in text.lines() {
97 if let Some(json_str) = line.strip_prefix("data: ") {
98 if json_str == "[DONE]" {
99 continue;
100 }
101 match serde_json::from_str::<GenerationResponse>(json_str) {
102 Ok(response) => responses.push(Ok(response)),
103 Err(e) => responses.push(Err(Error::JsonError(e))),
104 }
105 }
106 }
107 futures::stream::iter(responses)
108 }
109 Err(e) => futures::stream::iter(vec![Err(Error::HttpError(e))]),
110 }
111 })
112 .flatten();
113
114 Ok(Box::pin(stream))
115 }
116
117 pub(crate) async fn embed_content(
119 &self,
120 request: EmbedContentRequest,
121 ) -> Result<ContentEmbeddingResponse> {
122 let value = self.embed(request, "embedContent").await?;
123 let response = serde_json::from_value::<ContentEmbeddingResponse>(value)?;
124
125 Ok(response)
126 }
127
128 pub(crate) async fn embed_content_batch(
130 &self,
131 request: BatchEmbedContentsRequest,
132 ) -> Result<BatchContentEmbeddingResponse> {
133 let value = self.embed(request, "batchEmbedContents").await?;
134 let response = serde_json::from_value::<BatchContentEmbeddingResponse>(value)?;
135
136 Ok(response)
137 }
138
139 async fn embed<T: serde::Serialize>(&self, request: T, endpoint: &str) -> Result<Value> {
141 let url = self.build_url(endpoint)?;
142
143 let response = self.http_client.post(url).json(&request).send().await?;
144
145 let status = response.status();
146 if !status.is_success() {
147 let error_text = response.text().await?;
148 return Err(Error::ApiError {
149 status_code: status.as_u16(),
150 message: error_text,
151 });
152 }
153
154 let response = response.json().await?;
155 Ok(response)
156 }
157
158 fn build_url(&self, endpoint: &str) -> Result<Url> {
160 let url_str = format!(
163 "{}{}:{}?key={}",
164 self.base_url, self.model, endpoint, self.api_key
165 );
166 Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
167 }
168}
169
170#[derive(Clone)]
172pub struct Gemini {
173 client: Arc<GeminiClient>,
174}
175
176impl Gemini {
177 pub fn new(api_key: impl Into<String>) -> Self {
179 Self::with_model(api_key, DEFAULT_MODEL.to_string())
180 }
181
182 pub fn pro(api_key: impl Into<String>) -> Self {
184 Self::with_model(api_key, "models/gemini-2.0-pro-exp-02-05".to_string())
185 }
186
187 pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
189 Self::with_model_and_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
190 }
191
192 pub fn with_base_url(api_key: impl Into<String>, base_url: String) -> Self {
194 Self::with_model_and_base_url(api_key, DEFAULT_MODEL.to_string(), base_url)
195 }
196
197 pub fn with_model_and_base_url(
199 api_key: impl Into<String>,
200 model: String,
201 base_url: String,
202 ) -> Self {
203 let client = GeminiClient::with_base_url(api_key, model, base_url);
204 Self {
205 client: Arc::new(client),
206 }
207 }
208
209 pub fn generate_content(&self) -> ContentBuilder {
211 ContentBuilder::new(self.client.clone())
212 }
213
214 pub fn embed_content(&self) -> EmbedBuilder {
216 EmbedBuilder::new(self.client.clone())
217 }
218}