1#[cfg(feature = "http")]
8pub mod http;
9
10use std::sync::Arc;
11
12use crate::protocol::types::{ApiEndpoint, GeminiModel, SessionConfig};
13use crate::session::SessionError;
14use crate::session::SessionHandle;
15use crate::transport::auth::{
16 AuthProvider, GoogleAIAuth, GoogleAITokenAuth, ServiceEndpoint, VertexAIAuth,
17};
18use crate::transport::{connect, TransportConfig};
19
20pub struct Client {
39 endpoint: ApiEndpoint,
40 model: GeminiModel,
41 auth: Arc<dyn AuthProvider>,
42 #[cfg(feature = "http")]
43 http: http::HttpClient,
44}
45
46impl Client {
47 pub fn from_api_key(api_key: impl Into<String>) -> Self {
49 let key: String = api_key.into();
50 let endpoint = ApiEndpoint::google_ai(key.clone());
51 let auth: Arc<dyn AuthProvider> = Arc::new(GoogleAIAuth::new(key));
52 Self {
53 endpoint,
54 model: GeminiModel::default(),
55 auth,
56 #[cfg(feature = "http")]
57 http: http::HttpClient::new(http::HttpConfig::default()),
58 }
59 }
60
61 pub fn from_access_token(access_token: impl Into<String>) -> Self {
63 let token: String = access_token.into();
64 let endpoint = ApiEndpoint::google_ai_token(token.clone());
65 let auth: Arc<dyn AuthProvider> = Arc::new(GoogleAITokenAuth::new(token));
66 Self {
67 endpoint,
68 model: GeminiModel::default(),
69 auth,
70 #[cfg(feature = "http")]
71 http: http::HttpClient::new(http::HttpConfig::default()),
72 }
73 }
74
75 pub fn from_vertex(
77 project: impl Into<String>,
78 location: impl Into<String>,
79 access_token: impl Into<String>,
80 ) -> Self {
81 let proj: String = project.into();
82 let loc: String = location.into();
83 let tok: String = access_token.into();
84 let endpoint = ApiEndpoint::vertex(proj.clone(), loc.clone(), tok.clone());
85 let auth: Arc<dyn AuthProvider> = Arc::new(VertexAIAuth::new(proj, loc, tok));
86 Self {
87 endpoint,
88 model: GeminiModel::default(),
89 auth,
90 #[cfg(feature = "http")]
91 http: http::HttpClient::new(http::HttpConfig::default()),
92 }
93 }
94
95 pub fn from_vertex_refreshable(
104 project: impl Into<String>,
105 location: impl Into<String>,
106 refresher: impl Fn() -> String + Send + Sync + 'static,
107 ) -> Self {
108 let proj: String = project.into();
109 let loc: String = location.into();
110 let initial_token = refresher();
112 let endpoint = ApiEndpoint::vertex(proj.clone(), loc.clone(), initial_token);
113 let auth: Arc<dyn AuthProvider> =
114 Arc::new(VertexAIAuth::with_token_refresher(proj, loc, refresher));
115 Self {
116 endpoint,
117 model: GeminiModel::default(),
118 auth,
119 #[cfg(feature = "http")]
120 http: http::HttpClient::new(http::HttpConfig::default()),
121 }
122 }
123
124 pub fn model(mut self, model: impl Into<GeminiModel>) -> Self {
126 self.model = model.into();
127 self
128 }
129
130 #[cfg(feature = "http")]
132 pub fn http_config(mut self, config: http::HttpConfig) -> Self {
133 self.http = http::HttpClient::new(config);
134 self
135 }
136
137 pub fn auth(&self) -> &dyn AuthProvider {
139 &*self.auth
140 }
141
142 pub fn default_model(&self) -> &GeminiModel {
144 &self.model
145 }
146
147 pub fn rest_url(&self, endpoint: ServiceEndpoint) -> String {
149 self.auth.rest_url(endpoint, Some(&self.model))
150 }
151
152 pub fn rest_url_for(&self, endpoint: ServiceEndpoint, model: &GeminiModel) -> String {
154 self.auth.rest_url(endpoint, Some(model))
155 }
156
157 pub async fn auth_headers(&self) -> Result<Vec<(String, String)>, crate::session::AuthError> {
159 self.auth.auth_headers().await
160 }
161
162 pub fn live(&self, model: GeminiModel) -> LiveSessionBuilder {
166 LiveSessionBuilder {
167 endpoint: self.endpoint.clone(),
168 model,
169 transport_config: TransportConfig::default(),
170 config_fn: None,
171 }
172 }
173
174 #[cfg(feature = "http")]
176 pub fn http_client(&self) -> &http::HttpClient {
177 &self.http
178 }
179
180 #[cfg(feature = "http")]
184 pub async fn rest_request(
185 &self,
186 endpoint: ServiceEndpoint,
187 body: &impl serde::Serialize,
188 ) -> Result<serde_json::Value, http::HttpError> {
189 let url = self.rest_url(endpoint);
190 let headers = self
191 .auth
192 .auth_headers()
193 .await
194 .map_err(|e| http::HttpError::Auth(e.to_string()))?;
195 self.http.post_json(&url, headers, body).await
196 }
197}
198
199pub struct LiveSessionBuilder {
201 endpoint: ApiEndpoint,
202 model: GeminiModel,
203 transport_config: TransportConfig,
204 config_fn: Option<Box<dyn FnOnce(SessionConfig) -> SessionConfig>>,
205}
206
207impl LiveSessionBuilder {
208 pub fn transport_config(mut self, config: TransportConfig) -> Self {
210 self.transport_config = config;
211 self
212 }
213
214 pub fn configure(mut self, f: impl FnOnce(SessionConfig) -> SessionConfig + 'static) -> Self {
216 self.config_fn = Some(Box::new(f));
217 self
218 }
219
220 pub async fn connect(self) -> Result<SessionHandle, SessionError> {
222 let mut config = SessionConfig::from_endpoint(self.endpoint).model(self.model);
223
224 if let Some(f) = self.config_fn {
225 config = f(config);
226 }
227
228 connect(config, self.transport_config).await
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn client_from_api_key() {
238 let client = Client::from_api_key("test-key");
239 assert!(matches!(
240 client.default_model(),
241 GeminiModel::GeminiLive2_5FlashNativeAudio
242 ));
243 }
244
245 #[test]
246 fn client_from_vertex() {
247 let client = Client::from_vertex("proj", "us-central1", "tok");
248 let url = client.auth().ws_url(&GeminiModel::default());
249 assert!(url.contains("us-central1-aiplatform.googleapis.com"));
250 }
251
252 #[test]
253 fn client_model_override() {
254 let client = Client::from_api_key("key").model(GeminiModel::Gemini2_0FlashLive);
255 assert!(matches!(
256 client.default_model(),
257 GeminiModel::Gemini2_0FlashLive
258 ));
259 }
260
261 #[test]
262 fn client_rest_url_generate() {
263 let client = Client::from_api_key("my-key").model(GeminiModel::Gemini2_0FlashLive);
264 let url = client.rest_url(ServiceEndpoint::GenerateContent);
265 assert!(url.contains(":generateContent"));
266 assert!(url.contains("key=my-key"));
267 }
268
269 #[test]
270 fn client_rest_url_vertex() {
271 let client =
272 Client::from_vertex("proj", "us-east1", "tok").model(GeminiModel::Gemini2_0FlashLive);
273 let url = client.rest_url(ServiceEndpoint::GenerateContent);
274 assert!(url.contains("us-east1-aiplatform.googleapis.com"));
275 assert!(url.contains(":generateContent"));
276 }
277
278 #[test]
279 fn live_session_builder_created() {
280 let client = Client::from_api_key("key");
281 let _builder = client.live(GeminiModel::Gemini2_0FlashLive);
282 }
283
284 #[tokio::test]
285 async fn client_from_vertex_refreshable() {
286 use std::sync::atomic::{AtomicU32, Ordering};
287 let call_count = Arc::new(AtomicU32::new(0));
288 let cc = call_count.clone();
289 let client = Client::from_vertex_refreshable("proj", "us-central1", move || {
290 cc.fetch_add(1, Ordering::SeqCst);
291 "refreshed-token".to_string()
292 });
293 assert!(call_count.load(Ordering::SeqCst) >= 1);
295 let headers = client.auth_headers().await.unwrap();
297 assert_eq!(headers[0].1, "Bearer refreshed-token");
298 assert!(call_count.load(Ordering::SeqCst) >= 2);
299 }
300}