potato_agent/agents/
task.rs1use crate::AgentError;
2use crate::{AgentResponse, PyAgentResponse};
3use potato_type::prompt::Prompt;
4use potato_util::PyHelperFuncs;
5use pyo3::prelude::*;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::{error, instrument};
9#[pyclass(eq)]
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11pub enum TaskStatus {
12 Pending,
13 Running,
14 Completed,
15 Failed,
16}
17
18#[pyclass]
19#[derive(Debug, Serialize)]
20pub struct WorkflowTask {
21 #[pyo3(get)]
22 pub id: String,
23 #[pyo3(get, set)]
24 pub prompt: Prompt,
25 #[pyo3(get, set)]
26 pub dependencies: Vec<String>,
27 #[pyo3(get)]
28 pub status: TaskStatus,
29 #[pyo3(get)]
30 pub agent_id: String,
31 #[pyo3(get)]
32 pub max_retries: u32,
33 pub result: Option<PyAgentResponse>,
34 pub retry_count: u32,
35}
36
37#[pymethods]
38impl WorkflowTask {
39 #[getter]
40 pub fn result<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
41 if let Some(resp) = &self.result {
42 let output = resp.structured_output(py)?;
43 Ok(output)
44 } else {
45 Ok(py.None().bind(py).clone())
46 }
47 }
48
49 pub fn __str__(&self) -> String {
50 PyHelperFuncs::__str__(self)
51 }
52}
53
54#[pyclass]
55#[derive(Debug, Serialize, Deserialize, Clone)]
56pub struct Task {
57 #[pyo3(get)]
58 pub id: String,
59 #[pyo3(get, set)]
60 pub prompt: Prompt,
61 #[pyo3(get, set)]
62 pub dependencies: Vec<String>,
63 #[pyo3(get)]
64 pub status: TaskStatus,
65 #[pyo3(get, set)]
66 pub agent_id: String,
67 pub result: Option<AgentResponse>,
68 #[pyo3(get)]
69 pub max_retries: u32,
70 pub retry_count: u32,
71
72 #[serde(skip)]
73 output_validator: Option<jsonschema::Validator>,
74}
75
76impl PartialEq for Task {
77 fn eq(&self, other: &Self) -> bool {
78 self.id == other.id
79 && self.prompt == other.prompt
80 && self.dependencies == other.dependencies
81 && self.status == other.status
82 && self.agent_id == other.agent_id
83 && self.max_retries == other.max_retries
84 && self.retry_count == other.retry_count
85 }
86}
87
88#[pymethods]
89impl Task {
90 #[new]
91 #[pyo3(signature = (agent_id, prompt, id, dependencies = None, max_retries=None))]
92 pub fn new(
93 agent_id: &str,
94 prompt: Prompt,
95 id: &str,
96 dependencies: Option<Vec<String>>,
97 max_retries: Option<u32>,
98 ) -> Result<Self, AgentError> {
99 let validator = match prompt.response_json_schema() {
100 Some(schema) => {
101 let compiled_validator = jsonschema::validator_for(schema).map_err(|e| {
102 error!(
103 "Failed to compile JSON schema validator for task {}: {}",
104 id, e
105 );
106 AgentError::ValidationError(format!(
107 "Failed to compile JSON schema validator: {}",
108 e
109 ))
110 })?;
111 Some(compiled_validator)
112 }
113 None => None,
114 };
115
116 Ok(Self {
117 prompt,
118 dependencies: dependencies.unwrap_or_default(),
119 status: TaskStatus::Pending,
120 result: None,
121 id: id.to_string(),
122 agent_id: agent_id.to_string(),
123 max_retries: max_retries.unwrap_or(3),
124 retry_count: 0,
125 output_validator: validator,
126 })
127 }
128
129 pub fn add_dependency(&mut self, dependency: String) {
130 self.dependencies.push(dependency);
131 }
132
133 pub fn __str__(&self) -> String {
134 PyHelperFuncs::__str__(self)
135 }
136}
137
138impl Task {
139 pub fn increment_retry(&mut self) {
140 self.retry_count += 1;
141 }
142
143 pub fn set_status(&mut self, status: TaskStatus) {
144 self.status = status;
145 }
146
147 pub fn set_result(&mut self, result: AgentResponse) {
148 self.result = Some(result);
149 }
150
151 pub fn rebuild_validator(&mut self) -> Result<(), AgentError> {
153 if let Some(schema) = self.prompt.response_json_schema() {
154 let compiled_validator = jsonschema::validator_for(schema).map_err(|e| {
155 error!(
156 "Failed to compile JSON schema validator for task {}: {}",
157 self.id, e
158 );
159 AgentError::ValidationError(format!(
160 "Failed to compile JSON schema validator: {}",
161 e
162 ))
163 })?;
164 self.output_validator = Some(compiled_validator);
165 } else {
166 self.output_validator = None;
167 }
168
169 Ok(())
170 }
171
172 #[instrument(skip_all)]
175 pub fn validate_output(&self, output: &Value) -> Result<(), AgentError> {
176 if let Some(validator) = &self.output_validator {
177 validator.validate(output).map_err(|e| {
178 error!(
179 "Failed to validate output: {}, Received output: {:?}",
180 e, output
181 );
182 AgentError::ValidationError(e.to_string())
183 })
184 } else {
185 Ok(())
186 }
187 }
188}