use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::types::UsageMetadata;
#[derive(Debug, Clone, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub system_instruction: Option<Content>,
pub contents: Vec<Content>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub tools: Vec<ToolDecl>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_config: Option<ToolConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub generation_config: Option<GenerationConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct Content {
pub role: ContentRole,
pub parts: Vec<Part>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ContentRole {
User,
Model,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum Part {
FunctionCall {
#[serde(rename = "functionCall")]
function_call: FunctionCall,
},
FunctionResponse {
#[serde(rename = "functionResponse")]
function_response: FunctionResponse,
},
InlineData {
#[serde(rename = "inlineData")]
inline_data: InlineData,
},
Thought {
thought: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
text: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
thought_signature: Option<String>,
},
Text { text: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCall {
pub name: String,
#[serde(default)]
pub args: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct FunctionResponse {
pub name: String,
pub response: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct InlineData {
pub mime_type: String,
pub data: String,
}
#[derive(Debug, Clone, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ToolDecl {
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub function_declarations: Vec<FunctionDeclaration>,
}
#[derive(Debug, Clone, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: Value,
}
#[derive(Debug, Clone, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfig {
pub function_calling_config: FunctionCallingConfig,
}
#[derive(Debug, Clone, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallingConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub mode: Option<FunctionCallingMode>,
}
#[derive(Debug, Clone, Copy, Serialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum FunctionCallingMode {
Auto,
Any,
None,
}
#[derive(Debug, Clone, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_config: Option<ThinkingConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_mime_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_schema: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ThinkingConfig {
pub thinking_budget: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_thoughts: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct GenerateChunk {
#[serde(default)]
pub candidates: Vec<Candidate>,
#[serde(default)]
pub usage_metadata: Option<WireUsage>,
#[serde(default)]
pub model_version: Option<String>,
#[serde(default)]
pub response_id: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct Candidate {
pub content: Option<Content>,
#[serde(default)]
pub finish_reason: Option<FinishReason>,
#[serde(default)]
pub index: Option<u32>,
}
#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FinishReason {
Stop,
MaxTokens,
Safety,
Recitation,
ToolUse,
Language,
Other,
Blocklist,
ProhibitedContent,
Spii,
MalformedFunctionCall,
FinishReasonUnspecified,
#[serde(other)]
Unknown,
}
#[derive(Debug, Clone, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct WireUsage {
#[serde(default)]
pub prompt_token_count: Option<i32>,
#[serde(default)]
pub cached_content_token_count: Option<i32>,
#[serde(default)]
pub candidates_token_count: Option<i32>,
#[serde(default)]
pub thoughts_token_count: Option<i32>,
#[serde(default)]
pub total_token_count: Option<i32>,
}
impl From<WireUsage> for UsageMetadata {
fn from(w: WireUsage) -> Self {
UsageMetadata {
prompt_token_count: w.prompt_token_count,
cached_content_token_count: w.cached_content_token_count,
candidates_token_count: w.candidates_token_count,
thoughts_token_count: w.thoughts_token_count,
total_token_count: w.total_token_count,
}
}
}
impl Content {
pub fn user_text(text: impl Into<String>) -> Self {
Self {
role: ContentRole::User,
parts: vec![Part::Text { text: text.into() }],
}
}
pub fn model_text(text: impl Into<String>) -> Self {
Self {
role: ContentRole::Model,
parts: vec![Part::Text { text: text.into() }],
}
}
pub fn system_text(text: impl Into<String>) -> Self {
Self {
role: ContentRole::User,
parts: vec![Part::Text { text: text.into() }],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_text_part() {
let p: Part = serde_json::from_str(r#"{"text":"hello"}"#).unwrap();
assert!(matches!(p, Part::Text { ref text } if text == "hello"));
}
#[test]
fn deserialize_thought_part() {
let p: Part =
serde_json::from_str(r#"{"thought":true,"text":"reasoning..."}"#).unwrap();
match p {
Part::Thought { thought, text, .. } => {
assert!(thought);
assert_eq!(text.as_deref(), Some("reasoning..."));
}
other => panic!("expected Thought, got {other:?}"),
}
}
#[test]
fn deserialize_function_call_part() {
let json = r#"{"functionCall":{"name":"view_file","args":{"path":"x.txt"}}}"#;
let p: Part = serde_json::from_str(json).unwrap();
match p {
Part::FunctionCall { function_call } => {
assert_eq!(function_call.name, "view_file");
assert_eq!(function_call.args["path"], "x.txt");
}
other => panic!("expected FunctionCall, got {other:?}"),
}
}
#[test]
fn round_trip_chunk() {
let json = r#"{
"candidates": [{
"content": {"role":"model","parts":[{"text":"hi"}]},
"finishReason": "STOP"
}],
"usageMetadata": {"promptTokenCount":3,"candidatesTokenCount":1,"totalTokenCount":4}
}"#;
let chunk: GenerateChunk = serde_json::from_str(json).unwrap();
assert_eq!(chunk.candidates.len(), 1);
assert_eq!(chunk.candidates[0].finish_reason, Some(FinishReason::Stop));
let usage: UsageMetadata = chunk.usage_metadata.unwrap().into();
assert_eq!(usage.total_token_count, Some(4));
}
}