Skip to main content

rs_genai/client/
mod.rs

1//! Unified Gemini API client — wraps both Live (WebSocket) and REST API access.
2//!
3//! The [`Client`] struct provides a single entry point for all Gemini APIs.
4//! REST API modules are feature-gated behind their respective features
5//! (e.g., `generate`, `embed`, `models`) so that live-only users pay zero cost.
6
7#[cfg(feature = "http")]
8pub mod http;
9
10use std::sync::Arc;
11
12use crate::protocol::types::{ApiEndpoint, GeminiModel, SessionConfig};
13use crate::session::SessionError;
14use crate::session::SessionHandle;
15use crate::transport::auth::{
16    AuthProvider, GoogleAIAuth, GoogleAITokenAuth, ServiceEndpoint, VertexAIAuth,
17};
18use crate::transport::{connect, TransportConfig};
19
20/// Unified Gemini API client.
21///
22/// Mirrors the `GoogleGenAI` class from `@google/genai` (js-genai).
23/// Provides access to both Live (WebSocket) and REST APIs through a single
24/// authenticated entry point.
25///
26/// # Construction
27///
28/// ```ignore
29/// // From API key (Google AI)
30/// let client = Client::from_api_key("your-api-key");
31///
32/// // From Vertex AI credentials
33/// let client = Client::from_vertex("project-id", "us-central1", "access-token");
34///
35/// // Live WebSocket session
36/// let session = client.live("gemini-2.5-flash").connect().await?;
37/// ```
38pub struct Client {
39    endpoint: ApiEndpoint,
40    model: GeminiModel,
41    auth: Arc<dyn AuthProvider>,
42    #[cfg(feature = "http")]
43    http: http::HttpClient,
44}
45
46impl Client {
47    /// Create a client with Google AI API key authentication.
48    pub fn from_api_key(api_key: impl Into<String>) -> Self {
49        let key: String = api_key.into();
50        let endpoint = ApiEndpoint::google_ai(key.clone());
51        let auth: Arc<dyn AuthProvider> = Arc::new(GoogleAIAuth::new(key));
52        Self {
53            endpoint,
54            model: GeminiModel::default(),
55            auth,
56            #[cfg(feature = "http")]
57            http: http::HttpClient::new(http::HttpConfig::default()),
58        }
59    }
60
61    /// Create a client with Google AI OAuth2 token authentication.
62    pub fn from_access_token(access_token: impl Into<String>) -> Self {
63        let token: String = access_token.into();
64        let endpoint = ApiEndpoint::google_ai_token(token.clone());
65        let auth: Arc<dyn AuthProvider> = Arc::new(GoogleAITokenAuth::new(token));
66        Self {
67            endpoint,
68            model: GeminiModel::default(),
69            auth,
70            #[cfg(feature = "http")]
71            http: http::HttpClient::new(http::HttpConfig::default()),
72        }
73    }
74
75    /// Create a client with Vertex AI authentication.
76    pub fn from_vertex(
77        project: impl Into<String>,
78        location: impl Into<String>,
79        access_token: impl Into<String>,
80    ) -> Self {
81        let proj: String = project.into();
82        let loc: String = location.into();
83        let tok: String = access_token.into();
84        let endpoint = ApiEndpoint::vertex(proj.clone(), loc.clone(), tok.clone());
85        let auth: Arc<dyn AuthProvider> = Arc::new(VertexAIAuth::new(proj, loc, tok));
86        Self {
87            endpoint,
88            model: GeminiModel::default(),
89            auth,
90            #[cfg(feature = "http")]
91            http: http::HttpClient::new(http::HttpConfig::default()),
92        }
93    }
94
95    /// Create a client with Vertex AI authentication and dynamic token refresh.
96    ///
97    /// The `refresher` closure is called on every REST API request to obtain
98    /// a fresh Bearer token. It should handle caching internally to avoid
99    /// unnecessary overhead (see `GcloudTokenProvider` in rs-adk for an example).
100    ///
101    /// This is the recommended constructor for long-running HTTP clients
102    /// (e.g., extraction LLMs) where tokens may expire during the session.
103    pub fn from_vertex_refreshable(
104        project: impl Into<String>,
105        location: impl Into<String>,
106        refresher: impl Fn() -> String + Send + Sync + 'static,
107    ) -> Self {
108        let proj: String = project.into();
109        let loc: String = location.into();
110        // Get initial token for the ApiEndpoint (used if .live() is called)
111        let initial_token = refresher();
112        let endpoint = ApiEndpoint::vertex(proj.clone(), loc.clone(), initial_token);
113        let auth: Arc<dyn AuthProvider> =
114            Arc::new(VertexAIAuth::with_token_refresher(proj, loc, refresher));
115        Self {
116            endpoint,
117            model: GeminiModel::default(),
118            auth,
119            #[cfg(feature = "http")]
120            http: http::HttpClient::new(http::HttpConfig::default()),
121        }
122    }
123
124    /// Set the default model for all API calls.
125    pub fn model(mut self, model: impl Into<GeminiModel>) -> Self {
126        self.model = model.into();
127        self
128    }
129
130    /// Configure the HTTP client (timeouts, retries, etc.).
131    #[cfg(feature = "http")]
132    pub fn http_config(mut self, config: http::HttpConfig) -> Self {
133        self.http = http::HttpClient::new(config);
134        self
135    }
136
137    /// Get a reference to the underlying auth provider.
138    pub fn auth(&self) -> &dyn AuthProvider {
139        &*self.auth
140    }
141
142    /// Get the default model.
143    pub fn default_model(&self) -> &GeminiModel {
144        &self.model
145    }
146
147    /// Build the REST URL for a given service endpoint, using the default model.
148    pub fn rest_url(&self, endpoint: ServiceEndpoint) -> String {
149        self.auth.rest_url(endpoint, Some(&self.model))
150    }
151
152    /// Build the REST URL for a given service endpoint with a specific model.
153    pub fn rest_url_for(&self, endpoint: ServiceEndpoint, model: &GeminiModel) -> String {
154        self.auth.rest_url(endpoint, Some(model))
155    }
156
157    /// Get auth headers for REST API calls.
158    pub async fn auth_headers(&self) -> Result<Vec<(String, String)>, crate::session::AuthError> {
159        self.auth.auth_headers().await
160    }
161
162    /// Start a Live WebSocket session builder.
163    ///
164    /// Returns a [`LiveSessionBuilder`] that can be customized before connecting.
165    pub fn live(&self, model: GeminiModel) -> LiveSessionBuilder {
166        LiveSessionBuilder {
167            endpoint: self.endpoint.clone(),
168            model,
169            transport_config: TransportConfig::default(),
170            config_fn: None,
171        }
172    }
173
174    /// Get a reference to the HTTP client for making REST API calls.
175    #[cfg(feature = "http")]
176    pub fn http_client(&self) -> &http::HttpClient {
177        &self.http
178    }
179
180    /// Make a raw REST API request (low-level).
181    ///
182    /// Higher-level module methods (e.g., `generate_content()`) should be preferred.
183    #[cfg(feature = "http")]
184    pub async fn rest_request(
185        &self,
186        endpoint: ServiceEndpoint,
187        body: &impl serde::Serialize,
188    ) -> Result<serde_json::Value, http::HttpError> {
189        let url = self.rest_url(endpoint);
190        let headers = self
191            .auth
192            .auth_headers()
193            .await
194            .map_err(|e| http::HttpError::Auth(e.to_string()))?;
195        self.http.post_json(&url, headers, body).await
196    }
197}
198
199/// Builder for Live WebSocket sessions initiated from a [`Client`].
200pub struct LiveSessionBuilder {
201    endpoint: ApiEndpoint,
202    model: GeminiModel,
203    transport_config: TransportConfig,
204    config_fn: Option<Box<dyn FnOnce(SessionConfig) -> SessionConfig>>,
205}
206
207impl LiveSessionBuilder {
208    /// Set transport configuration (timeouts, reconnection, etc.).
209    pub fn transport_config(mut self, config: TransportConfig) -> Self {
210        self.transport_config = config;
211        self
212    }
213
214    /// Apply a customization function to the session config before connecting.
215    pub fn configure(mut self, f: impl FnOnce(SessionConfig) -> SessionConfig + 'static) -> Self {
216        self.config_fn = Some(Box::new(f));
217        self
218    }
219
220    /// Connect and return a [`SessionHandle`].
221    pub async fn connect(self) -> Result<SessionHandle, SessionError> {
222        let mut config = SessionConfig::from_endpoint(self.endpoint).model(self.model);
223
224        if let Some(f) = self.config_fn {
225            config = f(config);
226        }
227
228        connect(config, self.transport_config).await
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn client_from_api_key() {
238        let client = Client::from_api_key("test-key");
239        assert!(matches!(
240            client.default_model(),
241            GeminiModel::GeminiLive2_5FlashNativeAudio
242        ));
243    }
244
245    #[test]
246    fn client_from_vertex() {
247        let client = Client::from_vertex("proj", "us-central1", "tok");
248        let url = client.auth().ws_url(&GeminiModel::default());
249        assert!(url.contains("us-central1-aiplatform.googleapis.com"));
250    }
251
252    #[test]
253    fn client_model_override() {
254        let client = Client::from_api_key("key").model(GeminiModel::Gemini2_0FlashLive);
255        assert!(matches!(
256            client.default_model(),
257            GeminiModel::Gemini2_0FlashLive
258        ));
259    }
260
261    #[test]
262    fn client_rest_url_generate() {
263        let client = Client::from_api_key("my-key").model(GeminiModel::Gemini2_0FlashLive);
264        let url = client.rest_url(ServiceEndpoint::GenerateContent);
265        assert!(url.contains(":generateContent"));
266        assert!(url.contains("key=my-key"));
267    }
268
269    #[test]
270    fn client_rest_url_vertex() {
271        let client =
272            Client::from_vertex("proj", "us-east1", "tok").model(GeminiModel::Gemini2_0FlashLive);
273        let url = client.rest_url(ServiceEndpoint::GenerateContent);
274        assert!(url.contains("us-east1-aiplatform.googleapis.com"));
275        assert!(url.contains(":generateContent"));
276    }
277
278    #[test]
279    fn live_session_builder_created() {
280        let client = Client::from_api_key("key");
281        let _builder = client.live(GeminiModel::Gemini2_0FlashLive);
282    }
283
284    #[tokio::test]
285    async fn client_from_vertex_refreshable() {
286        use std::sync::atomic::{AtomicU32, Ordering};
287        let call_count = Arc::new(AtomicU32::new(0));
288        let cc = call_count.clone();
289        let client = Client::from_vertex_refreshable("proj", "us-central1", move || {
290            cc.fetch_add(1, Ordering::SeqCst);
291            "refreshed-token".to_string()
292        });
293        // Initial token fetch happens at construction
294        assert!(call_count.load(Ordering::SeqCst) >= 1);
295        // auth_headers should call the refresher again
296        let headers = client.auth_headers().await.unwrap();
297        assert_eq!(headers[0].1, "Bearer refreshed-token");
298        assert!(call_count.load(Ordering::SeqCst) >= 2);
299    }
300}