use reqwest::Url;
use serde::{Deserialize, Serialize, de};
use time::OffsetDateTime;
use crate::{
Content, Modality, Part,
safety::{SafetyRating, SafetySetting},
};
#[derive(Debug, Clone, Serialize, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FinishReason {
FinishReasonUnspecified,
Stop,
MaxTokens,
Safety,
Recitation,
Language,
Other,
Blocklist,
ProhibitedContent,
Spii,
MalformedFunctionCall,
ModelArmor,
ImageSafety,
UnexpectedToolCall,
TooManyToolCalls,
}
impl FinishReason {
fn from_wire_str(value: &str) -> Self {
match value {
"FINISH_REASON_UNSPECIFIED" => Self::FinishReasonUnspecified,
"STOP" => Self::Stop,
"MAX_TOKENS" => Self::MaxTokens,
"SAFETY" => Self::Safety,
"RECITATION" => Self::Recitation,
"LANGUAGE" => Self::Language,
"OTHER" => Self::Other,
"BLOCKLIST" => Self::Blocklist,
"PROHIBITED_CONTENT" => Self::ProhibitedContent,
"SPII" => Self::Spii,
"MALFORMED_FUNCTION_CALL" => Self::MalformedFunctionCall,
"MODEL_ARMOR" => Self::ModelArmor,
"IMAGE_SAFETY" => Self::ImageSafety,
"UNEXPECTED_TOOL_CALL" => Self::UnexpectedToolCall,
"TOO_MANY_TOOL_CALLS" => Self::TooManyToolCalls,
_ => Self::Other,
}
}
fn from_wire_number(value: i64) -> Self {
match value {
0 => Self::FinishReasonUnspecified,
1 => Self::Stop,
2 => Self::MaxTokens,
3 => Self::Safety,
4 => Self::Recitation,
5 => Self::Other,
6 => Self::Blocklist,
7 => Self::ProhibitedContent,
8 => Self::Spii,
9 => Self::MalformedFunctionCall,
10 => Self::ModelArmor,
_ => Self::Other,
}
}
}
impl<'de> Deserialize<'de> for FinishReason {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
match value {
serde_json::Value::String(s) => Ok(Self::from_wire_str(&s)),
serde_json::Value::Number(n) => {
n.as_i64().map(Self::from_wire_number).ok_or_else(|| {
de::Error::custom("finishReason must be an integer-compatible number")
})
}
_ => Err(de::Error::custom("finishReason must be a string or integer")),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct CitationMetadata {
#[serde(default)]
pub citation_sources: Vec<CitationSource>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct CitationSource {
pub uri: Option<String>,
pub title: Option<String>,
pub start_index: Option<i32>,
pub end_index: Option<i32>,
pub license: Option<String>,
#[serde(default, with = "time::serde::rfc3339::option")]
pub publication_date: Option<OffsetDateTime>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct Candidate {
#[serde(default)]
pub content: Content,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_ratings: Option<Vec<SafetyRating>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub citation_metadata: Option<CitationMetadata>,
#[serde(skip_serializing_if = "Option::is_none")]
pub grounding_metadata: Option<GroundingMetadata>,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>,
#[serde(skip_serializing_if = "Option::is_none")]
pub index: Option<i32>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_token_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub candidates_token_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_token_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thoughts_token_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<Vec<PromptTokenDetails>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_content_token_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_tokens_details: Option<Vec<PromptTokenDetails>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct PromptTokenDetails {
pub modality: Modality,
pub token_count: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct GroundingMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub grounding_chunks: Option<Vec<GroundingChunk>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub grounding_supports: Option<Vec<GroundingSupport>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub web_search_queries: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub google_maps_widget_context_token: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct GroundingChunk {
#[serde(skip_serializing_if = "Option::is_none")]
pub maps: Option<MapsGroundingChunk>,
#[serde(skip_serializing_if = "Option::is_none")]
pub web: Option<WebGroundingChunk>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct MapsGroundingChunk {
#[serde(default)]
pub uri: Option<Url>,
#[serde(default)]
pub title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub place_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct WebGroundingChunk {
#[serde(default)]
pub uri: Option<Url>,
#[serde(default)]
pub title: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct GroundingSupport {
pub segment: GroundingSegment,
pub grounding_chunk_indices: Vec<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct GroundingSegment {
#[serde(default)]
pub start_index: Option<u32>,
#[serde(default)]
pub end_index: Option<u32>,
#[serde(default)]
pub text: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct GenerationResponse {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub candidates: Vec<Candidate>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_feedback: Option<PromptFeedback>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage_metadata: Option<UsageMetadata>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_version: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum BlockReason {
BlockReasonUnspecified,
Safety,
Other,
Blocklist,
ProhibitedContent,
ModelArmor,
Jailbreak,
ImageSafety,
}
impl BlockReason {
fn from_wire_str(value: &str) -> Self {
match value {
"BLOCK_REASON_UNSPECIFIED" | "BLOCKED_REASON_UNSPECIFIED" => {
Self::BlockReasonUnspecified
}
"SAFETY" => Self::Safety,
"OTHER" => Self::Other,
"BLOCKLIST" => Self::Blocklist,
"PROHIBITED_CONTENT" => Self::ProhibitedContent,
"MODEL_ARMOR" => Self::ModelArmor,
"JAILBREAK" => Self::Jailbreak,
"IMAGE_SAFETY" => Self::ImageSafety,
_ => Self::Other,
}
}
fn from_wire_number(value: i64) -> Self {
match value {
0 => Self::BlockReasonUnspecified,
1 => Self::Safety,
2 => Self::Other,
3 => Self::Blocklist,
4 => Self::ProhibitedContent,
5 => Self::ModelArmor,
6 => Self::Jailbreak,
7 => Self::ImageSafety,
_ => Self::Other,
}
}
}
impl<'de> Deserialize<'de> for BlockReason {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
match value {
serde_json::Value::String(s) => Ok(Self::from_wire_str(&s)),
serde_json::Value::Number(n) => {
n.as_i64().map(Self::from_wire_number).ok_or_else(|| {
de::Error::custom("blockReason must be an integer-compatible number")
})
}
_ => Err(de::Error::custom("blockReason must be a string or integer")),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct PromptFeedback {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub safety_ratings: Vec<SafetyRating>,
#[serde(skip_serializing_if = "Option::is_none")]
pub block_reason: Option<BlockReason>,
}
impl GenerationResponse {
pub fn text(&self) -> String {
self.candidates
.first()
.and_then(|c| {
c.content.parts.as_ref().and_then(|parts| {
parts.first().and_then(|p| match p {
Part::Text { text, thought: _, thought_signature: _ } => Some(text.clone()),
_ => None,
})
})
})
.unwrap_or_default()
}
pub fn function_calls(&self) -> Vec<&crate::tools::FunctionCall> {
self.candidates
.iter()
.flat_map(|c| {
c.content
.parts
.as_ref()
.map(|parts| {
parts
.iter()
.filter_map(|p| match p {
Part::FunctionCall { function_call, thought_signature: _ } => {
Some(function_call)
}
_ => None,
})
.collect::<Vec<_>>()
})
.unwrap_or_default()
})
.collect()
}
pub fn function_calls_with_thoughts(
&self,
) -> Vec<(&crate::tools::FunctionCall, Option<&String>)> {
self.candidates
.iter()
.flat_map(|c| {
c.content
.parts
.as_ref()
.map(|parts| {
parts
.iter()
.filter_map(|p| match p {
Part::FunctionCall { function_call, thought_signature } => {
Some((function_call, thought_signature.as_ref()))
}
_ => None,
})
.collect::<Vec<_>>()
})
.unwrap_or_default()
})
.collect()
}
pub fn thoughts(&self) -> Vec<String> {
self.candidates
.iter()
.flat_map(|c| {
c.content
.parts
.as_ref()
.map(|parts| {
parts
.iter()
.filter_map(|p| match p {
Part::Text { text, thought: Some(true), thought_signature: _ } => {
Some(text.clone())
}
_ => None,
})
.collect::<Vec<_>>()
})
.unwrap_or_default()
})
.collect()
}
pub fn all_text(&self) -> Vec<(String, bool)> {
self.candidates
.iter()
.flat_map(|c| {
c.content
.parts
.as_ref()
.map(|parts| {
parts
.iter()
.filter_map(|p| match p {
Part::Text { text, thought, thought_signature: _ } => {
Some((text.clone(), thought.unwrap_or(false)))
}
_ => None,
})
.collect::<Vec<_>>()
})
.unwrap_or_default()
})
.collect()
}
pub fn text_with_thoughts(&self) -> Vec<(String, bool, Option<&String>)> {
self.candidates
.iter()
.flat_map(|c| {
c.content
.parts
.as_ref()
.map(|parts| {
parts
.iter()
.filter_map(|p| match p {
Part::Text { text, thought, thought_signature } => Some((
text.clone(),
thought.unwrap_or(false),
thought_signature.as_ref(),
)),
_ => None,
})
.collect::<Vec<_>>()
})
.unwrap_or_default()
})
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
pub contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub generation_config: Option<GenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_settings: Option<Vec<SafetySetting>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<crate::tools::Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_config: Option<crate::tools::ToolConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_instruction: Option<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_content: Option<String>,
}
impl GenerateContentRequest {
pub fn strip_vertex_unsupported_fields(&mut self) {
if let Some(tc) = &mut self.tool_config {
tc.include_server_side_tool_invocations = None;
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ThinkingLevel {
Minimal,
Low,
Medium,
High,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThinkingConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_budget: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_thoughts: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_level: Option<ThinkingLevel>,
}
impl ThinkingConfig {
pub fn validate(&self) -> Result<(), String> {
if self.thinking_budget.is_some() && self.thinking_level.is_some() {
return Err(
"thinking_budget and thinking_level are mutually exclusive; use one or the other"
.to_string(),
);
}
Ok(())
}
pub fn new() -> Self {
Self { thinking_budget: None, include_thoughts: None, thinking_level: None }
}
pub fn with_thinking_budget(mut self, budget: i32) -> Self {
self.thinking_budget = Some(budget);
self
}
pub fn with_dynamic_thinking(mut self) -> Self {
self.thinking_budget = Some(-1);
self
}
pub fn with_thoughts_included(mut self, include: bool) -> Self {
self.include_thoughts = Some(include);
self
}
pub fn with_thinking_level(mut self, level: ThinkingLevel) -> Self {
self.thinking_level = Some(level);
self
}
pub fn dynamic_thinking() -> Self {
Self { thinking_budget: Some(-1), include_thoughts: Some(true), thinking_level: None }
}
}
impl Default for ThinkingConfig {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub candidate_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_mime_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_schema: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_modalities: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speech_config: Option<SpeechConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_config: Option<ThinkingConfig>,
}
impl GenerationConfig {
pub fn validate(&self) -> Result<(), String> {
if let Some(t) = self.temperature
&& !(0.0..=2.0).contains(&t)
{
return Err("temperature must be between 0.0 and 2.0".to_string());
}
if let Some(p) = self.top_p
&& !(0.0..=1.0).contains(&p)
{
return Err("top_p must be between 0.0 and 1.0".to_string());
}
if let Some(k) = self.top_k
&& k <= 0
{
return Err("top_k must be positive".to_string());
}
if let Some(m) = self.max_output_tokens
&& m <= 0
{
return Err("max_output_tokens must be positive".to_string());
}
if let Some(ref tc) = self.thinking_config {
tc.validate()?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct SpeechConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub voice_config: Option<VoiceConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub multi_speaker_voice_config: Option<MultiSpeakerVoiceConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct VoiceConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub prebuilt_voice_config: Option<PrebuiltVoiceConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct PrebuiltVoiceConfig {
pub voice_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct MultiSpeakerVoiceConfig {
pub speaker_voice_configs: Vec<SpeakerVoiceConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct SpeakerVoiceConfig {
pub speaker: String,
pub voice_config: VoiceConfig,
}
impl SpeechConfig {
pub fn single_voice(voice_name: impl Into<String>) -> Self {
Self {
voice_config: Some(VoiceConfig {
prebuilt_voice_config: Some(PrebuiltVoiceConfig { voice_name: voice_name.into() }),
}),
multi_speaker_voice_config: None,
}
}
pub fn multi_speaker(speakers: Vec<SpeakerVoiceConfig>) -> Self {
Self {
voice_config: None,
multi_speaker_voice_config: Some(MultiSpeakerVoiceConfig {
speaker_voice_configs: speakers,
}),
}
}
}
impl SpeakerVoiceConfig {
pub fn new(speaker: impl Into<String>, voice_name: impl Into<String>) -> Self {
Self {
speaker: speaker.into(),
voice_config: VoiceConfig {
prebuilt_voice_config: Some(PrebuiltVoiceConfig { voice_name: voice_name.into() }),
},
}
}
}