1pub const GEMINI_2_0_FLASH: &str = "gemini-2.0-flash";
8pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
10pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
12pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b";
14pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";
16
17use gemini_api_types::{
18 Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse,
19 GenerationConfig, Part, Role, Tool,
20};
21use serde_json::{Map, Value};
22use std::convert::TryFrom;
23
24use crate::{
25 completion::{self, CompletionError, CompletionRequest},
26 OneOrMany,
27};
28
29use super::Client;
30
31#[derive(Clone)]
36pub struct CompletionModel {
37 pub(crate) client: Client,
38 pub model: String,
39}
40
41impl CompletionModel {
42 pub fn new(client: Client, model: &str) -> Self {
43 Self {
44 client,
45 model: model.to_string(),
46 }
47 }
48}
49
50impl completion::CompletionModel for CompletionModel {
51 type Response = GenerateContentResponse;
52
53 #[cfg_attr(feature = "worker", worker::send)]
54 async fn completion(
55 &self,
56 completion_request: CompletionRequest,
57 ) -> Result<completion::CompletionResponse<GenerateContentResponse>, CompletionError> {
58 let request = create_request_body(completion_request)?;
59
60 tracing::debug!(
61 "Sending completion request to Gemini API {}",
62 serde_json::to_string_pretty(&request)?
63 );
64
65 let response = self
66 .client
67 .post(&format!("/v1beta/models/{}:generateContent", self.model))
68 .json(&request)
69 .send()
70 .await?;
71
72 if response.status().is_success() {
73 let response = response.json::<GenerateContentResponse>().await?;
74 match response.usage_metadata {
75 Some(ref usage) => tracing::info!(target: "rig",
76 "Gemini completion token usage: {}",
77 usage
78 ),
79 None => tracing::info!(target: "rig",
80 "Gemini completion token usage: n/a",
81 ),
82 }
83
84 tracing::debug!("Received response");
85
86 Ok(completion::CompletionResponse::try_from(response))
87 } else {
88 Err(CompletionError::ProviderError(response.text().await?))
89 }?
90 }
91}
92
93pub(crate) fn create_request_body(
94 mut completion_request: CompletionRequest,
95) -> Result<GenerateContentRequest, CompletionError> {
96 let mut full_history = Vec::new();
97 full_history.append(&mut completion_request.chat_history);
98 full_history.push(completion_request.prompt_with_context());
99
100 let additional_params = completion_request
101 .additional_params
102 .unwrap_or_else(|| Value::Object(Map::new()));
103
104 let mut generation_config = serde_json::from_value::<GenerationConfig>(additional_params)?;
105
106 if let Some(temp) = completion_request.temperature {
107 generation_config.temperature = Some(temp);
108 }
109
110 if let Some(max_tokens) = completion_request.max_tokens {
111 generation_config.max_output_tokens = Some(max_tokens);
112 }
113
114 let system_instruction = completion_request.preamble.clone().map(|preamble| Content {
115 parts: OneOrMany::one(preamble.into()),
116 role: Some(Role::Model),
117 });
118
119 let request = GenerateContentRequest {
120 contents: full_history
121 .into_iter()
122 .map(|msg| {
123 msg.try_into()
124 .map_err(|e| CompletionError::RequestError(Box::new(e)))
125 })
126 .collect::<Result<Vec<_>, _>>()?,
127 generation_config: Some(generation_config),
128 safety_settings: None,
129 tools: Some(
130 completion_request
131 .tools
132 .into_iter()
133 .map(Tool::try_from)
134 .collect::<Result<Vec<_>, _>>()?,
135 ),
136 tool_config: None,
137 system_instruction,
138 };
139
140 Ok(request)
141}
142
143impl TryFrom<completion::ToolDefinition> for Tool {
144 type Error = CompletionError;
145
146 fn try_from(tool: completion::ToolDefinition) -> Result<Self, Self::Error> {
147 Ok(Self {
148 function_declarations: FunctionDeclaration {
149 name: tool.name,
150 description: tool.description,
151 parameters: Some(tool.parameters.try_into()?),
152 },
153 code_execution: None,
154 })
155 }
156}
157
158impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<GenerateContentResponse> {
159 type Error = CompletionError;
160
161 fn try_from(response: GenerateContentResponse) -> Result<Self, Self::Error> {
162 let candidate = response.candidates.first().ok_or_else(|| {
163 CompletionError::ResponseError("No response candidates in response".into())
164 })?;
165
166 let content = candidate
167 .content
168 .parts
169 .iter()
170 .map(|part| {
171 Ok(match part {
172 Part::Text(text) => completion::AssistantContent::text(text),
173 Part::FunctionCall(function_call) => completion::AssistantContent::tool_call(
174 &function_call.name,
175 &function_call.name,
176 function_call.args.clone(),
177 ),
178 _ => {
179 return Err(CompletionError::ResponseError(
180 "Response did not contain a message or tool call".into(),
181 ))
182 }
183 })
184 })
185 .collect::<Result<Vec<_>, _>>()?;
186
187 let choice = OneOrMany::many(content).map_err(|_| {
188 CompletionError::ResponseError(
189 "Response contained no message or tool call (empty)".to_owned(),
190 )
191 })?;
192
193 Ok(completion::CompletionResponse {
194 choice,
195 raw_response: response,
196 })
197 }
198}
199
200pub mod gemini_api_types {
201 use std::{collections::HashMap, convert::Infallible, str::FromStr};
202
203 use serde::{Deserialize, Serialize};
207 use serde_json::Value;
208
209 use crate::{
210 completion::CompletionError,
211 message::{self, MimeType as _},
212 one_or_many::string_or_one_or_many,
213 providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode},
214 OneOrMany,
215 };
216
217 #[derive(Debug, Deserialize)]
225 #[serde(rename_all = "camelCase")]
226 pub struct GenerateContentResponse {
227 pub candidates: Vec<ContentCandidate>,
229 pub prompt_feedback: Option<PromptFeedback>,
231 pub usage_metadata: Option<UsageMetadata>,
233 pub model_version: Option<String>,
234 }
235
236 #[derive(Debug, Deserialize)]
238 #[serde(rename_all = "camelCase")]
239 pub struct ContentCandidate {
240 pub content: Content,
242 pub finish_reason: Option<FinishReason>,
245 pub safety_ratings: Option<Vec<SafetyRating>>,
248 pub citation_metadata: Option<CitationMetadata>,
252 pub token_count: Option<i32>,
254 pub avg_logprobs: Option<f64>,
256 pub logprobs_result: Option<LogprobsResult>,
258 pub index: Option<i32>,
260 }
261 #[derive(Debug, Deserialize, Serialize)]
262 pub struct Content {
263 #[serde(deserialize_with = "string_or_one_or_many")]
265 pub parts: OneOrMany<Part>,
266 pub role: Option<Role>,
269 }
270
271 impl TryFrom<message::Message> for Content {
272 type Error = message::MessageError;
273
274 fn try_from(msg: message::Message) -> Result<Self, Self::Error> {
275 Ok(match msg {
276 message::Message::User { content } => Content {
277 parts: content.try_map(|c| c.try_into())?,
278 role: Some(Role::User),
279 },
280 message::Message::Assistant { content } => Content {
281 role: Some(Role::Model),
282 parts: content.map(|content| content.into()),
283 },
284 })
285 }
286 }
287
288 impl TryFrom<Content> for message::Message {
289 type Error = message::MessageError;
290
291 fn try_from(content: Content) -> Result<Self, Self::Error> {
292 match content.role {
293 Some(Role::User) | None => Ok(message::Message::User {
294 content: content.parts.try_map(|part| {
295 Ok(match part {
296 Part::Text(text) => message::UserContent::text(text),
297 Part::InlineData(inline_data) => {
298 let mime_type =
299 message::MediaType::from_mime_type(&inline_data.mime_type);
300
301 match mime_type {
302 Some(message::MediaType::Image(media_type)) => {
303 message::UserContent::image(
304 inline_data.data,
305 Some(message::ContentFormat::default()),
306 Some(media_type),
307 Some(message::ImageDetail::default()),
308 )
309 }
310 Some(message::MediaType::Document(media_type)) => {
311 message::UserContent::document(
312 inline_data.data,
313 Some(message::ContentFormat::default()),
314 Some(media_type),
315 )
316 }
317 Some(message::MediaType::Audio(media_type)) => {
318 message::UserContent::audio(
319 inline_data.data,
320 Some(message::ContentFormat::default()),
321 Some(media_type),
322 )
323 }
324 _ => {
325 return Err(message::MessageError::ConversionError(
326 format!("Unsupported media type {:?}", mime_type),
327 ))
328 }
329 }
330 }
331 _ => {
332 return Err(message::MessageError::ConversionError(format!(
333 "Unsupported gemini content part type: {:?}",
334 part
335 )))
336 }
337 })
338 })?,
339 }),
340 Some(Role::Model) => Ok(message::Message::Assistant {
341 content: content.parts.try_map(|part| {
342 Ok(match part {
343 Part::Text(text) => message::AssistantContent::text(text),
344 Part::FunctionCall(function_call) => {
345 message::AssistantContent::ToolCall(function_call.into())
346 }
347 _ => {
348 return Err(message::MessageError::ConversionError(format!(
349 "Unsupported part type: {:?}",
350 part
351 )))
352 }
353 })
354 })?,
355 }),
356 }
357 }
358 }
359
360 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
361 #[serde(rename_all = "lowercase")]
362 pub enum Role {
363 User,
364 Model,
365 }
366
367 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
371 #[serde(rename_all = "camelCase")]
372 pub enum Part {
373 Text(String),
374 InlineData(Blob),
375 FunctionCall(FunctionCall),
376 FunctionResponse(FunctionResponse),
377 FileData(FileData),
378 ExecutableCode(ExecutableCode),
379 CodeExecutionResult(CodeExecutionResult),
380 }
381
382 impl From<String> for Part {
383 fn from(text: String) -> Self {
384 Self::Text(text)
385 }
386 }
387
388 impl From<&str> for Part {
389 fn from(text: &str) -> Self {
390 Self::Text(text.to_string())
391 }
392 }
393
394 impl FromStr for Part {
395 type Err = Infallible;
396
397 fn from_str(s: &str) -> Result<Self, Self::Err> {
398 Ok(s.into())
399 }
400 }
401
402 impl TryFrom<message::UserContent> for Part {
403 type Error = message::MessageError;
404
405 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
406 match content {
407 message::UserContent::Text(message::Text { text }) => Ok(Self::Text(text)),
408 message::UserContent::ToolResult(message::ToolResult { id, content }) => {
409 let content = match content.first() {
410 message::ToolResultContent::Text(text) => text.text,
411 message::ToolResultContent::Image(_) => {
412 return Err(message::MessageError::ConversionError(
413 "Tool result content must be text".to_string(),
414 ))
415 }
416 };
417 Ok(Part::FunctionResponse(FunctionResponse {
418 name: id,
419 response: Some(serde_json::from_str(&content).map_err(|e| {
420 message::MessageError::ConversionError(format!(
421 "Failed to parse tool response: {}",
422 e
423 ))
424 })?),
425 }))
426 }
427 message::UserContent::Image(message::Image {
428 data, media_type, ..
429 }) => match media_type {
430 Some(media_type) => match media_type {
431 message::ImageMediaType::JPEG
432 | message::ImageMediaType::PNG
433 | message::ImageMediaType::WEBP
434 | message::ImageMediaType::HEIC
435 | message::ImageMediaType::HEIF => Ok(Self::InlineData(Blob {
436 mime_type: media_type.to_mime_type().to_owned(),
437 data,
438 })),
439 _ => Err(message::MessageError::ConversionError(format!(
440 "Unsupported image media type {:?}",
441 media_type
442 ))),
443 },
444 None => Err(message::MessageError::ConversionError(
445 "Media type for image is required for Anthropic".to_string(),
446 )),
447 },
448 message::UserContent::Document(message::Document {
449 data, media_type, ..
450 }) => match media_type {
451 Some(media_type) => match media_type {
452 message::DocumentMediaType::PDF
453 | message::DocumentMediaType::TXT
454 | message::DocumentMediaType::RTF
455 | message::DocumentMediaType::HTML
456 | message::DocumentMediaType::CSS
457 | message::DocumentMediaType::MARKDOWN
458 | message::DocumentMediaType::CSV
459 | message::DocumentMediaType::XML => Ok(Self::InlineData(Blob {
460 mime_type: media_type.to_mime_type().to_owned(),
461 data,
462 })),
463 _ => Err(message::MessageError::ConversionError(format!(
464 "Unsupported document media type {:?}",
465 media_type
466 ))),
467 },
468 None => Err(message::MessageError::ConversionError(
469 "Media type for document is required for Anthropic".to_string(),
470 )),
471 },
472 message::UserContent::Audio(message::Audio {
473 data, media_type, ..
474 }) => match media_type {
475 Some(media_type) => Ok(Self::InlineData(Blob {
476 mime_type: media_type.to_mime_type().to_owned(),
477 data,
478 })),
479 None => Err(message::MessageError::ConversionError(
480 "Media type for audio is required for Anthropic".to_string(),
481 )),
482 },
483 }
484 }
485 }
486
487 impl From<message::AssistantContent> for Part {
488 fn from(content: message::AssistantContent) -> Self {
489 match content {
490 message::AssistantContent::Text(message::Text { text }) => text.into(),
491 message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
492 }
493 }
494 }
495
496 impl From<message::ToolCall> for Part {
497 fn from(tool_call: message::ToolCall) -> Self {
498 Self::FunctionCall(FunctionCall {
499 name: tool_call.function.name,
500 args: tool_call.function.arguments,
501 })
502 }
503 }
504
505 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
508 #[serde(rename_all = "camelCase")]
509 pub struct Blob {
510 pub mime_type: String,
513 pub data: String,
515 }
516
517 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
520 pub struct FunctionCall {
521 pub name: String,
524 pub args: serde_json::Value,
526 }
527
528 impl From<FunctionCall> for message::ToolCall {
529 fn from(function_call: FunctionCall) -> Self {
530 Self {
531 id: function_call.name.clone(),
532 function: message::ToolFunction {
533 name: function_call.name,
534 arguments: function_call.args,
535 },
536 }
537 }
538 }
539
540 impl From<message::ToolCall> for FunctionCall {
541 fn from(tool_call: message::ToolCall) -> Self {
542 Self {
543 name: tool_call.function.name,
544 args: tool_call.function.arguments,
545 }
546 }
547 }
548
549 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
553 pub struct FunctionResponse {
554 pub name: String,
557 pub response: Option<HashMap<String, serde_json::Value>>,
559 }
560
561 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
563 #[serde(rename_all = "camelCase")]
564 pub struct FileData {
565 pub mime_type: Option<String>,
567 pub file_uri: String,
569 }
570
571 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
572 pub struct SafetyRating {
573 pub category: HarmCategory,
574 pub probability: HarmProbability,
575 }
576
577 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
578 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
579 pub enum HarmProbability {
580 HarmProbabilityUnspecified,
581 Negligible,
582 Low,
583 Medium,
584 High,
585 }
586
587 #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
588 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
589 pub enum HarmCategory {
590 HarmCategoryUnspecified,
591 HarmCategoryDerogatory,
592 HarmCategoryToxicity,
593 HarmCategoryViolence,
594 HarmCategorySexually,
595 HarmCategoryMedical,
596 HarmCategoryDangerous,
597 HarmCategoryHarassment,
598 HarmCategoryHateSpeech,
599 HarmCategorySexuallyExplicit,
600 HarmCategoryDangerousContent,
601 HarmCategoryCivicIntegrity,
602 }
603
604 #[derive(Debug, Deserialize)]
605 #[serde(rename_all = "camelCase")]
606 pub struct UsageMetadata {
607 pub prompt_token_count: i32,
608 #[serde(skip_serializing_if = "Option::is_none")]
609 pub cached_content_token_count: Option<i32>,
610 pub candidates_token_count: i32,
611 pub total_token_count: i32,
612 }
613
614 impl std::fmt::Display for UsageMetadata {
615 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
616 write!(
617 f,
618 "Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
619 self.prompt_token_count,
620 match self.cached_content_token_count {
621 Some(count) => count.to_string(),
622 None => "n/a".to_string(),
623 },
624 self.candidates_token_count,
625 self.total_token_count
626 )
627 }
628 }
629
630 #[derive(Debug, Deserialize)]
632 #[serde(rename_all = "camelCase")]
633 pub struct PromptFeedback {
634 pub block_reason: Option<BlockReason>,
636 pub safety_ratings: Option<Vec<SafetyRating>>,
638 }
639
640 #[derive(Debug, Deserialize)]
642 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
643 pub enum BlockReason {
644 BlockReasonUnspecified,
646 Safety,
648 Other,
650 Blocklist,
652 ProhibitedContent,
654 }
655
656 #[derive(Debug, Deserialize)]
657 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
658 pub enum FinishReason {
659 FinishReasonUnspecified,
661 Stop,
663 MaxTokens,
665 Safety,
667 Recitation,
669 Language,
671 Other,
673 Blocklist,
675 ProhibitedContent,
677 Spii,
679 MalformedFunctionCall,
681 }
682
683 #[derive(Debug, Deserialize)]
684 #[serde(rename_all = "camelCase")]
685 pub struct CitationMetadata {
686 pub citation_sources: Vec<CitationSource>,
687 }
688
689 #[derive(Debug, Deserialize)]
690 #[serde(rename_all = "camelCase")]
691 pub struct CitationSource {
692 #[serde(skip_serializing_if = "Option::is_none")]
693 pub uri: Option<String>,
694 #[serde(skip_serializing_if = "Option::is_none")]
695 pub start_index: Option<i32>,
696 #[serde(skip_serializing_if = "Option::is_none")]
697 pub end_index: Option<i32>,
698 #[serde(skip_serializing_if = "Option::is_none")]
699 pub license: Option<String>,
700 }
701
702 #[derive(Debug, Deserialize)]
703 #[serde(rename_all = "camelCase")]
704 pub struct LogprobsResult {
705 pub top_candidate: Vec<TopCandidate>,
706 pub chosen_candidate: Vec<LogProbCandidate>,
707 }
708
709 #[derive(Debug, Deserialize)]
710 pub struct TopCandidate {
711 pub candidates: Vec<LogProbCandidate>,
712 }
713
714 #[derive(Debug, Deserialize)]
715 #[serde(rename_all = "camelCase")]
716 pub struct LogProbCandidate {
717 pub token: String,
718 pub token_id: String,
719 pub log_probability: f64,
720 }
721
722 #[derive(Debug, Deserialize, Serialize)]
727 #[serde(rename_all = "camelCase")]
728 pub struct GenerationConfig {
729 #[serde(skip_serializing_if = "Option::is_none")]
732 pub stop_sequences: Option<Vec<String>>,
733 #[serde(skip_serializing_if = "Option::is_none")]
739 pub response_mime_type: Option<String>,
740 #[serde(skip_serializing_if = "Option::is_none")]
744 pub response_schema: Option<Schema>,
745 #[serde(skip_serializing_if = "Option::is_none")]
748 pub candidate_count: Option<i32>,
749 #[serde(skip_serializing_if = "Option::is_none")]
752 pub max_output_tokens: Option<u64>,
753 #[serde(skip_serializing_if = "Option::is_none")]
756 pub temperature: Option<f64>,
757 #[serde(skip_serializing_if = "Option::is_none")]
764 pub top_p: Option<f64>,
765 #[serde(skip_serializing_if = "Option::is_none")]
771 pub top_k: Option<i32>,
772 #[serde(skip_serializing_if = "Option::is_none")]
778 pub presence_penalty: Option<f64>,
779 #[serde(skip_serializing_if = "Option::is_none")]
787 pub frequency_penalty: Option<f64>,
788 #[serde(skip_serializing_if = "Option::is_none")]
790 pub response_logprobs: Option<bool>,
791 #[serde(skip_serializing_if = "Option::is_none")]
794 pub logprobs: Option<i32>,
795 }
796
797 impl Default for GenerationConfig {
798 fn default() -> Self {
799 Self {
800 temperature: Some(1.0),
801 max_output_tokens: Some(4096),
802 stop_sequences: None,
803 response_mime_type: None,
804 response_schema: None,
805 candidate_count: None,
806 top_p: None,
807 top_k: None,
808 presence_penalty: None,
809 frequency_penalty: None,
810 response_logprobs: None,
811 logprobs: None,
812 }
813 }
814 }
815 #[derive(Debug, Deserialize, Serialize)]
819 pub struct Schema {
820 pub r#type: String,
821 #[serde(skip_serializing_if = "Option::is_none")]
822 pub format: Option<String>,
823 #[serde(skip_serializing_if = "Option::is_none")]
824 pub description: Option<String>,
825 #[serde(skip_serializing_if = "Option::is_none")]
826 pub nullable: Option<bool>,
827 #[serde(skip_serializing_if = "Option::is_none")]
828 pub r#enum: Option<Vec<String>>,
829 #[serde(skip_serializing_if = "Option::is_none")]
830 pub max_items: Option<i32>,
831 #[serde(skip_serializing_if = "Option::is_none")]
832 pub min_items: Option<i32>,
833 #[serde(skip_serializing_if = "Option::is_none")]
834 pub properties: Option<HashMap<String, Schema>>,
835 #[serde(skip_serializing_if = "Option::is_none")]
836 pub required: Option<Vec<String>>,
837 #[serde(skip_serializing_if = "Option::is_none")]
838 pub items: Option<Box<Schema>>,
839 }
840
841 impl TryFrom<Value> for Schema {
842 type Error = CompletionError;
843
844 fn try_from(value: Value) -> Result<Self, Self::Error> {
845 if let Some(obj) = value.as_object() {
846 Ok(Schema {
847 r#type: obj
848 .get("type")
849 .and_then(|v| {
850 if v.is_string() {
851 v.as_str().map(String::from)
852 } else if v.is_array() {
853 v.as_array()
854 .and_then(|arr| arr.first())
855 .and_then(|v| v.as_str().map(String::from))
856 } else {
857 None
858 }
859 })
860 .unwrap_or_default(),
861 format: obj.get("format").and_then(|v| v.as_str()).map(String::from),
862 description: obj
863 .get("description")
864 .and_then(|v| v.as_str())
865 .map(String::from),
866 nullable: obj.get("nullable").and_then(|v| v.as_bool()),
867 r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| {
868 arr.iter()
869 .filter_map(|v| v.as_str().map(String::from))
870 .collect()
871 }),
872 max_items: obj
873 .get("maxItems")
874 .and_then(|v| v.as_i64())
875 .map(|v| v as i32),
876 min_items: obj
877 .get("minItems")
878 .and_then(|v| v.as_i64())
879 .map(|v| v as i32),
880 properties: obj
881 .get("properties")
882 .and_then(|v| v.as_object())
883 .map(|map| {
884 map.iter()
885 .filter_map(|(k, v)| {
886 v.clone().try_into().ok().map(|schema| (k.clone(), schema))
887 })
888 .collect()
889 }),
890 required: obj.get("required").and_then(|v| v.as_array()).map(|arr| {
891 arr.iter()
892 .filter_map(|v| v.as_str().map(String::from))
893 .collect()
894 }),
895 items: obj
896 .get("items")
897 .map(|v| Box::new(v.clone().try_into().unwrap())),
898 })
899 } else {
900 Err(CompletionError::ResponseError(
901 "Expected a JSON object for Schema".into(),
902 ))
903 }
904 }
905 }
906
907 #[derive(Debug, Serialize)]
908 #[serde(rename_all = "camelCase")]
909 pub struct GenerateContentRequest {
910 pub contents: Vec<Content>,
911 pub tools: Option<Vec<Tool>>,
912 pub tool_config: Option<ToolConfig>,
913 pub generation_config: Option<GenerationConfig>,
915 pub safety_settings: Option<Vec<SafetySetting>>,
929 pub system_instruction: Option<Content>,
932 }
934
935 #[derive(Debug, Serialize)]
936 #[serde(rename_all = "camelCase")]
937 pub struct Tool {
938 pub function_declarations: FunctionDeclaration,
939 pub code_execution: Option<CodeExecution>,
940 }
941
942 #[derive(Debug, Serialize)]
943 #[serde(rename_all = "camelCase")]
944 pub struct FunctionDeclaration {
945 pub name: String,
946 pub description: String,
947 #[serde(skip_serializing_if = "Option::is_none")]
948 pub parameters: Option<Schema>,
949 }
950
951 #[derive(Debug, Serialize)]
952 #[serde(rename_all = "camelCase")]
953 pub struct ToolConfig {
954 pub schema: Option<Schema>,
955 }
956
957 #[derive(Debug, Serialize)]
958 #[serde(rename_all = "camelCase")]
959 pub struct CodeExecution {}
960
961 #[derive(Debug, Serialize)]
962 #[serde(rename_all = "camelCase")]
963 pub struct SafetySetting {
964 pub category: HarmCategory,
965 pub threshold: HarmBlockThreshold,
966 }
967
968 #[derive(Debug, Serialize)]
969 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
970 pub enum HarmBlockThreshold {
971 HarmBlockThresholdUnspecified,
972 BlockLowAndAbove,
973 BlockMediumAndAbove,
974 BlockOnlyHigh,
975 BlockNone,
976 Off,
977 }
978}
979
980#[cfg(test)]
981mod tests {
982 use crate::message;
983
984 use super::*;
985 use serde_json::json;
986
987 #[test]
988 fn test_deserialize_message_user() {
989 let raw_message = r#"{
990 "parts": [
991 {"text": "Hello, world!"},
992 {"inlineData": {"mimeType": "image/png", "data": "base64encodeddata"}},
993 {"functionCall": {"name": "test_function", "args": {"arg1": "value1"}}},
994 {"functionResponse": {"name": "test_function", "response": {"result": "success"}}},
995 {"fileData": {"mimeType": "application/pdf", "fileUri": "http://example.com/file.pdf"}},
996 {"executableCode": {"code": "print('Hello, world!')", "language": "PYTHON"}},
997 {"codeExecutionResult": {"output": "Hello, world!", "outcome": "OUTCOME_OK"}}
998 ],
999 "role": "user"
1000 }"#;
1001
1002 let content: Content = {
1003 let jd = &mut serde_json::Deserializer::from_str(raw_message);
1004 serde_path_to_error::deserialize(jd).unwrap_or_else(|err| {
1005 panic!("Deserialization error at {}: {}", err.path(), err);
1006 })
1007 };
1008 assert_eq!(content.role, Some(Role::User));
1009 assert_eq!(content.parts.len(), 7);
1010
1011 let parts: Vec<Part> = content.parts.into_iter().collect();
1012
1013 if let Part::Text(text) = &parts[0] {
1014 assert_eq!(text, "Hello, world!");
1015 } else {
1016 panic!("Expected text part");
1017 }
1018
1019 if let Part::InlineData(inline_data) = &parts[1] {
1020 assert_eq!(inline_data.mime_type, "image/png");
1021 assert_eq!(inline_data.data, "base64encodeddata");
1022 } else {
1023 panic!("Expected inline data part");
1024 }
1025
1026 if let Part::FunctionCall(function_call) = &parts[2] {
1027 assert_eq!(function_call.name, "test_function");
1028 assert_eq!(
1029 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1030 "value1"
1031 );
1032 } else {
1033 panic!("Expected function call part");
1034 }
1035
1036 if let Part::FunctionResponse(function_response) = &parts[3] {
1037 assert_eq!(function_response.name, "test_function");
1038 assert_eq!(
1039 function_response
1040 .response
1041 .as_ref()
1042 .unwrap()
1043 .get("result")
1044 .unwrap(),
1045 "success"
1046 );
1047 } else {
1048 panic!("Expected function response part");
1049 }
1050
1051 if let Part::FileData(file_data) = &parts[4] {
1052 assert_eq!(file_data.mime_type.as_ref().unwrap(), "application/pdf");
1053 assert_eq!(file_data.file_uri, "http://example.com/file.pdf");
1054 } else {
1055 panic!("Expected file data part");
1056 }
1057
1058 if let Part::ExecutableCode(executable_code) = &parts[5] {
1059 assert_eq!(executable_code.code, "print('Hello, world!')");
1060 } else {
1061 panic!("Expected executable code part");
1062 }
1063
1064 if let Part::CodeExecutionResult(code_execution_result) = &parts[6] {
1065 assert_eq!(
1066 code_execution_result.clone().output.unwrap(),
1067 "Hello, world!"
1068 );
1069 } else {
1070 panic!("Expected code execution result part");
1071 }
1072 }
1073
1074 #[test]
1075 fn test_deserialize_message_model() {
1076 let json_data = json!({
1077 "parts": [{"text": "Hello, user!"}],
1078 "role": "model"
1079 });
1080
1081 let content: Content = serde_json::from_value(json_data).unwrap();
1082 assert_eq!(content.role, Some(Role::Model));
1083 assert_eq!(content.parts.len(), 1);
1084 if let Part::Text(text) = &content.parts.first() {
1085 assert_eq!(text, "Hello, user!");
1086 } else {
1087 panic!("Expected text part");
1088 }
1089 }
1090
1091 #[test]
1092 fn test_message_conversion_user() {
1093 let msg = message::Message::user("Hello, world!");
1094 let content: Content = msg.try_into().unwrap();
1095 assert_eq!(content.role, Some(Role::User));
1096 assert_eq!(content.parts.len(), 1);
1097 if let Part::Text(text) = &content.parts.first() {
1098 assert_eq!(text, "Hello, world!");
1099 } else {
1100 panic!("Expected text part");
1101 }
1102 }
1103
1104 #[test]
1105 fn test_message_conversion_model() {
1106 let msg = message::Message::assistant("Hello, user!");
1107
1108 let content: Content = msg.try_into().unwrap();
1109 assert_eq!(content.role, Some(Role::Model));
1110 assert_eq!(content.parts.len(), 1);
1111 if let Part::Text(text) = &content.parts.first() {
1112 assert_eq!(text, "Hello, user!");
1113 } else {
1114 panic!("Expected text part");
1115 }
1116 }
1117
1118 #[test]
1119 fn test_message_conversion_tool_call() {
1120 let tool_call = message::ToolCall {
1121 id: "test_tool".to_string(),
1122 function: message::ToolFunction {
1123 name: "test_function".to_string(),
1124 arguments: json!({"arg1": "value1"}),
1125 },
1126 };
1127
1128 let msg = message::Message::Assistant {
1129 content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)),
1130 };
1131
1132 let content: Content = msg.try_into().unwrap();
1133 assert_eq!(content.role, Some(Role::Model));
1134 assert_eq!(content.parts.len(), 1);
1135 if let Part::FunctionCall(function_call) = &content.parts.first() {
1136 assert_eq!(function_call.name, "test_function");
1137 assert_eq!(
1138 function_call.args.as_object().unwrap().get("arg1").unwrap(),
1139 "value1"
1140 );
1141 } else {
1142 panic!("Expected function call part");
1143 }
1144 }
1145}