potato_type/google/
embedding.rs

1use crate::TypeError;
2use pyo3::prelude::*;
3use pyo3::IntoPyObjectExt;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7#[pyclass(eq, eq_int)]
8#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
9pub enum EmbeddingTaskType {
10    TaskTypeUnspecified,
11    RetrievalQuery,
12    RetrievalDocument,
13    SemanticSimilarity,
14    Classification,
15    Clustering,
16    QuestionAnswering,
17    FactVerification,
18    CodeRetrievalQuery,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
22#[pyclass]
23pub struct GeminiEmbeddingConfig {
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub model: Option<String>,
26
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub output_dimensionality: Option<i32>,
29
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub task_type: Option<EmbeddingTaskType>,
32
33    #[serde(skip_serializing)]
34    pub is_configured: bool,
35}
36
37#[pymethods]
38impl GeminiEmbeddingConfig {
39    #[new]
40    #[pyo3(signature = (model=None, output_dimensionality=None, task_type=None))]
41    pub fn new(
42        model: Option<String>,
43        output_dimensionality: Option<i32>,
44        task_type: Option<EmbeddingTaskType>,
45    ) -> Result<Self, TypeError> {
46        if model.is_none() && task_type.is_none() {
47            return Err(TypeError::GeminiEmbeddingConfigError(
48                "Either 'model' or 'task_type' must be provided.".to_string(),
49            ));
50        }
51
52        let is_configured = output_dimensionality.is_some() || task_type.is_some();
53
54        Ok(Self {
55            model,
56            output_dimensionality,
57            task_type,
58            is_configured,
59        })
60    }
61}
62
63impl GeminiEmbeddingConfig {
64    pub fn get_parameters_for_predict(&self) -> serde_json::Value {
65        let mut params = serde_json::Map::new();
66        if let Some(dim) = self.output_dimensionality {
67            params.insert("outputDimensionality".to_string(), serde_json::json!(dim));
68        }
69        if let Some(task) = &self.task_type {
70            params.insert("task_type".to_string(), serde_json::json!(task));
71        }
72        if params.is_empty() {
73            serde_json::Value::Null
74        } else {
75            serde_json::Value::Object(params)
76        }
77    }
78}
79
80pub trait EmbeddingConfigTrait {
81    fn get_model(&self) -> &str;
82}
83
84impl EmbeddingConfigTrait for GeminiEmbeddingConfig {
85    fn get_model(&self) -> &str {
86        self.model.as_deref().unwrap_or("embedding-001")
87    }
88}
89
90#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
91#[pyclass]
92pub struct ContentEmbedding {
93    pub values: Vec<f32>,
94}
95
96#[pymethods]
97impl ContentEmbedding {
98    #[getter]
99    pub fn values(&self) -> &Vec<f32> {
100        &self.values
101    }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
105#[pyclass]
106pub struct GeminiEmbeddingResponse {
107    #[pyo3(get)]
108    pub embedding: ContentEmbedding,
109}
110
111impl GeminiEmbeddingResponse {
112    pub fn into_py_bound_any<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
113        let bound = Py::new(py, self.clone())?;
114        Ok(bound.into_bound_py_any(py)?)
115    }
116}