1#[allow(unused_imports)]
2use std::collections::VecDeque;
3use std::ops::Deref;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::RwLock;
7use tonic::body::Body;
8use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
9use tonic::{IntoRequest, RawRequest};
10
11use crate::auth::{Auth, AuthParsed};
12use crate::content::UpdateFieldMask as _;
13use crate::error::{status_into_error, Error, NetError, SetupError, TonicTransportError};
14use crate::full_model_name;
15use crate::proto::model_service_client::ModelServiceClient;
16use crate::proto::{
17 cache_service_client::CacheServiceClient, generative_service_client::GenerativeServiceClient,
18 CachedContent, CreateCachedContentRequest, DeleteCachedContentRequest, GetCachedContentRequest,
19 ListCachedContentsRequest, UpdateCachedContentRequest,
20};
21use crate::proto::{
22 DeleteTunedModelRequest, GetModelRequest, GetTunedModelRequest, ListModelsRequest,
23 ListTunedModelsRequest, Model, TunedModel, UpdateTunedModelRequest,
24};
25
26const DEFAULT_TIMEOUT: Duration = Duration::from_secs(120);
28const BASE_API_URL: &str = "https://generativelanguage.googleapis.com";
30const DEFAULT_PAGE_SIZE: i32 = 0;
32const USER_AGENT: &str = "google-ai-rs/0.1 (Rust)";
34
35#[derive(Clone, Debug)]
53pub struct Client {
54 pub(super) gc: GenerativeServiceClient<Channel>,
56 pub(super) cc: CacheServiceClient<Channel>,
58 pub(super) mc: ModelServiceClient<Channel>,
59 #[cfg(feature = "auth_update")]
61 auth_update: Arc<RwLock<AuthParsed>>,
63}
64
65#[derive(Clone, Debug)]
97pub struct SharedClient {
98 inner: Arc<Client>,
99}
100
101impl Deref for SharedClient {
102 type Target = Client;
103
104 fn deref(&self) -> &Self::Target {
105 &self.inner
106 }
107}
108
109impl From<Client> for SharedClient {
110 fn from(value: Client) -> Self {
111 SharedClient {
112 inner: Arc::new(value),
113 }
114 }
115}
116
117impl Client {
118 pub async fn new(auth: impl Into<Auth> + Send) -> Result<Self, Error> {
126 ClientBuilder::new()
127 .timeout(DEFAULT_TIMEOUT)
128 .user_agent(USER_AGENT)
129 .unwrap()
130 .build(auth)
131 .await
132 }
133
134 pub fn builder() -> ClientBuilder {
136 ClientBuilder::new()
137 }
138
139 pub fn into_shared(self) -> SharedClient {
144 self.into()
145 }
146
147 #[cfg(feature = "auth_update")]
156 pub async fn update_auth(&self, new_auth: impl Into<Auth> + Send) {
157 self.update_auth_fallibly(new_auth)
158 .await
159 .expect("Auth parsing failed in update_auth — ensure input was valid")
160 }
161
162 #[cfg(feature = "auth_update")]
164 pub async fn update_auth_fallibly(
165 &self,
166 new_auth: impl Into<Auth> + Send,
167 ) -> Result<(), crate::auth::Error> {
168 *self.auth_update.write().await = new_auth.into().parsed()?;
169 Ok(())
170 }
171
172 pub async fn create_cached_content(
180 &self,
181 content: CachedContent,
182 ) -> Result<CachedContent, Error> {
183 if content.name.is_some() {
184 return Err(Error::InvalidArgument(
185 "CachedContent name must be empty for creation".into(),
186 ));
187 }
188
189 let request = CreateCachedContentRequest {
190 cached_content: Some(content),
191 }
192 .into_request();
193
194 self.cc
195 .clone()
196 .create_cached_content(request)
197 .await
198 .map_err(status_into_error)
199 .map(|r| r.into_inner())
200 }
201
202 pub async fn get_cached_content(&self, name: &str) -> Result<CachedContent, Error> {
204 let request = GetCachedContentRequest {
205 name: name.to_owned(),
206 }
207 .into_request();
208
209 self.cc
210 .clone()
211 .get_cached_content(request)
212 .await
213 .map_err(status_into_error)
214 .map(|r| r.into_inner())
215 }
216
217 pub async fn delete_cached_content(&self, name: &str) -> Result<(), Error> {
219 let request = DeleteCachedContentRequest {
220 name: name.to_owned(),
221 }
222 .into_request();
223
224 self.cc
225 .clone()
226 .delete_cached_content(request)
227 .await
228 .map_err(status_into_error)
229 .map(|r| r.into_inner())
230 }
231
232 pub async fn update_cached_content(&self, cc: &CachedContent) -> Result<CachedContent, Error> {
238 let request = UpdateCachedContentRequest {
239 cached_content: Some(cc.to_owned()),
240 update_mask: Some(cc.field_mask()),
241 }
242 .into_request();
243
244 self.cc
245 .clone()
246 .update_cached_content(request)
247 .await
248 .map_err(status_into_error)
249 .map(|r| r.into_inner())
250 }
251
252 pub fn list_cached_contents(&self) -> CachedContentIterator<'_> {
256 PageIterator::<CachedContentPager>::new(self)
257 }
258
259 pub async fn get_model(&self, name: &str) -> Result<Model, Error> {
262 let request = GetModelRequest {
263 name: full_model_name(name).to_string(),
264 }
265 .into_request();
266
267 self.mc
268 .clone()
269 .get_model(request)
270 .await
271 .map_err(status_into_error)
272 .map(|r| r.into_inner())
273 }
274
275 pub async fn get_tuned_model(&self, resource_name: &str) -> Result<TunedModel, Error> {
277 let request = GetTunedModelRequest {
278 name: resource_name.to_owned(),
279 }
280 .into_request();
281
282 self.mc
283 .clone()
284 .get_tuned_model(request)
285 .await
286 .map_err(status_into_error)
287 .map(|r| r.into_inner())
288 }
289
290 pub async fn list_models(&self) -> ModelsListIterator<'_> {
294 PageIterator::<ModelsListPager>::new(self)
295 }
296
297 pub async fn list_tuned_models(&self) -> TunedModelsListIterator<'_> {
301 PageIterator::<TunedModelsListPager>::new(self)
302 }
303
304 pub async fn update_tuned_model(&self, m: &TunedModel) -> Result<TunedModel, Error> {
306 let request = UpdateTunedModelRequest {
307 tuned_model: Some(m.to_owned()),
308 update_mask: Some(m.field_mask()),
309 }
310 .into_request();
311
312 self.mc
313 .clone()
314 .update_tuned_model(request)
315 .await
316 .map_err(status_into_error)
317 .map(|r| r.into_inner())
318 }
319
320 pub async fn delete_tuned_model(&self, name: &str) -> Result<(), Error> {
322 let request = DeleteTunedModelRequest {
323 name: name.to_owned(),
324 }
325 .into_request();
326
327 self.mc
328 .clone()
329 .delete_tuned_model(request)
330 .await
331 .map_err(status_into_error)
332 .map(|r| r.into_inner())
333 }
334}
335
336#[derive(Debug, Clone)]
337pub struct ClientBuilder {
338 endpoint: Endpoint,
339}
340
341impl Default for ClientBuilder {
342 fn default() -> Self {
343 Self::new()
344 }
345}
346
347impl ClientBuilder {
348 pub fn new() -> Self {
350 Self {
351 endpoint: Endpoint::from_static(BASE_API_URL),
352 }
353 }
354
355 pub fn timeout(mut self, duration: Duration) -> Self {
357 self.endpoint = self.endpoint.timeout(duration);
358 self
359 }
360
361 pub fn connect_timeout(mut self, duration: Duration) -> Self {
363 self.endpoint = self.endpoint.connect_timeout(duration);
364 self
365 }
366
367 pub fn user_agent(mut self, ua: impl Into<String>) -> Result<Self, Error> {
369 self.endpoint = self
370 .endpoint
371 .user_agent(ua.into())
372 .map_err(|e| SetupError::new("User-Agent configuration", e))?;
373 Ok(self)
374 }
375
376 pub fn concurrency_limit(mut self, limit: usize) -> Self {
378 self.endpoint = self.endpoint.concurrency_limit(limit);
379 self
380 }
381
382 pub async fn build_shared(self, auth: impl Into<Auth> + Send) -> Result<SharedClient, Error> {
384 self.build(auth).await.map(Into::into)
385 }
386
387 pub async fn build(self, auth: impl Into<Auth> + Send) -> Result<Client, Error> {
396 let endpoint = self
397 .endpoint
398 .tls_config(ClientTlsConfig::new().with_enabled_roots())
399 .map_err(|e| SetupError::new("TLS configuration", e))?;
400
401 let auth = auth.into().parsed()?;
403
404 #[cfg(feature = "auth_update")]
406 let auth = Arc::new(RwLock::new(auth));
407 let auth_update = auth.clone();
408
409 let auth_adder = async move |mut raw_request: RawRequest<Body>| {
412 #[cfg(not(feature = "auth_update"))]
413 let _jwt_fut = auth._into_request(raw_request.headers_mut());
414
415 #[cfg(feature = "auth_update")]
416 let binding = auth.read().await;
417 let _jwt_fut = binding.to_request(raw_request.headers_mut());
418
419 #[cfg(feature = "jwt")]
420 _jwt_fut.await;
421
422 raw_request
423 };
424
425 let channel = unsafe { endpoint.connect_with_modifier_fn(auth_adder) };
426
427 let channel = channel.await.map_err(|e| {
428 Error::Net(NetError::TransportFailure(TonicTransportError(Box::new(e))))
429 })?;
430
431 let client = Client {
432 gc: GenerativeServiceClient::new(channel.clone()),
433 cc: CacheServiceClient::new(channel.clone()),
434 mc: ModelServiceClient::new(channel),
435 #[cfg(feature = "auth_update")]
436 auth_update,
437 };
438
439 Ok(client)
440 }
441}
442
443#[derive(Clone, Debug)]
445pub(crate) enum CClient<'a> {
446 Shared(SharedClient),
447 Borrowed(&'a Client),
448}
449
450impl CClient<'_> {
451 pub(crate) fn cloned(&self) -> CClient<'_> {
452 match self {
453 CClient::Shared(shared_client) => CClient::Borrowed(shared_client),
454 CClient::Borrowed(client) => CClient::Borrowed(client),
455 }
456 }
457}
458
459#[allow(clippy::from_over_into)]
460impl Into<CClient<'static>> for SharedClient {
461 fn into(self) -> CClient<'static> {
462 CClient::Shared(self)
463 }
464}
465
466#[allow(clippy::from_over_into)]
467impl<'a> Into<CClient<'a>> for &'a Client {
468 fn into(self) -> CClient<'a> {
469 CClient::Borrowed(self)
470 }
471}
472
473impl Deref for CClient<'_> {
474 type Target = Client;
475
476 fn deref(&self) -> &Self::Target {
477 match self {
478 CClient::Shared(shared_client) => &shared_client.inner,
479 CClient::Borrowed(client) => client,
480 }
481 }
482}
483
484pub type CachedContentIterator<'a> = PageIterator<'a, CachedContentPager>;
486
487pub type ModelsListIterator<'a> = PageIterator<'a, ModelsListPager>;
489
490pub type TunedModelsListIterator<'a> = PageIterator<'a, TunedModelsListPager>;
492
493pub struct PageIterator<'a, P>
498where
499 P: Page + Send,
500{
501 client: &'a Client,
502 next_page_token: Option<String>,
503 buffer: VecDeque<P::Content>,
504}
505
506impl<'a, P> PageIterator<'a, P>
507where
508 P: Page + Send,
509{
510 fn new(client: &'a Client) -> Self {
511 Self {
512 client,
513 next_page_token: Some(String::new()),
514 buffer: VecDeque::new(),
515 }
516 }
517
518 pub async fn next(&mut self) -> Result<Option<P::Content>, Error> {
522 if self.buffer.is_empty() {
523 if self.next_page_token.is_none() {
524 return Ok(None);
526 }
527
528 let (items, next_token) =
529 P::next(self.client, self.next_page_token.as_ref().unwrap()).await?;
530
531 self.next_page_token = if next_token.is_empty() {
532 None
533 } else {
534 Some(next_token)
535 };
536 self.buffer.extend(items);
537 }
538
539 Ok(self.buffer.pop_front())
540 }
541}
542
543#[tonic::async_trait]
544pub trait Page: sealed::Sealed {
545 type Content;
546 async fn next(client: &Client, page_token: &str)
548 -> Result<(Vec<Self::Content>, String), Error>;
549}
550
551impl<T> sealed::Sealed for T {}
552
553mod sealed {
554 pub trait Sealed {}
555}
556
557pub struct CachedContentPager;
558
559#[tonic::async_trait]
560impl Page for CachedContentPager {
561 type Content = CachedContent;
562
563 async fn next(
564 client: &Client,
565 page_token: &str,
566 ) -> Result<(Vec<Self::Content>, String), Error> {
567 let request = ListCachedContentsRequest {
568 page_size: DEFAULT_PAGE_SIZE,
569 page_token: page_token.to_owned(),
570 }
571 .into_request();
572
573 let response = client
574 .cc
575 .clone()
576 .list_cached_contents(request)
577 .await
578 .map_err(status_into_error)?
579 .into_inner();
580 Ok((response.cached_contents, response.next_page_token))
581 }
582}
583
584pub struct ModelsListPager;
585
586#[tonic::async_trait]
587impl Page for ModelsListPager {
588 type Content = Model;
589
590 async fn next(
591 client: &Client,
592 page_token: &str,
593 ) -> Result<(Vec<Self::Content>, String), Error> {
594 let request = ListModelsRequest {
595 page_size: DEFAULT_PAGE_SIZE,
596 page_token: page_token.to_owned(),
597 }
598 .into_request();
599
600 let response = client
601 .mc
602 .clone()
603 .list_models(request)
604 .await
605 .map_err(status_into_error)?
606 .into_inner();
607 Ok((response.models, response.next_page_token))
608 }
609}
610
611pub struct TunedModelsListPager;
612
613#[tonic::async_trait]
614impl Page for TunedModelsListPager {
615 type Content = TunedModel;
616
617 async fn next(
618 client: &Client,
619 page_token: &str,
620 ) -> Result<(Vec<Self::Content>, String), Error> {
621 let request = ListTunedModelsRequest {
622 page_size: DEFAULT_PAGE_SIZE,
623 page_token: page_token.to_owned(),
624 filter: String::new(),
625 }
626 .into_request();
627
628 let response = client
629 .mc
630 .clone()
631 .list_tuned_models(request)
632 .await
633 .map_err(status_into_error)?
634 .into_inner();
635 Ok((response.tuned_models, response.next_page_token))
636 }
637}