google_ai_rs/genai.rs
1use std::{
2 fmt::Debug,
3 io::Write,
4 ops::{Deref, DerefMut},
5};
6use tokio::io::AsyncWrite;
7use tonic::{IntoRequest, Streaming};
8
9use crate::{
10 client::{CClient, Client, SharedClient},
11 content::{IntoContent, TryFromCandidates, TryIntoContents},
12 error::{status_into_error, ActionError, Error},
13 full_model_name,
14 schema::AsSchema,
15};
16
17pub use crate::proto::{
18 safety_setting::HarmBlockThreshold, CachedContent, Content, CountTokensRequest,
19 CountTokensResponse, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
20 HarmCategory, Model, SafetySetting, Schema, Tool, ToolConfig, TunedModel,
21};
22
23/// Type-safe wrapper for [`GenerativeModel`] guaranteeing response type `T`.
24///
25/// This type enforces schema contracts through Rust's type system while maintaining
26/// compatibility with Google's Generative AI API. Use when:
27/// - You need structured output from the model
28/// - Response schema stability is critical
29/// - You want compile-time validation of response handling
30///
31/// # Example
32/// ```
33/// use google_ai_rs::{Client, GenerativeModel, AsSchema};
34/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
35/// # let auth = "YOUR-API-KEY";
36/// # use std::collections::HashMap;
37///
38/// #[derive(AsSchema)]
39/// struct Recipe {
40/// name: String,
41/// ingredients: Vec<String>,
42/// }
43///
44/// let client = Client::new(auth).await?;
45/// let model = client.typed_model::<Recipe>("gemini-pro");
46/// # Ok(())
47/// # }
48pub struct TypedModel<'c, T> {
49 inner: GenerativeModel<'c>,
50 _marker: PhantomInvariant<T>,
51}
52
53impl<T> Debug for TypedModel<'_, T> {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 self.inner.fmt(f)
56 }
57}
58
59impl<T> Clone for TypedModel<'_, T> {
60 fn clone(&self) -> Self {
61 Self {
62 inner: self.inner.clone(),
63 _marker: PhantomInvariant(std::marker::PhantomData),
64 }
65 }
66}
67
68// std is unstable
69struct PhantomInvariant<T>(std::marker::PhantomData<fn(T) -> T>);
70
71impl<'c, T> TypedModel<'c, T>
72where
73 T: AsSchema,
74{
75 /// Creates a new typed model configured to return responses of type `T`.
76 ///
77 /// # Arguments
78 /// - `client`: Authenticated API client.
79 /// - `name`: Model name (e.g., "gemini-pro").
80 pub fn new(client: &'c Client, name: &str) -> Self {
81 let inner = GenerativeModel::new(client, name).as_response_schema::<T>();
82 Self {
83 inner,
84 _marker: PhantomInvariant(std::marker::PhantomData),
85 }
86 }
87
88 fn new_inner(client: impl Into<CClient<'c>>, name: &str) -> Self {
89 let inner = GenerativeModel::new_inner(client, name).as_response_schema::<T>();
90 Self {
91 inner,
92 _marker: PhantomInvariant(std::marker::PhantomData),
93 }
94 }
95
96 /// Generates content with full response metadata.
97 ///
98 /// This method clones the model configuration and returns a `TypedResponse`,
99 /// containing both the parsed `T` and the raw API response.
100 ///
101 /// # Example
102 /// ```rust,ignore
103 /// # use google_ai_rs::{AsSchema, Client, TypedModel, TypedResponse};
104 /// # #[derive(AsSchema, serde::Deserialize, Debug)] struct StockAnalysis;
105 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
106 /// # let client = Client::new("api-key").await?;
107 /// let model = TypedModel::<StockAnalysis>::new(&client, "gemini-pro");
108 /// let analysis: TypedResponse<StockAnalysis> = model.generate_typed_content((
109 /// "Analyze NVDA stock performance",
110 /// "Consider PE ratio and recent earnings"
111 /// )).await?;
112 /// println!("Analysis: {:?}", analysis.t);
113 /// println!("Token Usage: {:?}", analysis.raw.usage_metadata);
114 /// # Ok(()) }
115 /// ```
116 #[inline]
117 pub async fn generate_typed_content<I>(&self, contents: I) -> Result<TypedResponse<T>, Error>
118 where
119 I: TryIntoContents + Send,
120 T: TryFromCandidates + Send,
121 {
122 self.cloned()
123 .generate_typed_content_consuming(contents)
124 .await
125 }
126
127 /// Generates content with metadata, consuming the model instance.
128 ///
129 /// An efficient alternative to `generate_typed_content` that avoids cloning
130 /// the model configuration, useful for one-shot requests.
131 #[inline]
132 pub async fn generate_typed_content_consuming<I>(
133 self,
134 contents: I,
135 ) -> Result<TypedResponse<T>, Error>
136 where
137 I: TryIntoContents + Send,
138 T: TryFromCandidates + Send,
139 {
140 let response = self.inner.generate_content_consuming(contents).await?;
141 let t = T::try_from_candidates(&response.candidates)?;
142 Ok(TypedResponse { t, raw: response })
143 }
144
145 /// Generates content and parses it directly into type `T`.
146 ///
147 /// This is the primary method for most users wanting type-safe responses.
148 /// It handles all the details of requesting structured JSON and deserializing
149 /// it into your specified Rust type. It clones the model configuration to allow reuse.
150 ///
151 /// # Serde Integration
152 /// When the `serde` feature is enabled, any type implementing `serde::Deserialize`
153 /// automatically works with this method. Just define your response structure and
154 /// let the library handle parsing.
155 ///
156 /// # Example: Simple JSON Response
157 /// ```rust,ignore
158 /// # use google_ai_rs::{AsSchema, Client, TypedModel};
159 /// # use serde::Deserialize;
160 /// #[derive(AsSchema, Deserialize)]
161 /// struct StoryResponse {
162 /// title: String,
163 /// length: usize,
164 /// tags: Vec<String>,
165 /// }
166 ///
167 /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
168 /// # let client = Client::new("key".into()).await?;
169 /// let model = TypedModel::<StoryResponse>::new(&client, "gemini-pro");
170 /// let story = model.generate_content("Write a short story about a robot astronaut").await?;
171 ///
172 /// println!("{} ({} words)", story.title, story.length);
173 /// # Ok(())
174 /// # }
175 /// ```
176 ///
177 /// # Example: Multi-part Input
178 /// ```rust,ignore
179 /// # use google_ai_rs::{AsSchema, Client, TypedModel, Part};
180 /// # use serde::Deserialize;
181 /// #[derive(AsSchema, Deserialize)]
182 /// struct Analysis { safety_rating: u8 }
183 ///
184 /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
185 /// # let client = Client::new("key".into()).await?;
186 /// # let image_data = vec![];
187 /// let model = TypedModel::<Analysis>::new(&client, "gemini-pro-vision");
188 /// let analysis = model.generate_content((
189 /// "Analyze this scene safety:",
190 /// Part::blob("image/jpeg", image_data),
191 /// "Consider vehicles, pedestrians, and weather"
192 /// )).await?;
193 /// # Ok(())
194 /// # }
195 /// ```
196 ///
197 /// # Errors
198 /// Returns an error if API communication fails or if the response cannot be
199 /// parsed into type `T`.
200 #[inline]
201 pub async fn generate_content<I>(&self, contents: I) -> Result<T, Error>
202 where
203 I: TryIntoContents + Send,
204 T: TryFromCandidates + Send,
205 {
206 self.cloned().generate_content_consuming(contents).await
207 }
208
209 #[inline]
210 pub async fn generate_content_consuming<I>(self, contents: I) -> Result<T, Error>
211 where
212 I: TryIntoContents + Send,
213 T: TryFromCandidates + Send,
214 {
215 let response = self.inner.generate_content_consuming(contents).await?;
216 let t = T::try_from_candidates(&response.candidates)?;
217 Ok(t)
218 }
219
220 /// Consumes the `TypedModel`, returning the underlying `GenerativeModel`.
221 ///
222 /// The returned `GenerativeModel` will retain the response schema configuration
223 /// that was set for type `T`.
224 pub fn into_inner(self) -> GenerativeModel<'c> {
225 self.inner
226 }
227
228 /// Creates a `TypedModel` from a `GenerativeModel` without validation.
229 ///
230 /// This is an advanced-use method that assumes the provided `GenerativeModel`
231 /// has already been configured with a response schema that is compatible with `T`.
232 ///
233 /// # Safety
234 /// The caller must ensure that `inner.generation_config.response_schema` is `Some`
235 /// and that its schema corresponds exactly to the schema of type `T`. Failure to
236 /// uphold this invariant will likely result in API errors or deserialization failures.
237 pub unsafe fn from_inner_unchecked(inner: GenerativeModel<'c>) -> Self {
238 Self {
239 inner,
240 _marker: PhantomInvariant(std::marker::PhantomData),
241 }
242 }
243
244 fn cloned(&self) -> TypedModel<'_, T> {
245 TypedModel {
246 inner: self.inner.cloned(),
247 _marker: PhantomInvariant(std::marker::PhantomData),
248 }
249 }
250}
251
252impl<'c, T> Deref for TypedModel<'c, T> {
253 type Target = GenerativeModel<'c>;
254
255 fn deref(&self) -> &Self::Target {
256 &self.inner
257 }
258}
259
260impl<'c, T> From<GenerativeModel<'c>> for TypedModel<'c, T>
261where
262 T: AsSchema,
263{
264 fn from(value: GenerativeModel<'c>) -> Self {
265 let inner = value.as_response_schema::<T>();
266 TypedModel {
267 inner,
268 _marker: PhantomInvariant(std::marker::PhantomData),
269 }
270 }
271}
272
273/// Container for typed responses with raw API data.
274///
275/// Preserves full response details while providing parsed content.
276pub struct TypedResponse<T> {
277 /// Parsed content of type `T`
278 pub t: T,
279 /// Raw API response structure
280 pub raw: GenerateContentResponse,
281}
282
283impl<T> Debug for TypedResponse<T>
284where
285 T: Debug,
286{
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 self.t.fmt(f)
289 }
290}
291
292impl<T> Deref for TypedResponse<T> {
293 type Target = T;
294
295 fn deref(&self) -> &Self::Target {
296 &self.t
297 }
298}
299
300impl<T> DerefMut for TypedResponse<T> {
301 fn deref_mut(&mut self) -> &mut Self::Target {
302 &mut self.t
303 }
304}
305
306/// Configured interface for a specific generative AI model
307///
308/// # Example
309/// ```
310/// use google_ai_rs::{Client, GenerativeModel};
311///
312/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
313/// # let auth = "YOUR-API-KEY";
314/// let client = Client::new(auth).await?;
315/// let model = client.generative_model("gemini-pro")
316/// .with_system_instruction("You are a helpful assistant")
317/// .with_response_format("application/json");
318/// # Ok(())
319/// # }
320/// ```
321#[derive(Clone, Debug)]
322pub struct GenerativeModel<'c> {
323 /// Backing API clienty
324 pub(super) client: CClient<'c>,
325 /// Fully qualified model name (e.g., "models/gemini-1.0-pro")
326 model_name: Box<str>,
327 /// System prompt guiding model behavior
328 pub system_instruction: Option<Content>,
329 /// Available functions/tools the model can use
330 pub tools: Option<Vec<Tool>>,
331 /// Configuration for tool usage
332 pub tool_config: Option<ToolConfig>,
333 /// Content safety filters
334 pub safety_settings: Option<Vec<SafetySetting>>,
335 /// Generation parameters (temperature, top-k, etc.)
336 pub generation_config: Option<GenerationConfig>,
337 /// Fullname of the cached content to use as context
338 /// (e.g., "cachedContents/NAME")
339 pub cached_content: Option<Box<str>>,
340}
341
342impl<'c> GenerativeModel<'c> {
343 /// Creates a new model interface with default configuration
344 ///
345 /// # Arguments
346 /// * `client` - Authenticated API client
347 /// * `name` - Model identifier (e.g., "gemini-pro")
348 ///
349 /// To access a tuned model named NAME, pass "tunedModels/NAME".
350 pub fn new(client: &'c Client, name: &str) -> Self {
351 Self::new_inner(client, name)
352 }
353
354 fn new_inner(client: impl Into<CClient<'c>>, name: &str) -> Self {
355 Self {
356 client: client.into(),
357 model_name: full_model_name(name).into(),
358 system_instruction: None,
359 tools: None,
360 tool_config: None,
361 safety_settings: None,
362 generation_config: None,
363 cached_content: None,
364 }
365 }
366
367 /// Converts this `GenerativeModel` into a `TypedModel`.
368 ///
369 /// This prepares the model to return responses that are automatically
370 /// parsed into the specified type `T`.
371 pub fn to_typed<T: AsSchema>(self) -> TypedModel<'c, T> {
372 self.into()
373 }
374
375 /// Generates content from flexible input types.
376 ///
377 /// This method clones the model's configuration for the request, allowing the original
378 /// `GenerativeModel` instance to be reused.
379 ///
380 /// # Example
381 /// ```
382 /// # use google_ai_rs::{Client, GenerativeModel};
383 /// use google_ai_rs::Part;
384 ///
385 /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
386 /// # let auth = "YOUR-API-KEY";
387 /// # let client = Client::new(auth).await?;
388 /// # let model = client.generative_model("gemini-pro");
389 /// // Simple text generation
390 /// let response = model.generate_content("Hello world!").await?;
391 ///
392 /// // Multi-part content
393 /// # let image_data = vec![];
394 /// model.generate_content((
395 /// "What's in this image?",
396 /// Part::blob("image/jpeg", image_data)
397 /// )).await?;
398 /// # Ok(())
399 /// # }
400 /// ```
401 ///
402 /// # Errors
403 /// Returns [`Error::Service`] for model errors or [`Error::Net`] for transport failures.
404 pub async fn generate_content<T>(&self, contents: T) -> Result<GenerateContentResponse, Error>
405 where
406 T: TryIntoContents,
407 {
408 self.cloned().generate_content_consuming(contents).await
409 }
410
411 /// Generates content by consuming the model instance.
412 ///
413 /// This is an efficient alternative to `generate_content` if you don't need to reuse the
414 /// model instance, as it avoids cloning the model's configuration. This is useful
415 /// for one-shot requests where the model is built, used, and then discarded.
416 pub async fn generate_content_consuming<T>(
417 self,
418 contents: T,
419 ) -> Result<GenerateContentResponse, Error>
420 where
421 T: TryIntoContents,
422 {
423 let mut gc = self.client.gc.clone();
424 let request = self.build_request(contents)?;
425 gc.generate_content(request)
426 .await
427 .map_err(status_into_error)
428 .map(|r| r.into_inner())
429 }
430
431 /// A convenience method to generate a structured response of type `T`.
432 ///
433 /// This method internally converts the `GenerativeModel` to a `TypedModel<T>`,
434 /// makes the request, and returns the parsed result directly. It clones the model
435 /// configuration for the request.
436 ///
437 /// For repeated calls with the same target type, it may be more efficient to create a
438 /// `TypedModel` instance directly.
439 pub async fn typed_generate_content<I, T>(&self, contents: I) -> Result<T, Error>
440 where
441 I: TryIntoContents + Send,
442 T: AsSchema + TryFromCandidates + Send,
443 {
444 // Cloning occurs just this once with owned_generate_content
445 self.cloned()
446 .to_typed()
447 .generate_content_consuming(contents)
448 .await
449 }
450
451 /// A convenience method to generate a structured response with metadata.
452 ///
453 /// Similar to `typed_generate_content`, but returns a `TypedResponse<T>` which includes
454 /// both the parsed data and the raw API response metadata.
455 pub async fn generate_typed_content<I, T>(&self, contents: I) -> Result<TypedResponse<T>, Error>
456 where
457 I: TryIntoContents + Send,
458 T: AsSchema + TryFromCandidates + Send,
459 {
460 self.cloned()
461 .to_typed()
462 .generate_typed_content_consuming(contents)
463 .await
464 }
465
466 /// Generates a streaming response from flexible input.
467 ///
468 /// This method clones the model's configuration for the request, allowing the original
469 /// `GenerativeModel` instance to be reused.
470 ///
471 /// # Example
472 /// ```
473 /// # use google_ai_rs::{Client, GenerativeModel};
474 /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
475 /// # let auth = "YOUR-API-KEY";
476 /// # let client = Client::new(auth).await?;
477 /// # let model = client.generative_model("gemini-pro");
478 /// let mut stream = model.stream_generate_content("Tell me a story.").await?;
479 /// while let Some(chunk) = stream.next().await? {
480 /// // Process streaming response
481 /// }
482 /// # Ok(())
483 /// # }
484 /// ```
485 ///
486 /// # Errors
487 /// Returns [`Error::Service`] for model errors or [`Error::Net`] for transport failures.
488 pub async fn stream_generate_content<T>(&self, contents: T) -> Result<ResponseStream, Error>
489 where
490 T: TryIntoContents,
491 {
492 self.cloned()
493 .stream_generate_content_consuming(contents)
494 .await
495 }
496
497 /// Generates a streaming response by consuming the model instance.
498 ///
499 /// This is an efficient alternative to `stream_generate_content` if you don't need to
500 /// reuse the model instance, as it avoids cloning the model's configuration.
501 pub async fn stream_generate_content_consuming<T>(
502 self,
503 contents: T,
504 ) -> Result<ResponseStream, Error>
505 where
506 T: TryIntoContents,
507 {
508 let mut gc = self.client.gc.clone();
509 let request = self.build_request(contents)?;
510 gc.stream_generate_content(request)
511 .await
512 .map_err(status_into_error)
513 .map(|s| ResponseStream(s.into_inner()))
514 }
515
516 /// Estimates token usage for given content
517 ///
518 /// Useful for cost estimation and validation before full generation
519 ///
520 /// # Arguments
521 /// * `parts` - Content input that can be converted to parts
522 ///
523 /// # Example
524 /// ```
525 /// # use google_ai_rs::{Client, GenerativeModel};
526 /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
527 /// # let auth = "YOUR-API-KEY";
528 /// # let client = Client::new(auth).await?;
529 /// # let model = client.generative_model("gemini-pro");
530 /// # let content = "";
531 /// let token_count = model.count_tokens(content).await?;
532 /// # const COST_PER_TOKEN: f64 = 1.0;
533 /// println!("Estimated cost: ${}", token_count.total() * COST_PER_TOKEN);
534 /// # Ok(())
535 /// # }
536 /// ```
537 pub async fn count_tokens<T>(&self, contents: T) -> Result<CountTokensResponse, Error>
538 where
539 T: TryIntoContents,
540 {
541 let mut gc = self.client.gc.clone();
542
543 // Builds token counting request
544 let request = CountTokensRequest {
545 model: self.model_name.to_string(),
546 contents: vec![],
547 generate_content_request: Some(self.clone().build_request(contents)?),
548 };
549
550 gc.count_tokens(request)
551 .await
552 .map_err(status_into_error)
553 .map(|r| r.into_inner())
554 }
555
556 /// info returns information about the model.
557 ///
558 /// `Info::Tuned` if the current model is a fine-tuned one,
559 /// otherwise `Info::Model`.
560 pub async fn info(&self) -> Result<Info, Error> {
561 if self.model_name.starts_with("tunedModels") {
562 Ok(Info::Tuned(
563 self.client.get_tuned_model(&self.model_name).await?,
564 ))
565 } else {
566 Ok(Info::Model(self.client.get_model(&self.model_name).await?))
567 }
568 }
569
570 /// Changes the model identifier for this instance in place.
571 pub fn change_model(&mut self, to: &str) {
572 self.model_name = full_model_name(to).into()
573 }
574
575 /// Returns the full identifier of the model, including any `models/` prefix.
576 pub fn full_name(&self) -> &str {
577 &self.model_name
578 }
579
580 // Builder pattern methods
581 // -----------------------------------------------------------------
582
583 /// Sets system-level behavior instructions
584 pub fn with_system_instruction<I: IntoContent>(mut self, instruction: I) -> Self {
585 self.system_instruction = Some(instruction.into_content());
586 self
587 }
588
589 /// Changes the model identifier, returning the modified instance.
590 pub fn to_model(mut self, to: &str) -> Self {
591 self.change_model(to);
592 self
593 }
594
595 /// Sets cached content for persisted context
596 ///
597 /// # Example
598 /// ```
599 /// # use google_ai_rs::{Client, GenerativeModel};
600 /// use google_ai_rs::content::IntoContents as _;
601 ///
602 /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
603 /// # let auth = "YOUR-API-KEY";
604 /// # let client = Client::new(auth).await?;
605 /// let content = "You are a helpful assistant".into_cached_content_for("gemini-1.0-pro");
606 ///
607 /// let cached_content = client.create_cached_content(content).await?;
608 /// let model = client.generative_model("gemini-pro")
609 /// .with_cached_content(&cached_content);
610 /// # Ok(())
611 /// # }
612 /// ```
613 pub fn with_cached_content(mut self, c: &CachedContent) -> Result<Self, Error> {
614 self.cached_content = Some(
615 c.name
616 .as_deref()
617 .ok_or(Error::InvalidArgument(
618 "cached content name is empty".into(),
619 ))?
620 .into(),
621 );
622 Ok(self)
623 }
624
625 /// Specifies expected response format (e.g., "application/json")
626 pub fn with_response_format(mut self, mime_type: &str) -> Self {
627 self.generation_config
628 .get_or_insert_with(Default::default)
629 .response_mime_type = mime_type.into();
630 self
631 }
632
633 /// Configures the model to respond with a schema matching the type `T`.
634 ///
635 /// This is a convenient way to get structured JSON output.
636 ///
637 /// # Example
638 /// ```rust
639 /// use google_ai_rs::AsSchema;
640 ///
641 /// #[derive(Debug, AsSchema)]
642 /// #[schema(description = "A primary colour")]
643 /// struct PrimaryColor {
644 /// #[schema(description = "The name of the colour")]
645 /// name: String,
646 ///
647 /// #[schema(description = "The RGB value of the color, in hex")]
648 /// #[schema(rename = "RGB")]
649 /// rgb: String
650 /// }
651 ///
652 /// # use google_ai_rs::{Client, GenerativeModel};
653 /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
654 /// # let auth = "YOUR-API-KEY";
655 /// # let client = Client::new(auth).await?;
656 /// let model = client.generative_model("gemini-pro")
657 /// .as_response_schema::<Vec<PrimaryColor>>();
658 /// # Ok(())
659 /// # }
660 /// ```
661 pub fn as_response_schema<T: AsSchema>(self) -> Self {
662 self.with_response_schema(T::as_schema())
663 }
664
665 /// Set response schema with explicit Schema object
666 ///
667 /// Use when you need full control over schema details. Automatically
668 /// sets response format to JSON if not specified.
669 ///
670 /// # Example
671 ///
672 /// ```rust
673 /// use google_ai_rs::Schema;
674 /// use google_ai_rs::SchemaType;
675 ///
676 /// # use google_ai_rs::{Client, GenerativeModel};
677 /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
678 /// # let auth = "YOUR-API-KEY";
679 /// # let client = Client::new(auth).await?;
680 /// let model = client.generative_model("gemini-pro")
681 /// .with_response_schema(Schema {
682 /// r#type: SchemaType::String as i32,
683 /// format: "enum".into(),
684 /// ..Default::default()
685 /// });
686 /// # Ok(())
687 /// # }
688 /// ```
689 pub fn with_response_schema(mut self, schema: Schema) -> Self {
690 let c = self.generation_config.get_or_insert_with(Default::default);
691 if c.response_mime_type.is_empty() {
692 c.response_mime_type = "application/json".into();
693 }
694 c.response_schema = Some(schema);
695 self
696 }
697
698 /// Adds a collection of tools to the model.
699 ///
700 /// Tools define external functions that the model can call.
701 ///
702 /// # Arguments
703 /// * `tools` - An iterator of `Tool` instances.
704 pub fn tools<I>(mut self, tools: I) -> Self
705 where
706 I: IntoIterator<Item = Tool>,
707 {
708 self.tools = Some(tools.into_iter().collect());
709 self
710 }
711
712 /// Configures how the model uses tools.
713 ///
714 /// # Arguments
715 /// * `tool_config` - The configuration for tool usage.
716 pub fn tool_config(mut self, tool_config: impl Into<ToolConfig>) -> Self {
717 self.tool_config = Some(tool_config.into());
718 self
719 }
720
721 /// Applies content safety filters to the model.
722 ///
723 /// Safety settings control the probability thresholds for filtering
724 /// potentially harmful content.
725 ///
726 /// # Arguments
727 /// * `safety_settings` - An iterator of `SafetySetting` instances.
728 pub fn safety_settings<I>(mut self, safety_settings: I) -> Self
729 where
730 I: IntoIterator<Item = SafetySetting>,
731 {
732 self.safety_settings = Some(safety_settings.into_iter().collect());
733 self
734 }
735
736 /// Sets the generation parameters for the model.
737 ///
738 /// This includes settings like `temperature`, `top_k`, and `top_p`
739 /// to control the creativity and randomness of the model's output.
740 ///
741 /// # Arguments
742 /// * `generation_config` - The configuration for generation.
743 pub fn generation_config(mut self, generation_config: impl Into<GenerationConfig>) -> Self {
744 self.generation_config = Some(generation_config.into());
745 self
746 }
747
748 /// Creates a copy with new system instructions
749 pub fn with_cloned_instruction<I: IntoContent>(&self, instruction: I) -> Self {
750 let mut clone = self.clone();
751
752 clone.system_instruction = Some(instruction.into_content());
753 clone
754 }
755
756 /// Sets the number of candidates to generate.
757 ///
758 /// This parameter specifies how many different response candidates the model should generate
759 /// for a given prompt. The model will then select the best one based on its internal
760 /// evaluation.
761 pub fn candidate_count(mut self, x: i32) -> Self {
762 self.set_candidate_count(x);
763 self
764 }
765
766 /// Sets the maximum number of output tokens.
767 ///
768 /// This parameter caps the length of the generated response, measured in tokens.
769 /// It's useful for controlling response size and preventing excessively long outputs.
770 pub fn max_output_tokens(mut self, x: i32) -> Self {
771 self.set_max_output_tokens(x);
772 self
773 }
774
775 /// Sets the temperature for generation.
776 ///
777 /// Temperature controls the randomness of the output. Higher values, like 1.0,
778 /// make the output more creative and unpredictable, while lower values, like 0.1,
779 /// make it more deterministic and focused.
780 pub fn temperature(mut self, x: f32) -> Self {
781 self.set_temperature(x);
782 self
783 }
784
785 /// Sets the top-p sampling parameter.
786 ///
787 /// Top-p (also known as nucleus sampling) chooses the smallest set of most likely
788 /// tokens whose cumulative probability exceeds the value of `x`. This technique
789 /// helps to prevent low-probability, nonsensical tokens from being chosen.
790 pub fn top_p(mut self, x: f32) -> Self {
791 self.set_top_p(x);
792 self
793 }
794
795 /// Sets the top-k sampling parameter.
796 ///
797 /// Top-k restricts the model's token selection to the `k` most likely tokens at
798 /// each step. It's a method for controlling the model's creativity and focus.
799 pub fn top_k(mut self, x: i32) -> Self {
800 self.set_top_k(x);
801 self
802 }
803
804 /// Sets the number of candidates to generate.
805 ///
806 /// This parameter specifies how many different response candidates the model should generate
807 /// for a given prompt. The model will then select the best one based on its internal
808 /// evaluation.
809 pub fn set_candidate_count(&mut self, x: i32) {
810 self.generation_config
811 .get_or_insert_default()
812 .candidate_count = Some(x)
813 }
814
815 /// Sets the maximum number of output tokens.
816 ///
817 /// This parameter caps the length of the generated response, measured in tokens.
818 /// It's useful for controlling response size and preventing excessively long outputs.
819 pub fn set_max_output_tokens(&mut self, x: i32) {
820 self.generation_config
821 .get_or_insert_default()
822 .max_output_tokens = Some(x)
823 }
824
825 /// Sets the temperature for generation.
826 ///
827 /// Temperature controls the randomness of the output. Higher values, like 1.0,
828 /// make the output more creative and unpredictable, while lower values, like 0.1,
829 /// make it more deterministic and focused.
830 pub fn set_temperature(&mut self, x: f32) {
831 self.generation_config.get_or_insert_default().temperature = Some(x)
832 }
833
834 /// Sets the top-p sampling parameter.
835 ///
836 /// Top-p (also known as nucleus sampling) chooses the smallest set of most likely
837 /// tokens whose cumulative probability exceeds the value of `x`. This technique
838 /// helps to prevent low-probability, nonsensical tokens from being chosen.
839 pub fn set_top_p(&mut self, x: f32) {
840 self.generation_config.get_or_insert_default().top_p = Some(x)
841 }
842
843 /// Sets the top-k sampling parameter.
844 ///
845 /// Top-k restricts the model's token selection to the `k` most likely tokens at
846 /// each step. It's a method for controlling the model's creativity and focus.
847 pub fn set_top_k(&mut self, x: i32) {
848 self.generation_config.get_or_insert_default().top_k = Some(x)
849 }
850
851 #[inline(always)]
852 fn build_request(
853 self,
854 contents: impl TryIntoContents,
855 ) -> Result<GenerateContentRequest, Error> {
856 let contents = contents.try_into_contents()?;
857 Ok(GenerateContentRequest {
858 model: self.model_name.into(),
859 contents,
860 system_instruction: self.system_instruction,
861 tools: self.tools.unwrap_or_default(),
862 tool_config: self.tool_config,
863 safety_settings: self.safety_settings.unwrap_or_default(),
864 generation_config: self.generation_config,
865 cached_content: self.cached_content.map(|c| c.into()),
866 })
867 }
868
869 // This is to avoid the performance overhead while cloning
870 // SharedClient - Arc backed. Insignificant but unnecessary.
871 fn cloned(&self) -> GenerativeModel<'_> {
872 GenerativeModel {
873 client: self.client.cloned(),
874 ..Clone::clone(self)
875 }
876 }
877}
878
879impl SafetySetting {
880 /// Creates a new [`SafetySetting`] with default values
881 pub fn new() -> Self {
882 Self {
883 category: 0,
884 threshold: 0,
885 }
886 }
887
888 /// Set the category for this setting
889 pub fn harm_category(mut self, category: HarmCategory) -> Self {
890 self.category = category.into();
891 self
892 }
893
894 /// Control the probability threshold at which harm is blocked
895 pub fn harm_threshold(mut self, threshold: HarmBlockThreshold) -> Self {
896 self.threshold = threshold.into();
897 self
898 }
899}
900
901/// Generation response containing model output and metadata
902pub type Response = GenerateContentResponse;
903
904impl Response {
905 /// Total tokens used in request/response cycle
906 pub fn total_tokens(&self) -> f64 {
907 // FIXME: I'm confused
908 self.usage_metadata.as_ref().map_or(0.0, |meta| {
909 meta.total_token_count as f64 + meta.cached_content_token_count as f64
910 })
911 }
912}
913
914/// Streaming response handler implementing async iteration
915pub struct ResponseStream(Streaming<GenerateContentResponse>);
916
917impl ResponseStream {
918 /// Streams content chunks to any `Write` implementer
919 ///
920 /// # Arguments
921 /// * `writer` - Target for streaming output
922 ///
923 /// # Returns
924 /// Total bytes written
925 pub async fn write_to<W: Write>(&mut self, writer: &mut W) -> Result<usize, Error> {
926 let mut total = 0;
927
928 while let Some(response) = self
929 .next()
930 .await
931 .map_err(|e| Error::Stream(ActionError::Error(e.into())))?
932 {
933 let bytes = response.try_into_bytes()?;
934 let written = writer
935 .write(&bytes)
936 .map_err(|e| Error::Stream(ActionError::Action(e)))?;
937 total += written;
938 }
939
940 Ok(total)
941 }
942
943 /// Streams content chunks to any `AsyncWrite` implementer
944 ///
945 /// # Returns
946 /// Total bytes written
947 pub async fn write_to_sync<W: AsyncWrite + std::marker::Unpin>(
948 &mut self,
949 dst: &mut W,
950 ) -> Result<usize, Error> {
951 use tokio::io::AsyncWriteExt;
952
953 let mut total = 0;
954
955 while let Some(response) = self
956 .next()
957 .await
958 .map_err(|e| Error::Stream(ActionError::Error(e.into())))?
959 {
960 let bytes = response.try_into_bytes()?;
961 let written = dst
962 .write(&bytes)
963 .await
964 .map_err(|e| Error::Stream(ActionError::Action(e)))?;
965 total += written;
966 }
967
968 Ok(total)
969 }
970
971 /// Fetches next response chunk
972 pub async fn next(&mut self) -> Result<Option<GenerateContentResponse>, Error> {
973 self.0.message().await.map_err(status_into_error)
974 }
975}
976
977impl Client {
978 /// Creates a new generative model interface
979 ///
980 /// Shorthand for `GenerativeModel::new()`
981 pub fn generative_model<'c>(&'c self, name: &str) -> GenerativeModel<'c> {
982 GenerativeModel::new_inner(self, name)
983 }
984
985 /// Creates a new typed generative model interface
986 ///
987 /// Shorthand for `TypedModel::new()`
988 pub fn typed_model<'c, T: AsSchema>(&'c self, name: &str) -> TypedModel<'c, T> {
989 TypedModel::<T>::new_inner(self, name)
990 }
991}
992
993impl SharedClient {
994 /// Creates a new generative model interface
995 pub fn generative_model(&self, name: &str) -> GenerativeModel<'static> {
996 GenerativeModel::new_inner(self.clone(), name)
997 }
998
999 /// Creates a new typed generative model interface
1000 pub fn typed_model<T: AsSchema>(&self, name: &str) -> TypedModel<'static, T> {
1001 TypedModel::<T>::new_inner(self.clone(), name)
1002 }
1003}
1004
1005impl CountTokensResponse {
1006 pub fn total(&self) -> f64 {
1007 self.total_tokens as f64 + self.cached_content_token_count as f64
1008 }
1009}
1010
1011#[derive(Debug)]
1012pub enum Info {
1013 Tuned(TunedModel),
1014 Model(Model),
1015}