potato_agent/agents/
embed.rs1use potato_type::google::EmbeddingConfigTrait;
2use potato_type::Provider;
3
4use crate::agents::client::GenAiClient;
5use crate::agents::provider::gemini::GeminiClient;
6use crate::agents::provider::openai::OpenAIClient;
7use crate::AgentError;
8use potato_type::google::GeminiEmbeddingConfig;
9use potato_type::google::GeminiEmbeddingResponse;
10use potato_type::openai::embedding::{OpenAIEmbeddingConfig, OpenAIEmbeddingResponse};
11use pyo3::prelude::*;
12use serde::Serialize;
13use std::sync::Arc;
14
15#[derive(Debug, Clone, PartialEq, Serialize)]
16#[serde(untagged)]
17pub enum EmbeddingConfig {
18 OpenAI(OpenAIEmbeddingConfig),
19 Gemini(GeminiEmbeddingConfig),
20}
21
22impl EmbeddingConfig {
23 pub fn extract_config(
24 config: Option<&Bound<'_, PyAny>>,
25 provider: &Provider,
26 ) -> Result<Self, AgentError> {
27 match provider {
28 Provider::OpenAI => {
29 let config = if config.is_none() {
30 OpenAIEmbeddingConfig::default()
31 } else {
32 config
33 .unwrap()
34 .extract::<OpenAIEmbeddingConfig>()
35 .map_err(|e| {
36 AgentError::EmbeddingConfigExtractionError(format!(
37 "Failed to extract OpenAIEmbeddingConfig: {}",
38 e
39 ))
40 })?
41 };
42
43 Ok(EmbeddingConfig::OpenAI(config))
44 }
45 Provider::Gemini => {
46 let config = if config.is_none() {
47 GeminiEmbeddingConfig::default()
48 } else {
49 config
50 .unwrap()
51 .extract::<GeminiEmbeddingConfig>()
52 .map_err(|e| {
53 AgentError::EmbeddingConfigExtractionError(format!(
54 "Failed to extract GeminiEmbeddingConfig: {}",
55 e
56 ))
57 })?
58 };
59
60 Ok(EmbeddingConfig::Gemini(config))
61 }
62 _ => Err(AgentError::ProviderNotSupportedError(provider.to_string())),
63 }
64 }
65}
66
67impl EmbeddingConfigTrait for EmbeddingConfig {
68 fn get_model(&self) -> &str {
69 match self {
70 EmbeddingConfig::OpenAI(config) => config.model.as_str(),
71 EmbeddingConfig::Gemini(config) => config.get_model(),
72 }
73 }
74}
75
76use tracing::error;
77#[derive(Debug, Clone, PartialEq)]
78pub struct Embedder {
79 client: GenAiClient,
80 config: EmbeddingConfig,
81}
82
83impl Embedder {
84 pub fn new(provider: Provider, config: EmbeddingConfig) -> Result<Self, AgentError> {
89 let client = match provider {
90 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
91 Provider::Gemini => GenAiClient::Gemini(GeminiClient::new(None, None, None)?),
92 _ => {
93 let msg = "No provider specified in ModelSettings";
94 error!("{}", msg);
95 return Err(AgentError::UndefinedError(msg.to_string()));
96 } };
98
99 Ok(Self { client, config })
100 }
101
102 pub async fn embed(&self, inputs: Vec<String>) -> Result<EmbeddingResponse, AgentError> {
103 self.client.create_embedding(inputs, &self.config).await
105 }
106}
107
108pub enum EmbeddingResponse {
109 OpenAI(OpenAIEmbeddingResponse),
110 Gemini(GeminiEmbeddingResponse),
111}
112
113impl EmbeddingResponse {
114 pub fn to_openai_response(&self) -> Result<&OpenAIEmbeddingResponse, AgentError> {
115 match self {
116 EmbeddingResponse::OpenAI(response) => Ok(response),
117 _ => Err(AgentError::InvalidResponseType("OpenAI".to_string())),
118 }
119 }
120
121 pub fn to_gemini_response(&self) -> Result<&GeminiEmbeddingResponse, AgentError> {
122 match self {
123 EmbeddingResponse::Gemini(response) => Ok(response),
124 _ => Err(AgentError::InvalidResponseType("Gemini".to_string())),
125 }
126 }
127
128 pub fn into_py_bound_any<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
129 match self {
130 EmbeddingResponse::OpenAI(response) => Ok(response.into_py_bound_any(py)?),
131 EmbeddingResponse::Gemini(response) => Ok(response.into_py_bound_any(py)?),
132 }
133 }
134
135 pub fn values(&self) -> Result<&Vec<f32>, AgentError> {
136 match self {
137 EmbeddingResponse::OpenAI(response) => {
138 let first = response
139 .data
140 .first()
141 .ok_or_else(|| AgentError::NoEmbeddingsFound)?;
142 Ok(&first.embedding)
143 }
144
145 EmbeddingResponse::Gemini(response) => Ok(&response.embedding.values),
146 }
147 }
148}
149
150#[pyclass(name = "Embedder")]
151#[derive(Debug, Clone)]
152pub struct PyEmbedder {
153 pub embedder: Arc<Embedder>,
154 pub runtime: Arc<tokio::runtime::Runtime>,
155}
156
157#[pymethods]
158impl PyEmbedder {
159 #[new]
160 #[pyo3(signature = (provider, config=None))]
161 fn new(
162 provider: &Bound<'_, PyAny>,
163 config: Option<&Bound<'_, PyAny>>,
164 ) -> Result<Self, AgentError> {
165 let provider = Provider::extract_provider(provider)?;
166 let config = EmbeddingConfig::extract_config(config, &provider)?;
167 let embedder = Arc::new(Embedder::new(provider, config).unwrap());
168 Ok(Self {
169 embedder,
170 runtime: Arc::new(
171 tokio::runtime::Runtime::new()
172 .map_err(|e| AgentError::RuntimeError(e.to_string()))?,
173 ),
174 })
175 }
176
177 #[pyo3(signature = (input))]
182 pub fn embed<'py>(
183 &self,
184 py: Python<'py>,
185 input: String,
186 ) -> Result<Bound<'py, PyAny>, AgentError> {
187 let embedder = self.embedder.clone();
188 let embeddings = self
189 .runtime
190 .block_on(async { embedder.embed(vec![input]).await })?;
191 embeddings.into_py_bound_any(py)
192 }
193}