Skip to main content

potato_agent/agents/
task.rs

1use crate::AgentError;
2use crate::{AgentResponse, PyAgentResponse};
3use potato_type::prompt::Prompt;
4use potato_util::PyHelperFuncs;
5use pyo3::prelude::*;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::{error, instrument};
9#[pyclass(eq)]
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11pub enum TaskStatus {
12    Pending,
13    Running,
14    Completed,
15    Failed,
16}
17
18#[pyclass]
19#[derive(Debug, Serialize)]
20pub struct WorkflowTask {
21    #[pyo3(get)]
22    pub id: String,
23    #[pyo3(get, set)]
24    pub prompt: Prompt,
25    #[pyo3(get, set)]
26    pub dependencies: Vec<String>,
27    #[pyo3(get)]
28    pub status: TaskStatus,
29    #[pyo3(get)]
30    pub agent_id: String,
31    #[pyo3(get)]
32    pub max_retries: u32,
33    pub result: Option<PyAgentResponse>,
34    pub retry_count: u32,
35}
36
37#[pymethods]
38impl WorkflowTask {
39    #[getter]
40    pub fn result<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
41        if let Some(resp) = &self.result {
42            let output = resp.structured_output(py)?;
43            Ok(output)
44        } else {
45            Ok(py.None().bind(py).clone())
46        }
47    }
48
49    pub fn __str__(&self) -> String {
50        PyHelperFuncs::__str__(self)
51    }
52}
53
54#[pyclass]
55#[derive(Debug, Serialize, Deserialize, Clone)]
56pub struct Task {
57    #[pyo3(get)]
58    pub id: String,
59    #[pyo3(get, set)]
60    pub prompt: Prompt,
61    #[pyo3(get, set)]
62    pub dependencies: Vec<String>,
63    #[pyo3(get)]
64    pub status: TaskStatus,
65    #[pyo3(get, set)]
66    pub agent_id: String,
67    pub result: Option<AgentResponse>,
68    #[pyo3(get)]
69    pub max_retries: u32,
70    pub retry_count: u32,
71
72    #[serde(skip)]
73    output_validator: Option<jsonschema::Validator>,
74}
75
76impl PartialEq for Task {
77    fn eq(&self, other: &Self) -> bool {
78        self.id == other.id
79            && self.prompt == other.prompt
80            && self.dependencies == other.dependencies
81            && self.status == other.status
82            && self.agent_id == other.agent_id
83            && self.max_retries == other.max_retries
84            && self.retry_count == other.retry_count
85    }
86}
87
88#[pymethods]
89impl Task {
90    #[new]
91    #[pyo3(signature = (agent_id, prompt, id, dependencies = None, max_retries=None))]
92    pub fn new(
93        agent_id: &str,
94        prompt: Prompt,
95        id: &str,
96        dependencies: Option<Vec<String>>,
97        max_retries: Option<u32>,
98    ) -> Result<Self, AgentError> {
99        let validator = match prompt.response_json_schema() {
100            Some(schema) => {
101                let compiled_validator = jsonschema::validator_for(schema).map_err(|e| {
102                    error!(
103                        "Failed to compile JSON schema validator for task {}: {}",
104                        id, e
105                    );
106                    AgentError::ValidationError(format!(
107                        "Failed to compile JSON schema validator: {}",
108                        e
109                    ))
110                })?;
111                Some(compiled_validator)
112            }
113            None => None,
114        };
115
116        Ok(Self {
117            prompt,
118            dependencies: dependencies.unwrap_or_default(),
119            status: TaskStatus::Pending,
120            result: None,
121            id: id.to_string(),
122            agent_id: agent_id.to_string(),
123            max_retries: max_retries.unwrap_or(3),
124            retry_count: 0,
125            output_validator: validator,
126        })
127    }
128
129    pub fn add_dependency(&mut self, dependency: String) {
130        self.dependencies.push(dependency);
131    }
132
133    pub fn __str__(&self) -> String {
134        PyHelperFuncs::__str__(self)
135    }
136}
137
138impl Task {
139    pub fn increment_retry(&mut self) {
140        self.retry_count += 1;
141    }
142
143    pub fn set_status(&mut self, status: TaskStatus) {
144        self.status = status;
145    }
146
147    pub fn set_result(&mut self, result: AgentResponse) {
148        self.result = Some(result);
149    }
150
151    /// Helper to rebuild the validator when workflow is deserialized
152    pub fn rebuild_validator(&mut self) -> Result<(), AgentError> {
153        if let Some(schema) = self.prompt.response_json_schema() {
154            let compiled_validator = jsonschema::validator_for(schema).map_err(|e| {
155                error!(
156                    "Failed to compile JSON schema validator for task {}: {}",
157                    self.id, e
158                );
159                AgentError::ValidationError(format!(
160                    "Failed to compile JSON schema validator: {}",
161                    e
162                ))
163            })?;
164            self.output_validator = Some(compiled_validator);
165        } else {
166            self.output_validator = None;
167        }
168
169        Ok(())
170    }
171
172    /// Validate the output against the task's output schema, if defined.
173    /// Make come back to this later and change. Still unsure if this is the right place
174    #[instrument(skip_all)]
175    pub fn validate_output(&self, output: &Value) -> Result<(), AgentError> {
176        if let Some(validator) = &self.output_validator {
177            validator.validate(output).map_err(|e| {
178                error!(
179                    "Failed to validate output: {}, Received output: {:?}",
180                    e, output
181                );
182                AgentError::ValidationError(e.to_string())
183            })
184        } else {
185            Ok(())
186        }
187    }
188}