google_ai_rs/
client.rs

1#[allow(unused_imports)]
2use std::collections::VecDeque;
3use std::ops::Deref;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::RwLock;
7use tonic::body::Body;
8use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
9use tonic::{IntoRequest, RawRequest};
10
11use crate::auth::{Auth, AuthParsed};
12use crate::content::UpdateFieldMask as _;
13use crate::error::{status_into_error, Error, NetError, SetupError, TonicTransportError};
14use crate::full_model_name;
15use crate::proto::model_service_client::ModelServiceClient;
16use crate::proto::{
17    cache_service_client::CacheServiceClient, generative_service_client::GenerativeServiceClient,
18    CachedContent, CreateCachedContentRequest, DeleteCachedContentRequest, GetCachedContentRequest,
19    ListCachedContentsRequest, UpdateCachedContentRequest,
20};
21use crate::proto::{
22    DeleteTunedModelRequest, GetModelRequest, GetTunedModelRequest, ListModelsRequest,
23    ListTunedModelsRequest, Model, TunedModel, UpdateTunedModelRequest,
24};
25
26/// Default timeout for client requests (2 minutes)
27const DEFAULT_TIMEOUT: Duration = Duration::from_secs(120);
28/// Base URL for Google's Generative Language API
29const BASE_API_URL: &str = "https://generativelanguage.googleapis.com";
30/// Default page size for paginated requests (server determines actual size when 0)
31const DEFAULT_PAGE_SIZE: i32 = 0;
32/// Default user agent for the client (to be appended to tonic's)
33const USER_AGENT: &str = "google-ai-rs/0.1 (Rust)";
34
35/// A thread-safe client for interacting with Google's Generative Language API.
36///
37/// # Features
38/// - Manages authentication tokens and TLS configuration
39/// - Provides access to generative AI operations
40/// - Implements content caching functionality
41/// - Supports automatic pagination of list operations
42///
43/// # Example
44/// ```
45/// use google_ai_rs::Client;
46///
47/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
48/// let client = Client::new("your-api-key").await?;
49/// # Ok(())
50/// # }
51/// ```
52#[derive(Clone, Debug)]
53pub struct Client {
54    /// Generative service gRPC client
55    pub(super) gc: GenerativeServiceClient<Channel>,
56    /// Cache service gRPC client
57    pub(super) cc: CacheServiceClient<Channel>,
58    pub(super) mc: ModelServiceClient<Channel>,
59    /// Authentication credentials with concurrent access support
60    #[cfg(feature = "auth_update")]
61    // Enable this if we have auth_update
62    auth_update: Arc<RwLock<AuthParsed>>,
63}
64
65/// A thread-safe, cheaply clonable client for interacting with the Generative Language API.
66///
67/// This client wraps a standard `Client` in an `Arc`, making it easy to share
68/// across threads without lifetime issues. Unlike the regular `Client`, which
69/// provides a borrowed reference (`&'c self`), methods on `SharedClient`
70/// return models with a static lifetime (`'static`), allowing them to be
71/// moved and stored independently of the client.
72///
73/// Use `SharedClient` when you need to pass the client to different threads,
74/// store it in a global state, or when the client is intended to live for the
75/// duration of the application.
76///
77/// # Example
78/// ```
79/// use google_ai_rs::{Client, SharedClient};
80///
81/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
82/// let client = Client::new("your-api-key").await?;
83/// let shared_client: SharedClient = client.into_shared();
84///
85/// let model = shared_client.generative_model("models/gemini-pro");
86/// // The model can now be used in a different thread or stored.
87///
88/// drop(shared_client);
89///
90/// // You can still use model
91/// model.generate_content("Hello, AI").await?;
92///
93/// # Ok(())
94/// # }
95/// ```
96#[derive(Clone, Debug)]
97pub struct SharedClient {
98    inner: Arc<Client>,
99}
100
101impl Deref for SharedClient {
102    type Target = Client;
103
104    fn deref(&self) -> &Self::Target {
105        &self.inner
106    }
107}
108
109impl From<Client> for SharedClient {
110    fn from(value: Client) -> Self {
111        SharedClient {
112            inner: Arc::new(value),
113        }
114    }
115}
116
117impl Client {
118    /// Constructs a new client with authentication and optional configuration.
119    ///
120    /// # Arguments
121    /// * `auth` - API authentication credentials
122    ///
123    /// # Errors
124    /// Returns [`Error::Setup`] for configuration issues or [`Error::Net`] for connection failures.
125    pub async fn new(auth: impl Into<Auth> + Send) -> Result<Self, Error> {
126        ClientBuilder::new()
127            .timeout(DEFAULT_TIMEOUT)
128            .user_agent(USER_AGENT)
129            .unwrap()
130            .build(auth)
131            .await
132    }
133
134    /// Create a builder for configuring client options
135    pub fn builder() -> ClientBuilder {
136        ClientBuilder::new()
137    }
138
139    /// Converts the `Client` into a `SharedClient`.
140    ///
141    /// This moves the `Client` into an `Arc`, making it suitable for
142    /// multithreaded applications or long-lived static contexts.
143    pub fn into_shared(self) -> SharedClient {
144        self.into()
145    }
146
147    /// Updates authentication credentials atomically
148    ///
149    /// Subsequent requests will use the new credentials immediately. This operation
150    /// is thread-safe.
151    ///
152    /// # Panics
153    ///
154    /// May panic if auth cannot parsed
155    #[cfg(feature = "auth_update")]
156    pub async fn update_auth(&self, new_auth: impl Into<Auth> + Send) {
157        self.update_auth_fallibly(new_auth)
158            .await
159            .expect("Auth parsing failed in update_auth — ensure input was valid")
160    }
161
162    /// Fallible [`Self::update_auth`].
163    #[cfg(feature = "auth_update")]
164    pub async fn update_auth_fallibly(
165        &self,
166        new_auth: impl Into<Auth> + Send,
167    ) -> Result<(), crate::auth::Error> {
168        *self.auth_update.write().await = new_auth.into().parsed()?;
169        Ok(())
170    }
171
172    /// Creates a new cached content entry
173    ///
174    /// # Arguments
175    /// * `content` - Content to cache without name (server-generated)
176    ///
177    /// # Errors
178    /// Returns [`Error::InvalidArgument`] if content contains a name
179    pub async fn create_cached_content(
180        &self,
181        content: CachedContent,
182    ) -> Result<CachedContent, Error> {
183        if content.name.is_some() {
184            return Err(Error::InvalidArgument(
185                "CachedContent name must be empty for creation".into(),
186            ));
187        }
188
189        let request = CreateCachedContentRequest {
190            cached_content: Some(content),
191        }
192        .into_request();
193
194        self.cc
195            .clone()
196            .create_cached_content(request)
197            .await
198            .map_err(status_into_error)
199            .map(|r| r.into_inner())
200    }
201
202    /// Retrieves the `CachedContent` with the given name.
203    pub async fn get_cached_content(&self, name: &str) -> Result<CachedContent, Error> {
204        let request = GetCachedContentRequest {
205            name: name.to_owned(),
206        }
207        .into_request();
208
209        self.cc
210            .clone()
211            .get_cached_content(request)
212            .await
213            .map_err(status_into_error)
214            .map(|r| r.into_inner())
215    }
216
217    /// Deletes the `CachedContent` with the given name.
218    pub async fn delete_cached_content(&self, name: &str) -> Result<(), Error> {
219        let request = DeleteCachedContentRequest {
220            name: name.to_owned(),
221        }
222        .into_request();
223
224        self.cc
225            .clone()
226            .delete_cached_content(request)
227            .await
228            .map_err(status_into_error)
229            .map(|r| r.into_inner())
230    }
231
232    /// Modifies the `CachedContent`.
233    ///
234    /// It returns the modified CachedContent.
235    ///
236    /// The argument CachedContent must have its name field and fields to update populated.
237    pub async fn update_cached_content(&self, cc: &CachedContent) -> Result<CachedContent, Error> {
238        let request = UpdateCachedContentRequest {
239            cached_content: Some(cc.to_owned()),
240            update_mask: Some(cc.field_mask()),
241        }
242        .into_request();
243
244        self.cc
245            .clone()
246            .update_cached_content(request)
247            .await
248            .map_err(status_into_error)
249            .map(|r| r.into_inner())
250    }
251
252    /// Returns an async iterator over cached content entries
253    ///
254    /// Automatically handles pagination through server-side results.
255    pub fn list_cached_contents(&self) -> CachedContentIterator<'_> {
256        PageIterator::<CachedContentPager>::new(self)
257    }
258
259    /// Gets information about a specific `Model` such as its version number, token
260    /// limits, etc
261    pub async fn get_model(&self, name: &str) -> Result<Model, Error> {
262        let request = GetModelRequest {
263            name: full_model_name(name).to_string(),
264        }
265        .into_request();
266
267        self.mc
268            .clone()
269            .get_model(request)
270            .await
271            .map_err(status_into_error)
272            .map(|r| r.into_inner())
273    }
274
275    /// Gets information about a specific `TunedModel`.
276    pub async fn get_tuned_model(&self, resource_name: &str) -> Result<TunedModel, Error> {
277        let request = GetTunedModelRequest {
278            name: resource_name.to_owned(),
279        }
280        .into_request();
281
282        self.mc
283            .clone()
284            .get_tuned_model(request)
285            .await
286            .map_err(status_into_error)
287            .map(|r| r.into_inner())
288    }
289
290    /// Returns an async iterator over models list results
291    ///
292    /// Automatically handles pagination through server-side results.
293    pub async fn list_models(&self) -> ModelsListIterator<'_> {
294        PageIterator::<ModelsListPager>::new(self)
295    }
296
297    /// Returns an async iterator over tuned models list results
298    ///
299    /// Automatically handles pagination through server-side results.
300    pub async fn list_tuned_models(&self) -> TunedModelsListIterator<'_> {
301        PageIterator::<TunedModelsListPager>::new(self)
302    }
303
304    /// Updates a tuned model.
305    pub async fn update_tuned_model(&self, m: &TunedModel) -> Result<TunedModel, Error> {
306        let request = UpdateTunedModelRequest {
307            tuned_model: Some(m.to_owned()),
308            update_mask: Some(m.field_mask()),
309        }
310        .into_request();
311
312        self.mc
313            .clone()
314            .update_tuned_model(request)
315            .await
316            .map_err(status_into_error)
317            .map(|r| r.into_inner())
318    }
319
320    /// Deletes the `TunedModel` with the given name.
321    pub async fn delete_tuned_model(&self, name: &str) -> Result<(), Error> {
322        let request = DeleteTunedModelRequest {
323            name: name.to_owned(),
324        }
325        .into_request();
326
327        self.mc
328            .clone()
329            .delete_tuned_model(request)
330            .await
331            .map_err(status_into_error)
332            .map(|r| r.into_inner())
333    }
334}
335
336#[derive(Debug, Clone)]
337pub struct ClientBuilder {
338    endpoint: Endpoint,
339}
340
341impl Default for ClientBuilder {
342    fn default() -> Self {
343        Self::new()
344    }
345}
346
347impl ClientBuilder {
348    /// Creates new builder with required authentication
349    pub fn new() -> Self {
350        Self {
351            endpoint: Endpoint::from_static(BASE_API_URL),
352        }
353    }
354
355    /// Sets overall request timeout (default: 120s)
356    pub fn timeout(mut self, duration: Duration) -> Self {
357        self.endpoint = self.endpoint.timeout(duration);
358        self
359    }
360
361    /// Set connection establishment timeout
362    pub fn connect_timeout(mut self, duration: Duration) -> Self {
363        self.endpoint = self.endpoint.connect_timeout(duration);
364        self
365    }
366
367    /// Set custom user agent string
368    pub fn user_agent(mut self, ua: impl Into<String>) -> Result<Self, Error> {
369        self.endpoint = self
370            .endpoint
371            .user_agent(ua.into())
372            .map_err(|e| SetupError::new("User-Agent configuration", e))?;
373        Ok(self)
374    }
375
376    /// Set maximum concurrent requests per connection
377    pub fn concurrency_limit(mut self, limit: usize) -> Self {
378        self.endpoint = self.endpoint.concurrency_limit(limit);
379        self
380    }
381
382    /// Finalizes configuration and constructs a [`SharedClient`]
383    pub async fn build_shared(self, auth: impl Into<Auth> + Send) -> Result<SharedClient, Error> {
384        self.build(auth).await.map(Into::into)
385    }
386
387    /// Finalizes configuration and constructs client
388    ///
389    /// # Arguments
390    /// * `auth` - Authentication credentials (API key or service account)
391    ///
392    /// # Errors
393    /// - Returns [`Error::Setup`] for invalid configurations
394    /// - Returns [`Error::Net`] for connection failures  
395    pub async fn build(self, auth: impl Into<Auth> + Send) -> Result<Client, Error> {
396        let endpoint = self
397            .endpoint
398            .tls_config(ClientTlsConfig::new().with_enabled_roots())
399            .map_err(|e| SetupError::new("TLS configuration", e))?;
400
401        // We make sure to parse to avoid 'after init' error
402        let auth = auth.into().parsed()?;
403
404        // We need exclusive access when we may need to update
405        #[cfg(feature = "auth_update")]
406        let auth = Arc::new(RwLock::new(auth));
407        let auth_update = auth.clone();
408
409        // This is done to reduce client size and eliminate calls to add_auth
410        // in library methods.
411        let auth_adder = async move |mut raw_request: RawRequest<Body>| {
412            #[cfg(not(feature = "auth_update"))]
413            let _jwt_fut = auth._into_request(raw_request.headers_mut());
414
415            #[cfg(feature = "auth_update")]
416            let binding = auth.read().await;
417            let _jwt_fut = binding.to_request(raw_request.headers_mut());
418
419            #[cfg(feature = "jwt")]
420            _jwt_fut.await;
421
422            raw_request
423        };
424
425        let channel = unsafe { endpoint.connect_with_modifier_fn(auth_adder) };
426
427        let channel = channel.await.map_err(|e| {
428            Error::Net(NetError::TransportFailure(TonicTransportError(Box::new(e))))
429        })?;
430
431        let client = Client {
432            gc: GenerativeServiceClient::new(channel.clone()),
433            cc: CacheServiceClient::new(channel.clone()),
434            mc: ModelServiceClient::new(channel),
435            #[cfg(feature = "auth_update")]
436            auth_update,
437        };
438
439        Ok(client)
440    }
441}
442
443// I don't know what to name it but think CowClient
444#[derive(Clone, Debug)]
445pub(crate) enum CClient<'a> {
446    Shared(SharedClient),
447    Borrowed(&'a Client),
448}
449
450impl CClient<'_> {
451    pub(crate) fn cloned(&self) -> CClient<'_> {
452        match self {
453            CClient::Shared(shared_client) => CClient::Borrowed(shared_client),
454            CClient::Borrowed(client) => CClient::Borrowed(client),
455        }
456    }
457}
458
459#[allow(clippy::from_over_into)]
460impl Into<CClient<'static>> for SharedClient {
461    fn into(self) -> CClient<'static> {
462        CClient::Shared(self)
463    }
464}
465
466#[allow(clippy::from_over_into)]
467impl<'a> Into<CClient<'a>> for &'a Client {
468    fn into(self) -> CClient<'a> {
469        CClient::Borrowed(self)
470    }
471}
472
473impl Deref for CClient<'_> {
474    type Target = Client;
475
476    fn deref(&self) -> &Self::Target {
477        match self {
478            CClient::Shared(shared_client) => &shared_client.inner,
479            CClient::Borrowed(client) => client,
480        }
481    }
482}
483
484/// Async iterator for paginated cached content results
485pub type CachedContentIterator<'a> = PageIterator<'a, CachedContentPager>;
486
487/// Async iterator for paginated models results
488pub type ModelsListIterator<'a> = PageIterator<'a, ModelsListPager>;
489
490/// Async iterator for paginated tuned models results
491pub type TunedModelsListIterator<'a> = PageIterator<'a, TunedModelsListPager>;
492
493/// Async iterator for paginated contents
494///
495/// Buffers results from multiple pages and provides linear access.
496/// Implements automatic page fetching when buffer is exhausted.
497pub struct PageIterator<'a, P>
498where
499    P: Page + Send,
500{
501    client: &'a Client,
502    next_page_token: Option<String>,
503    buffer: VecDeque<P::Content>,
504}
505
506impl<'a, P> PageIterator<'a, P>
507where
508    P: Page + Send,
509{
510    fn new(client: &'a Client) -> Self {
511        Self {
512            client,
513            next_page_token: Some(String::new()),
514            buffer: VecDeque::new(),
515        }
516    }
517
518    /// Returns the next content item
519    ///
520    /// Returns `Ok(None)` when all items have been exhausted.
521    pub async fn next(&mut self) -> Result<Option<P::Content>, Error> {
522        if self.buffer.is_empty() {
523            if self.next_page_token.is_none() {
524                // We've already fetched all pages
525                return Ok(None);
526            }
527
528            let (items, next_token) =
529                P::next(self.client, self.next_page_token.as_ref().unwrap()).await?;
530
531            self.next_page_token = if next_token.is_empty() {
532                None
533            } else {
534                Some(next_token)
535            };
536            self.buffer.extend(items);
537        }
538
539        Ok(self.buffer.pop_front())
540    }
541}
542
543#[tonic::async_trait]
544pub trait Page: sealed::Sealed {
545    type Content;
546    /// Fetches the next page of results
547    async fn next(client: &Client, page_token: &str)
548        -> Result<(Vec<Self::Content>, String), Error>;
549}
550
551impl<T> sealed::Sealed for T {}
552
553mod sealed {
554    pub trait Sealed {}
555}
556
557pub struct CachedContentPager;
558
559#[tonic::async_trait]
560impl Page for CachedContentPager {
561    type Content = CachedContent;
562
563    async fn next(
564        client: &Client,
565        page_token: &str,
566    ) -> Result<(Vec<Self::Content>, String), Error> {
567        let request = ListCachedContentsRequest {
568            page_size: DEFAULT_PAGE_SIZE,
569            page_token: page_token.to_owned(),
570        }
571        .into_request();
572
573        let response = client
574            .cc
575            .clone()
576            .list_cached_contents(request)
577            .await
578            .map_err(status_into_error)?
579            .into_inner();
580        Ok((response.cached_contents, response.next_page_token))
581    }
582}
583
584pub struct ModelsListPager;
585
586#[tonic::async_trait]
587impl Page for ModelsListPager {
588    type Content = Model;
589
590    async fn next(
591        client: &Client,
592        page_token: &str,
593    ) -> Result<(Vec<Self::Content>, String), Error> {
594        let request = ListModelsRequest {
595            page_size: DEFAULT_PAGE_SIZE,
596            page_token: page_token.to_owned(),
597        }
598        .into_request();
599
600        let response = client
601            .mc
602            .clone()
603            .list_models(request)
604            .await
605            .map_err(status_into_error)?
606            .into_inner();
607        Ok((response.models, response.next_page_token))
608    }
609}
610
611pub struct TunedModelsListPager;
612
613#[tonic::async_trait]
614impl Page for TunedModelsListPager {
615    type Content = TunedModel;
616
617    async fn next(
618        client: &Client,
619        page_token: &str,
620    ) -> Result<(Vec<Self::Content>, String), Error> {
621        let request = ListTunedModelsRequest {
622            page_size: DEFAULT_PAGE_SIZE,
623            page_token: page_token.to_owned(),
624            filter: String::new(),
625        }
626        .into_request();
627
628        let response = client
629            .mc
630            .clone()
631            .list_tuned_models(request)
632            .await
633            .map_err(status_into_error)?
634            .into_inner();
635        Ok((response.tuned_models, response.next_page_token))
636    }
637}