google_ai_rs/
genai.rs

1use std::{
2    fmt::Debug,
3    io::Write,
4    ops::{Deref, DerefMut},
5};
6use tokio::io::AsyncWrite;
7use tonic::{IntoRequest, Streaming};
8
9use crate::{
10    client::{CClient, Client, SharedClient},
11    content::{IntoContent, TryFromCandidates, TryIntoContents},
12    error::{status_into_error, ActionError, Error},
13    full_model_name,
14    schema::AsSchema,
15};
16
17pub use crate::proto::{
18    safety_setting::HarmBlockThreshold, CachedContent, Content, CountTokensRequest,
19    CountTokensResponse, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
20    HarmCategory, Model, SafetySetting, Schema, Tool, ToolConfig, TunedModel,
21};
22
23/// Type-safe wrapper for [`GenerativeModel`] guaranteeing response type `T`.
24///
25/// This type enforces schema contracts through Rust's type system while maintaining
26/// compatibility with Google's Generative AI API. Use when:
27/// - You need structured output from the model
28/// - Response schema stability is critical
29/// - You want compile-time validation of response handling
30///
31/// # Example
32/// ```
33/// use google_ai_rs::{Client, GenerativeModel, AsSchema};
34/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
35/// # let auth = "YOUR-API-KEY";
36/// # use std::collections::HashMap;
37///
38/// #[derive(AsSchema)]
39/// struct Recipe {
40///     name: String,
41///     ingredients: Vec<String>,
42/// }
43///
44/// let client = Client::new(auth).await?;
45/// let model = client.typed_model::<Recipe>("gemini-pro");
46/// # Ok(())
47/// # }
48pub struct TypedModel<'c, T> {
49    inner: GenerativeModel<'c>,
50    _marker: PhantomInvariant<T>,
51}
52
53impl<T> Debug for TypedModel<'_, T> {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        self.inner.fmt(f)
56    }
57}
58
59impl<T> Clone for TypedModel<'_, T> {
60    fn clone(&self) -> Self {
61        Self {
62            inner: self.inner.clone(),
63            _marker: PhantomInvariant(std::marker::PhantomData),
64        }
65    }
66}
67
68// std is unstable
69struct PhantomInvariant<T>(std::marker::PhantomData<fn(T) -> T>);
70
71impl<'c, T> TypedModel<'c, T>
72where
73    T: AsSchema,
74{
75    /// Creates a new typed model configured to return responses of type `T`.
76    ///
77    /// # Arguments
78    /// - `client`: Authenticated API client.
79    /// - `name`: Model name (e.g., "gemini-pro").
80    pub fn new(client: &'c Client, name: &str) -> Self {
81        let inner = GenerativeModel::new(client, name).as_response_schema::<T>();
82        Self {
83            inner,
84            _marker: PhantomInvariant(std::marker::PhantomData),
85        }
86    }
87
88    fn new_inner(client: impl Into<CClient<'c>>, name: &str) -> Self {
89        let inner = GenerativeModel::new_inner(client, name).as_response_schema::<T>();
90        Self {
91            inner,
92            _marker: PhantomInvariant(std::marker::PhantomData),
93        }
94    }
95
96    /// Generates content with full response metadata.
97    ///
98    /// This method clones the model configuration and returns a `TypedResponse`,
99    /// containing both the parsed `T` and the raw API response.
100    ///
101    /// # Example
102    /// ```rust,ignore
103    /// # use google_ai_rs::{AsSchema, Client, TypedModel, TypedResponse};
104    /// # #[derive(AsSchema, serde::Deserialize, Debug)] struct StockAnalysis;
105    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
106    /// # let client = Client::new("api-key").await?;
107    /// let model = TypedModel::<StockAnalysis>::new(&client, "gemini-pro");
108    /// let analysis: TypedResponse<StockAnalysis> = model.generate_typed_content((
109    ///     "Analyze NVDA stock performance",
110    ///     "Consider PE ratio and recent earnings"
111    /// )).await?;
112    /// println!("Analysis: {:?}", analysis.t);
113    /// println!("Token Usage: {:?}", analysis.raw.usage_metadata);
114    /// # Ok(()) }
115    /// ```
116    #[inline]
117    pub async fn generate_typed_content<I>(&self, contents: I) -> Result<TypedResponse<T>, Error>
118    where
119        I: TryIntoContents + Send,
120        T: TryFromCandidates + Send,
121    {
122        self.cloned()
123            .generate_typed_content_consuming(contents)
124            .await
125    }
126
127    /// Generates content with metadata, consuming the model instance.
128    ///
129    /// An efficient alternative to `generate_typed_content` that avoids cloning
130    /// the model configuration, useful for one-shot requests.
131    #[inline]
132    pub async fn generate_typed_content_consuming<I>(
133        self,
134        contents: I,
135    ) -> Result<TypedResponse<T>, Error>
136    where
137        I: TryIntoContents + Send,
138        T: TryFromCandidates + Send,
139    {
140        let response = self.inner.generate_content_consuming(contents).await?;
141        let t = T::try_from_candidates(&response.candidates)?;
142        Ok(TypedResponse { t, raw: response })
143    }
144
145    /// Generates content and parses it directly into type `T`.
146    ///
147    /// This is the primary method for most users wanting type-safe responses.
148    /// It handles all the details of requesting structured JSON and deserializing
149    /// it into your specified Rust type. It clones the model configuration to allow reuse.
150    ///
151    /// # Serde Integration
152    /// When the `serde` feature is enabled, any type implementing `serde::Deserialize`
153    /// automatically works with this method. Just define your response structure and
154    /// let the library handle parsing.
155    ///
156    /// # Example: Simple JSON Response
157    /// ```rust,ignore
158    /// # use google_ai_rs::{AsSchema, Client, TypedModel};
159    /// # use serde::Deserialize;
160    /// #[derive(AsSchema, Deserialize)]
161    /// struct StoryResponse {
162    ///     title: String,
163    ///     length: usize,
164    ///     tags: Vec<String>,
165    /// }
166    ///
167    /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
168    /// # let client = Client::new("key".into()).await?;
169    /// let model = TypedModel::<StoryResponse>::new(&client, "gemini-pro");
170    /// let story = model.generate_content("Write a short story about a robot astronaut").await?;
171    ///
172    /// println!("{} ({} words)", story.title, story.length);
173    /// # Ok(())
174    /// # }
175    /// ```
176    ///
177    /// # Example: Multi-part Input
178    /// ```rust,ignore
179    /// # use google_ai_rs::{AsSchema, Client, TypedModel, Part};
180    /// # use serde::Deserialize;
181    /// #[derive(AsSchema, Deserialize)]
182    /// struct Analysis { safety_rating: u8 }
183    ///
184    /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
185    /// # let client = Client::new("key".into()).await?;
186    /// # let image_data = vec![];
187    /// let model = TypedModel::<Analysis>::new(&client, "gemini-pro-vision");
188    /// let analysis = model.generate_content((
189    ///     "Analyze this scene safety:",
190    ///     Part::blob("image/jpeg", image_data),
191    ///     "Consider vehicles, pedestrians, and weather"
192    /// )).await?;
193    /// # Ok(())
194    /// # }
195    /// ```
196    ///
197    /// # Errors
198    /// Returns an error if API communication fails or if the response cannot be
199    /// parsed into type `T`.
200    #[inline]
201    pub async fn generate_content<I>(&self, contents: I) -> Result<T, Error>
202    where
203        I: TryIntoContents + Send,
204        T: TryFromCandidates + Send,
205    {
206        self.cloned().generate_content_consuming(contents).await
207    }
208
209    #[inline]
210    pub async fn generate_content_consuming<I>(self, contents: I) -> Result<T, Error>
211    where
212        I: TryIntoContents + Send,
213        T: TryFromCandidates + Send,
214    {
215        let response = self.inner.generate_content_consuming(contents).await?;
216        let t = T::try_from_candidates(&response.candidates)?;
217        Ok(t)
218    }
219
220    /// Consumes the `TypedModel`, returning the underlying `GenerativeModel`.
221    ///
222    /// The returned `GenerativeModel` will retain the response schema configuration
223    /// that was set for type `T`.
224    pub fn into_inner(self) -> GenerativeModel<'c> {
225        self.inner
226    }
227
228    /// Creates a `TypedModel` from a `GenerativeModel` without validation.
229    ///
230    /// This is an advanced-use method that assumes the provided `GenerativeModel`
231    /// has already been configured with a response schema that is compatible with `T`.
232    ///
233    /// # Safety
234    /// The caller must ensure that `inner.generation_config.response_schema` is `Some`
235    /// and that its schema corresponds exactly to the schema of type `T`. Failure to
236    /// uphold this invariant will likely result in API errors or deserialization failures.
237    pub unsafe fn from_inner_unchecked(inner: GenerativeModel<'c>) -> Self {
238        Self {
239            inner,
240            _marker: PhantomInvariant(std::marker::PhantomData),
241        }
242    }
243
244    fn cloned(&self) -> TypedModel<'_, T> {
245        TypedModel {
246            inner: self.inner.cloned(),
247            _marker: PhantomInvariant(std::marker::PhantomData),
248        }
249    }
250}
251
252impl<'c, T> Deref for TypedModel<'c, T> {
253    type Target = GenerativeModel<'c>;
254
255    fn deref(&self) -> &Self::Target {
256        &self.inner
257    }
258}
259
260impl<'c, T> From<GenerativeModel<'c>> for TypedModel<'c, T>
261where
262    T: AsSchema,
263{
264    fn from(value: GenerativeModel<'c>) -> Self {
265        let inner = value.as_response_schema::<T>();
266        TypedModel {
267            inner,
268            _marker: PhantomInvariant(std::marker::PhantomData),
269        }
270    }
271}
272
273/// Container for typed responses with raw API data.
274///
275/// Preserves full response details while providing parsed content.
276pub struct TypedResponse<T> {
277    /// Parsed content of type `T`    
278    pub t: T,
279    /// Raw API response structure    
280    pub raw: GenerateContentResponse,
281}
282
283impl<T> Debug for TypedResponse<T>
284where
285    T: Debug,
286{
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        self.t.fmt(f)
289    }
290}
291
292impl<T> Deref for TypedResponse<T> {
293    type Target = T;
294
295    fn deref(&self) -> &Self::Target {
296        &self.t
297    }
298}
299
300impl<T> DerefMut for TypedResponse<T> {
301    fn deref_mut(&mut self) -> &mut Self::Target {
302        &mut self.t
303    }
304}
305
306/// Configured interface for a specific generative AI model
307///
308/// # Example
309/// ```
310/// use google_ai_rs::{Client, GenerativeModel};
311///
312/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
313/// # let auth = "YOUR-API-KEY";
314/// let client = Client::new(auth).await?;
315/// let model = client.generative_model("gemini-pro")
316///     .with_system_instruction("You are a helpful assistant")
317///     .with_response_format("application/json");
318/// # Ok(())
319/// # }
320/// ```
321#[derive(Clone, Debug)]
322pub struct GenerativeModel<'c> {
323    /// Backing API clienty
324    pub(super) client: CClient<'c>,
325    /// Fully qualified model name (e.g., "models/gemini-1.0-pro")
326    model_name: Box<str>,
327    /// System prompt guiding model behavior
328    pub system_instruction: Option<Content>,
329    /// Available functions/tools the model can use
330    pub tools: Option<Vec<Tool>>,
331    /// Configuration for tool usage
332    pub tool_config: Option<ToolConfig>,
333    /// Content safety filters
334    pub safety_settings: Option<Vec<SafetySetting>>,
335    /// Generation parameters (temperature, top-k, etc.)
336    pub generation_config: Option<GenerationConfig>,
337    /// Fullname of the cached content to use as context
338    /// (e.g., "cachedContents/NAME")
339    pub cached_content: Option<Box<str>>,
340}
341
342impl<'c> GenerativeModel<'c> {
343    /// Creates a new model interface with default configuration
344    ///
345    /// # Arguments
346    /// * `client` - Authenticated API client
347    /// * `name` - Model identifier (e.g., "gemini-pro")
348    ///
349    /// To access a tuned model named NAME, pass "tunedModels/NAME".
350    pub fn new(client: &'c Client, name: &str) -> Self {
351        Self::new_inner(client, name)
352    }
353
354    fn new_inner(client: impl Into<CClient<'c>>, name: &str) -> Self {
355        Self {
356            client: client.into(),
357            model_name: full_model_name(name).into(),
358            system_instruction: None,
359            tools: None,
360            tool_config: None,
361            safety_settings: None,
362            generation_config: None,
363            cached_content: None,
364        }
365    }
366
367    /// Converts this `GenerativeModel` into a `TypedModel`.
368    ///
369    /// This prepares the model to return responses that are automatically
370    /// parsed into the specified type `T`.
371    pub fn to_typed<T: AsSchema>(self) -> TypedModel<'c, T> {
372        self.into()
373    }
374
375    /// Generates content from flexible input types.
376    ///
377    /// This method clones the model's configuration for the request, allowing the original
378    /// `GenerativeModel` instance to be reused.
379    ///
380    /// # Example
381    /// ```
382    /// # use google_ai_rs::{Client, GenerativeModel};
383    /// use google_ai_rs::Part;
384    ///
385    /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
386    /// # let auth = "YOUR-API-KEY";
387    /// # let client = Client::new(auth).await?;
388    /// # let model = client.generative_model("gemini-pro");
389    /// // Simple text generation
390    /// let response = model.generate_content("Hello world!").await?;
391    ///
392    /// // Multi-part content
393    /// # let image_data = vec![];
394    /// model.generate_content((
395    ///     "What's in this image?",
396    ///     Part::blob("image/jpeg", image_data)
397    /// )).await?;
398    /// # Ok(())
399    /// # }
400    /// ```
401    ///
402    /// # Errors
403    /// Returns [`Error::Service`] for model errors or [`Error::Net`] for transport failures.
404    pub async fn generate_content<T>(&self, contents: T) -> Result<GenerateContentResponse, Error>
405    where
406        T: TryIntoContents,
407    {
408        self.cloned().generate_content_consuming(contents).await
409    }
410
411    /// Generates content by consuming the model instance.
412    ///
413    /// This is an efficient alternative to `generate_content` if you don't need to reuse the
414    /// model instance, as it avoids cloning the model's configuration. This is useful
415    /// for one-shot requests where the model is built, used, and then discarded.
416    pub async fn generate_content_consuming<T>(
417        self,
418        contents: T,
419    ) -> Result<GenerateContentResponse, Error>
420    where
421        T: TryIntoContents,
422    {
423        let mut gc = self.client.gc.clone();
424        let request = self.build_request(contents)?;
425        gc.generate_content(request)
426            .await
427            .map_err(status_into_error)
428            .map(|r| r.into_inner())
429    }
430
431    /// A convenience method to generate a structured response of type `T`.
432    ///
433    /// This method internally converts the `GenerativeModel` to a `TypedModel<T>`,
434    /// makes the request, and returns the parsed result directly. It clones the model
435    /// configuration for the request.
436    ///
437    /// For repeated calls with the same target type, it may be more efficient to create a
438    /// `TypedModel` instance directly.
439    pub async fn typed_generate_content<I, T>(&self, contents: I) -> Result<T, Error>
440    where
441        I: TryIntoContents + Send,
442        T: AsSchema + TryFromCandidates + Send,
443    {
444        // Cloning occurs just this once with owned_generate_content
445        self.cloned()
446            .to_typed()
447            .generate_content_consuming(contents)
448            .await
449    }
450
451    /// A convenience method to generate a structured response with metadata.
452    ///
453    /// Similar to `typed_generate_content`, but returns a `TypedResponse<T>` which includes
454    /// both the parsed data and the raw API response metadata.
455    pub async fn generate_typed_content<I, T>(&self, contents: I) -> Result<TypedResponse<T>, Error>
456    where
457        I: TryIntoContents + Send,
458        T: AsSchema + TryFromCandidates + Send,
459    {
460        self.cloned()
461            .to_typed()
462            .generate_typed_content_consuming(contents)
463            .await
464    }
465
466    /// Generates a streaming response from flexible input.
467    ///
468    /// This method clones the model's configuration for the request, allowing the original
469    /// `GenerativeModel` instance to be reused.
470    ///
471    /// # Example
472    /// ```
473    /// # use google_ai_rs::{Client, GenerativeModel};
474    /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
475    /// # let auth = "YOUR-API-KEY";
476    /// # let client = Client::new(auth).await?;
477    /// # let model = client.generative_model("gemini-pro");
478    /// let mut stream = model.stream_generate_content("Tell me a story.").await?;
479    /// while let Some(chunk) = stream.next().await? {
480    ///     // Process streaming response
481    /// }
482    /// # Ok(())
483    /// # }
484    /// ```
485    ///
486    /// # Errors
487    /// Returns [`Error::Service`] for model errors or [`Error::Net`] for transport failures.
488    pub async fn stream_generate_content<T>(&self, contents: T) -> Result<ResponseStream, Error>
489    where
490        T: TryIntoContents,
491    {
492        self.cloned()
493            .stream_generate_content_consuming(contents)
494            .await
495    }
496
497    /// Generates a streaming response by consuming the model instance.
498    ///
499    /// This is an efficient alternative to `stream_generate_content` if you don't need to
500    /// reuse the model instance, as it avoids cloning the model's configuration.
501    pub async fn stream_generate_content_consuming<T>(
502        self,
503        contents: T,
504    ) -> Result<ResponseStream, Error>
505    where
506        T: TryIntoContents,
507    {
508        let mut gc = self.client.gc.clone();
509        let request = self.build_request(contents)?;
510        gc.stream_generate_content(request)
511            .await
512            .map_err(status_into_error)
513            .map(|s| ResponseStream(s.into_inner()))
514    }
515
516    /// Estimates token usage for given content
517    ///
518    /// Useful for cost estimation and validation before full generation
519    ///
520    /// # Arguments
521    /// * `parts` - Content input that can be converted to parts
522    ///
523    /// # Example
524    /// ```
525    /// # use google_ai_rs::{Client, GenerativeModel};
526    /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
527    /// # let auth = "YOUR-API-KEY";
528    /// # let client = Client::new(auth).await?;
529    /// # let model = client.generative_model("gemini-pro");
530    /// # let content = "";
531    /// let token_count = model.count_tokens(content).await?;
532    /// # const COST_PER_TOKEN: f64 = 1.0;
533    /// println!("Estimated cost: ${}", token_count.total() * COST_PER_TOKEN);
534    /// # Ok(())
535    /// # }
536    /// ```
537    pub async fn count_tokens<T>(&self, contents: T) -> Result<CountTokensResponse, Error>
538    where
539        T: TryIntoContents,
540    {
541        let mut gc = self.client.gc.clone();
542
543        // Builds token counting request
544        let request = CountTokensRequest {
545            model: self.model_name.to_string(),
546            contents: vec![],
547            generate_content_request: Some(self.clone().build_request(contents)?),
548        };
549
550        gc.count_tokens(request)
551            .await
552            .map_err(status_into_error)
553            .map(|r| r.into_inner())
554    }
555
556    /// info returns information about the model.
557    ///
558    /// `Info::Tuned` if the current model is a fine-tuned one,
559    /// otherwise `Info::Model`.
560    pub async fn info(&self) -> Result<Info, Error> {
561        if self.model_name.starts_with("tunedModels") {
562            Ok(Info::Tuned(
563                self.client.get_tuned_model(&self.model_name).await?,
564            ))
565        } else {
566            Ok(Info::Model(self.client.get_model(&self.model_name).await?))
567        }
568    }
569
570    /// Changes the model identifier for this instance in place.
571    pub fn change_model(&mut self, to: &str) {
572        self.model_name = full_model_name(to).into()
573    }
574
575    /// Returns the full identifier of the model, including any `models/` prefix.
576    pub fn full_name(&self) -> &str {
577        &self.model_name
578    }
579
580    // Builder pattern methods
581    // -----------------------------------------------------------------
582
583    /// Sets system-level behavior instructions
584    pub fn with_system_instruction<I: IntoContent>(mut self, instruction: I) -> Self {
585        self.system_instruction = Some(instruction.into_content());
586        self
587    }
588
589    /// Changes the model identifier, returning the modified instance.
590    pub fn to_model(mut self, to: &str) -> Self {
591        self.change_model(to);
592        self
593    }
594
595    /// Sets cached content for persisted context
596    ///
597    /// # Example
598    /// ```
599    /// # use google_ai_rs::{Client, GenerativeModel};
600    /// use google_ai_rs::content::IntoContents as _;
601    ///
602    /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
603    /// # let auth = "YOUR-API-KEY";
604    /// # let client = Client::new(auth).await?;
605    /// let content = "You are a helpful assistant".into_cached_content_for("gemini-1.0-pro");
606    ///
607    /// let cached_content = client.create_cached_content(content).await?;
608    /// let model = client.generative_model("gemini-pro")
609    ///             .with_cached_content(&cached_content);
610    /// # Ok(())
611    /// # }
612    /// ```
613    pub fn with_cached_content(mut self, c: &CachedContent) -> Result<Self, Error> {
614        self.cached_content = Some(
615            c.name
616                .as_deref()
617                .ok_or(Error::InvalidArgument(
618                    "cached content name is empty".into(),
619                ))?
620                .into(),
621        );
622        Ok(self)
623    }
624
625    /// Specifies expected response format (e.g., "application/json")
626    pub fn with_response_format(mut self, mime_type: &str) -> Self {
627        self.generation_config
628            .get_or_insert_with(Default::default)
629            .response_mime_type = mime_type.into();
630        self
631    }
632
633    /// Configures the model to respond with a schema matching the type `T`.
634    ///
635    /// This is a convenient way to get structured JSON output.
636    ///
637    /// # Example
638    /// ```rust
639    /// use google_ai_rs::AsSchema;
640    ///
641    /// #[derive(Debug, AsSchema)]
642    /// #[schema(description = "A primary colour")]
643    /// struct PrimaryColor {
644    ///     #[schema(description = "The name of the colour")]
645    ///     name: String,
646    ///
647    ///     #[schema(description = "The RGB value of the color, in hex")]
648    ///     #[schema(rename = "RGB")]
649    ///     rgb: String
650    /// }
651    ///
652    /// # use google_ai_rs::{Client, GenerativeModel};
653    /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
654    /// # let auth = "YOUR-API-KEY";
655    /// # let client = Client::new(auth).await?;
656    /// let model = client.generative_model("gemini-pro")
657    ///     .as_response_schema::<Vec<PrimaryColor>>();
658    /// # Ok(())
659    /// # }
660    /// ```
661    pub fn as_response_schema<T: AsSchema>(self) -> Self {
662        self.with_response_schema(T::as_schema())
663    }
664
665    /// Set response schema with explicit Schema object
666    ///
667    /// Use when you need full control over schema details. Automatically
668    /// sets response format to JSON if not specified.
669    ///
670    /// # Example
671    ///
672    /// ```rust
673    /// use google_ai_rs::Schema;
674    /// use google_ai_rs::SchemaType;
675    ///
676    /// # use google_ai_rs::{Client, GenerativeModel};
677    /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
678    /// # let auth = "YOUR-API-KEY";
679    /// # let client = Client::new(auth).await?;
680    /// let model = client.generative_model("gemini-pro")
681    ///      .with_response_schema(Schema {
682    ///         r#type: SchemaType::String as i32,
683    ///         format: "enum".into(),
684    ///         ..Default::default()
685    /// });
686    /// # Ok(())
687    /// # }
688    /// ```
689    pub fn with_response_schema(mut self, schema: Schema) -> Self {
690        let c = self.generation_config.get_or_insert_with(Default::default);
691        if c.response_mime_type.is_empty() {
692            c.response_mime_type = "application/json".into();
693        }
694        c.response_schema = Some(schema);
695        self
696    }
697
698    /// Adds a collection of tools to the model.
699    ///
700    /// Tools define external functions that the model can call.
701    ///
702    /// # Arguments
703    /// * `tools` - An iterator of `Tool` instances.
704    pub fn tools<I>(mut self, tools: I) -> Self
705    where
706        I: IntoIterator<Item = Tool>,
707    {
708        self.tools = Some(tools.into_iter().collect());
709        self
710    }
711
712    /// Configures how the model uses tools.
713    ///
714    /// # Arguments
715    /// * `tool_config` - The configuration for tool usage.
716    pub fn tool_config(mut self, tool_config: impl Into<ToolConfig>) -> Self {
717        self.tool_config = Some(tool_config.into());
718        self
719    }
720
721    /// Applies content safety filters to the model.
722    ///
723    /// Safety settings control the probability thresholds for filtering
724    /// potentially harmful content.
725    ///
726    /// # Arguments
727    /// * `safety_settings` - An iterator of `SafetySetting` instances.
728    pub fn safety_settings<I>(mut self, safety_settings: I) -> Self
729    where
730        I: IntoIterator<Item = SafetySetting>,
731    {
732        self.safety_settings = Some(safety_settings.into_iter().collect());
733        self
734    }
735
736    /// Sets the generation parameters for the model.
737    ///
738    /// This includes settings like `temperature`, `top_k`, and `top_p`
739    /// to control the creativity and randomness of the model's output.
740    ///
741    /// # Arguments
742    /// * `generation_config` - The configuration for generation.
743    pub fn generation_config(mut self, generation_config: impl Into<GenerationConfig>) -> Self {
744        self.generation_config = Some(generation_config.into());
745        self
746    }
747
748    /// Creates a copy with new system instructions
749    pub fn with_cloned_instruction<I: IntoContent>(&self, instruction: I) -> Self {
750        let mut clone = self.clone();
751
752        clone.system_instruction = Some(instruction.into_content());
753        clone
754    }
755
756    /// Sets the number of candidates to generate.
757    ///
758    /// This parameter specifies how many different response candidates the model should generate
759    /// for a given prompt. The model will then select the best one based on its internal
760    /// evaluation.
761    pub fn candidate_count(mut self, x: i32) -> Self {
762        self.set_candidate_count(x);
763        self
764    }
765
766    /// Sets the maximum number of output tokens.
767    ///
768    /// This parameter caps the length of the generated response, measured in tokens.
769    /// It's useful for controlling response size and preventing excessively long outputs.
770    pub fn max_output_tokens(mut self, x: i32) -> Self {
771        self.set_max_output_tokens(x);
772        self
773    }
774
775    /// Sets the temperature for generation.
776    ///
777    /// Temperature controls the randomness of the output. Higher values, like 1.0,
778    /// make the output more creative and unpredictable, while lower values, like 0.1,
779    /// make it more deterministic and focused.
780    pub fn temperature(mut self, x: f32) -> Self {
781        self.set_temperature(x);
782        self
783    }
784
785    /// Sets the top-p sampling parameter.
786    ///
787    /// Top-p (also known as nucleus sampling) chooses the smallest set of most likely
788    /// tokens whose cumulative probability exceeds the value of `x`. This technique
789    /// helps to prevent low-probability, nonsensical tokens from being chosen.
790    pub fn top_p(mut self, x: f32) -> Self {
791        self.set_top_p(x);
792        self
793    }
794
795    /// Sets the top-k sampling parameter.
796    ///
797    /// Top-k restricts the model's token selection to the `k` most likely tokens at
798    /// each step. It's a method for controlling the model's creativity and focus.
799    pub fn top_k(mut self, x: i32) -> Self {
800        self.set_top_k(x);
801        self
802    }
803
804    /// Sets the number of candidates to generate.
805    ///
806    /// This parameter specifies how many different response candidates the model should generate
807    /// for a given prompt. The model will then select the best one based on its internal
808    /// evaluation.
809    pub fn set_candidate_count(&mut self, x: i32) {
810        self.generation_config
811            .get_or_insert_default()
812            .candidate_count = Some(x)
813    }
814
815    /// Sets the maximum number of output tokens.
816    ///
817    /// This parameter caps the length of the generated response, measured in tokens.
818    /// It's useful for controlling response size and preventing excessively long outputs.
819    pub fn set_max_output_tokens(&mut self, x: i32) {
820        self.generation_config
821            .get_or_insert_default()
822            .max_output_tokens = Some(x)
823    }
824
825    /// Sets the temperature for generation.
826    ///
827    /// Temperature controls the randomness of the output. Higher values, like 1.0,
828    /// make the output more creative and unpredictable, while lower values, like 0.1,
829    /// make it more deterministic and focused.
830    pub fn set_temperature(&mut self, x: f32) {
831        self.generation_config.get_or_insert_default().temperature = Some(x)
832    }
833
834    /// Sets the top-p sampling parameter.
835    ///
836    /// Top-p (also known as nucleus sampling) chooses the smallest set of most likely
837    /// tokens whose cumulative probability exceeds the value of `x`. This technique
838    /// helps to prevent low-probability, nonsensical tokens from being chosen.
839    pub fn set_top_p(&mut self, x: f32) {
840        self.generation_config.get_or_insert_default().top_p = Some(x)
841    }
842
843    /// Sets the top-k sampling parameter.
844    ///
845    /// Top-k restricts the model's token selection to the `k` most likely tokens at
846    /// each step. It's a method for controlling the model's creativity and focus.
847    pub fn set_top_k(&mut self, x: i32) {
848        self.generation_config.get_or_insert_default().top_k = Some(x)
849    }
850
851    #[inline(always)]
852    fn build_request(
853        self,
854        contents: impl TryIntoContents,
855    ) -> Result<GenerateContentRequest, Error> {
856        let contents = contents.try_into_contents()?;
857        Ok(GenerateContentRequest {
858            model: self.model_name.into(),
859            contents,
860            system_instruction: self.system_instruction,
861            tools: self.tools.unwrap_or_default(),
862            tool_config: self.tool_config,
863            safety_settings: self.safety_settings.unwrap_or_default(),
864            generation_config: self.generation_config,
865            cached_content: self.cached_content.map(|c| c.into()),
866        })
867    }
868
869    // This is to avoid the performance overhead while cloning
870    // SharedClient - Arc backed. Insignificant but unnecessary.
871    fn cloned(&self) -> GenerativeModel<'_> {
872        GenerativeModel {
873            client: self.client.cloned(),
874            ..Clone::clone(self)
875        }
876    }
877}
878
879impl SafetySetting {
880    /// Creates a new [`SafetySetting`] with default values
881    pub fn new() -> Self {
882        Self {
883            category: 0,
884            threshold: 0,
885        }
886    }
887
888    /// Set the category for this setting
889    pub fn harm_category(mut self, category: HarmCategory) -> Self {
890        self.category = category.into();
891        self
892    }
893
894    /// Control the probability threshold at which harm is blocked
895    pub fn harm_threshold(mut self, threshold: HarmBlockThreshold) -> Self {
896        self.threshold = threshold.into();
897        self
898    }
899}
900
901/// Generation response containing model output and metadata
902pub type Response = GenerateContentResponse;
903
904impl Response {
905    /// Total tokens used in request/response cycle
906    pub fn total_tokens(&self) -> f64 {
907        // FIXME: I'm confused
908        self.usage_metadata.as_ref().map_or(0.0, |meta| {
909            meta.total_token_count as f64 + meta.cached_content_token_count as f64
910        })
911    }
912}
913
914/// Streaming response handler implementing async iteration
915pub struct ResponseStream(Streaming<GenerateContentResponse>);
916
917impl ResponseStream {
918    /// Streams content chunks to any `Write` implementer
919    ///
920    /// # Arguments
921    /// * `writer` - Target for streaming output
922    ///
923    /// # Returns
924    /// Total bytes written
925    pub async fn write_to<W: Write>(&mut self, writer: &mut W) -> Result<usize, Error> {
926        let mut total = 0;
927
928        while let Some(response) = self
929            .next()
930            .await
931            .map_err(|e| Error::Stream(ActionError::Error(e.into())))?
932        {
933            let bytes = response.try_into_bytes()?;
934            let written = writer
935                .write(&bytes)
936                .map_err(|e| Error::Stream(ActionError::Action(e)))?;
937            total += written;
938        }
939
940        Ok(total)
941    }
942
943    /// Streams content chunks to any `AsyncWrite` implementer
944    ///
945    /// # Returns
946    /// Total bytes written
947    pub async fn write_to_sync<W: AsyncWrite + std::marker::Unpin>(
948        &mut self,
949        dst: &mut W,
950    ) -> Result<usize, Error> {
951        use tokio::io::AsyncWriteExt;
952
953        let mut total = 0;
954
955        while let Some(response) = self
956            .next()
957            .await
958            .map_err(|e| Error::Stream(ActionError::Error(e.into())))?
959        {
960            let bytes = response.try_into_bytes()?;
961            let written = dst
962                .write(&bytes)
963                .await
964                .map_err(|e| Error::Stream(ActionError::Action(e)))?;
965            total += written;
966        }
967
968        Ok(total)
969    }
970
971    /// Fetches next response chunk
972    pub async fn next(&mut self) -> Result<Option<GenerateContentResponse>, Error> {
973        self.0.message().await.map_err(status_into_error)
974    }
975}
976
977impl Client {
978    /// Creates a new generative model interface
979    ///
980    /// Shorthand for `GenerativeModel::new()`
981    pub fn generative_model<'c>(&'c self, name: &str) -> GenerativeModel<'c> {
982        GenerativeModel::new_inner(self, name)
983    }
984
985    /// Creates a new typed generative model interface
986    ///
987    /// Shorthand for `TypedModel::new()`
988    pub fn typed_model<'c, T: AsSchema>(&'c self, name: &str) -> TypedModel<'c, T> {
989        TypedModel::<T>::new_inner(self, name)
990    }
991}
992
993impl SharedClient {
994    /// Creates a new generative model interface
995    pub fn generative_model(&self, name: &str) -> GenerativeModel<'static> {
996        GenerativeModel::new_inner(self.clone(), name)
997    }
998
999    /// Creates a new typed generative model interface
1000    pub fn typed_model<T: AsSchema>(&self, name: &str) -> TypedModel<'static, T> {
1001        TypedModel::<T>::new_inner(self.clone(), name)
1002    }
1003}
1004
1005impl CountTokensResponse {
1006    pub fn total(&self) -> f64 {
1007        self.total_tokens as f64 + self.cached_content_token_count as f64
1008    }
1009}
1010
1011#[derive(Debug)]
1012pub enum Info {
1013    Tuned(TunedModel),
1014    Model(Model),
1015}