potato_agent/agents/
embed.rs

1use potato_type::google::EmbeddingConfigTrait;
2use potato_type::Provider;
3
4use crate::agents::client::GenAiClient;
5use crate::agents::provider::gemini::GeminiClient;
6use crate::agents::provider::openai::OpenAIClient;
7use crate::AgentError;
8use potato_type::google::GeminiEmbeddingConfig;
9use potato_type::google::GeminiEmbeddingResponse;
10use potato_type::openai::embedding::{OpenAIEmbeddingConfig, OpenAIEmbeddingResponse};
11use pyo3::prelude::*;
12use serde::Serialize;
13use std::sync::Arc;
14
15#[derive(Debug, Clone, PartialEq, Serialize)]
16#[serde(untagged)]
17pub enum EmbeddingConfig {
18    OpenAI(OpenAIEmbeddingConfig),
19    Gemini(GeminiEmbeddingConfig),
20}
21
22impl EmbeddingConfig {
23    pub fn extract_config(
24        config: Option<&Bound<'_, PyAny>>,
25        provider: &Provider,
26    ) -> Result<Self, AgentError> {
27        match provider {
28            Provider::OpenAI => {
29                let config = if config.is_none() {
30                    OpenAIEmbeddingConfig::default()
31                } else {
32                    config
33                        .unwrap()
34                        .extract::<OpenAIEmbeddingConfig>()
35                        .map_err(|e| {
36                            AgentError::EmbeddingConfigExtractionError(format!(
37                                "Failed to extract OpenAIEmbeddingConfig: {}",
38                                e
39                            ))
40                        })?
41                };
42
43                Ok(EmbeddingConfig::OpenAI(config))
44            }
45            Provider::Gemini => {
46                let config = if config.is_none() {
47                    GeminiEmbeddingConfig::default()
48                } else {
49                    config
50                        .unwrap()
51                        .extract::<GeminiEmbeddingConfig>()
52                        .map_err(|e| {
53                            AgentError::EmbeddingConfigExtractionError(format!(
54                                "Failed to extract GeminiEmbeddingConfig: {}",
55                                e
56                            ))
57                        })?
58                };
59
60                Ok(EmbeddingConfig::Gemini(config))
61            }
62            _ => Err(AgentError::ProviderNotSupportedError(provider.to_string())),
63        }
64    }
65}
66
67impl EmbeddingConfigTrait for EmbeddingConfig {
68    fn get_model(&self) -> &str {
69        match self {
70            EmbeddingConfig::OpenAI(config) => config.model.as_str(),
71            EmbeddingConfig::Gemini(config) => config.get_model(),
72        }
73    }
74}
75
76use tracing::error;
77#[derive(Debug, Clone, PartialEq)]
78pub struct Embedder {
79    client: GenAiClient,
80    config: EmbeddingConfig,
81}
82
83impl Embedder {
84    /// Create a new Embedder instance that can be used to generate embeddings.
85    /// # Arguments
86    /// * `provider`: The provider to use for generating embeddings.
87    /// * `config`: The configuration for the embedding.
88    pub fn new(provider: Provider, config: EmbeddingConfig) -> Result<Self, AgentError> {
89        let client = match provider {
90            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
91            Provider::Gemini => GenAiClient::Gemini(GeminiClient::new(None, None, None)?),
92            _ => {
93                let msg = "No provider specified in ModelSettings";
94                error!("{}", msg);
95                return Err(AgentError::UndefinedError(msg.to_string()));
96            } // Add other providers here as needed
97        };
98
99        Ok(Self { client, config })
100    }
101
102    pub async fn embed(&self, inputs: Vec<String>) -> Result<EmbeddingResponse, AgentError> {
103        // Implementation for creating an embedding
104        self.client.create_embedding(inputs, &self.config).await
105    }
106}
107
108pub enum EmbeddingResponse {
109    OpenAI(OpenAIEmbeddingResponse),
110    Gemini(GeminiEmbeddingResponse),
111}
112
113impl EmbeddingResponse {
114    pub fn to_openai_response(&self) -> Result<&OpenAIEmbeddingResponse, AgentError> {
115        match self {
116            EmbeddingResponse::OpenAI(response) => Ok(response),
117            _ => Err(AgentError::InvalidResponseType("OpenAI".to_string())),
118        }
119    }
120
121    pub fn to_gemini_response(&self) -> Result<&GeminiEmbeddingResponse, AgentError> {
122        match self {
123            EmbeddingResponse::Gemini(response) => Ok(response),
124            _ => Err(AgentError::InvalidResponseType("Gemini".to_string())),
125        }
126    }
127
128    pub fn into_py_bound_any<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
129        match self {
130            EmbeddingResponse::OpenAI(response) => Ok(response.into_py_bound_any(py)?),
131            EmbeddingResponse::Gemini(response) => Ok(response.into_py_bound_any(py)?),
132        }
133    }
134
135    pub fn values(&self) -> Result<&Vec<f32>, AgentError> {
136        match self {
137            EmbeddingResponse::OpenAI(response) => {
138                let first = response
139                    .data
140                    .first()
141                    .ok_or_else(|| AgentError::NoEmbeddingsFound)?;
142                Ok(&first.embedding)
143            }
144
145            EmbeddingResponse::Gemini(response) => Ok(&response.embedding.values),
146        }
147    }
148}
149
150#[pyclass(name = "Embedder")]
151#[derive(Debug, Clone)]
152pub struct PyEmbedder {
153    pub embedder: Arc<Embedder>,
154    pub runtime: Arc<tokio::runtime::Runtime>,
155}
156
157#[pymethods]
158impl PyEmbedder {
159    #[new]
160    #[pyo3(signature = (provider, config=None))]
161    fn new(
162        provider: &Bound<'_, PyAny>,
163        config: Option<&Bound<'_, PyAny>>,
164    ) -> Result<Self, AgentError> {
165        let provider = Provider::extract_provider(provider)?;
166        let config = EmbeddingConfig::extract_config(config, &provider)?;
167        let embedder = Arc::new(Embedder::new(provider, config).unwrap());
168        Ok(Self {
169            embedder,
170            runtime: Arc::new(
171                tokio::runtime::Runtime::new()
172                    .map_err(|e| AgentError::RuntimeError(e.to_string()))?,
173            ),
174        })
175    }
176
177    /// Create a new embedding from a single input string
178    /// # Arguments
179    /// * `inputs`: The input string to embed.
180    /// * `config`: The configuration for the embedding.
181    #[pyo3(signature = (input))]
182    pub fn embed<'py>(
183        &self,
184        py: Python<'py>,
185        input: String,
186    ) -> Result<Bound<'py, PyAny>, AgentError> {
187        let embedder = self.embedder.clone();
188        let embeddings = self
189            .runtime
190            .block_on(async { embedder.embed(vec![input]).await })?;
191        embeddings.into_py_bound_any(py)
192    }
193}