potato_agent/agents/
embed.rs1use potato_type::google::EmbeddingConfigTrait;
2use potato_type::Provider;
3
4use crate::agents::client::GenAiClient;
5use crate::agents::provider::gemini::GeminiClient;
6use crate::agents::provider::openai::OpenAIClient;
7use crate::AgentError;
8use potato_type::google::GeminiEmbeddingConfig;
9use potato_type::google::GeminiEmbeddingResponse;
10use potato_type::openai::embedding::{OpenAIEmbeddingConfig, OpenAIEmbeddingResponse};
11use pyo3::prelude::*;
12use serde::Serialize;
13use std::sync::Arc;
14
15#[derive(Debug, Clone, PartialEq, Serialize)]
16#[serde(untagged)]
17pub enum EmbeddingConfig {
18 OpenAI(OpenAIEmbeddingConfig),
19 Gemini(GeminiEmbeddingConfig),
20}
21
22impl EmbeddingConfig {
23 pub fn extract_config(
24 config: Option<&Bound<'_, PyAny>>,
25 provider: &Provider,
26 ) -> Result<Self, AgentError> {
27 match provider {
28 Provider::OpenAI => {
29 let config = if config.is_none() {
30 OpenAIEmbeddingConfig::default()
31 } else {
32 config
33 .unwrap()
34 .extract::<OpenAIEmbeddingConfig>()
35 .map_err(|e| {
36 AgentError::EmbeddingConfigExtractionError(format!(
37 "Failed to extract OpenAIEmbeddingConfig: {}",
38 e
39 ))
40 })?
41 };
42
43 Ok(EmbeddingConfig::OpenAI(config))
44 }
45 Provider::Gemini => {
46 let config = if config.is_none() {
47 GeminiEmbeddingConfig::default()
48 } else {
49 config
50 .unwrap()
51 .extract::<GeminiEmbeddingConfig>()
52 .map_err(|e| {
53 AgentError::EmbeddingConfigExtractionError(format!(
54 "Failed to extract GeminiEmbeddingConfig: {}",
55 e
56 ))
57 })?
58 };
59
60 Ok(EmbeddingConfig::Gemini(config))
61 }
62 _ => Err(AgentError::ProviderNotSupportedError(provider.to_string())),
63 }
64 }
65}
66
67impl EmbeddingConfigTrait for EmbeddingConfig {
68 fn get_model(&self) -> &str {
69 match self {
70 EmbeddingConfig::OpenAI(config) => config.model.as_str(),
71 EmbeddingConfig::Gemini(config) => config.get_model(),
72 }
73 }
74}
75
76use tracing::error;
77#[derive(Debug, Clone, PartialEq)]
78pub struct Embedder {
79 client: GenAiClient,
80}
81
82impl Embedder {
83 pub fn new(provider: Provider) -> Result<Self, AgentError> {
84 let client = match provider {
85 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
86 Provider::Gemini => GenAiClient::Gemini(GeminiClient::new(None, None, None)?),
87 _ => {
88 let msg = "No provider specified in ModelSettings";
89 error!("{}", msg);
90 return Err(AgentError::UndefinedError(msg.to_string()));
91 } };
93
94 Ok(Self { client })
95 }
96
97 pub async fn embed(
98 &self,
99 inputs: Vec<String>,
100 config: EmbeddingConfig,
101 ) -> Result<EmbeddingResponse, AgentError> {
102 self.client.create_embedding(inputs, config).await
104 }
105}
106
107pub enum EmbeddingResponse {
108 OpenAI(OpenAIEmbeddingResponse),
109 Gemini(GeminiEmbeddingResponse),
110}
111
112impl EmbeddingResponse {
113 pub fn to_openai_response(&self) -> Result<&OpenAIEmbeddingResponse, AgentError> {
114 match self {
115 EmbeddingResponse::OpenAI(response) => Ok(response),
116 _ => Err(AgentError::InvalidResponseType("OpenAI".to_string())),
117 }
118 }
119
120 pub fn to_gemini_response(&self) -> Result<&GeminiEmbeddingResponse, AgentError> {
121 match self {
122 EmbeddingResponse::Gemini(response) => Ok(response),
123 _ => Err(AgentError::InvalidResponseType("Gemini".to_string())),
124 }
125 }
126
127 pub fn into_py_bound_any<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
128 match self {
129 EmbeddingResponse::OpenAI(response) => Ok(response.into_py_bound_any(py)?),
130 EmbeddingResponse::Gemini(response) => Ok(response.into_py_bound_any(py)?),
131 }
132 }
133}
134
135#[pyclass(name = "Embedder")]
136#[derive(Debug, Clone)]
137pub struct PyEmbedder {
138 pub embedder: Arc<Embedder>,
139 pub runtime: Arc<tokio::runtime::Runtime>,
140}
141
142#[pymethods]
143impl PyEmbedder {
144 #[new]
145 fn new(provider: &Bound<'_, PyAny>) -> Result<Self, AgentError> {
146 let provider = Provider::extract_provider(provider)?;
147 let embedder = Arc::new(Embedder::new(provider).unwrap());
148 Ok(Self {
149 embedder,
150 runtime: Arc::new(
151 tokio::runtime::Runtime::new()
152 .map_err(|e| AgentError::RuntimeError(e.to_string()))?,
153 ),
154 })
155 }
156
157 #[pyo3(signature = (input, config=None))]
162 pub fn embed<'py>(
163 &self,
164 py: Python<'py>,
165 input: String,
166 config: Option<&Bound<'py, PyAny>>,
167 ) -> Result<Bound<'py, PyAny>, AgentError> {
168 let config = EmbeddingConfig::extract_config(config, self.embedder.client.provider())?;
169 let embedder = self.embedder.clone();
170 let embeddings = self
171 .runtime
172 .block_on(async { embedder.embed(vec![input], config).await })?;
173 embeddings.into_py_bound_any(py)
174 }
175}