1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5#[serde(rename_all = "lowercase")]
6pub enum Role {
7 User,
9 Model,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(untagged)]
16pub enum Part {
17 Text {
19 text: String,
21 },
22 InlineData {
23 #[serde(rename = "inlineData")]
25 inline_data: Blob,
26 },
27 FunctionCall {
29 #[serde(rename = "functionCall")]
31 function_call: super::tools::FunctionCall,
32 },
33 FunctionResponse {
35 #[serde(rename = "functionResponse")]
37 function_response: super::tools::FunctionResponse,
38 },
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct Blob {
45 pub mime_type: String,
46 pub data: String, }
48
49impl Blob {
50 pub fn new(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
52 Self {
53 mime_type: mime_type.into(),
54 data: data.into(),
55 }
56 }
57}
58
59#[derive(Debug, Default, Clone, Serialize, Deserialize)]
61pub struct Content {
62 pub parts: Vec<Part>,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub role: Option<Role>,
67}
68
69impl Content {
70 pub fn text(text: impl Into<String>) -> Self {
72 Self {
73 parts: vec![Part::Text { text: text.into() }],
74 role: None,
75 }
76 }
77
78 pub fn function_call(function_call: super::tools::FunctionCall) -> Self {
80 Self {
81 parts: vec![Part::FunctionCall { function_call }],
82 role: None,
83 }
84 }
85
86 pub fn function_response(function_response: super::tools::FunctionResponse) -> Self {
88 Self {
89 parts: vec![Part::FunctionResponse { function_response }],
90 role: None,
91 }
92 }
93
94 pub fn function_response_json(name: impl Into<String>, response: serde_json::Value) -> Self {
96 Self {
97 parts: vec![Part::FunctionResponse {
98 function_response: super::tools::FunctionResponse::new(name, response),
99 }],
100 role: None,
101 }
102 }
103
104 pub fn inline_data(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
106 Self {
107 parts: vec![Part::InlineData {
108 inline_data: Blob::new(mime_type, data),
109 }],
110 role: None,
111 }
112 }
113
114 pub fn with_role(mut self, role: Role) -> Self {
116 self.role = Some(role);
117 self
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct Message {
124 pub content: Content,
126 pub role: Role,
128}
129
130impl Message {
131 pub fn user(text: impl Into<String>) -> Self {
133 Self {
134 content: Content::text(text).with_role(Role::User),
135 role: Role::User,
136 }
137 }
138
139 pub fn model(text: impl Into<String>) -> Self {
141 Self {
142 content: Content::text(text).with_role(Role::Model),
143 role: Role::Model,
144 }
145 }
146
147 pub fn embed(text: impl Into<String>) -> Self {
148 Self {
149 content: Content::text(text),
150 role: Role::Model,
151 }
152 }
153
154 pub fn function(name: impl Into<String>, response: serde_json::Value) -> Self {
156 Self {
157 content: Content::function_response_json(name, response).with_role(Role::Model),
158 role: Role::Model,
159 }
160 }
161
162 pub fn function_str(
164 name: impl Into<String>,
165 response: impl Into<String>,
166 ) -> Result<Self, serde_json::Error> {
167 let response_str = response.into();
168 let json = serde_json::from_str(&response_str)?;
169 Ok(Self {
170 content: Content::function_response_json(name, json).with_role(Role::Model),
171 role: Role::Model,
172 })
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct SafetyRating {
179 pub category: String,
181 pub probability: String,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
187#[serde(rename_all = "camelCase")]
188pub struct CitationMetadata {
189 pub citation_sources: Vec<CitationSource>,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195#[serde(rename_all = "camelCase")]
196pub struct CitationSource {
197 pub uri: Option<String>,
199 pub title: Option<String>,
201 pub start_index: Option<i32>,
203 pub end_index: Option<i32>,
205 pub license: Option<String>,
207 pub publication_date: Option<String>,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213#[serde(rename_all = "camelCase")]
214pub struct Candidate {
215 pub content: Content,
217 #[serde(skip_serializing_if = "Option::is_none")]
219 pub safety_ratings: Option<Vec<SafetyRating>>,
220 #[serde(skip_serializing_if = "Option::is_none")]
222 pub citation_metadata: Option<CitationMetadata>,
223 #[serde(skip_serializing_if = "Option::is_none")]
225 pub finish_reason: Option<String>,
226 #[serde(skip_serializing_if = "Option::is_none")]
228 pub usage_metadata: Option<UsageMetadata>,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
233#[serde(rename_all = "camelCase")]
234pub struct UsageMetadata {
235 pub prompt_token_count: i32,
237 pub candidates_token_count: i32,
239 pub total_token_count: i32,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245#[serde(rename_all = "camelCase")]
246pub struct GenerationResponse {
247 pub candidates: Vec<Candidate>,
249 #[serde(skip_serializing_if = "Option::is_none")]
251 pub prompt_feedback: Option<PromptFeedback>,
252 #[serde(skip_serializing_if = "Option::is_none")]
254 pub usage_metadata: Option<UsageMetadata>,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct ContentEmbedding {
260 pub values: Vec<f32>, }
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct ContentEmbeddingResponse {
267 pub embedding: ContentEmbedding,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct BatchContentEmbeddingResponse {
274 pub embeddings: Vec<ContentEmbedding>,
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
280#[serde(rename_all = "camelCase")]
281pub struct PromptFeedback {
282 pub safety_ratings: Vec<SafetyRating>,
284 #[serde(skip_serializing_if = "Option::is_none")]
286 pub block_reason: Option<String>,
287}
288
289impl GenerationResponse {
290 pub fn text(&self) -> String {
292 self.candidates
293 .first()
294 .and_then(|c| {
295 c.content.parts.first().and_then(|p| match p {
296 Part::Text { text } => Some(text.clone()),
297 _ => None,
298 })
299 })
300 .unwrap_or_default()
301 }
302
303 pub fn function_calls(&self) -> Vec<&super::tools::FunctionCall> {
305 self.candidates
306 .iter()
307 .flat_map(|c| {
308 c.content.parts.iter().filter_map(|p| match p {
309 Part::FunctionCall { function_call } => Some(function_call),
310 _ => None,
311 })
312 })
313 .collect()
314 }
315}
316
317#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct GenerateContentRequest {
320 pub contents: Vec<Content>,
322 #[serde(skip_serializing_if = "Option::is_none")]
324 pub generation_config: Option<GenerationConfig>,
325 #[serde(skip_serializing_if = "Option::is_none")]
327 pub safety_settings: Option<Vec<SafetySetting>>,
328 #[serde(skip_serializing_if = "Option::is_none")]
330 pub tools: Option<Vec<super::tools::Tool>>,
331 #[serde(skip_serializing_if = "Option::is_none")]
333 pub tool_config: Option<ToolConfig>,
334 #[serde(skip_serializing_if = "Option::is_none")]
336 pub system_instruction: Option<Content>,
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct EmbedContentRequest {
342 pub model: String,
344 pub content: Content,
346 #[serde(skip_serializing_if = "Option::is_none")]
348 pub task_type: Option<TaskType>,
349 #[serde(skip_serializing_if = "Option::is_none")]
351 pub title: Option<String>,
352 #[serde(skip_serializing_if = "Option::is_none")]
354 pub output_dimensionality: Option<i32>,
355}
356
357#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct BatchEmbedContentsRequest {
360 pub requests: Vec<EmbedContentRequest>,
362}
363
364#[derive(Debug, Clone, Serialize, Deserialize)]
366pub struct GenerationConfig {
367 #[serde(skip_serializing_if = "Option::is_none")]
372 pub temperature: Option<f32>,
373
374 #[serde(skip_serializing_if = "Option::is_none")]
380 pub top_p: Option<f32>,
381
382 #[serde(skip_serializing_if = "Option::is_none")]
387 pub top_k: Option<i32>,
388
389 #[serde(skip_serializing_if = "Option::is_none")]
393 pub max_output_tokens: Option<i32>,
394
395 #[serde(skip_serializing_if = "Option::is_none")]
399 pub candidate_count: Option<i32>,
400
401 #[serde(skip_serializing_if = "Option::is_none")]
405 pub stop_sequences: Option<Vec<String>>,
406
407 #[serde(skip_serializing_if = "Option::is_none")]
411 pub response_mime_type: Option<String>,
412
413 #[serde(skip_serializing_if = "Option::is_none")]
417 pub response_schema: Option<serde_json::Value>,
418}
419
420impl Default for GenerationConfig {
421 fn default() -> Self {
422 Self {
423 temperature: Some(0.7),
424 top_p: Some(0.95),
425 top_k: Some(40),
426 max_output_tokens: Some(1024),
427 candidate_count: Some(1),
428 stop_sequences: None,
429 response_mime_type: None,
430 response_schema: None,
431 }
432 }
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct ToolConfig {
438 #[serde(skip_serializing_if = "Option::is_none")]
440 pub function_calling_config: Option<FunctionCallingConfig>,
441}
442
443#[derive(Debug, Clone, Serialize, Deserialize)]
445pub struct FunctionCallingConfig {
446 pub mode: FunctionCallingMode,
448}
449
450#[derive(Debug, Clone, Serialize, Deserialize)]
452#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
453pub enum FunctionCallingMode {
454 Auto,
456 Any,
458 None,
460}
461
462#[derive(Debug, Clone, Serialize, Deserialize)]
464pub struct SafetySetting {
465 pub category: HarmCategory,
467 pub threshold: HarmBlockThreshold,
469}
470
471#[derive(Debug, Clone, Serialize, Deserialize)]
473#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
474pub enum HarmCategory {
475 Dangerous,
477 Harassment,
479 HateSpeech,
481 SexuallyExplicit,
483}
484
485#[allow(clippy::enum_variant_names)]
487#[derive(Debug, Clone, Serialize, Deserialize)]
488#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
489pub enum HarmBlockThreshold {
490 BlockLowAndAbove,
492 BlockMediumAndAbove,
494 BlockHighAndAbove,
496 BlockOnlyHigh,
498 BlockNone,
500}
501
502#[derive(Debug, Clone, Serialize, Deserialize)]
504#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
505pub enum TaskType {
506 SemanticSimilarity,
508 Classification,
510 Clustering,
512
513 RetrievalDocument,
515 RetrievalQuery,
516 QuestionAnswering,
517 FactVerification,
518
519 CodeRetrievalQuery,
522}