1use crate::{
2 agents::error::AgentError,
3 agents::task::Task,
4 agents::types::{AgentResponse, PyAgentResponse},
5};
6use potato_prompt::prompt::settings::ModelSettings;
7use potato_prompt::{
8 parse_response_to_json, prompt::parse_prompt, prompt::types::Message, Prompt, Role,
9};
10
11use potato_provider::{providers::google::VertexClient, GenAiClient, OpenAIClient};
12
13use potato_provider::providers::types::ServiceType;
14use potato_provider::GeminiClient;
15use potato_type::Provider;
16use potato_util::create_uuid7;
17use pyo3::{prelude::*, IntoPyObjectExt};
18use serde::{
19 de::{self, MapAccess, Visitor},
20 ser::SerializeStruct,
21 Deserializer, Serializer,
22};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use std::collections::HashMap;
26use std::sync::Arc;
27use std::sync::RwLock;
28use tracing::{debug, instrument, warn};
29
30#[derive(Debug, PartialEq, Clone)]
31pub struct Agent {
32 pub id: String,
33 client: Arc<GenAiClient>,
34 pub provider: Provider,
35 pub system_instruction: Vec<Message>,
36}
37
38impl Agent {
40 #[instrument(skip_all)]
42 pub async fn rebuild_client(&self) -> Result<Self, AgentError> {
43 let client = match self.provider {
44 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
45 Provider::Gemini => {
46 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
47 }
48 Provider::Vertex => {
49 GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
50 }
51 _ => {
52 return Err(AgentError::MissingProviderError);
53 } };
55
56 Ok(Self {
57 id: self.id.clone(),
58 client: Arc::new(client),
59 system_instruction: self.system_instruction.clone(),
60 provider: self.provider.clone(),
61 })
62 }
63 pub async fn new(
64 provider: Provider,
65 system_instruction: Option<Vec<Message>>,
66 ) -> Result<Self, AgentError> {
67 let client = match provider {
68 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
69 Provider::Gemini => {
70 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
71 }
72 Provider::Vertex => {
73 GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
74 }
75 _ => {
76 return Err(AgentError::MissingProviderError);
77 } };
79
80 let system_instruction = system_instruction.unwrap_or_default();
81
82 Ok(Self {
83 client: Arc::new(client),
84 id: create_uuid7(),
85 system_instruction,
86 provider,
87 })
88 }
89
90 #[instrument(skip_all)]
91 fn append_task_with_message_context(
92 &self,
93 task: &mut Task,
94 context_messages: &HashMap<String, Vec<Message>>,
95 ) {
96 debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
98 if !task.dependencies.is_empty() {
99 for dep in &task.dependencies {
100 if let Some(messages) = context_messages.get(dep) {
101 for message in messages {
102 task.prompt.message.insert(0, message.clone());
104 }
105 }
106 }
107 }
108 }
109
110 #[instrument(skip_all)]
120 fn bind_context(
121 &self,
122 prompt: &mut Prompt,
123 parameter_context: &Value,
124 global_context: &Option<Value>,
125 ) -> Result<(), AgentError> {
126 if !prompt.parameters.is_empty() {
128 for param in &prompt.parameters {
129 if let Some(value) = parameter_context.get(param) {
131 for message in &mut prompt.message {
132 if message.role == "user" {
133 debug!("Binding parameter: {} with value: {}", param, value);
134 message.bind_mut(param, &value.to_string())?;
135 }
136 }
137 }
138
139 if let Some(global_value) = global_context {
141 if let Some(value) = global_value.get(param) {
142 for message in &mut prompt.message {
143 if message.role == "user" {
144 debug!("Binding global parameter: {} with value: {}", param, value);
145 message.bind_mut(param, &value.to_string())?;
146 }
147 }
148 }
149 }
150 }
151 }
152 Ok(())
153 }
154
155 fn append_system_instructions(&self, prompt: &mut Prompt) {
156 if !self.system_instruction.is_empty() {
157 let mut combined_messages = self.system_instruction.clone();
158 combined_messages.extend(prompt.system_instruction.clone());
159 prompt.system_instruction = combined_messages;
160 }
161 }
162 pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
163 debug!("Executing task: {}, count: {}", task.id, task.retry_count);
165 let mut prompt = task.prompt.clone();
166 self.append_system_instructions(&mut prompt);
167
168 let chat_response = self.client.generate_content(&prompt).await?;
170
171 Ok(AgentResponse::new(task.id.clone(), chat_response))
172 }
173
174 #[instrument(skip_all)]
175 pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
176 debug!("Executing prompt");
178 let mut prompt = prompt.clone();
179 self.append_system_instructions(&mut prompt);
180
181 let chat_response = self.client.generate_content(&prompt).await?;
183
184 Ok(AgentResponse::new(chat_response.id(), chat_response))
185 }
186
187 pub async fn execute_task_with_context(
188 &self,
189 task: &Arc<RwLock<Task>>,
190 context_messages: HashMap<String, Vec<Message>>,
191 parameter_context: Value,
192 global_context: Option<Value>,
193 ) -> Result<AgentResponse, AgentError> {
194 let (prompt, task_id) = {
196 let mut task = task.write().unwrap();
197 self.append_task_with_message_context(&mut task, &context_messages);
198 self.bind_context(&mut task.prompt, ¶meter_context, &global_context)?;
199
200 self.append_system_instructions(&mut task.prompt);
201 (task.prompt.clone(), task.id.clone())
202 };
203
204 let chat_response = self.client.generate_content(&prompt).await?;
206
207 Ok(AgentResponse::new(task_id, chat_response))
208 }
209
210 pub fn client_provider(&self) -> &Provider {
211 self.client.provider()
212 }
213
214 pub async fn from_model_settings(model_settings: &ModelSettings) -> Result<Self, AgentError> {
215 let provider = model_settings.provider();
216 let client = match provider {
217 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(ServiceType::Generate)?),
218 Provider::Gemini => {
219 GenAiClient::Gemini(GeminiClient::new(ServiceType::Generate).await?)
220 }
221 Provider::Vertex => {
222 GenAiClient::Vertex(VertexClient::new(ServiceType::Generate).await?)
223 }
224 Provider::Undefined => {
225 return Err(AgentError::MissingProviderError);
226 }
227 };
228
229 Ok(Self {
230 client: Arc::new(client),
231 id: create_uuid7(),
232 system_instruction: Vec::new(),
233 provider,
234 })
235 }
236}
237
238impl Serialize for Agent {
239 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
240 where
241 S: Serializer,
242 {
243 let mut state = serializer.serialize_struct("Agent", 3)?;
244 state.serialize_field("id", &self.id)?;
245 state.serialize_field("provider", &self.provider)?;
246 state.serialize_field("system_instruction", &self.system_instruction)?;
247 state.end()
248 }
249}
250
251impl<'de> Deserialize<'de> for Agent {
253 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
254 where
255 D: Deserializer<'de>,
256 {
257 #[derive(Deserialize)]
258 #[serde(field_identifier, rename_all = "snake_case")]
259 enum Field {
260 Id,
261 Provider,
262 SystemInstruction,
263 }
264
265 struct AgentVisitor;
266
267 impl<'de> Visitor<'de> for AgentVisitor {
268 type Value = Agent;
269
270 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
271 formatter.write_str("struct Agent")
272 }
273
274 fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
275 where
276 V: MapAccess<'de>,
277 {
278 let mut id = None;
279 let mut provider = None;
280 let mut system_instruction = None;
281
282 while let Some(key) = map.next_key()? {
283 match key {
284 Field::Id => {
285 id = Some(map.next_value()?);
286 }
287 Field::Provider => {
288 provider = Some(map.next_value()?);
289 }
290 Field::SystemInstruction => {
291 system_instruction = Some(map.next_value()?);
292 }
293 }
294 }
295
296 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
297 let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
298 let system_instruction = system_instruction
299 .ok_or_else(|| de::Error::missing_field("system_instruction"))?;
300
301 let client = GenAiClient::Undefined;
304 Ok(Agent {
305 id,
306 client: Arc::new(client),
307 system_instruction,
308 provider,
309 })
310 }
311 }
312
313 const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
314 deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
315 }
316}
317
318#[pyclass(name = "Agent")]
319#[derive(Debug, Clone)]
320pub struct PyAgent {
321 pub agent: Arc<Agent>,
322 pub runtime: Arc<tokio::runtime::Runtime>,
323}
324
325#[pymethods]
326impl PyAgent {
327 #[new]
328 #[pyo3(signature = (provider, system_instruction = None))]
329 pub fn new(
336 provider: &Bound<'_, PyAny>,
337 system_instruction: Option<&Bound<'_, PyAny>>,
338 ) -> Result<Self, AgentError> {
339 let provider = Provider::extract_provider(provider)?;
340
341 let system_instruction = if let Some(system_instruction) = system_instruction {
342 Some(
343 parse_prompt(system_instruction)?
344 .into_iter()
345 .map(|mut msg| {
346 msg.role = Role::Developer.to_string();
347 msg
348 })
349 .collect::<Vec<Message>>(),
350 )
351 } else {
352 None
353 };
354
355 let runtime = tokio::runtime::Runtime::new()?;
356
357 let agent = runtime.block_on(async { Agent::new(provider, system_instruction).await })?;
358
359 Ok(Self {
360 agent: Arc::new(agent),
361 runtime: Arc::new(runtime),
362 })
363 }
364
365 #[pyo3(signature = (task, output_type=None, model=None))]
366 pub fn execute_task(
367 &self,
368 py: Python<'_>,
369 task: &mut Task,
370 output_type: Option<Bound<'_, PyAny>>,
371 model: Option<&str>,
372 ) -> Result<PyAgentResponse, AgentError> {
373 debug!("Executing task");
375
376 if let Some(output_type) = &output_type {
378 match parse_response_to_json(py, output_type) {
379 Ok((response_type, response_format)) => {
380 task.prompt.response_type = response_type;
381 task.prompt.response_json_schema = response_format;
382 }
383 Err(_) => {
384 return Err(AgentError::InvalidOutputType(output_type.to_string()));
385 }
386 }
387 }
388
389 if let Some(model) = model {
391 task.prompt.model = model.to_string();
392 }
393
394 if task.prompt.provider != *self.agent.client_provider() {
396 return Err(AgentError::ProviderMismatch(
397 task.prompt.provider.to_string(),
398 self.agent.client_provider().as_str().to_string(),
399 ));
400 }
401
402 debug!(
403 "Task prompt model identifier: {}",
404 task.prompt.model_identifier()
405 );
406
407 let chat_response = self
408 .runtime
409 .block_on(async { self.agent.execute_task(task).await })?;
410
411 debug!("Task executed successfully");
412 let output = output_type.as_ref().map(|obj| obj.clone().unbind());
413 let response = PyAgentResponse::new(chat_response, output);
414
415 Ok(response)
416 }
417
418 #[pyo3(signature = (prompt, output_type=None, model=None))]
419 pub fn execute_prompt(
420 &self,
421 py: Python<'_>,
422 prompt: &mut Prompt,
423 output_type: Option<Bound<'_, PyAny>>,
424 model: Option<&str>,
425 ) -> Result<PyAgentResponse, AgentError> {
426 debug!("Executing task");
428 if let Some(output_type) = &output_type {
430 match parse_response_to_json(py, output_type) {
431 Ok((response_type, response_format)) => {
432 prompt.response_type = response_type;
433 prompt.response_json_schema = response_format;
434 }
435 Err(_) => {
436 return Err(AgentError::InvalidOutputType(output_type.to_string()));
437 }
438 }
439 }
440
441 if let Some(model) = model {
443 prompt.model = model.to_string();
444 }
445
446 if prompt.provider != *self.agent.client_provider() {
448 return Err(AgentError::ProviderMismatch(
449 prompt.provider.to_string(),
450 self.agent.client_provider().as_str().to_string(),
451 ));
452 }
453
454 let chat_response = self
455 .runtime
456 .block_on(async { self.agent.execute_prompt(prompt).await })?;
457
458 debug!("Task executed successfully");
459 let output = output_type.as_ref().map(|obj| obj.clone().unbind());
460 let response = PyAgentResponse::new(chat_response, output);
461
462 Ok(response)
463 }
464
465 #[getter]
466 pub fn system_instruction<'py>(
467 &self,
468 py: Python<'py>,
469 ) -> Result<Bound<'py, PyAny>, AgentError> {
470 Ok(self
471 .agent
472 .system_instruction
473 .clone()
474 .into_bound_py_any(py)?)
475 }
476
477 #[getter]
478 pub fn id(&self) -> &str {
479 self.agent.id.as_str()
480 }
481}