potato_agent/agents/
embed.rs

1use potato_type::google::EmbeddingConfigTrait;
2use potato_type::Provider;
3
4use crate::agents::client::GenAiClient;
5use crate::agents::provider::gemini::GeminiClient;
6use crate::agents::provider::openai::OpenAIClient;
7use crate::AgentError;
8use potato_type::google::GeminiEmbeddingConfig;
9use potato_type::google::GeminiEmbeddingResponse;
10use potato_type::openai::embedding::{OpenAIEmbeddingConfig, OpenAIEmbeddingResponse};
11use pyo3::prelude::*;
12use serde::Serialize;
13use std::sync::Arc;
14
15#[derive(Debug, Clone, PartialEq, Serialize)]
16#[serde(untagged)]
17pub enum EmbeddingConfig {
18    OpenAI(OpenAIEmbeddingConfig),
19    Gemini(GeminiEmbeddingConfig),
20}
21
22impl EmbeddingConfig {
23    pub fn extract_config(
24        config: Option<&Bound<'_, PyAny>>,
25        provider: &Provider,
26    ) -> Result<Self, AgentError> {
27        match provider {
28            Provider::OpenAI => {
29                let config = if config.is_none() {
30                    OpenAIEmbeddingConfig::default()
31                } else {
32                    config
33                        .unwrap()
34                        .extract::<OpenAIEmbeddingConfig>()
35                        .map_err(|e| {
36                            AgentError::EmbeddingConfigExtractionError(format!(
37                                "Failed to extract OpenAIEmbeddingConfig: {}",
38                                e
39                            ))
40                        })?
41                };
42
43                Ok(EmbeddingConfig::OpenAI(config))
44            }
45            Provider::Gemini => {
46                let config = if config.is_none() {
47                    GeminiEmbeddingConfig::default()
48                } else {
49                    config
50                        .unwrap()
51                        .extract::<GeminiEmbeddingConfig>()
52                        .map_err(|e| {
53                            AgentError::EmbeddingConfigExtractionError(format!(
54                                "Failed to extract GeminiEmbeddingConfig: {}",
55                                e
56                            ))
57                        })?
58                };
59
60                Ok(EmbeddingConfig::Gemini(config))
61            }
62            _ => Err(AgentError::ProviderNotSupportedError(provider.to_string())),
63        }
64    }
65}
66
67impl EmbeddingConfigTrait for EmbeddingConfig {
68    fn get_model(&self) -> &str {
69        match self {
70            EmbeddingConfig::OpenAI(config) => config.model.as_str(),
71            EmbeddingConfig::Gemini(config) => config.get_model(),
72        }
73    }
74}
75
76use tracing::error;
77#[derive(Debug, Clone, PartialEq)]
78pub struct Embedder {
79    client: GenAiClient,
80}
81
82impl Embedder {
83    pub fn new(provider: Provider) -> Result<Self, AgentError> {
84        let client = match provider {
85            Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
86            Provider::Gemini => GenAiClient::Gemini(GeminiClient::new(None, None, None)?),
87            _ => {
88                let msg = "No provider specified in ModelSettings";
89                error!("{}", msg);
90                return Err(AgentError::UndefinedError(msg.to_string()));
91            } // Add other providers here as needed
92        };
93
94        Ok(Self { client })
95    }
96
97    pub async fn embed(
98        &self,
99        inputs: Vec<String>,
100        config: EmbeddingConfig,
101    ) -> Result<EmbeddingResponse, AgentError> {
102        // Implementation for creating an embedding
103        self.client.create_embedding(inputs, config).await
104    }
105}
106
107pub enum EmbeddingResponse {
108    OpenAI(OpenAIEmbeddingResponse),
109    Gemini(GeminiEmbeddingResponse),
110}
111
112impl EmbeddingResponse {
113    pub fn to_openai_response(&self) -> Result<&OpenAIEmbeddingResponse, AgentError> {
114        match self {
115            EmbeddingResponse::OpenAI(response) => Ok(response),
116            _ => Err(AgentError::InvalidResponseType("OpenAI".to_string())),
117        }
118    }
119
120    pub fn to_gemini_response(&self) -> Result<&GeminiEmbeddingResponse, AgentError> {
121        match self {
122            EmbeddingResponse::Gemini(response) => Ok(response),
123            _ => Err(AgentError::InvalidResponseType("Gemini".to_string())),
124        }
125    }
126
127    pub fn into_py_bound_any<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
128        match self {
129            EmbeddingResponse::OpenAI(response) => Ok(response.into_py_bound_any(py)?),
130            EmbeddingResponse::Gemini(response) => Ok(response.into_py_bound_any(py)?),
131        }
132    }
133}
134
135#[pyclass(name = "Embedder")]
136#[derive(Debug, Clone)]
137pub struct PyEmbedder {
138    pub embedder: Arc<Embedder>,
139    pub runtime: Arc<tokio::runtime::Runtime>,
140}
141
142#[pymethods]
143impl PyEmbedder {
144    #[new]
145    fn new(provider: &Bound<'_, PyAny>) -> Result<Self, AgentError> {
146        let provider = Provider::extract_provider(provider)?;
147        let embedder = Arc::new(Embedder::new(provider).unwrap());
148        Ok(Self {
149            embedder,
150            runtime: Arc::new(
151                tokio::runtime::Runtime::new()
152                    .map_err(|e| AgentError::RuntimeError(e.to_string()))?,
153            ),
154        })
155    }
156
157    /// Create a new embedding from a single input string
158    /// # Arguments
159    /// * `inputs`: The input string to embed.
160    /// * `config`: The configuration for the embedding.
161    #[pyo3(signature = (input, config=None))]
162    pub fn embed<'py>(
163        &self,
164        py: Python<'py>,
165        input: String,
166        config: Option<&Bound<'py, PyAny>>,
167    ) -> Result<Bound<'py, PyAny>, AgentError> {
168        let config = EmbeddingConfig::extract_config(config, self.embedder.client.provider())?;
169        let embedder = self.embedder.clone();
170        let embeddings = self
171            .runtime
172            .block_on(async { embedder.embed(vec![input], config).await })?;
173        embeddings.into_py_bound_any(py)
174    }
175}