potato_type/google/
embedding.rs1use crate::TypeError;
2use pyo3::prelude::*;
3use pyo3::IntoPyObjectExt;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7#[pyclass(eq, eq_int)]
8#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
9pub enum EmbeddingTaskType {
10 TaskTypeUnspecified,
11 RetrievalQuery,
12 RetrievalDocument,
13 SemanticSimilarity,
14 Classification,
15 Clustering,
16 QuestionAnswering,
17 FactVerification,
18 CodeRetrievalQuery,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
22#[pyclass]
23pub struct GeminiEmbeddingConfig {
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub model: Option<String>,
26
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub output_dimensionality: Option<i32>,
29
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub task_type: Option<EmbeddingTaskType>,
32
33 #[serde(skip_serializing)]
34 pub is_configured: bool,
35}
36
37#[pymethods]
38impl GeminiEmbeddingConfig {
39 #[new]
40 #[pyo3(signature = (model=None, output_dimensionality=None, task_type=None))]
41 pub fn new(
42 model: Option<String>,
43 output_dimensionality: Option<i32>,
44 task_type: Option<EmbeddingTaskType>,
45 ) -> Result<Self, TypeError> {
46 if model.is_none() && task_type.is_none() {
47 return Err(TypeError::GeminiEmbeddingConfigError(
48 "Either 'model' or 'task_type' must be provided.".to_string(),
49 ));
50 }
51
52 let is_configured = output_dimensionality.is_some() || task_type.is_some();
53
54 Ok(Self {
55 model,
56 output_dimensionality,
57 task_type,
58 is_configured,
59 })
60 }
61}
62
63impl GeminiEmbeddingConfig {
64 pub fn get_parameters_for_predict(&self) -> serde_json::Value {
65 let mut params = serde_json::Map::new();
66 if let Some(dim) = self.output_dimensionality {
67 params.insert("outputDimensionality".to_string(), serde_json::json!(dim));
68 }
69 if let Some(task) = &self.task_type {
70 params.insert("task_type".to_string(), serde_json::json!(task));
71 }
72 if params.is_empty() {
73 serde_json::Value::Null
74 } else {
75 serde_json::Value::Object(params)
76 }
77 }
78}
79
80pub trait EmbeddingConfigTrait {
81 fn get_model(&self) -> &str;
82}
83
84impl EmbeddingConfigTrait for GeminiEmbeddingConfig {
85 fn get_model(&self) -> &str {
86 self.model.as_deref().unwrap_or("embedding-001")
87 }
88}
89
90#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
91#[pyclass]
92pub struct ContentEmbedding {
93 pub values: Vec<f32>,
94}
95
96#[pymethods]
97impl ContentEmbedding {
98 #[getter]
99 pub fn values(&self) -> &Vec<f32> {
100 &self.values
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
105#[pyclass]
106pub struct GeminiEmbeddingResponse {
107 #[pyo3(get)]
108 pub embedding: ContentEmbedding,
109}
110
111impl GeminiEmbeddingResponse {
112 pub fn into_py_bound_any<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
113 let bound = Py::new(py, self.clone())?;
114 Ok(bound.into_bound_py_any(py)?)
115 }
116}