Skip to main content

inference_runtime_gemini/
wire.rs

1use serde::{Deserialize, Serialize};
2
3use inference_core::batch::{ContentPart, ExecuteBatch, MessageContent, Role};
4
5#[derive(Debug, Serialize)]
6pub struct GenerateContentRequest<'a> {
7    pub contents: Vec<Content>,
8    #[serde(skip_serializing_if = "Option::is_none")]
9    pub system_instruction: Option<Content>,
10    #[serde(skip_serializing_if = "Option::is_none")]
11    pub generation_config: Option<GenerationConfig>,
12    #[serde(skip_serializing_if = "Vec::is_empty", rename = "safetySettings")]
13    pub safety_settings: Vec<crate::config::SafetySetting>,
14    #[serde(skip)]
15    _model_lifetime: std::marker::PhantomData<&'a ()>,
16}
17
18#[derive(Debug, Serialize)]
19pub struct Content {
20    pub role: String,
21    pub parts: Vec<Part>,
22}
23
24#[derive(Debug, Serialize)]
25#[serde(untagged)]
26pub enum Part {
27    Text {
28        text: String,
29    },
30    InlineData {
31        #[serde(rename = "inlineData")]
32        inline_data: InlineData,
33    },
34    FileData {
35        #[serde(rename = "fileData")]
36        file_data: FileData,
37    },
38}
39
40#[derive(Debug, Serialize)]
41pub struct InlineData {
42    #[serde(rename = "mimeType")]
43    pub mime_type: String,
44    pub data: String,
45}
46
47#[derive(Debug, Serialize)]
48pub struct FileData {
49    #[serde(rename = "mimeType")]
50    pub mime_type: String,
51    #[serde(rename = "fileUri")]
52    pub file_uri: String,
53}
54
55#[derive(Debug, Serialize, Default)]
56pub struct GenerationConfig {
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub temperature: Option<f32>,
59    #[serde(skip_serializing_if = "Option::is_none", rename = "topP")]
60    pub top_p: Option<f32>,
61    #[serde(skip_serializing_if = "Option::is_none", rename = "topK")]
62    pub top_k: Option<u32>,
63    #[serde(skip_serializing_if = "Option::is_none", rename = "maxOutputTokens")]
64    pub max_output_tokens: Option<u32>,
65    #[serde(skip_serializing_if = "Vec::is_empty", rename = "stopSequences")]
66    pub stop_sequences: Vec<String>,
67}
68
69impl GenerateContentRequest<'_> {
70    pub fn from_batch<'b>(
71        b: &'b ExecuteBatch,
72        safety: Vec<crate::config::SafetySetting>,
73    ) -> GenerateContentRequest<'b> {
74        let mut system: Option<String> = None;
75        let mut contents = Vec::with_capacity(b.messages.len());
76        for m in &b.messages {
77            if matches!(m.role, Role::System) {
78                if let MessageContent::Text(t) = &m.content {
79                    system = Some(system.map(|s| format!("{s}\n{t}")).unwrap_or_else(|| t.clone()));
80                }
81                continue;
82            }
83            let role = match m.role {
84                Role::User | Role::Tool => "user",
85                Role::Assistant => "model",
86                Role::System => unreachable!(),
87            }
88            .to_string();
89            let parts = match &m.content {
90                MessageContent::Text(t) => vec![Part::Text { text: t.clone() }],
91                MessageContent::Parts(parts) => parts.iter().map(serialize_part).collect(),
92            };
93            contents.push(Content { role, parts });
94        }
95        let system_instruction = system.map(|t| Content {
96            role: "system".into(),
97            parts: vec![Part::Text { text: t }],
98        });
99        GenerateContentRequest {
100            contents,
101            system_instruction,
102            generation_config: Some(GenerationConfig {
103                temperature: b.sampling.temperature,
104                top_p: b.sampling.top_p,
105                top_k: b.sampling.top_k,
106                max_output_tokens: b.sampling.max_tokens,
107                stop_sequences: b.sampling.stop.clone(),
108            }),
109            safety_settings: safety,
110            _model_lifetime: std::marker::PhantomData,
111        }
112    }
113}
114
115fn serialize_part(p: &ContentPart) -> Part {
116    match p {
117        ContentPart::Text { text } => Part::Text { text: text.clone() },
118        ContentPart::ImageBase64 { mime, data } => Part::InlineData {
119            inline_data: InlineData {
120                mime_type: mime.clone(),
121                data: data.clone(),
122            },
123        },
124        ContentPart::ImageUrl { url } => Part::FileData {
125            file_data: FileData {
126                mime_type: "image/jpeg".into(),
127                file_uri: url.clone(),
128            },
129        },
130    }
131}
132
133// ---- response -------------------------------------------------------------
134
135#[derive(Debug, Deserialize)]
136pub struct GenerateContentResponse {
137    #[serde(default)]
138    pub candidates: Vec<Candidate>,
139    #[serde(default, rename = "usageMetadata")]
140    pub usage_metadata: Option<UsageMetadata>,
141}
142
143#[derive(Debug, Deserialize)]
144pub struct Candidate {
145    #[serde(default)]
146    pub content: Option<ResponseContent>,
147    #[serde(default, rename = "finishReason")]
148    pub finish_reason: Option<String>,
149}
150
151#[derive(Debug, Deserialize)]
152pub struct ResponseContent {
153    #[serde(default)]
154    pub parts: Vec<ResponsePart>,
155}
156
157#[derive(Debug, Deserialize)]
158pub struct ResponsePart {
159    #[serde(default)]
160    pub text: Option<String>,
161}
162
163#[derive(Debug, Deserialize, Default, Clone, Copy)]
164pub struct UsageMetadata {
165    #[serde(default, rename = "promptTokenCount")]
166    pub prompt_token_count: u32,
167    #[serde(default, rename = "candidatesTokenCount")]
168    pub candidates_token_count: u32,
169    #[serde(default, rename = "cachedContentTokenCount")]
170    pub cached_content_token_count: u32,
171}