tt_provider_gemini/
lib.rs1pub mod client;
27pub mod errors;
28pub mod pricing;
29pub mod stream;
30pub mod translate;
31
32use async_trait::async_trait;
33use futures::stream::BoxStream;
34use reqwest::Client;
35use tracing::instrument;
36use tt_shared::{
37 filter_extra_headers, validate_provider_url, ChatCompletionChunk, ChatCompletionRequest,
38 ChatCompletionResponse, EmbeddingsRequest, EmbeddingsResponse, ModelInfo, ModelPricing,
39 Provider, ProviderError, RequestContext,
40};
41
42pub use client::ClientConfig;
43
44const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com";
46
47pub struct GeminiProvider {
51 client: Client,
52 allow_local: bool,
56}
57
58impl GeminiProvider {
59 pub fn new(cfg: ClientConfig) -> Self {
66 let client =
67 client::build_client(&cfg).expect("failed to build reqwest::Client for Gemini adapter");
68 Self {
69 client,
70 allow_local: false,
71 }
72 }
73
74 #[doc(hidden)]
81 pub fn new_allow_local(cfg: ClientConfig) -> Self {
82 let client =
83 client::build_client(&cfg).expect("failed to build reqwest::Client for Gemini adapter");
84 Self {
85 client,
86 allow_local: true,
87 }
88 }
89
90 fn base_url<'a>(&self, ctx: &'a RequestContext) -> &'a str {
92 ctx.credentials
93 .base_url
94 .as_deref()
95 .unwrap_or(DEFAULT_BASE_URL)
96 }
97}
98
99#[async_trait]
100impl Provider for GeminiProvider {
101 fn id(&self) -> &'static str {
102 "gemini"
103 }
104
105 fn models(&self) -> Vec<ModelInfo> {
106 pricing::all_models()
107 }
108
109 fn pricing(&self, model: &str) -> Option<ModelPricing> {
110 pricing::pricing_for(model)
111 }
112
113 fn dropped_params(&self, req: &tt_shared::ChatCompletionRequest) -> Vec<String> {
114 let mut out = Vec::new();
116 if req.n.is_some() {
117 out.push("n".to_string());
118 }
119 if req.seed.is_some() {
120 out.push("seed".to_string());
121 }
122 if req.presence_penalty.is_some() {
123 out.push("presence_penalty".to_string());
124 }
125 if req.frequency_penalty.is_some() {
126 out.push("frequency_penalty".to_string());
127 }
128 if req.user.is_some() {
129 out.push("user".to_string());
130 }
131 out
132 }
133
134 #[instrument(skip(self, ctx), fields(provider = "gemini", model = %req.model))]
140 async fn chat_completion(
141 &self,
142 req: ChatCompletionRequest,
143 ctx: &RequestContext,
144 ) -> Result<ChatCompletionResponse, ProviderError> {
145 let base_url = self.base_url(ctx);
146 if ctx.credentials.base_url.is_some() {
149 validate_provider_url(base_url, self.allow_local)
150 .map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
151 }
152
153 let api_key = ctx.credentials.api_key.expose().to_string();
154 let model = req.model.clone();
155
156 translate::validate_model_id(&model)?;
157 let url = format!("{base_url}/v1beta/models/{model}:generateContent");
158
159 let body = translate::translate_request(req)?;
160
161 let mut request_builder = self
162 .client
163 .post(&url)
164 .header("Content-Type", "application/json")
165 .header("x-goog-api-key", &api_key)
166 .json(&body);
167 for (name, value) in &filter_extra_headers(&ctx.credentials.extra_headers) {
170 request_builder = request_builder.header(name, value);
171 }
172 let response = request_builder
173 .send()
174 .await
175 .map_err(errors::map_reqwest_error)?;
176
177 let status = response.status().as_u16();
178 let retry_after = response
179 .headers()
180 .get("retry-after")
181 .and_then(|v| v.to_str().ok())
182 .map(|s| s.to_string());
183
184 let response_text = response.text().await.map_err(errors::map_reqwest_error)?;
185
186 if status >= 400 {
187 return Err(errors::map_response_error(
188 status,
189 &response_text,
190 retry_after.as_deref(),
191 &model,
192 ));
193 }
194
195 translate::deserialize_response(&response_text, &model)
196 }
197
198 #[instrument(skip(self, ctx), fields(provider = "gemini", model = %req.model))]
205 async fn chat_completion_stream(
206 &self,
207 req: ChatCompletionRequest,
208 ctx: &RequestContext,
209 ) -> Result<BoxStream<'static, Result<ChatCompletionChunk, ProviderError>>, ProviderError> {
210 let base_url = self.base_url(ctx);
211 if ctx.credentials.base_url.is_some() {
212 validate_provider_url(base_url, self.allow_local)
213 .map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
214 }
215
216 let base_url = base_url.to_string();
217 let client = self.client.clone();
218 stream::stream_chat_completion(client, &base_url, req, ctx).await
219 }
220
221 async fn embeddings(
228 &self,
229 _req: EmbeddingsRequest,
230 _ctx: &RequestContext,
231 ) -> Result<EmbeddingsResponse, ProviderError> {
232 Err(ProviderError::Unsupported(
233 "Gemini embedding models use a separate endpoint; use a dedicated embedding adapter"
234 .to_string(),
235 ))
236 }
237}