1use crate::agents::provider::gemini::GeminiClient;
2use crate::agents::provider::openai::OpenAIClient;
3use crate::{
4 agents::client::GenAiClient,
5 agents::error::AgentError,
6 agents::task::Task,
7 agents::types::{AgentResponse, PyAgentResponse},
8};
9use potato_prompt::{
10 parse_response_to_json, prompt::parse_prompt, prompt::types::Message, ModelSettings, Prompt,
11 Role,
12};
13use potato_type::Model;
14use potato_type::Provider;
15use potato_util::create_uuid7;
16use pyo3::{prelude::*, IntoPyObjectExt};
17use serde::{
18 de::{self, MapAccess, Visitor},
19 ser::SerializeStruct,
20 Deserializer, Serializer,
21};
22use serde::{Deserialize, Serialize};
23use serde_json::Value;
24use std::collections::HashMap;
25use std::sync::Arc;
26use std::sync::RwLock;
27use tracing::{debug, error, instrument, warn};
28
29#[derive(Debug, Clone, PartialEq)]
30pub struct Agent {
31 pub id: String,
32
33 client: GenAiClient,
34
35 pub system_instruction: Vec<Message>,
36}
37
38impl Agent {
40 pub fn new(
41 provider: Provider,
42 system_instruction: Option<Vec<Message>>,
43 ) -> Result<Self, AgentError> {
44 let client = match provider {
45 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
46 Provider::Gemini => GenAiClient::Gemini(GeminiClient::new(None, None, None)?),
47 _ => {
48 let msg = "No provider specified in ModelSettings";
49 error!("{}", msg);
50 return Err(AgentError::UndefinedError(msg.to_string()));
51 } };
53
54 let system_instruction = system_instruction.unwrap_or_default();
55
56 Ok(Self {
57 client,
58 id: create_uuid7(),
59 system_instruction,
60 })
61 }
62
63 #[instrument(skip_all)]
64 fn append_task_with_message_context(
65 &self,
66 task: &mut Task,
67 context_messages: &HashMap<String, Vec<Message>>,
68 ) {
69 debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
71 if !task.dependencies.is_empty() {
72 for dep in &task.dependencies {
73 if let Some(messages) = context_messages.get(dep) {
74 for message in messages {
75 task.prompt.message.insert(0, message.clone());
77 }
78 }
79 }
80 }
81 }
82
83 #[instrument(skip_all)]
93 fn bind_context(
94 &self,
95 prompt: &mut Prompt,
96 parameter_context: &Value,
97 global_context: &Option<Value>,
98 ) -> Result<(), AgentError> {
99 if !prompt.parameters.is_empty() {
101 for param in &prompt.parameters {
102 if let Some(value) = parameter_context.get(param) {
104 for message in &mut prompt.message {
105 if message.role == "user" {
106 debug!("Binding parameter: {} with value: {}", param, value);
107 message.bind_mut(param, &value.to_string())?;
108 }
109 }
110 }
111
112 if let Some(global_value) = global_context {
114 if let Some(value) = global_value.get(param) {
115 for message in &mut prompt.message {
116 if message.role == "user" {
117 debug!("Binding global parameter: {} with value: {}", param, value);
118 message.bind_mut(param, &value.to_string())?;
119 }
120 }
121 }
122 }
123 }
124 }
125 Ok(())
126 }
127
128 fn append_system_instructions(&self, prompt: &mut Prompt) {
129 if !self.system_instruction.is_empty() {
130 let mut combined_messages = self.system_instruction.clone();
131 combined_messages.extend(prompt.system_instruction.clone());
132 prompt.system_instruction = combined_messages;
133 }
134 }
135 pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
136 debug!("Executing task: {}, count: {}", task.id, task.retry_count);
138 let mut prompt = task.prompt.clone();
139 self.append_system_instructions(&mut prompt);
140
141 let chat_response = self.client.execute(&prompt).await?;
143
144 Ok(AgentResponse::new(task.id.clone(), chat_response))
145 }
146
147 #[instrument(skip_all)]
148 pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
149 debug!("Executing prompt");
151 let mut prompt = prompt.clone();
152 self.append_system_instructions(&mut prompt);
153
154 let chat_response = self.client.execute(&prompt).await?;
156
157 Ok(AgentResponse::new(chat_response.id(), chat_response))
158 }
159
160 pub async fn execute_task_with_context(
161 &self,
162 task: &Arc<RwLock<Task>>,
163 context_messages: HashMap<String, Vec<Message>>,
164 parameter_context: Value,
165 global_context: Option<Value>,
166 ) -> Result<AgentResponse, AgentError> {
167 let (prompt, task_id) = {
169 let mut task = task.write().unwrap();
170 self.append_task_with_message_context(&mut task, &context_messages);
171 self.bind_context(&mut task.prompt, ¶meter_context, &global_context)?;
172
173 self.append_system_instructions(&mut task.prompt);
174 (task.prompt.clone(), task.id.clone())
175 };
176
177 let chat_response = self.client.execute(&prompt).await?;
179
180 Ok(AgentResponse::new(task_id, chat_response))
181 }
182
183 pub fn provider(&self) -> &Provider {
184 self.client.provider()
185 }
186
187 pub fn from_model_settings(model_settings: &ModelSettings) -> Result<Self, AgentError> {
188 let provider = Provider::from_string(&model_settings.provider)?;
189 let client = match provider {
190 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
191 Provider::Gemini => GenAiClient::Gemini(GeminiClient::new(None, None, None)?),
192 Provider::Undefined => {
193 let msg = "No provider specified in ModelSettings";
194 error!("{}", msg);
195 return Err(AgentError::UndefinedError(msg.to_string()));
196 }
197 };
198
199 Ok(Self {
200 client,
201 id: create_uuid7(),
202 system_instruction: Vec::new(),
203 })
204 }
205}
206
207impl Serialize for Agent {
208 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
209 where
210 S: Serializer,
211 {
212 let mut state = serializer.serialize_struct("Agent", 3)?;
213 state.serialize_field("id", &self.id)?;
214 state.serialize_field("provider", &self.client.provider())?;
215 state.serialize_field("system_instruction", &self.system_instruction)?;
216 state.end()
217 }
218}
219
220impl<'de> Deserialize<'de> for Agent {
222 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
223 where
224 D: Deserializer<'de>,
225 {
226 #[derive(Deserialize)]
227 #[serde(field_identifier, rename_all = "snake_case")]
228 enum Field {
229 Id,
230 Provider,
231 SystemInstruction,
232 }
233
234 struct AgentVisitor;
235
236 impl<'de> Visitor<'de> for AgentVisitor {
237 type Value = Agent;
238
239 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
240 formatter.write_str("struct Agent")
241 }
242
243 fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
244 where
245 V: MapAccess<'de>,
246 {
247 let mut id = None;
248 let mut provider = None;
249 let mut system_instruction = None;
250
251 while let Some(key) = map.next_key()? {
252 match key {
253 Field::Id => {
254 id = Some(map.next_value()?);
255 }
256 Field::Provider => {
257 provider = Some(map.next_value()?);
258 }
259 Field::SystemInstruction => {
260 system_instruction = Some(map.next_value()?);
261 }
262 }
263 }
264
265 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
266 let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
267 let system_instruction = system_instruction
268 .ok_or_else(|| de::Error::missing_field("system_instruction"))?;
269
270 let client = match provider {
272 Provider::OpenAI => {
273 GenAiClient::OpenAI(OpenAIClient::new(None, None, None).map_err(|e| {
274 de::Error::custom(format!("Failed to initialize OpenAIClient: {e}"))
275 })?)
276 }
277 Provider::Gemini => {
278 GenAiClient::Gemini(GeminiClient::new(None, None, None).map_err(|e| {
279 de::Error::custom(format!("Failed to initialize GeminiClient: {e}"))
280 })?)
281 }
282
283 Provider::Undefined => {
284 let msg = "No provider specified in ModelSettings";
285 error!("{}", msg);
286 return Err(de::Error::custom(msg));
287 }
288 };
289
290 Ok(Agent {
291 id,
292 client,
293 system_instruction,
294 })
295 }
296 }
297
298 const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
299 deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
300 }
301}
302
303#[pyclass(name = "Agent")]
304#[derive(Debug, Clone)]
305pub struct PyAgent {
306 pub agent: Arc<Agent>,
307 pub runtime: Arc<tokio::runtime::Runtime>,
308}
309
310#[pymethods]
311impl PyAgent {
312 #[new]
313 #[pyo3(signature = (provider, system_instruction = None))]
314 pub fn new(
321 provider: &Bound<'_, PyAny>,
322 system_instruction: Option<&Bound<'_, PyAny>>,
323 ) -> Result<Self, AgentError> {
324 let provider = Provider::extract_provider(provider)?;
325
326 let system_instruction = if let Some(system_instruction) = system_instruction {
327 Some(
328 parse_prompt(system_instruction)?
329 .into_iter()
330 .map(|mut msg| {
331 msg.role = Role::Developer.to_string();
332 msg
333 })
334 .collect::<Vec<Message>>(),
335 )
336 } else {
337 None
338 };
339
340 let agent = Agent::new(provider, system_instruction)?;
341
342 Ok(Self {
343 agent: Arc::new(agent),
344 runtime: Arc::new(
345 tokio::runtime::Runtime::new()
346 .map_err(|e| AgentError::RuntimeError(e.to_string()))?,
347 ),
348 })
349 }
350
351 #[pyo3(signature = (task, output_type=None, model=None))]
352 pub fn execute_task(
353 &self,
354 py: Python<'_>,
355 task: &mut Task,
356 output_type: Option<Bound<'_, PyAny>>,
357 model: Option<&str>,
358 ) -> Result<PyAgentResponse, AgentError> {
359 debug!("Executing task");
361
362 if let Some(output_type) = &output_type {
364 match parse_response_to_json(py, output_type) {
365 Ok((response_type, response_format)) => {
366 task.prompt.response_type = response_type;
367 task.prompt.response_json_schema = response_format;
368 }
369 Err(_) => {
370 return Err(AgentError::InvalidOutputType(output_type.to_string()));
371 }
372 }
373 }
374
375 if model.is_none() && task.prompt.model() == Model::Undefined.as_str() {
377 return Err(AgentError::UndefinedError(
378 "Model must be specified either as an argument or in the Task prompt".to_string(),
379 ));
380 }
381
382 if let Some(model) = model {
384 task.prompt.set_model(model);
385 }
386
387 if task.prompt.provider() == Model::Undefined.as_str() {
389 task.prompt.set_provider(self.agent.provider().as_str());
390 }
391
392 println!("Task prompt model: {}", task.prompt.model());
393 println!("Task prompt provider: {}", task.prompt.provider());
394
395 let chat_response = self
396 .runtime
397 .block_on(async { self.agent.execute_task(task).await })?;
398
399 debug!("Task executed successfully");
400 let output = output_type.as_ref().map(|obj| obj.clone().unbind());
401 let response = PyAgentResponse::new(chat_response, output);
402
403 Ok(response)
404 }
405
406 #[pyo3(signature = (prompt, output_type=None, model=None))]
407 pub fn execute_prompt(
408 &self,
409 py: Python<'_>,
410 prompt: &mut Prompt,
411 output_type: Option<Bound<'_, PyAny>>,
412 model: Option<&str>,
413 ) -> Result<PyAgentResponse, AgentError> {
414 debug!("Executing task");
416 if let Some(output_type) = &output_type {
418 match parse_response_to_json(py, output_type) {
419 Ok((response_type, response_format)) => {
420 prompt.response_type = response_type;
421 prompt.response_json_schema = response_format;
422 }
423 Err(_) => {
424 return Err(AgentError::InvalidOutputType(output_type.to_string()));
425 }
426 }
427 }
428
429 if model.is_none() && prompt.model() == Model::Undefined.as_str() {
431 return Err(AgentError::UndefinedError(
432 "Model must be specified either as an argument or in the Prompt".to_string(),
433 ));
434 }
435
436 if let Some(model) = model {
438 prompt.set_model(model);
439 }
440
441 if prompt.provider() == Model::Undefined.as_str() {
442 prompt.set_provider(self.agent.provider().as_str());
443 }
444
445 let chat_response = self
446 .runtime
447 .block_on(async { self.agent.execute_prompt(prompt).await })?;
448
449 debug!("Task executed successfully");
450 let output = output_type.as_ref().map(|obj| obj.clone().unbind());
451 let response = PyAgentResponse::new(chat_response, output);
452
453 Ok(response)
454 }
455
456 #[getter]
457 pub fn system_instruction<'py>(
458 &self,
459 py: Python<'py>,
460 ) -> Result<Bound<'py, PyAny>, AgentError> {
461 Ok(self
462 .agent
463 .system_instruction
464 .clone()
465 .into_bound_py_any(py)?)
466 }
467
468 #[getter]
469 pub fn id(&self) -> &str {
470 self.agent.id.as_str()
471 }
472}