potato_type/google/
embedding.rs1use crate::TypeError;
2use pyo3::prelude::*;
3use pyo3::IntoPyObjectExt;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7#[pyclass(eq, eq_int)]
8#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
9pub enum EmbeddingTaskType {
10 TaskTypeUnspecified,
11 RetrievalQuery,
12 RetrievalDocument,
13 SemanticSimilarity,
14 Classification,
15 Clustering,
16 QuestionAnswering,
17 FactVerification,
18 CodeRetrievalQuery,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
22#[pyclass]
23pub struct GeminiEmbeddingConfig {
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub model: Option<String>,
26
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub output_dimensionality: Option<i32>,
29
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub task_type: Option<EmbeddingTaskType>,
32}
33
34#[pymethods]
35impl GeminiEmbeddingConfig {
36 #[new]
37 #[pyo3(signature = (model=None, output_dimensionality=None, task_type=None))]
38 pub fn new(
39 model: Option<String>,
40 output_dimensionality: Option<i32>,
41 task_type: Option<EmbeddingTaskType>,
42 ) -> Result<Self, TypeError> {
43 if model.is_none() && task_type.is_none() {
44 return Err(TypeError::GeminiEmbeddingConfigError(
45 "Either 'model' or 'task_type' must be provided.".to_string(),
46 ));
47 }
48 Ok(Self {
49 model,
50 output_dimensionality,
51 task_type,
52 })
53 }
54}
55
56pub trait EmbeddingConfigTrait {
57 fn get_model(&self) -> &str;
58}
59
60impl EmbeddingConfigTrait for GeminiEmbeddingConfig {
61 fn get_model(&self) -> &str {
62 self.model.as_deref().unwrap_or("embedding-001")
63 }
64}
65
66#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
67#[pyclass]
68pub struct ContentEmbedding {
69 pub values: Vec<f32>,
70}
71
72#[pymethods]
73impl ContentEmbedding {
74 #[getter]
75 pub fn values(&self) -> &Vec<f32> {
76 &self.values
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
81#[pyclass]
82pub struct GeminiEmbeddingResponse {
83 #[pyo3(get)]
84 pub embedding: ContentEmbedding,
85}
86
87impl GeminiEmbeddingResponse {
88 pub fn into_py_bound_any<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
89 let bound = Py::new(py, self.clone())?;
90 Ok(bound.into_bound_py_any(py)?)
91 }
92}