Skip to main content

gproxy_protocol/transform/openai/embeddings/gemini/
request.rs

1use crate::gemini::count_tokens::types::GeminiPart;
2use crate::gemini::embeddings::request::{
3    GeminiEmbedContentRequest, PathParameters, QueryParameters, RequestBody, RequestHeaders,
4};
5use crate::gemini::embeddings::types as gt;
6use crate::openai::embeddings::request::OpenAiEmbeddingsRequest;
7use crate::openai::embeddings::types as ot;
8use crate::transform::gemini::model_get::utils::ensure_models_prefix;
9use crate::transform::utils::TransformError;
10
11impl TryFrom<OpenAiEmbeddingsRequest> for GeminiEmbedContentRequest {
12    type Error = TransformError;
13
14    fn try_from(value: OpenAiEmbeddingsRequest) -> Result<Self, TransformError> {
15        let input_parts = match value.body.input {
16            ot::OpenAiEmbeddingInput::String(text) => vec![GeminiPart {
17                text: Some(text),
18                ..GeminiPart::default()
19            }],
20            ot::OpenAiEmbeddingInput::StringArray(texts) => {
21                if texts.is_empty() {
22                    vec![GeminiPart {
23                        text: Some(String::new()),
24                        ..GeminiPart::default()
25                    }]
26                } else {
27                    texts
28                        .into_iter()
29                        .map(|text| GeminiPart {
30                            text: Some(text),
31                            ..GeminiPart::default()
32                        })
33                        .collect::<Vec<_>>()
34                }
35            }
36            ot::OpenAiEmbeddingInput::TokenArray(tokens) => vec![GeminiPart {
37                text: Some(
38                    tokens
39                        .into_iter()
40                        .map(|token| token.to_string())
41                        .collect::<Vec<_>>()
42                        .join(" "),
43                ),
44                ..GeminiPart::default()
45            }],
46            ot::OpenAiEmbeddingInput::TokenArrayArray(token_batches) => {
47                if token_batches.is_empty() {
48                    vec![GeminiPart {
49                        text: Some(String::new()),
50                        ..GeminiPart::default()
51                    }]
52                } else {
53                    token_batches
54                        .into_iter()
55                        .map(|tokens| GeminiPart {
56                            text: Some(
57                                tokens
58                                    .into_iter()
59                                    .map(|token| token.to_string())
60                                    .collect::<Vec<_>>()
61                                    .join(" "),
62                            ),
63                            ..GeminiPart::default()
64                        })
65                        .collect::<Vec<_>>()
66                }
67            }
68        };
69
70        let model_name = match value.body.model {
71            ot::OpenAiEmbeddingModel::Known(ot::OpenAiEmbeddingModelKnown::TextEmbeddingAda002) => {
72                "text-embedding-ada-002".to_string()
73            }
74            ot::OpenAiEmbeddingModel::Known(ot::OpenAiEmbeddingModelKnown::TextEmbedding3Small) => {
75                "text-embedding-3-small".to_string()
76            }
77            ot::OpenAiEmbeddingModel::Known(ot::OpenAiEmbeddingModelKnown::TextEmbedding3Large) => {
78                "text-embedding-3-large".to_string()
79            }
80            ot::OpenAiEmbeddingModel::Custom(model) => model,
81        };
82        let model = ensure_models_prefix(&model_name);
83
84        Ok(GeminiEmbedContentRequest {
85            method: gt::HttpMethod::Post,
86            path: PathParameters { model },
87            query: QueryParameters::default(),
88            headers: RequestHeaders::default(),
89            body: RequestBody {
90                content: gt::GeminiContent {
91                    parts: input_parts,
92                    role: None,
93                },
94                task_type: None,
95                title: None,
96                output_dimensionality: value.body.dimensions,
97            },
98        })
99    }
100}