inference_runtime_gemini/
wire.rs1use serde::{Deserialize, Serialize};
2
3use inference_core::batch::{ContentPart, ExecuteBatch, MessageContent, Role};
4
5#[derive(Debug, Serialize)]
6pub struct GenerateContentRequest<'a> {
7 pub contents: Vec<Content>,
8 #[serde(skip_serializing_if = "Option::is_none")]
9 pub system_instruction: Option<Content>,
10 #[serde(skip_serializing_if = "Option::is_none")]
11 pub generation_config: Option<GenerationConfig>,
12 #[serde(skip_serializing_if = "Vec::is_empty", rename = "safetySettings")]
13 pub safety_settings: Vec<crate::config::SafetySetting>,
14 #[serde(skip)]
15 _model_lifetime: std::marker::PhantomData<&'a ()>,
16}
17
18#[derive(Debug, Serialize)]
19pub struct Content {
20 pub role: String,
21 pub parts: Vec<Part>,
22}
23
24#[derive(Debug, Serialize)]
25#[serde(untagged)]
26pub enum Part {
27 Text {
28 text: String,
29 },
30 InlineData {
31 #[serde(rename = "inlineData")]
32 inline_data: InlineData,
33 },
34 FileData {
35 #[serde(rename = "fileData")]
36 file_data: FileData,
37 },
38}
39
40#[derive(Debug, Serialize)]
41pub struct InlineData {
42 #[serde(rename = "mimeType")]
43 pub mime_type: String,
44 pub data: String,
45}
46
47#[derive(Debug, Serialize)]
48pub struct FileData {
49 #[serde(rename = "mimeType")]
50 pub mime_type: String,
51 #[serde(rename = "fileUri")]
52 pub file_uri: String,
53}
54
55#[derive(Debug, Serialize, Default)]
56pub struct GenerationConfig {
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub temperature: Option<f32>,
59 #[serde(skip_serializing_if = "Option::is_none", rename = "topP")]
60 pub top_p: Option<f32>,
61 #[serde(skip_serializing_if = "Option::is_none", rename = "topK")]
62 pub top_k: Option<u32>,
63 #[serde(skip_serializing_if = "Option::is_none", rename = "maxOutputTokens")]
64 pub max_output_tokens: Option<u32>,
65 #[serde(skip_serializing_if = "Vec::is_empty", rename = "stopSequences")]
66 pub stop_sequences: Vec<String>,
67}
68
69impl GenerateContentRequest<'_> {
70 pub fn from_batch<'b>(
71 b: &'b ExecuteBatch,
72 safety: Vec<crate::config::SafetySetting>,
73 ) -> GenerateContentRequest<'b> {
74 let mut system: Option<String> = None;
75 let mut contents = Vec::with_capacity(b.messages.len());
76 for m in &b.messages {
77 if matches!(m.role, Role::System) {
78 if let MessageContent::Text(t) = &m.content {
79 system = Some(system.map(|s| format!("{s}\n{t}")).unwrap_or_else(|| t.clone()));
80 }
81 continue;
82 }
83 let role = match m.role {
84 Role::User | Role::Tool => "user",
85 Role::Assistant => "model",
86 Role::System => unreachable!(),
87 }
88 .to_string();
89 let parts = match &m.content {
90 MessageContent::Text(t) => vec![Part::Text { text: t.clone() }],
91 MessageContent::Parts(parts) => parts.iter().map(serialize_part).collect(),
92 };
93 contents.push(Content { role, parts });
94 }
95 let system_instruction = system.map(|t| Content {
96 role: "system".into(),
97 parts: vec![Part::Text { text: t }],
98 });
99 GenerateContentRequest {
100 contents,
101 system_instruction,
102 generation_config: Some(GenerationConfig {
103 temperature: b.sampling.temperature,
104 top_p: b.sampling.top_p,
105 top_k: b.sampling.top_k,
106 max_output_tokens: b.sampling.max_tokens,
107 stop_sequences: b.sampling.stop.clone(),
108 }),
109 safety_settings: safety,
110 _model_lifetime: std::marker::PhantomData,
111 }
112 }
113}
114
115fn serialize_part(p: &ContentPart) -> Part {
116 match p {
117 ContentPart::Text { text } => Part::Text { text: text.clone() },
118 ContentPart::ImageBase64 { mime, data } => Part::InlineData {
119 inline_data: InlineData {
120 mime_type: mime.clone(),
121 data: data.clone(),
122 },
123 },
124 ContentPart::ImageUrl { url } => Part::FileData {
125 file_data: FileData {
126 mime_type: "image/jpeg".into(),
127 file_uri: url.clone(),
128 },
129 },
130 }
131}
132
133#[derive(Debug, Deserialize)]
136pub struct GenerateContentResponse {
137 #[serde(default)]
138 pub candidates: Vec<Candidate>,
139 #[serde(default, rename = "usageMetadata")]
140 pub usage_metadata: Option<UsageMetadata>,
141}
142
143#[derive(Debug, Deserialize)]
144pub struct Candidate {
145 #[serde(default)]
146 pub content: Option<ResponseContent>,
147 #[serde(default, rename = "finishReason")]
148 pub finish_reason: Option<String>,
149}
150
151#[derive(Debug, Deserialize)]
152pub struct ResponseContent {
153 #[serde(default)]
154 pub parts: Vec<ResponsePart>,
155}
156
157#[derive(Debug, Deserialize)]
158pub struct ResponsePart {
159 #[serde(default)]
160 pub text: Option<String>,
161}
162
163#[derive(Debug, Deserialize, Default, Clone, Copy)]
164pub struct UsageMetadata {
165 #[serde(default, rename = "promptTokenCount")]
166 pub prompt_token_count: u32,
167 #[serde(default, rename = "candidatesTokenCount")]
168 pub candidates_token_count: u32,
169 #[serde(default, rename = "cachedContentTokenCount")]
170 pub cached_content_token_count: u32,
171}