1use std::pin::Pin;
4use std::sync::Arc;
5
6use base64::engine::general_purpose::STANDARD;
7use base64::Engine as _;
8use futures_util::{Stream, StreamExt};
9use rust_genai_types::content::{Content, FunctionCall, Part, Role};
10use rust_genai_types::converters;
11use rust_genai_types::models::{
12 ComputeTokensConfig, ComputeTokensRequest, ComputeTokensResponse, ContentEmbedding,
13 CountTokensConfig, CountTokensRequest, CountTokensResponse, DeleteModelConfig,
14 DeleteModelResponse, EditImageConfig, EditImageResponse, EmbedContentConfig,
15 EmbedContentMetadata, EmbedContentResponse, EntityLabel, GenerateContentConfig,
16 GenerateContentRequest, GenerateImagesConfig, GenerateImagesResponse, GenerateVideosConfig,
17 GenerateVideosSource, GeneratedImage, GeneratedImageMask, Image, ListModelsConfig,
18 ListModelsResponse, Model, RecontextImageConfig, RecontextImageResponse, RecontextImageSource,
19 ReferenceImage, SafetyAttributes, SegmentImageConfig, SegmentImageResponse, SegmentImageSource,
20 UpdateModelConfig, Video, VideoGenerationMask, VideoGenerationReferenceImage,
21};
22use rust_genai_types::response::GenerateContentResponse;
23
24use crate::afc::{
25 call_callable_tools, max_remote_calls, resolve_callable_tools, should_append_history,
26 should_disable_afc, validate_afc_config, validate_afc_tools, CallableTool,
27};
28use crate::client::{Backend, ClientInner};
29use crate::error::{Error, Result};
30use crate::model_capabilities::{
31 validate_code_execution_image_inputs, validate_function_response_media,
32};
33use crate::sse::parse_sse_stream;
34use crate::thinking::{validate_temperature, ThoughtSignatureValidator};
35use crate::tokenizer::TokenEstimator;
36use serde_json::{Map, Number, Value};
37
38#[derive(Clone)]
39pub struct Models {
40 pub(crate) inner: Arc<ClientInner>,
41}
42
43impl Models {
44 pub(crate) fn new(inner: Arc<ClientInner>) -> Self {
45 Self { inner }
46 }
47
48 pub async fn generate_content(
50 &self,
51 model: impl Into<String>,
52 contents: Vec<Content>,
53 ) -> Result<GenerateContentResponse> {
54 self.generate_content_with_config(model, contents, GenerateContentConfig::default())
55 .await
56 }
57
58 pub async fn generate_content_with_config(
60 &self,
61 model: impl Into<String>,
62 contents: Vec<Content>,
63 config: GenerateContentConfig,
64 ) -> Result<GenerateContentResponse> {
65 let model = model.into();
66 validate_temperature(&model, &config)?;
67 ThoughtSignatureValidator::new(&model).validate(&contents)?;
68 validate_function_response_media(&model, &contents)?;
69 validate_code_execution_image_inputs(&model, &contents, config.tools.as_deref())?;
70
71 let request = GenerateContentRequest {
72 contents,
73 system_instruction: config.system_instruction,
74 generation_config: config.generation_config,
75 safety_settings: config.safety_settings,
76 tools: config.tools,
77 tool_config: config.tool_config,
78 cached_content: config.cached_content,
79 labels: config.labels,
80 };
81
82 let backend = self.inner.config.backend;
83 let url = build_model_method_url(&self.inner, &model, "generateContent")?;
84 let body = match backend {
85 Backend::GeminiApi => converters::generate_content_request_to_mldev(&request)?,
86 Backend::VertexAi => converters::generate_content_request_to_vertex(&request)?,
87 };
88
89 let request = self.inner.http.post(url).json(&body);
90 let response = self.inner.send(request).await?;
91 if !response.status().is_success() {
92 return Err(Error::ApiError {
93 status: response.status().as_u16(),
94 message: response.text().await.unwrap_or_default(),
95 });
96 }
97 let value = response.json::<Value>().await?;
98 let result = match backend {
99 Backend::GeminiApi => converters::generate_content_response_from_mldev(value)?,
100 Backend::VertexAi => converters::generate_content_response_from_vertex(value)?,
101 };
102 Ok(result)
103 }
104
105 pub async fn generate_content_with_callable_tools(
107 &self,
108 model: impl Into<String>,
109 contents: Vec<Content>,
110 config: GenerateContentConfig,
111 mut callable_tools: Vec<Box<dyn CallableTool>>,
112 ) -> Result<GenerateContentResponse> {
113 let model = model.into();
114 if callable_tools.is_empty() {
115 return self
116 .generate_content_with_config(model, contents, config)
117 .await;
118 }
119
120 validate_afc_config(&config)?;
121
122 let mut callable_info = resolve_callable_tools(&mut callable_tools).await?;
123 let has_callable = !callable_info.function_map.is_empty();
124 let mut merged_tools = config.tools.clone().unwrap_or_default();
125 merged_tools.append(&mut callable_info.tools);
126
127 let mut request_config = config.clone();
128 request_config.tools = Some(merged_tools);
129
130 if should_disable_afc(&config, has_callable) {
131 return self
132 .generate_content_with_config(model, contents, request_config)
133 .await;
134 }
135
136 validate_afc_tools(&callable_info.function_map, config.tools.as_deref())?;
137
138 let max_calls = max_remote_calls(&config);
139 let append_history = should_append_history(&config);
140 let mut history: Vec<Content> = Vec::new();
141 let mut conversation = contents.clone();
142 let mut remaining_calls = max_calls;
143 let mut response = self
144 .generate_content_with_config(&model, conversation.clone(), request_config.clone())
145 .await?;
146
147 loop {
148 let function_calls: Vec<FunctionCall> =
149 response.function_calls().into_iter().cloned().collect();
150
151 if function_calls.is_empty() {
152 if append_history && !history.is_empty() {
153 response.automatic_function_calling_history = Some(history);
154 }
155 return Ok(response);
156 }
157
158 if remaining_calls == 0 {
159 break;
160 }
161
162 let response_parts = call_callable_tools(
163 &mut callable_tools,
164 &callable_info.function_map,
165 &function_calls,
166 )
167 .await?;
168 if response_parts.is_empty() {
169 break;
170 }
171
172 let call_content = build_function_call_content(&function_calls);
173 let response_content = Content::from_parts(response_parts.clone(), Role::Function);
174
175 if append_history {
176 if history.is_empty() {
177 history.extend(conversation.clone());
178 }
179 history.push(call_content.clone());
180 history.push(response_content.clone());
181 }
182
183 conversation.push(call_content);
184 conversation.push(response_content);
185 remaining_calls = remaining_calls.saturating_sub(1);
186
187 response = self
188 .generate_content_with_config(&model, conversation.clone(), request_config.clone())
189 .await?;
190 }
191
192 if append_history && !history.is_empty() {
193 response.automatic_function_calling_history = Some(history);
194 }
195 Ok(response)
196 }
197
198 pub async fn generate_content_stream_with_callable_tools(
200 &self,
201 model: impl Into<String>,
202 contents: Vec<Content>,
203 config: GenerateContentConfig,
204 mut callable_tools: Vec<Box<dyn CallableTool>>,
205 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerateContentResponse>> + Send>>> {
206 let model = model.into();
207 if callable_tools.is_empty() {
208 return self.generate_content_stream(model, contents, config).await;
209 }
210
211 validate_afc_config(&config)?;
212
213 let callable_info = resolve_callable_tools(&mut callable_tools).await?;
214 let function_map = callable_info.function_map;
215 let has_callable = !function_map.is_empty();
216 let mut merged_tools = config.tools.clone().unwrap_or_default();
217 merged_tools.extend(callable_info.tools);
218
219 let mut request_config = config.clone();
220 request_config.tools = Some(merged_tools);
221
222 if should_disable_afc(&config, has_callable) {
223 return self
224 .generate_content_stream(model, contents, request_config)
225 .await;
226 }
227
228 validate_afc_tools(&function_map, config.tools.as_deref())?;
229
230 let max_calls = max_remote_calls(&config);
231 let append_history = should_append_history(&config);
232 let (tx, rx) = tokio::sync::mpsc::channel(8);
233 let models = self.clone();
234
235 tokio::spawn(async move {
236 let mut conversation = contents;
237 let mut history: Vec<Content> = Vec::new();
238 let mut remaining_calls = max_calls;
239 let mut callable_tools = callable_tools;
240 let request_config = request_config;
241
242 loop {
243 if remaining_calls == 0 {
244 break;
245 }
246
247 let stream = match models
248 .generate_content_stream(&model, conversation.clone(), request_config.clone())
249 .await
250 {
251 Ok(stream) => stream,
252 Err(err) => {
253 let _ = tx.send(Err(err)).await;
254 break;
255 }
256 };
257
258 let mut stream = stream;
259 let mut function_calls: Vec<FunctionCall> = Vec::new();
260 let mut response_contents: Vec<Content> = Vec::new();
261
262 while let Some(item) = stream.next().await {
263 if let Ok(response) = &item {
264 if let Some(content) =
265 response.candidates.first().and_then(|c| c.content.clone())
266 {
267 for part in &content.parts {
268 if let Some(call) = part.function_call_ref() {
269 function_calls.push(call.clone());
270 }
271 }
272 response_contents.push(content);
273 }
274 }
275
276 if tx.send(item).await.is_err() {
277 return;
278 }
279 }
280
281 if function_calls.is_empty() {
282 break;
283 }
284
285 let response_parts =
286 match call_callable_tools(&mut callable_tools, &function_map, &function_calls)
287 .await
288 {
289 Ok(parts) => parts,
290 Err(err) => {
291 let _ = tx.send(Err(err)).await;
292 break;
293 }
294 };
295
296 if response_parts.is_empty() {
297 break;
298 }
299
300 let call_content = build_function_call_content(&function_calls);
301 let response_content = Content::from_parts(response_parts.clone(), Role::Function);
302
303 if append_history {
304 if history.is_empty() {
305 history.extend(conversation.clone());
306 }
307 history.push(call_content.clone());
308 history.push(response_content.clone());
309 }
310
311 conversation.extend(response_contents);
312 conversation.push(call_content);
313 conversation.push(response_content.clone());
314 remaining_calls = remaining_calls.saturating_sub(1);
315
316 let mut synthetic = GenerateContentResponse {
317 candidates: vec![rust_genai_types::response::Candidate {
318 content: Some(response_content),
319 citation_metadata: None,
320 finish_message: None,
321 token_count: None,
322 finish_reason: None,
323 avg_logprobs: None,
324 grounding_metadata: None,
325 index: None,
326 logprobs_result: None,
327 safety_ratings: Vec::new(),
328 url_context_metadata: None,
329 }],
330 create_time: None,
331 automatic_function_calling_history: None,
332 prompt_feedback: None,
333 usage_metadata: None,
334 model_version: None,
335 response_id: None,
336 };
337
338 if append_history && !history.is_empty() {
339 synthetic.automatic_function_calling_history = Some(history.clone());
340 }
341
342 if tx.send(Ok(synthetic)).await.is_err() {
343 return;
344 }
345 }
346 });
347
348 let output = futures_util::stream::unfold(rx, |mut rx| async {
349 rx.recv().await.map(|item| (item, rx))
350 });
351
352 Ok(Box::pin(output))
353 }
354
355 pub async fn generate_content_stream(
357 &self,
358 model: impl Into<String>,
359 contents: Vec<Content>,
360 config: GenerateContentConfig,
361 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerateContentResponse>> + Send>>> {
362 let model = model.into();
363 validate_temperature(&model, &config)?;
364 ThoughtSignatureValidator::new(&model).validate(&contents)?;
365 validate_function_response_media(&model, &contents)?;
366 validate_code_execution_image_inputs(&model, &contents, config.tools.as_deref())?;
367
368 let request = GenerateContentRequest {
369 contents,
370 system_instruction: config.system_instruction,
371 generation_config: config.generation_config,
372 safety_settings: config.safety_settings,
373 tools: config.tools,
374 tool_config: config.tool_config,
375 cached_content: config.cached_content,
376 labels: config.labels,
377 };
378
379 let mut url = build_model_method_url(&self.inner, &model, "streamGenerateContent")?;
380 url.push_str("?alt=sse");
381
382 let request = self.inner.http.post(url).json(&request);
383 let response = self.inner.send(request).await?;
384 if !response.status().is_success() {
385 return Err(Error::ApiError {
386 status: response.status().as_u16(),
387 message: response.text().await.unwrap_or_default(),
388 });
389 }
390
391 Ok(Box::pin(parse_sse_stream(response)))
392 }
393
394 pub async fn embed_content(
396 &self,
397 model: impl Into<String>,
398 contents: Vec<Content>,
399 ) -> Result<EmbedContentResponse> {
400 self.embed_content_with_config(model, contents, EmbedContentConfig::default())
401 .await
402 }
403
404 pub async fn embed_content_with_config(
406 &self,
407 model: impl Into<String>,
408 contents: Vec<Content>,
409 config: EmbedContentConfig,
410 ) -> Result<EmbedContentResponse> {
411 let model = model.into();
412 let url = match self.inner.config.backend {
413 Backend::GeminiApi => {
414 build_model_method_url(&self.inner, &model, "batchEmbedContents")?
415 }
416 Backend::VertexAi => build_model_method_url(&self.inner, &model, "predict")?,
417 };
418
419 let body = match self.inner.config.backend {
420 Backend::GeminiApi => build_embed_body_gemini(&model, &contents, &config)?,
421 Backend::VertexAi => build_embed_body_vertex(&contents, &config)?,
422 };
423
424 let request = self.inner.http.post(url).json(&body);
425 let response = self.inner.send(request).await?;
426 if !response.status().is_success() {
427 return Err(Error::ApiError {
428 status: response.status().as_u16(),
429 message: response.text().await.unwrap_or_default(),
430 });
431 }
432
433 match self.inner.config.backend {
434 Backend::GeminiApi => Ok(response.json::<EmbedContentResponse>().await?),
435 Backend::VertexAi => {
436 let value = response.json::<Value>().await?;
437 Ok(convert_vertex_embed_response(value)?)
438 }
439 }
440 }
441
442 pub async fn count_tokens(
444 &self,
445 model: impl Into<String>,
446 contents: Vec<Content>,
447 ) -> Result<CountTokensResponse> {
448 self.count_tokens_with_config(model, contents, CountTokensConfig::default())
449 .await
450 }
451
452 pub async fn count_tokens_with_config(
454 &self,
455 model: impl Into<String>,
456 contents: Vec<Content>,
457 config: CountTokensConfig,
458 ) -> Result<CountTokensResponse> {
459 let request = CountTokensRequest {
460 contents,
461 system_instruction: config.system_instruction,
462 tools: config.tools,
463 generation_config: config.generation_config,
464 };
465
466 let backend = self.inner.config.backend;
467 let url = build_model_method_url(&self.inner, &model.into(), "countTokens")?;
468 let body = match backend {
469 Backend::GeminiApi => converters::count_tokens_request_to_mldev(&request)?,
470 Backend::VertexAi => converters::count_tokens_request_to_vertex(&request)?,
471 };
472 let request = self.inner.http.post(url).json(&body);
473 let response = self.inner.send(request).await?;
474 if !response.status().is_success() {
475 return Err(Error::ApiError {
476 status: response.status().as_u16(),
477 message: response.text().await.unwrap_or_default(),
478 });
479 }
480 let value = response.json::<Value>().await?;
481 let result = match backend {
482 Backend::GeminiApi => converters::count_tokens_response_from_mldev(value)?,
483 Backend::VertexAi => converters::count_tokens_response_from_vertex(value)?,
484 };
485 Ok(result)
486 }
487
488 pub async fn compute_tokens(
490 &self,
491 model: impl Into<String>,
492 contents: Vec<Content>,
493 ) -> Result<ComputeTokensResponse> {
494 self.compute_tokens_with_config(model, contents, ComputeTokensConfig::default())
495 .await
496 }
497
498 pub async fn compute_tokens_with_config(
500 &self,
501 model: impl Into<String>,
502 contents: Vec<Content>,
503 config: ComputeTokensConfig,
504 ) -> Result<ComputeTokensResponse> {
505 if self.inner.config.backend != Backend::VertexAi {
506 return Err(Error::InvalidConfig {
507 message: "Compute tokens is only supported in Vertex AI backend".into(),
508 });
509 }
510
511 let request = ComputeTokensRequest { contents };
512 let url = build_model_method_url(&self.inner, &model.into(), "computeTokens")?;
513 let mut body = converters::compute_tokens_request_to_vertex(&request)?;
514 if let Some(options) = config.http_options.as_ref() {
515 merge_extra_body(&mut body, options)?;
516 }
517
518 let mut request = self.inner.http.post(url).json(&body);
519 request = apply_http_options(request, config.http_options.as_ref())?;
520
521 let response = self.inner.send(request).await?;
522 if !response.status().is_success() {
523 return Err(Error::ApiError {
524 status: response.status().as_u16(),
525 message: response.text().await.unwrap_or_default(),
526 });
527 }
528 let value = response.json::<Value>().await?;
529 let result = converters::compute_tokens_response_from_vertex(value)?;
530 Ok(result)
531 }
532
533 pub fn estimate_tokens_local(
535 &self,
536 contents: &[Content],
537 estimator: &dyn TokenEstimator,
538 ) -> CountTokensResponse {
539 let total = estimator.estimate_tokens(contents) as i32;
540 CountTokensResponse {
541 total_tokens: Some(total),
542 cached_content_token_count: None,
543 }
544 }
545
546 pub fn estimate_tokens_local_with_config(
548 &self,
549 contents: &[Content],
550 config: &CountTokensConfig,
551 estimator: &dyn TokenEstimator,
552 ) -> CountTokensResponse {
553 let estimation_contents = crate::tokenizer::build_estimation_contents(contents, config);
554 let total = estimator.estimate_tokens(&estimation_contents) as i32;
555 CountTokensResponse {
556 total_tokens: Some(total),
557 cached_content_token_count: None,
558 }
559 }
560
561 pub async fn count_tokens_or_estimate(
563 &self,
564 model: impl Into<String>,
565 contents: Vec<Content>,
566 config: CountTokensConfig,
567 estimator: Option<&dyn TokenEstimator>,
568 ) -> Result<CountTokensResponse> {
569 if let Some(estimator) = estimator {
570 return Ok(self.estimate_tokens_local_with_config(&contents, &config, estimator));
571 }
572 self.count_tokens_with_config(model, contents, config).await
573 }
574
575 pub async fn generate_images(
577 &self,
578 model: impl Into<String>,
579 prompt: impl Into<String>,
580 mut config: GenerateImagesConfig,
581 ) -> Result<GenerateImagesResponse> {
582 let http_options = config.http_options.take();
583 let model = model.into();
584 let prompt = prompt.into();
585 let mut body = build_generate_images_body(self.inner.config.backend, &prompt, &config)?;
586 if let Some(options) = http_options.as_ref() {
587 merge_extra_body(&mut body, options)?;
588 }
589 let url = build_model_method_url(&self.inner, &model, "predict")?;
590
591 let mut request = self.inner.http.post(url).json(&body);
592 request = apply_http_options(request, http_options.as_ref())?;
593
594 let response = self.inner.send(request).await?;
595 if !response.status().is_success() {
596 return Err(Error::ApiError {
597 status: response.status().as_u16(),
598 message: response.text().await.unwrap_or_default(),
599 });
600 }
601
602 let value = response.json::<Value>().await?;
603 parse_generate_images_response(value, self.inner.config.backend)
604 }
605
606 pub async fn edit_image(
608 &self,
609 model: impl Into<String>,
610 prompt: impl Into<String>,
611 reference_images: Vec<ReferenceImage>,
612 mut config: EditImageConfig,
613 ) -> Result<EditImageResponse> {
614 if self.inner.config.backend != Backend::VertexAi {
615 return Err(Error::InvalidConfig {
616 message: "Edit image is only supported in Vertex AI backend".into(),
617 });
618 }
619
620 let http_options = config.http_options.take();
621 let model = model.into();
622 let prompt = prompt.into();
623 let mut body = build_edit_image_body(&prompt, &reference_images, &config)?;
624 if let Some(options) = http_options.as_ref() {
625 merge_extra_body(&mut body, options)?;
626 }
627 let url = build_model_method_url(&self.inner, &model, "predict")?;
628
629 let mut request = self.inner.http.post(url).json(&body);
630 request = apply_http_options(request, http_options.as_ref())?;
631
632 let response = self.inner.send(request).await?;
633 if !response.status().is_success() {
634 return Err(Error::ApiError {
635 status: response.status().as_u16(),
636 message: response.text().await.unwrap_or_default(),
637 });
638 }
639
640 let value = response.json::<Value>().await?;
641 parse_edit_image_response(value)
642 }
643
644 pub async fn upscale_image(
646 &self,
647 model: impl Into<String>,
648 image: Image,
649 upscale_factor: impl Into<String>,
650 mut config: rust_genai_types::models::UpscaleImageConfig,
651 ) -> Result<rust_genai_types::models::UpscaleImageResponse> {
652 if self.inner.config.backend != Backend::VertexAi {
653 return Err(Error::InvalidConfig {
654 message: "Upscale image is only supported in Vertex AI backend".into(),
655 });
656 }
657
658 let http_options = config.http_options.take();
659 let model = model.into();
660 let upscale_factor = upscale_factor.into();
661 let mut body = build_upscale_image_body(&image, &upscale_factor, &config)?;
662 if let Some(options) = http_options.as_ref() {
663 merge_extra_body(&mut body, options)?;
664 }
665 let url = build_model_method_url(&self.inner, &model, "predict")?;
666
667 let mut request = self.inner.http.post(url).json(&body);
668 request = apply_http_options(request, http_options.as_ref())?;
669
670 let response = self.inner.send(request).await?;
671 if !response.status().is_success() {
672 return Err(Error::ApiError {
673 status: response.status().as_u16(),
674 message: response.text().await.unwrap_or_default(),
675 });
676 }
677
678 let value = response.json::<Value>().await?;
679 parse_upscale_image_response(value)
680 }
681
682 pub async fn recontext_image(
684 &self,
685 model: impl Into<String>,
686 source: RecontextImageSource,
687 mut config: RecontextImageConfig,
688 ) -> Result<RecontextImageResponse> {
689 if self.inner.config.backend != Backend::VertexAi {
690 return Err(Error::InvalidConfig {
691 message: "Recontext image is only supported in Vertex AI backend".into(),
692 });
693 }
694
695 let http_options = config.http_options.take();
696 let model = model.into();
697 let mut body = build_recontext_image_body(&source, &config)?;
698 if let Some(options) = http_options.as_ref() {
699 merge_extra_body(&mut body, options)?;
700 }
701 let url = build_model_method_url(&self.inner, &model, "predict")?;
702
703 let mut request = self.inner.http.post(url).json(&body);
704 request = apply_http_options(request, http_options.as_ref())?;
705
706 let response = self.inner.send(request).await?;
707 if !response.status().is_success() {
708 return Err(Error::ApiError {
709 status: response.status().as_u16(),
710 message: response.text().await.unwrap_or_default(),
711 });
712 }
713
714 let value = response.json::<Value>().await?;
715 parse_recontext_image_response(value)
716 }
717
718 pub async fn segment_image(
720 &self,
721 model: impl Into<String>,
722 source: SegmentImageSource,
723 mut config: SegmentImageConfig,
724 ) -> Result<SegmentImageResponse> {
725 if self.inner.config.backend != Backend::VertexAi {
726 return Err(Error::InvalidConfig {
727 message: "Segment image is only supported in Vertex AI backend".into(),
728 });
729 }
730
731 let http_options = config.http_options.take();
732 let model = model.into();
733 let mut body = build_segment_image_body(&source, &config)?;
734 if let Some(options) = http_options.as_ref() {
735 merge_extra_body(&mut body, options)?;
736 }
737 let url = build_model_method_url(&self.inner, &model, "predict")?;
738
739 let mut request = self.inner.http.post(url).json(&body);
740 request = apply_http_options(request, http_options.as_ref())?;
741
742 let response = self.inner.send(request).await?;
743 if !response.status().is_success() {
744 return Err(Error::ApiError {
745 status: response.status().as_u16(),
746 message: response.text().await.unwrap_or_default(),
747 });
748 }
749
750 let value = response.json::<Value>().await?;
751 parse_segment_image_response(value)
752 }
753
754 pub async fn generate_videos(
756 &self,
757 model: impl Into<String>,
758 source: GenerateVideosSource,
759 mut config: GenerateVideosConfig,
760 ) -> Result<rust_genai_types::operations::Operation> {
761 let http_options = config.http_options.take();
762 let model = model.into();
763 let mut body = build_generate_videos_body(self.inner.config.backend, &source, &config)?;
764 if let Some(options) = http_options.as_ref() {
765 merge_extra_body(&mut body, options)?;
766 }
767 let url = build_model_method_url(&self.inner, &model, "predictLongRunning")?;
768
769 let mut request = self.inner.http.post(url).json(&body);
770 request = apply_http_options(request, http_options.as_ref())?;
771
772 let response = self.inner.send(request).await?;
773 if !response.status().is_success() {
774 return Err(Error::ApiError {
775 status: response.status().as_u16(),
776 message: response.text().await.unwrap_or_default(),
777 });
778 }
779
780 let value = response.json::<Value>().await?;
781 parse_generate_videos_operation(value, self.inner.config.backend)
782 }
783
784 pub async fn generate_videos_with_prompt(
786 &self,
787 model: impl Into<String>,
788 prompt: impl Into<String>,
789 config: GenerateVideosConfig,
790 ) -> Result<rust_genai_types::operations::Operation> {
791 let source = GenerateVideosSource {
792 prompt: Some(prompt.into()),
793 ..GenerateVideosSource::default()
794 };
795 self.generate_videos(model, source, config).await
796 }
797
798 pub async fn list(&self) -> Result<ListModelsResponse> {
800 self.list_with_config(ListModelsConfig::default()).await
801 }
802
803 pub async fn list_with_config(&self, config: ListModelsConfig) -> Result<ListModelsResponse> {
805 let url = build_models_list_url(&self.inner, &config)?;
806 let request = self.inner.http.get(url);
807 let response = self.inner.send(request).await?;
808 if !response.status().is_success() {
809 return Err(Error::ApiError {
810 status: response.status().as_u16(),
811 message: response.text().await.unwrap_or_default(),
812 });
813 }
814 let result = response.json::<ListModelsResponse>().await?;
815 Ok(result)
816 }
817
818 pub async fn all(&self) -> Result<Vec<Model>> {
820 self.all_with_config(ListModelsConfig::default()).await
821 }
822
823 pub async fn all_with_config(&self, mut config: ListModelsConfig) -> Result<Vec<Model>> {
825 let mut models = Vec::new();
826 loop {
827 let response = self.list_with_config(config.clone()).await?;
828 if let Some(items) = response.models {
829 models.extend(items);
830 }
831 match response.next_page_token {
832 Some(token) if !token.is_empty() => {
833 config.page_token = Some(token);
834 }
835 _ => break,
836 }
837 }
838 Ok(models)
839 }
840
841 pub async fn get(&self, model: impl Into<String>) -> Result<Model> {
843 let url = build_model_get_url(&self.inner, &model.into())?;
844 let request = self.inner.http.get(url);
845 let response = self.inner.send(request).await?;
846 if !response.status().is_success() {
847 return Err(Error::ApiError {
848 status: response.status().as_u16(),
849 message: response.text().await.unwrap_or_default(),
850 });
851 }
852 let result = response.json::<Model>().await?;
853 Ok(result)
854 }
855
856 pub async fn update(
858 &self,
859 model: impl Into<String>,
860 mut config: UpdateModelConfig,
861 ) -> Result<Model> {
862 let http_options = config.http_options.take();
863 let url =
864 build_model_get_url_with_options(&self.inner, &model.into(), http_options.as_ref())?;
865
866 let mut body = serde_json::to_value(&config)?;
867 if let Some(options) = http_options.as_ref() {
868 merge_extra_body(&mut body, options)?;
869 }
870 let mut request = self.inner.http.patch(url).json(&body);
871 request = apply_http_options(request, http_options.as_ref())?;
872
873 let response = self.inner.send(request).await?;
874 if !response.status().is_success() {
875 return Err(Error::ApiError {
876 status: response.status().as_u16(),
877 message: response.text().await.unwrap_or_default(),
878 });
879 }
880 Ok(response.json::<Model>().await?)
881 }
882
883 pub async fn delete(
885 &self,
886 model: impl Into<String>,
887 mut config: DeleteModelConfig,
888 ) -> Result<DeleteModelResponse> {
889 let http_options = config.http_options.take();
890 let url =
891 build_model_get_url_with_options(&self.inner, &model.into(), http_options.as_ref())?;
892
893 let mut request = self.inner.http.delete(url);
894 request = apply_http_options(request, http_options.as_ref())?;
895
896 let response = self.inner.send(request).await?;
897 if !response.status().is_success() {
898 return Err(Error::ApiError {
899 status: response.status().as_u16(),
900 message: response.text().await.unwrap_or_default(),
901 });
902 }
903 if response.content_length().unwrap_or(0) == 0 {
904 return Ok(DeleteModelResponse::default());
905 }
906 Ok(response
907 .json::<DeleteModelResponse>()
908 .await
909 .unwrap_or_default())
910 }
911}
912
913fn transform_model_name(backend: Backend, model: &str) -> String {
914 match backend {
915 Backend::GeminiApi => {
916 if model.starts_with("models/") {
917 model.to_string()
918 } else {
919 format!("models/{model}")
920 }
921 }
922 Backend::VertexAi => {
923 if model.starts_with("projects/") || model.starts_with("publishers/") {
924 model.to_string()
925 } else {
926 format!("publishers/google/models/{model}")
927 }
928 }
929 }
930}
931
932fn build_model_method_url(inner: &ClientInner, model: &str, method: &str) -> Result<String> {
933 let model = transform_model_name(inner.config.backend, model);
934 let base = &inner.api_client.base_url;
935 let version = &inner.api_client.api_version;
936 let url = match inner.config.backend {
937 Backend::GeminiApi => format!("{base}{version}/{model}:{method}"),
938 Backend::VertexAi => {
939 let vertex =
940 inner
941 .config
942 .vertex_config
943 .as_ref()
944 .ok_or_else(|| Error::InvalidConfig {
945 message: "Vertex config missing".into(),
946 })?;
947 format!(
948 "{base}{version}/projects/{}/locations/{}/{}:{method}",
949 vertex.project, vertex.location, model
950 )
951 }
952 };
953 Ok(url)
954}
955
956fn build_model_get_url(inner: &ClientInner, model: &str) -> Result<String> {
957 let model = transform_model_name(inner.config.backend, model);
958 let base = &inner.api_client.base_url;
959 let version = &inner.api_client.api_version;
960 let url = match inner.config.backend {
961 Backend::GeminiApi => format!("{base}{version}/{model}"),
962 Backend::VertexAi => {
963 let vertex =
964 inner
965 .config
966 .vertex_config
967 .as_ref()
968 .ok_or_else(|| Error::InvalidConfig {
969 message: "Vertex config missing".into(),
970 })?;
971 format!(
972 "{base}{version}/projects/{}/locations/{}/{}",
973 vertex.project, vertex.location, model
974 )
975 }
976 };
977 Ok(url)
978}
979
980fn build_model_get_url_with_options(
981 inner: &ClientInner,
982 model: &str,
983 http_options: Option<&rust_genai_types::http::HttpOptions>,
984) -> Result<String> {
985 let model = transform_model_name(inner.config.backend, model);
986 let base = http_options
987 .and_then(|opts| opts.base_url.as_deref())
988 .unwrap_or(&inner.api_client.base_url);
989 let version = http_options
990 .and_then(|opts| opts.api_version.as_deref())
991 .unwrap_or(&inner.api_client.api_version);
992 let url = match inner.config.backend {
993 Backend::GeminiApi => format!("{base}{version}/{model}"),
994 Backend::VertexAi => {
995 let vertex =
996 inner
997 .config
998 .vertex_config
999 .as_ref()
1000 .ok_or_else(|| Error::InvalidConfig {
1001 message: "Vertex config missing".into(),
1002 })?;
1003 format!(
1004 "{base}{version}/projects/{}/locations/{}/{}",
1005 vertex.project, vertex.location, model
1006 )
1007 }
1008 };
1009 Ok(url)
1010}
1011
1012fn build_models_list_url(inner: &ClientInner, config: &ListModelsConfig) -> Result<String> {
1013 let base = &inner.api_client.base_url;
1014 let version = &inner.api_client.api_version;
1015 let url = match inner.config.backend {
1016 Backend::GeminiApi => format!("{base}{version}/models"),
1017 Backend::VertexAi => {
1018 let vertex =
1019 inner
1020 .config
1021 .vertex_config
1022 .as_ref()
1023 .ok_or_else(|| Error::InvalidConfig {
1024 message: "Vertex config missing".into(),
1025 })?;
1026 format!(
1027 "{base}{version}/projects/{}/locations/{}/publishers/google/models",
1028 vertex.project, vertex.location
1029 )
1030 }
1031 };
1032 add_list_query_params(url, config)
1033}
1034
1035fn add_list_query_params(url: String, config: &ListModelsConfig) -> Result<String> {
1036 let mut url = reqwest::Url::parse(&url).map_err(|err| Error::InvalidConfig {
1037 message: err.to_string(),
1038 })?;
1039 {
1040 let mut pairs = url.query_pairs_mut();
1041 if let Some(page_size) = config.page_size {
1042 pairs.append_pair("pageSize", &page_size.to_string());
1043 }
1044 if let Some(page_token) = &config.page_token {
1045 pairs.append_pair("pageToken", page_token);
1046 }
1047 if let Some(filter) = &config.filter {
1048 pairs.append_pair("filter", filter);
1049 }
1050 if let Some(query_base) = config.query_base {
1051 pairs.append_pair("queryBase", if query_base { "true" } else { "false" });
1052 }
1053 }
1054 Ok(url.to_string())
1055}
1056
1057fn build_embed_body_gemini(
1058 model: &str,
1059 contents: &[Content],
1060 config: &EmbedContentConfig,
1061) -> Result<Value> {
1062 if config.mime_type.is_some() || config.auto_truncate.is_some() {
1063 return Err(Error::InvalidConfig {
1064 message: "mime_type/auto_truncate not supported in Gemini API".into(),
1065 });
1066 }
1067
1068 let mut requests: Vec<Value> = Vec::new();
1069 for content in contents {
1070 let mut obj = Map::new();
1071 obj.insert(
1072 "model".to_string(),
1073 Value::String(transform_model_name(Backend::GeminiApi, model)),
1074 );
1075 obj.insert("content".to_string(), serde_json::to_value(content)?);
1076 if let Some(task_type) = &config.task_type {
1077 obj.insert("taskType".to_string(), Value::String(task_type.clone()));
1078 }
1079 if let Some(title) = &config.title {
1080 obj.insert("title".to_string(), Value::String(title.clone()));
1081 }
1082 if let Some(output_dimensionality) = config.output_dimensionality {
1083 obj.insert(
1084 "outputDimensionality".to_string(),
1085 Value::Number(Number::from(output_dimensionality as i64)),
1086 );
1087 }
1088 requests.push(Value::Object(obj));
1089 }
1090
1091 Ok(Value::Object({
1092 let mut root = Map::new();
1093 root.insert("requests".to_string(), Value::Array(requests));
1094 root
1095 }))
1096}
1097
1098fn build_embed_body_vertex(contents: &[Content], config: &EmbedContentConfig) -> Result<Value> {
1099 let mut instances: Vec<Value> = Vec::new();
1100 for content in contents {
1101 let mut obj = Map::new();
1102 obj.insert("content".to_string(), serde_json::to_value(content)?);
1103 if let Some(task_type) = &config.task_type {
1104 obj.insert("task_type".to_string(), Value::String(task_type.clone()));
1105 }
1106 if let Some(title) = &config.title {
1107 obj.insert("title".to_string(), Value::String(title.clone()));
1108 }
1109 if let Some(mime_type) = &config.mime_type {
1110 obj.insert("mimeType".to_string(), Value::String(mime_type.clone()));
1111 }
1112 instances.push(Value::Object(obj));
1113 }
1114
1115 let mut root = Map::new();
1116 root.insert("instances".to_string(), Value::Array(instances));
1117
1118 let mut parameters = Map::new();
1119 if let Some(output_dimensionality) = config.output_dimensionality {
1120 parameters.insert(
1121 "outputDimensionality".to_string(),
1122 Value::Number(Number::from(output_dimensionality as i64)),
1123 );
1124 }
1125 if let Some(auto_truncate) = config.auto_truncate {
1126 parameters.insert("autoTruncate".to_string(), Value::Bool(auto_truncate));
1127 }
1128 if !parameters.is_empty() {
1129 root.insert("parameters".to_string(), Value::Object(parameters));
1130 }
1131
1132 Ok(Value::Object(root))
1133}
1134
1135fn convert_vertex_embed_response(value: Value) -> Result<EmbedContentResponse> {
1136 let predictions = value
1137 .get("predictions")
1138 .and_then(|pred| pred.as_array())
1139 .cloned()
1140 .unwrap_or_default();
1141
1142 let mut embeddings: Vec<ContentEmbedding> = Vec::new();
1143 for item in predictions {
1144 if let Some(embedding_value) = item.get("embeddings") {
1145 let embedding: ContentEmbedding = serde_json::from_value(embedding_value.clone())?;
1146 embeddings.push(embedding);
1147 }
1148 }
1149
1150 let metadata: Option<EmbedContentMetadata> = value
1151 .get("metadata")
1152 .map(|meta| serde_json::from_value(meta.clone()))
1153 .transpose()?;
1154
1155 Ok(EmbedContentResponse {
1156 embeddings: Some(embeddings),
1157 metadata,
1158 })
1159}
1160
1161fn build_generate_images_body(
1162 backend: Backend,
1163 prompt: &str,
1164 config: &GenerateImagesConfig,
1165) -> Result<Value> {
1166 let mut instances = Vec::new();
1167 let mut instance = Map::new();
1168 instance.insert("prompt".to_string(), Value::String(prompt.to_string()));
1169 instances.push(Value::Object(instance));
1170
1171 let mut root = Map::new();
1172 root.insert("instances".to_string(), Value::Array(instances));
1173
1174 let mut parameters = Map::new();
1175 let mut output_options = Map::new();
1176
1177 if let Some(value) = &config.output_gcs_uri {
1178 if backend == Backend::GeminiApi {
1179 return Err(Error::InvalidConfig {
1180 message: "output_gcs_uri is not supported in Gemini API".into(),
1181 });
1182 }
1183 parameters.insert("storageUri".to_string(), Value::String(value.clone()));
1184 }
1185 if let Some(value) = &config.negative_prompt {
1186 if backend == Backend::GeminiApi {
1187 return Err(Error::InvalidConfig {
1188 message: "negative_prompt is not supported in Gemini API".into(),
1189 });
1190 }
1191 parameters.insert("negativePrompt".to_string(), Value::String(value.clone()));
1192 }
1193 if let Some(value) = config.number_of_images {
1194 parameters.insert(
1195 "sampleCount".to_string(),
1196 Value::Number(Number::from(value)),
1197 );
1198 }
1199 if let Some(value) = &config.aspect_ratio {
1200 parameters.insert("aspectRatio".to_string(), Value::String(value.clone()));
1201 }
1202 if let Some(value) = config.guidance_scale {
1203 parameters.insert(
1204 "guidanceScale".to_string(),
1205 Value::Number(Number::from_f64(value as f64).unwrap_or_else(|| Number::from(0))),
1206 );
1207 }
1208 if let Some(value) = config.seed {
1209 if backend == Backend::GeminiApi {
1210 return Err(Error::InvalidConfig {
1211 message: "seed is not supported in Gemini API".into(),
1212 });
1213 }
1214 parameters.insert("seed".to_string(), Value::Number(Number::from(value)));
1215 }
1216 if let Some(value) = config.safety_filter_level {
1217 parameters.insert("safetySetting".to_string(), serde_json::to_value(value)?);
1218 }
1219 if let Some(value) = config.person_generation {
1220 parameters.insert("personGeneration".to_string(), serde_json::to_value(value)?);
1221 }
1222 if let Some(value) = config.include_safety_attributes {
1223 parameters.insert("includeSafetyAttributes".to_string(), Value::Bool(value));
1224 }
1225 if let Some(value) = config.include_rai_reason {
1226 parameters.insert("includeRaiReason".to_string(), Value::Bool(value));
1227 }
1228 if let Some(value) = config.language {
1229 parameters.insert("language".to_string(), serde_json::to_value(value)?);
1230 }
1231 if let Some(value) = &config.output_mime_type {
1232 output_options.insert("mimeType".to_string(), Value::String(value.clone()));
1233 }
1234 if let Some(value) = config.output_compression_quality {
1235 output_options.insert(
1236 "compressionQuality".to_string(),
1237 Value::Number(Number::from(value)),
1238 );
1239 }
1240 if !output_options.is_empty() {
1241 parameters.insert("outputOptions".to_string(), Value::Object(output_options));
1242 }
1243 if let Some(value) = config.add_watermark {
1244 if backend == Backend::GeminiApi {
1245 return Err(Error::InvalidConfig {
1246 message: "add_watermark is not supported in Gemini API".into(),
1247 });
1248 }
1249 parameters.insert("addWatermark".to_string(), Value::Bool(value));
1250 }
1251 if let Some(labels) = &config.labels {
1252 if backend == Backend::GeminiApi {
1253 return Err(Error::InvalidConfig {
1254 message: "labels is not supported in Gemini API".into(),
1255 });
1256 }
1257 root.insert("labels".to_string(), serde_json::to_value(labels)?);
1258 }
1259 if let Some(value) = &config.image_size {
1260 parameters.insert("sampleImageSize".to_string(), Value::String(value.clone()));
1261 }
1262 if let Some(value) = config.enhance_prompt {
1263 if backend == Backend::GeminiApi {
1264 return Err(Error::InvalidConfig {
1265 message: "enhance_prompt is not supported in Gemini API".into(),
1266 });
1267 }
1268 parameters.insert("enhancePrompt".to_string(), Value::Bool(value));
1269 }
1270
1271 if !parameters.is_empty() {
1272 root.insert("parameters".to_string(), Value::Object(parameters));
1273 }
1274
1275 Ok(Value::Object(root))
1276}
1277
1278fn build_edit_image_body(
1279 prompt: &str,
1280 reference_images: &[ReferenceImage],
1281 config: &EditImageConfig,
1282) -> Result<Value> {
1283 let mut instances = Vec::new();
1284 let mut instance = Map::new();
1285 instance.insert("prompt".to_string(), Value::String(prompt.to_string()));
1286 if !reference_images.is_empty() {
1287 let mut refs = Vec::new();
1288 for image in reference_images {
1289 refs.push(reference_image_to_vertex(image)?);
1290 }
1291 instance.insert("referenceImages".to_string(), Value::Array(refs));
1292 }
1293 instances.push(Value::Object(instance));
1294
1295 let mut root = Map::new();
1296 root.insert("instances".to_string(), Value::Array(instances));
1297
1298 let mut parameters = Map::new();
1299 let mut output_options = Map::new();
1300 let mut edit_config = Map::new();
1301
1302 if let Some(value) = &config.output_gcs_uri {
1303 parameters.insert("storageUri".to_string(), Value::String(value.clone()));
1304 }
1305 if let Some(value) = &config.negative_prompt {
1306 parameters.insert("negativePrompt".to_string(), Value::String(value.clone()));
1307 }
1308 if let Some(value) = config.number_of_images {
1309 parameters.insert(
1310 "sampleCount".to_string(),
1311 Value::Number(Number::from(value)),
1312 );
1313 }
1314 if let Some(value) = &config.aspect_ratio {
1315 parameters.insert("aspectRatio".to_string(), Value::String(value.clone()));
1316 }
1317 if let Some(value) = config.guidance_scale {
1318 parameters.insert(
1319 "guidanceScale".to_string(),
1320 Value::Number(Number::from_f64(value as f64).unwrap_or_else(|| Number::from(0))),
1321 );
1322 }
1323 if let Some(value) = config.seed {
1324 parameters.insert("seed".to_string(), Value::Number(Number::from(value)));
1325 }
1326 if let Some(value) = config.safety_filter_level {
1327 parameters.insert("safetySetting".to_string(), serde_json::to_value(value)?);
1328 }
1329 if let Some(value) = config.person_generation {
1330 parameters.insert("personGeneration".to_string(), serde_json::to_value(value)?);
1331 }
1332 if let Some(value) = config.include_safety_attributes {
1333 parameters.insert("includeSafetyAttributes".to_string(), Value::Bool(value));
1334 }
1335 if let Some(value) = config.include_rai_reason {
1336 parameters.insert("includeRaiReason".to_string(), Value::Bool(value));
1337 }
1338 if let Some(value) = config.language {
1339 parameters.insert("language".to_string(), serde_json::to_value(value)?);
1340 }
1341 if let Some(value) = &config.output_mime_type {
1342 output_options.insert("mimeType".to_string(), Value::String(value.clone()));
1343 }
1344 if let Some(value) = config.output_compression_quality {
1345 output_options.insert(
1346 "compressionQuality".to_string(),
1347 Value::Number(Number::from(value)),
1348 );
1349 }
1350 if !output_options.is_empty() {
1351 parameters.insert("outputOptions".to_string(), Value::Object(output_options));
1352 }
1353 if let Some(value) = config.add_watermark {
1354 parameters.insert("addWatermark".to_string(), Value::Bool(value));
1355 }
1356 if let Some(labels) = &config.labels {
1357 root.insert("labels".to_string(), serde_json::to_value(labels)?);
1358 }
1359 if let Some(value) = config.edit_mode {
1360 parameters.insert("editMode".to_string(), serde_json::to_value(value)?);
1361 }
1362 if let Some(value) = config.base_steps {
1363 edit_config.insert("baseSteps".to_string(), Value::Number(Number::from(value)));
1364 }
1365 if !edit_config.is_empty() {
1366 parameters.insert("editConfig".to_string(), Value::Object(edit_config));
1367 }
1368
1369 if !parameters.is_empty() {
1370 root.insert("parameters".to_string(), Value::Object(parameters));
1371 }
1372
1373 Ok(Value::Object(root))
1374}
1375
1376fn build_upscale_image_body(
1377 image: &Image,
1378 upscale_factor: &str,
1379 config: &rust_genai_types::models::UpscaleImageConfig,
1380) -> Result<Value> {
1381 let mut instances = Vec::new();
1382 let mut instance = Map::new();
1383 instance.insert("image".to_string(), image_to_vertex(image)?);
1384 instances.push(Value::Object(instance));
1385
1386 let mut root = Map::new();
1387 root.insert("instances".to_string(), Value::Array(instances));
1388
1389 let mut parameters = Map::new();
1390 let mut output_options = Map::new();
1391 let mut upscale_config = Map::new();
1392
1393 parameters.insert(
1394 "mode".to_string(),
1395 Value::String(config.mode.clone().unwrap_or_else(|| "upscale".to_string())),
1396 );
1397
1398 if let Some(value) = config.number_of_images {
1399 parameters.insert(
1400 "sampleCount".to_string(),
1401 Value::Number(Number::from(value)),
1402 );
1403 } else {
1404 parameters.insert("sampleCount".to_string(), Value::Number(Number::from(1)));
1405 }
1406
1407 if let Some(value) = &config.output_gcs_uri {
1408 parameters.insert("storageUri".to_string(), Value::String(value.clone()));
1409 }
1410 if let Some(value) = config.safety_filter_level {
1411 parameters.insert("safetySetting".to_string(), serde_json::to_value(value)?);
1412 }
1413 if let Some(value) = config.person_generation {
1414 parameters.insert("personGeneration".to_string(), serde_json::to_value(value)?);
1415 }
1416 if let Some(value) = config.include_rai_reason {
1417 parameters.insert("includeRaiReason".to_string(), Value::Bool(value));
1418 }
1419 if let Some(value) = &config.output_mime_type {
1420 output_options.insert("mimeType".to_string(), Value::String(value.clone()));
1421 }
1422 if let Some(value) = config.output_compression_quality {
1423 output_options.insert(
1424 "compressionQuality".to_string(),
1425 Value::Number(Number::from(value)),
1426 );
1427 }
1428 if !output_options.is_empty() {
1429 parameters.insert("outputOptions".to_string(), Value::Object(output_options));
1430 }
1431 if let Some(value) = config.enhance_input_image {
1432 upscale_config.insert("enhanceInputImage".to_string(), Value::Bool(value));
1433 }
1434 if let Some(value) = config.image_preservation_factor {
1435 upscale_config.insert(
1436 "imagePreservationFactor".to_string(),
1437 Value::Number(Number::from_f64(value as f64).unwrap_or_else(|| Number::from(0))),
1438 );
1439 }
1440 upscale_config.insert(
1441 "upscaleFactor".to_string(),
1442 Value::String(upscale_factor.to_string()),
1443 );
1444 parameters.insert("upscaleConfig".to_string(), Value::Object(upscale_config));
1445
1446 if let Some(labels) = &config.labels {
1447 root.insert("labels".to_string(), serde_json::to_value(labels)?);
1448 }
1449
1450 root.insert("parameters".to_string(), Value::Object(parameters));
1451
1452 Ok(Value::Object(root))
1453}
1454
1455fn build_recontext_image_body(
1456 source: &RecontextImageSource,
1457 config: &RecontextImageConfig,
1458) -> Result<Value> {
1459 let mut instance = Map::new();
1460 if let Some(prompt) = &source.prompt {
1461 instance.insert("prompt".to_string(), Value::String(prompt.clone()));
1462 }
1463 if let Some(person_image) = &source.person_image {
1464 let mut person = Map::new();
1465 person.insert("image".to_string(), image_to_vertex(person_image)?);
1466 instance.insert("personImage".to_string(), Value::Object(person));
1467 }
1468 if let Some(product_images) = &source.product_images {
1469 let mut products = Vec::new();
1470 for item in product_images {
1471 if let Some(image) = &item.product_image {
1472 let mut product = Map::new();
1473 product.insert("image".to_string(), image_to_vertex(image)?);
1474 products.push(Value::Object(product));
1475 }
1476 }
1477 if !products.is_empty() {
1478 instance.insert("productImages".to_string(), Value::Array(products));
1479 }
1480 }
1481
1482 let mut root = Map::new();
1483 root.insert(
1484 "instances".to_string(),
1485 Value::Array(vec![Value::Object(instance)]),
1486 );
1487
1488 let mut parameters = Map::new();
1489 let mut output_options = Map::new();
1490
1491 if let Some(value) = config.number_of_images {
1492 parameters.insert(
1493 "sampleCount".to_string(),
1494 Value::Number(Number::from(value)),
1495 );
1496 }
1497 if let Some(value) = config.base_steps {
1498 parameters.insert("baseSteps".to_string(), Value::Number(Number::from(value)));
1499 }
1500 if let Some(value) = &config.output_gcs_uri {
1501 parameters.insert("storageUri".to_string(), Value::String(value.clone()));
1502 }
1503 if let Some(value) = config.seed {
1504 parameters.insert("seed".to_string(), Value::Number(Number::from(value)));
1505 }
1506 if let Some(value) = config.safety_filter_level {
1507 parameters.insert("safetySetting".to_string(), serde_json::to_value(value)?);
1508 }
1509 if let Some(value) = config.person_generation {
1510 parameters.insert("personGeneration".to_string(), serde_json::to_value(value)?);
1511 }
1512 if let Some(value) = config.add_watermark {
1513 parameters.insert("addWatermark".to_string(), Value::Bool(value));
1514 }
1515 if let Some(value) = &config.output_mime_type {
1516 output_options.insert("mimeType".to_string(), Value::String(value.clone()));
1517 }
1518 if let Some(value) = config.output_compression_quality {
1519 output_options.insert(
1520 "compressionQuality".to_string(),
1521 Value::Number(Number::from(value)),
1522 );
1523 }
1524 if !output_options.is_empty() {
1525 parameters.insert("outputOptions".to_string(), Value::Object(output_options));
1526 }
1527 if let Some(value) = config.enhance_prompt {
1528 parameters.insert("enhancePrompt".to_string(), Value::Bool(value));
1529 }
1530 if let Some(labels) = &config.labels {
1531 root.insert("labels".to_string(), serde_json::to_value(labels)?);
1532 }
1533
1534 if !parameters.is_empty() {
1535 root.insert("parameters".to_string(), Value::Object(parameters));
1536 }
1537
1538 Ok(Value::Object(root))
1539}
1540
1541fn build_segment_image_body(
1542 source: &SegmentImageSource,
1543 config: &SegmentImageConfig,
1544) -> Result<Value> {
1545 let mut instance = Map::new();
1546 if let Some(prompt) = &source.prompt {
1547 instance.insert("prompt".to_string(), Value::String(prompt.clone()));
1548 }
1549 if let Some(image) = &source.image {
1550 instance.insert("image".to_string(), image_to_vertex(image)?);
1551 }
1552 if let Some(scribble) = &source.scribble_image {
1553 if let Some(image) = &scribble.image {
1554 let mut scribble_map = Map::new();
1555 scribble_map.insert("image".to_string(), image_to_vertex(image)?);
1556 instance.insert("scribble".to_string(), Value::Object(scribble_map));
1557 }
1558 }
1559
1560 let mut root = Map::new();
1561 root.insert(
1562 "instances".to_string(),
1563 Value::Array(vec![Value::Object(instance)]),
1564 );
1565
1566 let mut parameters = Map::new();
1567 if let Some(value) = config.mode {
1568 parameters.insert("mode".to_string(), serde_json::to_value(value)?);
1569 }
1570 if let Some(value) = config.max_predictions {
1571 parameters.insert(
1572 "maxPredictions".to_string(),
1573 Value::Number(Number::from(value)),
1574 );
1575 }
1576 if let Some(value) = config.confidence_threshold {
1577 parameters.insert(
1578 "confidenceThreshold".to_string(),
1579 Value::Number(Number::from_f64(value as f64).unwrap_or_else(|| Number::from(0))),
1580 );
1581 }
1582 if let Some(value) = config.mask_dilation {
1583 parameters.insert(
1584 "maskDilation".to_string(),
1585 Value::Number(Number::from_f64(value as f64).unwrap_or_else(|| Number::from(0))),
1586 );
1587 }
1588 if let Some(value) = config.binary_color_threshold {
1589 parameters.insert(
1590 "binaryColorThreshold".to_string(),
1591 Value::Number(Number::from_f64(value as f64).unwrap_or_else(|| Number::from(0))),
1592 );
1593 }
1594 if !parameters.is_empty() {
1595 root.insert("parameters".to_string(), Value::Object(parameters));
1596 }
1597
1598 if let Some(labels) = &config.labels {
1599 root.insert("labels".to_string(), serde_json::to_value(labels)?);
1600 }
1601
1602 Ok(Value::Object(root))
1603}
1604
1605fn build_generate_videos_body(
1606 backend: Backend,
1607 source: &GenerateVideosSource,
1608 config: &GenerateVideosConfig,
1609) -> Result<Value> {
1610 let mut instance = Map::new();
1611 if let Some(prompt) = &source.prompt {
1612 instance.insert("prompt".to_string(), Value::String(prompt.clone()));
1613 }
1614 if let Some(image) = &source.image {
1615 let value = match backend {
1616 Backend::GeminiApi => image_to_mldev(image)?,
1617 Backend::VertexAi => image_to_vertex(image)?,
1618 };
1619 instance.insert("image".to_string(), value);
1620 }
1621 if let Some(video) = &source.video {
1622 let value = match backend {
1623 Backend::GeminiApi => video_to_mldev(video)?,
1624 Backend::VertexAi => video_to_vertex(video)?,
1625 };
1626 instance.insert("video".to_string(), value);
1627 }
1628
1629 if let Some(last_frame) = &config.last_frame {
1630 let value = match backend {
1631 Backend::GeminiApi => image_to_mldev(last_frame)?,
1632 Backend::VertexAi => image_to_vertex(last_frame)?,
1633 };
1634 instance.insert("lastFrame".to_string(), value);
1635 }
1636
1637 if let Some(reference_images) = &config.reference_images {
1638 let mut refs = Vec::new();
1639 for item in reference_images {
1640 refs.push(video_reference_image_to_value(backend, item)?);
1641 }
1642 instance.insert("referenceImages".to_string(), Value::Array(refs));
1643 }
1644
1645 if let Some(mask) = &config.mask {
1646 if backend == Backend::GeminiApi {
1647 return Err(Error::InvalidConfig {
1648 message: "mask is not supported in Gemini API".into(),
1649 });
1650 }
1651 instance.insert("mask".to_string(), video_mask_to_vertex(mask)?);
1652 }
1653
1654 let mut root = Map::new();
1655 root.insert(
1656 "instances".to_string(),
1657 Value::Array(vec![Value::Object(instance)]),
1658 );
1659
1660 let mut parameters = Map::new();
1661
1662 if let Some(value) = config.number_of_videos {
1663 parameters.insert(
1664 "sampleCount".to_string(),
1665 Value::Number(Number::from(value)),
1666 );
1667 }
1668 if let Some(value) = &config.output_gcs_uri {
1669 if backend == Backend::GeminiApi {
1670 return Err(Error::InvalidConfig {
1671 message: "output_gcs_uri is not supported in Gemini API".into(),
1672 });
1673 }
1674 parameters.insert("storageUri".to_string(), Value::String(value.clone()));
1675 }
1676 if let Some(value) = config.fps {
1677 if backend == Backend::GeminiApi {
1678 return Err(Error::InvalidConfig {
1679 message: "fps is not supported in Gemini API".into(),
1680 });
1681 }
1682 parameters.insert("fps".to_string(), Value::Number(Number::from(value)));
1683 }
1684 if let Some(value) = config.duration_seconds {
1685 parameters.insert(
1686 "durationSeconds".to_string(),
1687 Value::Number(Number::from(value)),
1688 );
1689 }
1690 if let Some(value) = config.seed {
1691 if backend == Backend::GeminiApi {
1692 return Err(Error::InvalidConfig {
1693 message: "seed is not supported in Gemini API".into(),
1694 });
1695 }
1696 parameters.insert("seed".to_string(), Value::Number(Number::from(value)));
1697 }
1698 if let Some(value) = &config.aspect_ratio {
1699 parameters.insert("aspectRatio".to_string(), Value::String(value.clone()));
1700 }
1701 if let Some(value) = &config.resolution {
1702 parameters.insert("resolution".to_string(), Value::String(value.clone()));
1703 }
1704 if let Some(value) = &config.person_generation {
1705 parameters.insert("personGeneration".to_string(), Value::String(value.clone()));
1706 }
1707 if let Some(value) = &config.pubsub_topic {
1708 if backend == Backend::GeminiApi {
1709 return Err(Error::InvalidConfig {
1710 message: "pubsub_topic is not supported in Gemini API".into(),
1711 });
1712 }
1713 parameters.insert("pubsubTopic".to_string(), Value::String(value.clone()));
1714 }
1715 if let Some(value) = &config.negative_prompt {
1716 parameters.insert("negativePrompt".to_string(), Value::String(value.clone()));
1717 }
1718 if let Some(value) = config.enhance_prompt {
1719 parameters.insert("enhancePrompt".to_string(), Value::Bool(value));
1720 }
1721 if let Some(value) = config.generate_audio {
1722 if backend == Backend::GeminiApi {
1723 return Err(Error::InvalidConfig {
1724 message: "generate_audio is not supported in Gemini API".into(),
1725 });
1726 }
1727 parameters.insert("generateAudio".to_string(), Value::Bool(value));
1728 }
1729 if let Some(value) = config.compression_quality {
1730 if backend == Backend::GeminiApi {
1731 return Err(Error::InvalidConfig {
1732 message: "compression_quality is not supported in Gemini API".into(),
1733 });
1734 }
1735 parameters.insert(
1736 "compressionQuality".to_string(),
1737 serde_json::to_value(value)?,
1738 );
1739 }
1740
1741 if !parameters.is_empty() {
1742 root.insert("parameters".to_string(), Value::Object(parameters));
1743 }
1744
1745 Ok(Value::Object(root))
1746}
1747
1748fn parse_generate_images_response(
1749 value: Value,
1750 backend: Backend,
1751) -> Result<GenerateImagesResponse> {
1752 let predictions = value
1753 .get("predictions")
1754 .and_then(|pred| pred.as_array())
1755 .cloned()
1756 .unwrap_or_default();
1757
1758 let mut generated_images = Vec::new();
1759 for item in predictions {
1760 generated_images.push(parse_generated_image(&item, backend)?);
1761 }
1762
1763 let positive_prompt_safety_attributes = value
1764 .get("positivePromptSafetyAttributes")
1765 .and_then(parse_safety_attributes);
1766
1767 Ok(GenerateImagesResponse {
1768 generated_images,
1769 positive_prompt_safety_attributes,
1770 })
1771}
1772
1773fn parse_edit_image_response(value: Value) -> Result<EditImageResponse> {
1774 let predictions = value
1775 .get("predictions")
1776 .and_then(|pred| pred.as_array())
1777 .cloned()
1778 .unwrap_or_default();
1779
1780 let mut generated_images = Vec::new();
1781 for item in predictions {
1782 generated_images.push(parse_generated_image(&item, Backend::VertexAi)?);
1783 }
1784
1785 Ok(EditImageResponse { generated_images })
1786}
1787
1788fn parse_upscale_image_response(
1789 value: Value,
1790) -> Result<rust_genai_types::models::UpscaleImageResponse> {
1791 let predictions = value
1792 .get("predictions")
1793 .and_then(|pred| pred.as_array())
1794 .cloned()
1795 .unwrap_or_default();
1796
1797 let mut generated_images = Vec::new();
1798 for item in predictions {
1799 generated_images.push(parse_generated_image(&item, Backend::VertexAi)?);
1800 }
1801
1802 Ok(rust_genai_types::models::UpscaleImageResponse { generated_images })
1803}
1804
1805fn parse_recontext_image_response(value: Value) -> Result<RecontextImageResponse> {
1806 let predictions = value
1807 .get("predictions")
1808 .and_then(|pred| pred.as_array())
1809 .cloned()
1810 .unwrap_or_default();
1811
1812 let mut generated_images = Vec::new();
1813 for item in predictions {
1814 generated_images.push(parse_generated_image(&item, Backend::VertexAi)?);
1815 }
1816
1817 Ok(RecontextImageResponse { generated_images })
1818}
1819
1820fn parse_segment_image_response(value: Value) -> Result<SegmentImageResponse> {
1821 let predictions = value
1822 .get("predictions")
1823 .and_then(|pred| pred.as_array())
1824 .cloned()
1825 .unwrap_or_default();
1826
1827 let mut generated_masks = Vec::new();
1828 for item in predictions {
1829 generated_masks.push(parse_generated_image_mask(&item)?);
1830 }
1831
1832 Ok(SegmentImageResponse { generated_masks })
1833}
1834
1835fn parse_generate_videos_operation(
1836 value: Value,
1837 backend: Backend,
1838) -> Result<rust_genai_types::operations::Operation> {
1839 let mut operation: rust_genai_types::operations::Operation = serde_json::from_value(value)?;
1840 if backend == Backend::GeminiApi {
1841 if let Some(response) = operation.response.take() {
1842 if let Some(inner) = response.get("generateVideoResponse") {
1843 operation.response = Some(inner.clone());
1844 } else {
1845 operation.response = Some(response);
1846 }
1847 }
1848 }
1849 Ok(operation)
1850}
1851
1852fn parse_generated_image(value: &Value, backend: Backend) -> Result<GeneratedImage> {
1853 let image = match backend {
1854 Backend::GeminiApi => serde_json::from_value::<Image>(value.clone()).ok(),
1855 Backend::VertexAi => serde_json::from_value::<Image>(value.clone()).ok(),
1856 };
1857
1858 let rai_filtered_reason = value
1859 .get("raiFilteredReason")
1860 .and_then(|v| v.as_str())
1861 .map(|v| v.to_string());
1862 let enhanced_prompt = value
1863 .get("enhancedPrompt")
1864 .and_then(|v| v.as_str())
1865 .map(|v| v.to_string());
1866
1867 let safety_attributes = parse_safety_attributes(value);
1868
1869 Ok(GeneratedImage {
1870 image,
1871 rai_filtered_reason,
1872 safety_attributes,
1873 enhanced_prompt,
1874 })
1875}
1876
1877fn parse_generated_image_mask(value: &Value) -> Result<GeneratedImageMask> {
1878 let mask = serde_json::from_value::<Image>(value.clone()).ok();
1879 let labels = value
1880 .get("labels")
1881 .and_then(|value| value.as_array())
1882 .map(|items| {
1883 items
1884 .iter()
1885 .filter_map(parse_entity_label)
1886 .collect::<Vec<EntityLabel>>()
1887 });
1888
1889 Ok(GeneratedImageMask { mask, labels })
1890}
1891
1892fn parse_entity_label(value: &Value) -> Option<EntityLabel> {
1893 let obj = value.as_object()?;
1894 let label = obj
1895 .get("label")
1896 .and_then(|value| value.as_str())
1897 .map(|value| value.to_string());
1898 let score = obj.get("score").and_then(|value| match value {
1899 Value::Number(num) => num.as_f64().map(|num| num as f32),
1900 Value::String(text) => text.parse::<f32>().ok(),
1901 _ => None,
1902 });
1903
1904 Some(EntityLabel { label, score })
1905}
1906
1907fn parse_safety_attributes(value: &Value) -> Option<SafetyAttributes> {
1908 let obj = value.as_object()?;
1909 let safety = obj.get("safetyAttributes").and_then(|v| v.as_object());
1910
1911 let categories = obj
1912 .get("categories")
1913 .or_else(|| safety.and_then(|s| s.get("categories")))
1914 .and_then(|v| v.as_array())
1915 .map(|items| {
1916 items
1917 .iter()
1918 .filter_map(|item| item.as_str().map(|s| s.to_string()))
1919 .collect::<Vec<_>>()
1920 });
1921
1922 let scores = obj
1923 .get("scores")
1924 .or_else(|| safety.and_then(|s| s.get("scores")))
1925 .and_then(|v| v.as_array())
1926 .map(|items| {
1927 items
1928 .iter()
1929 .filter_map(|item| item.as_f64().map(|score| score as f32))
1930 .collect::<Vec<_>>()
1931 });
1932
1933 let content_type = obj
1934 .get("contentType")
1935 .and_then(|v| v.as_str())
1936 .map(|v| v.to_string());
1937
1938 if categories.is_none() && scores.is_none() && content_type.is_none() {
1939 None
1940 } else {
1941 Some(SafetyAttributes {
1942 categories,
1943 scores,
1944 content_type,
1945 })
1946 }
1947}
1948
1949fn image_to_mldev(image: &Image) -> Result<Value> {
1950 if image.gcs_uri.is_some() {
1951 return Err(Error::InvalidConfig {
1952 message: "gcs_uri is not supported in Gemini API".into(),
1953 });
1954 }
1955 let mut map = Map::new();
1956 if let Some(bytes) = &image.image_bytes {
1957 map.insert(
1958 "bytesBase64Encoded".to_string(),
1959 Value::String(STANDARD.encode(bytes)),
1960 );
1961 }
1962 if let Some(mime) = &image.mime_type {
1963 map.insert("mimeType".to_string(), Value::String(mime.clone()));
1964 }
1965 Ok(Value::Object(map))
1966}
1967
1968fn image_to_vertex(image: &Image) -> Result<Value> {
1969 let mut map = Map::new();
1970 if let Some(gcs_uri) = &image.gcs_uri {
1971 map.insert("gcsUri".to_string(), Value::String(gcs_uri.clone()));
1972 }
1973 if let Some(bytes) = &image.image_bytes {
1974 map.insert(
1975 "bytesBase64Encoded".to_string(),
1976 Value::String(STANDARD.encode(bytes)),
1977 );
1978 }
1979 if let Some(mime) = &image.mime_type {
1980 map.insert("mimeType".to_string(), Value::String(mime.clone()));
1981 }
1982 Ok(Value::Object(map))
1983}
1984
1985fn video_to_mldev(video: &Video) -> Result<Value> {
1986 if let Some(uri) = &video.uri {
1987 let mut map = Map::new();
1988 map.insert("uri".to_string(), Value::String(uri.clone()));
1989 if let Some(bytes) = &video.video_bytes {
1990 map.insert(
1991 "encodedVideo".to_string(),
1992 Value::String(STANDARD.encode(bytes)),
1993 );
1994 }
1995 if let Some(mime) = &video.mime_type {
1996 map.insert("encoding".to_string(), Value::String(mime.clone()));
1997 }
1998 return Ok(Value::Object(map));
1999 }
2000
2001 let mut map = Map::new();
2002 if let Some(bytes) = &video.video_bytes {
2003 map.insert(
2004 "encodedVideo".to_string(),
2005 Value::String(STANDARD.encode(bytes)),
2006 );
2007 }
2008 if let Some(mime) = &video.mime_type {
2009 map.insert("encoding".to_string(), Value::String(mime.clone()));
2010 }
2011 Ok(Value::Object(map))
2012}
2013
2014fn video_to_vertex(video: &Video) -> Result<Value> {
2015 let mut map = Map::new();
2016 if let Some(uri) = &video.uri {
2017 map.insert("gcsUri".to_string(), Value::String(uri.clone()));
2018 }
2019 if let Some(bytes) = &video.video_bytes {
2020 map.insert(
2021 "bytesBase64Encoded".to_string(),
2022 Value::String(STANDARD.encode(bytes)),
2023 );
2024 }
2025 if let Some(mime) = &video.mime_type {
2026 map.insert("mimeType".to_string(), Value::String(mime.clone()));
2027 }
2028 Ok(Value::Object(map))
2029}
2030
2031fn reference_image_to_vertex(image: &ReferenceImage) -> Result<Value> {
2032 let mut map = Map::new();
2033 if let Some(reference_image) = &image.reference_image {
2034 map.insert(
2035 "referenceImage".to_string(),
2036 image_to_vertex(reference_image)?,
2037 );
2038 }
2039 if let Some(reference_id) = image.reference_id {
2040 map.insert(
2041 "referenceId".to_string(),
2042 Value::Number(Number::from(reference_id)),
2043 );
2044 }
2045 if let Some(reference_type) = image.reference_type {
2046 map.insert(
2047 "referenceType".to_string(),
2048 serde_json::to_value(reference_type)?,
2049 );
2050 }
2051 if let Some(config) = &image.mask_image_config {
2052 map.insert("maskImageConfig".to_string(), serde_json::to_value(config)?);
2053 }
2054 if let Some(config) = &image.control_image_config {
2055 map.insert(
2056 "controlImageConfig".to_string(),
2057 serde_json::to_value(config)?,
2058 );
2059 }
2060 if let Some(config) = &image.style_image_config {
2061 map.insert(
2062 "styleImageConfig".to_string(),
2063 serde_json::to_value(config)?,
2064 );
2065 }
2066 if let Some(config) = &image.subject_image_config {
2067 map.insert(
2068 "subjectImageConfig".to_string(),
2069 serde_json::to_value(config)?,
2070 );
2071 }
2072 Ok(Value::Object(map))
2073}
2074
2075fn video_reference_image_to_value(
2076 backend: Backend,
2077 reference: &VideoGenerationReferenceImage,
2078) -> Result<Value> {
2079 let mut map = Map::new();
2080 if let Some(image) = &reference.image {
2081 let value = match backend {
2082 Backend::GeminiApi => image_to_mldev(image)?,
2083 Backend::VertexAi => image_to_vertex(image)?,
2084 };
2085 map.insert("image".to_string(), value);
2086 }
2087 if let Some(reference_type) = reference.reference_type {
2088 map.insert(
2089 "referenceType".to_string(),
2090 serde_json::to_value(reference_type)?,
2091 );
2092 }
2093 Ok(Value::Object(map))
2094}
2095
2096fn video_mask_to_vertex(mask: &VideoGenerationMask) -> Result<Value> {
2097 let mut map = Map::new();
2098 if let Some(image) = &mask.image {
2099 map.insert("image".to_string(), image_to_vertex(image)?);
2100 }
2101 if let Some(mode) = mask.mask_mode {
2102 map.insert("maskMode".to_string(), serde_json::to_value(mode)?);
2103 }
2104 Ok(Value::Object(map))
2105}
2106
2107fn apply_http_options(
2108 mut request: reqwest::RequestBuilder,
2109 http_options: Option<&rust_genai_types::http::HttpOptions>,
2110) -> Result<reqwest::RequestBuilder> {
2111 if let Some(options) = http_options {
2112 if let Some(timeout) = options.timeout {
2113 request = request.timeout(std::time::Duration::from_millis(timeout));
2114 }
2115 if let Some(headers) = &options.headers {
2116 for (key, value) in headers {
2117 let name =
2118 reqwest::header::HeaderName::from_bytes(key.as_bytes()).map_err(|_| {
2119 Error::InvalidConfig {
2120 message: format!("Invalid header name: {key}"),
2121 }
2122 })?;
2123 let value = reqwest::header::HeaderValue::from_str(value).map_err(|_| {
2124 Error::InvalidConfig {
2125 message: format!("Invalid header value for {key}"),
2126 }
2127 })?;
2128 request = request.header(name, value);
2129 }
2130 }
2131 }
2132 Ok(request)
2133}
2134
2135fn build_function_call_content(function_calls: &[FunctionCall]) -> Content {
2136 let parts = function_calls
2137 .iter()
2138 .cloned()
2139 .map(Part::function_call)
2140 .collect();
2141 Content::from_parts(parts, Role::Model)
2142}
2143
2144fn merge_extra_body(
2145 body: &mut Value,
2146 http_options: &rust_genai_types::http::HttpOptions,
2147) -> Result<()> {
2148 if let Some(extra) = &http_options.extra_body {
2149 match (body, extra) {
2150 (Value::Object(body_map), Value::Object(extra_map)) => {
2151 for (key, value) in extra_map {
2152 body_map.insert(key.clone(), value.clone());
2153 }
2154 }
2155 (_, _) => {
2156 return Err(Error::InvalidConfig {
2157 message: "HttpOptions.extra_body must be an object".into(),
2158 });
2159 }
2160 }
2161 }
2162 Ok(())
2163}
2164
2165#[cfg(test)]
2166mod tests {
2167 use super::*;
2168 use crate::client::Client;
2169
2170 #[test]
2171 fn test_transform_model_name() {
2172 assert_eq!(
2173 transform_model_name(Backend::GeminiApi, "gemini-1.5-pro"),
2174 "models/gemini-1.5-pro"
2175 );
2176 assert_eq!(
2177 transform_model_name(Backend::VertexAi, "gemini-1.5-pro"),
2178 "publishers/google/models/gemini-1.5-pro"
2179 );
2180 }
2181
2182 #[test]
2183 fn test_build_model_urls() {
2184 let client = Client::new("test-key").unwrap();
2185 let models = client.models();
2186 let url =
2187 build_model_method_url(&models.inner, "gemini-1.5-pro", "generateContent").unwrap();
2188 assert_eq!(
2189 url,
2190 "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro:generateContent"
2191 );
2192 }
2193
2194 #[test]
2195 fn test_build_recontext_image_body() {
2196 let source = RecontextImageSource {
2197 prompt: Some("test".to_string()),
2198 person_image: Some(Image {
2199 gcs_uri: Some("gs://person.png".to_string()),
2200 ..Default::default()
2201 }),
2202 product_images: Some(vec![rust_genai_types::models::ProductImage {
2203 product_image: Some(Image {
2204 gcs_uri: Some("gs://product.png".to_string()),
2205 ..Default::default()
2206 }),
2207 }]),
2208 };
2209 let config = RecontextImageConfig {
2210 number_of_images: Some(2),
2211 ..Default::default()
2212 };
2213
2214 let body = build_recontext_image_body(&source, &config).unwrap();
2215 let instances = body.get("instances").and_then(|v| v.as_array()).unwrap();
2216 let instance = instances[0].as_object().unwrap();
2217 assert!(instance.get("prompt").is_some());
2218 assert!(instance.get("personImage").is_some());
2219 assert!(instance.get("productImages").is_some());
2220 }
2221
2222 #[test]
2223 fn test_build_segment_image_body() {
2224 let source = SegmentImageSource {
2225 prompt: Some("foreground".to_string()),
2226 image: Some(Image {
2227 gcs_uri: Some("gs://input.png".to_string()),
2228 ..Default::default()
2229 }),
2230 scribble_image: None,
2231 };
2232 let config = SegmentImageConfig {
2233 mode: Some(rust_genai_types::enums::SegmentMode::Foreground),
2234 ..Default::default()
2235 };
2236
2237 let body = build_segment_image_body(&source, &config).unwrap();
2238 let instances = body.get("instances").and_then(|v| v.as_array()).unwrap();
2239 let instance = instances[0].as_object().unwrap();
2240 assert!(instance.get("image").is_some());
2241 assert!(body.get("parameters").is_some());
2242 }
2243}