1use crate::agents::provider::gemini::GeminiClient;
2use crate::agents::provider::openai::OpenAIClient;
3use crate::{
4 agents::client::GenAiClient,
5 agents::error::AgentError,
6 agents::task::Task,
7 agents::types::{AgentResponse, PyAgentResponse},
8};
9use potato_prompt::prompt::settings::ModelSettings;
10use potato_prompt::{
11 parse_response_to_json, prompt::parse_prompt, prompt::types::Message, Prompt, Role,
12};
13use potato_type::Provider;
14use potato_util::create_uuid7;
15use pyo3::{prelude::*, IntoPyObjectExt};
16use serde::{
17 de::{self, MapAccess, Visitor},
18 ser::SerializeStruct,
19 Deserializer, Serializer,
20};
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use std::collections::HashMap;
24use std::sync::Arc;
25use std::sync::RwLock;
26use tracing::{debug, error, instrument, warn};
27
28#[derive(Debug, Clone, PartialEq)]
29pub struct Agent {
30 pub id: String,
31
32 client: GenAiClient,
33
34 pub system_instruction: Vec<Message>,
35}
36
37impl Agent {
39 pub fn new(
40 provider: Provider,
41 system_instruction: Option<Vec<Message>>,
42 ) -> Result<Self, AgentError> {
43 let client = match provider {
44 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
45 Provider::Gemini => GenAiClient::Gemini(GeminiClient::new(None, None, None)?),
46 _ => {
47 let msg = "No provider specified in ModelSettings";
48 error!("{}", msg);
49 return Err(AgentError::UndefinedError(msg.to_string()));
50 } };
52
53 let system_instruction = system_instruction.unwrap_or_default();
54
55 Ok(Self {
56 client,
57 id: create_uuid7(),
58 system_instruction,
59 })
60 }
61
62 #[instrument(skip_all)]
63 fn append_task_with_message_context(
64 &self,
65 task: &mut Task,
66 context_messages: &HashMap<String, Vec<Message>>,
67 ) {
68 debug!(task.id = %task.id, task.dependencies = ?task.dependencies, context_messages = ?context_messages, "Appending messages");
70 if !task.dependencies.is_empty() {
71 for dep in &task.dependencies {
72 if let Some(messages) = context_messages.get(dep) {
73 for message in messages {
74 task.prompt.message.insert(0, message.clone());
76 }
77 }
78 }
79 }
80 }
81
82 #[instrument(skip_all)]
92 fn bind_context(
93 &self,
94 prompt: &mut Prompt,
95 parameter_context: &Value,
96 global_context: &Option<Value>,
97 ) -> Result<(), AgentError> {
98 if !prompt.parameters.is_empty() {
100 for param in &prompt.parameters {
101 if let Some(value) = parameter_context.get(param) {
103 for message in &mut prompt.message {
104 if message.role == "user" {
105 debug!("Binding parameter: {} with value: {}", param, value);
106 message.bind_mut(param, &value.to_string())?;
107 }
108 }
109 }
110
111 if let Some(global_value) = global_context {
113 if let Some(value) = global_value.get(param) {
114 for message in &mut prompt.message {
115 if message.role == "user" {
116 debug!("Binding global parameter: {} with value: {}", param, value);
117 message.bind_mut(param, &value.to_string())?;
118 }
119 }
120 }
121 }
122 }
123 }
124 Ok(())
125 }
126
127 fn append_system_instructions(&self, prompt: &mut Prompt) {
128 if !self.system_instruction.is_empty() {
129 let mut combined_messages = self.system_instruction.clone();
130 combined_messages.extend(prompt.system_instruction.clone());
131 prompt.system_instruction = combined_messages;
132 }
133 }
134 pub async fn execute_task(&self, task: &Task) -> Result<AgentResponse, AgentError> {
135 debug!("Executing task: {}, count: {}", task.id, task.retry_count);
137 let mut prompt = task.prompt.clone();
138 self.append_system_instructions(&mut prompt);
139
140 let chat_response = self.client.execute(&prompt).await?;
142
143 Ok(AgentResponse::new(task.id.clone(), chat_response))
144 }
145
146 #[instrument(skip_all)]
147 pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<AgentResponse, AgentError> {
148 debug!("Executing prompt");
150 let mut prompt = prompt.clone();
151 self.append_system_instructions(&mut prompt);
152
153 let chat_response = self.client.execute(&prompt).await?;
155
156 Ok(AgentResponse::new(chat_response.id(), chat_response))
157 }
158
159 pub async fn execute_task_with_context(
160 &self,
161 task: &Arc<RwLock<Task>>,
162 context_messages: HashMap<String, Vec<Message>>,
163 parameter_context: Value,
164 global_context: Option<Value>,
165 ) -> Result<AgentResponse, AgentError> {
166 let (prompt, task_id) = {
168 let mut task = task.write().unwrap();
169 self.append_task_with_message_context(&mut task, &context_messages);
170 self.bind_context(&mut task.prompt, ¶meter_context, &global_context)?;
171
172 self.append_system_instructions(&mut task.prompt);
173 (task.prompt.clone(), task.id.clone())
174 };
175
176 let chat_response = self.client.execute(&prompt).await?;
178
179 Ok(AgentResponse::new(task_id, chat_response))
180 }
181
182 pub fn provider(&self) -> &Provider {
183 self.client.provider()
184 }
185
186 pub fn from_model_settings(model_settings: &ModelSettings) -> Result<Self, AgentError> {
187 let provider = model_settings.provider();
188 let client = match provider {
189 Provider::OpenAI => GenAiClient::OpenAI(OpenAIClient::new(None, None, None)?),
190 Provider::Gemini => GenAiClient::Gemini(GeminiClient::new(None, None, None)?),
191 Provider::Undefined => {
192 let msg = "No provider specified in ModelSettings";
193 error!("{}", msg);
194 return Err(AgentError::UndefinedError(msg.to_string()));
195 }
196 };
197
198 Ok(Self {
199 client,
200 id: create_uuid7(),
201 system_instruction: Vec::new(),
202 })
203 }
204}
205
206impl Serialize for Agent {
207 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
208 where
209 S: Serializer,
210 {
211 let mut state = serializer.serialize_struct("Agent", 3)?;
212 state.serialize_field("id", &self.id)?;
213 state.serialize_field("provider", &self.client.provider())?;
214 state.serialize_field("system_instruction", &self.system_instruction)?;
215 state.end()
216 }
217}
218
219impl<'de> Deserialize<'de> for Agent {
221 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222 where
223 D: Deserializer<'de>,
224 {
225 #[derive(Deserialize)]
226 #[serde(field_identifier, rename_all = "snake_case")]
227 enum Field {
228 Id,
229 Provider,
230 SystemInstruction,
231 }
232
233 struct AgentVisitor;
234
235 impl<'de> Visitor<'de> for AgentVisitor {
236 type Value = Agent;
237
238 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
239 formatter.write_str("struct Agent")
240 }
241
242 fn visit_map<V>(self, mut map: V) -> Result<Agent, V::Error>
243 where
244 V: MapAccess<'de>,
245 {
246 let mut id = None;
247 let mut provider = None;
248 let mut system_instruction = None;
249
250 while let Some(key) = map.next_key()? {
251 match key {
252 Field::Id => {
253 id = Some(map.next_value()?);
254 }
255 Field::Provider => {
256 provider = Some(map.next_value()?);
257 }
258 Field::SystemInstruction => {
259 system_instruction = Some(map.next_value()?);
260 }
261 }
262 }
263
264 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
265 let provider = provider.ok_or_else(|| de::Error::missing_field("provider"))?;
266 let system_instruction = system_instruction
267 .ok_or_else(|| de::Error::missing_field("system_instruction"))?;
268
269 let client = match provider {
271 Provider::OpenAI => {
272 GenAiClient::OpenAI(OpenAIClient::new(None, None, None).map_err(|e| {
273 de::Error::custom(format!("Failed to initialize OpenAIClient: {e}"))
274 })?)
275 }
276 Provider::Gemini => {
277 GenAiClient::Gemini(GeminiClient::new(None, None, None).map_err(|e| {
278 de::Error::custom(format!("Failed to initialize GeminiClient: {e}"))
279 })?)
280 }
281
282 Provider::Undefined => {
283 let msg = "No provider specified in ModelSettings";
284 error!("{}", msg);
285 return Err(de::Error::custom(msg));
286 }
287 };
288
289 Ok(Agent {
290 id,
291 client,
292 system_instruction,
293 })
294 }
295 }
296
297 const FIELDS: &[&str] = &["id", "provider", "system_instruction"];
298 deserializer.deserialize_struct("Agent", FIELDS, AgentVisitor)
299 }
300}
301
302#[pyclass(name = "Agent")]
303#[derive(Debug, Clone)]
304pub struct PyAgent {
305 pub agent: Arc<Agent>,
306 pub runtime: Arc<tokio::runtime::Runtime>,
307}
308
309#[pymethods]
310impl PyAgent {
311 #[new]
312 #[pyo3(signature = (provider, system_instruction = None))]
313 pub fn new(
320 provider: &Bound<'_, PyAny>,
321 system_instruction: Option<&Bound<'_, PyAny>>,
322 ) -> Result<Self, AgentError> {
323 let provider = Provider::extract_provider(provider)?;
324
325 let system_instruction = if let Some(system_instruction) = system_instruction {
326 Some(
327 parse_prompt(system_instruction)?
328 .into_iter()
329 .map(|mut msg| {
330 msg.role = Role::Developer.to_string();
331 msg
332 })
333 .collect::<Vec<Message>>(),
334 )
335 } else {
336 None
337 };
338
339 let agent = Agent::new(provider, system_instruction)?;
340
341 Ok(Self {
342 agent: Arc::new(agent),
343 runtime: Arc::new(
344 tokio::runtime::Runtime::new()
345 .map_err(|e| AgentError::RuntimeError(e.to_string()))?,
346 ),
347 })
348 }
349
350 #[pyo3(signature = (task, output_type=None, model=None))]
351 pub fn execute_task(
352 &self,
353 py: Python<'_>,
354 task: &mut Task,
355 output_type: Option<Bound<'_, PyAny>>,
356 model: Option<&str>,
357 ) -> Result<PyAgentResponse, AgentError> {
358 debug!("Executing task");
360
361 if let Some(output_type) = &output_type {
363 match parse_response_to_json(py, output_type) {
364 Ok((response_type, response_format)) => {
365 task.prompt.response_type = response_type;
366 task.prompt.response_json_schema = response_format;
367 }
368 Err(_) => {
369 return Err(AgentError::InvalidOutputType(output_type.to_string()));
370 }
371 }
372 }
373
374 if let Some(model) = model {
376 task.prompt.model = model.to_string();
377 }
378
379 if task.prompt.provider != *self.agent.provider() {
381 return Err(AgentError::ProviderMismatch(
382 task.prompt.provider.to_string(),
383 self.agent.provider().as_str().to_string(),
384 ));
385 }
386
387 debug!(
388 "Task prompt model identifier: {}",
389 task.prompt.model_identifier()
390 );
391
392 let chat_response = self
393 .runtime
394 .block_on(async { self.agent.execute_task(task).await })?;
395
396 debug!("Task executed successfully");
397 let output = output_type.as_ref().map(|obj| obj.clone().unbind());
398 let response = PyAgentResponse::new(chat_response, output);
399
400 Ok(response)
401 }
402
403 #[pyo3(signature = (prompt, output_type=None, model=None))]
404 pub fn execute_prompt(
405 &self,
406 py: Python<'_>,
407 prompt: &mut Prompt,
408 output_type: Option<Bound<'_, PyAny>>,
409 model: Option<&str>,
410 ) -> Result<PyAgentResponse, AgentError> {
411 debug!("Executing task");
413 if let Some(output_type) = &output_type {
415 match parse_response_to_json(py, output_type) {
416 Ok((response_type, response_format)) => {
417 prompt.response_type = response_type;
418 prompt.response_json_schema = response_format;
419 }
420 Err(_) => {
421 return Err(AgentError::InvalidOutputType(output_type.to_string()));
422 }
423 }
424 }
425
426 if let Some(model) = model {
428 prompt.model = model.to_string();
429 }
430
431 if prompt.provider != *self.agent.provider() {
433 return Err(AgentError::ProviderMismatch(
434 prompt.provider.to_string(),
435 self.agent.provider().as_str().to_string(),
436 ));
437 }
438
439 let chat_response = self
440 .runtime
441 .block_on(async { self.agent.execute_prompt(prompt).await })?;
442
443 debug!("Task executed successfully");
444 let output = output_type.as_ref().map(|obj| obj.clone().unbind());
445 let response = PyAgentResponse::new(chat_response, output);
446
447 Ok(response)
448 }
449
450 #[getter]
451 pub fn system_instruction<'py>(
452 &self,
453 py: Python<'py>,
454 ) -> Result<Bound<'py, PyAny>, AgentError> {
455 Ok(self
456 .agent
457 .system_instruction
458 .clone()
459 .into_bound_py_any(py)?)
460 }
461
462 #[getter]
463 pub fn id(&self) -> &str {
464 self.agent.id.as_str()
465 }
466}