use super::workflow_context::PyWorkflowContext;
use async_trait::async_trait;
use parking_lot::Mutex;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use std::sync::Arc;
use std::time::Duration;
#[pyclass(name = "TaskHandle")]
pub struct PyTaskHandle {
inner: Option<crate::TaskHandle>,
}
#[pymethods]
impl PyTaskHandle {
#[pyo3(signature = (condition, poll_interval_ms = 1000))]
pub fn defer_until(
&mut self,
py: Python,
condition: PyObject,
poll_interval_ms: u64,
) -> PyResult<()> {
let handle = self
.inner
.as_mut()
.ok_or_else(|| PyValueError::new_err("TaskHandle has already been consumed"))?;
let poll_interval = Duration::from_millis(poll_interval_ms);
let rt_handle = tokio::runtime::Handle::current();
py.allow_threads(|| {
rt_handle.block_on(async {
handle
.defer_until(
move || {
let result = Python::with_gil(|py| match condition.call0(py) {
Ok(r) => r.extract::<bool>(py).unwrap_or(false),
Err(e) => {
eprintln!("[cloaca] defer_until condition error: {}", e);
false
}
});
async move { result }
},
poll_interval,
)
.await
})
})
.map_err(|e| PyValueError::new_err(format!("defer_until failed: {}", e)))
}
pub fn is_slot_held(&self) -> PyResult<bool> {
let handle = self
.inner
.as_ref()
.ok_or_else(|| PyValueError::new_err("TaskHandle has already been consumed"))?;
Ok(handle.is_slot_held())
}
}
#[derive(Clone)]
pub struct WorkflowBuilderRef {
pub context: PyWorkflowContext,
}
static WORKFLOW_CONTEXT_STACK: Mutex<Vec<WorkflowBuilderRef>> = Mutex::new(Vec::new());
pub fn push_workflow_context(context: PyWorkflowContext) {
WORKFLOW_CONTEXT_STACK
.lock()
.push(WorkflowBuilderRef { context });
}
pub fn pop_workflow_context() -> Option<WorkflowBuilderRef> {
WORKFLOW_CONTEXT_STACK.lock().pop()
}
pub fn current_workflow_context() -> PyResult<PyWorkflowContext> {
let stack = WORKFLOW_CONTEXT_STACK.lock();
stack.last().map(|ref_| ref_.context.clone()).ok_or_else(|| {
PyValueError::new_err(
"No workflow context available. Tasks must be defined within a WorkflowBuilder context manager."
)
})
}
pub struct PythonTaskWrapper {
id: String,
dependencies: Vec<crate::TaskNamespace>,
retry_policy: crate::retry::RetryPolicy,
python_function: PyObject,
on_success_callback: Option<PyObject>,
on_failure_callback: Option<PyObject>,
requires_handle: bool,
}
unsafe impl Send for PythonTaskWrapper {}
unsafe impl Sync for PythonTaskWrapper {}
#[async_trait]
impl crate::Task for PythonTaskWrapper {
async fn execute(
&self,
context: crate::Context<serde_json::Value>,
) -> Result<crate::Context<serde_json::Value>, crate::TaskError> {
use super::context::PyContext;
let function = Python::with_gil(|py| self.python_function.clone_ref(py));
let on_success =
Python::with_gil(|py| self.on_success_callback.as_ref().map(|f| f.clone_ref(py)));
let on_failure =
Python::with_gil(|py| self.on_failure_callback.as_ref().map(|f| f.clone_ref(py)));
let task_id = self.id.clone();
let task_id_for_error = self.id.clone();
let needs_handle = self.requires_handle;
let task_handle = if needs_handle {
Some(crate::take_task_handle())
} else {
None
};
let (context_result, returned_handle) = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let original_data = context.data().clone();
let py_context = PyContext::from_rust_context(context);
let (result, recovered_handle) = if let Some(handle) = task_handle {
let py_handle = Py::new(
py,
PyTaskHandle {
inner: Some(handle),
},
)
.map_err(|e| crate::TaskError::ExecutionFailed {
message: format!("Failed to create PyTaskHandle: {}", e),
task_id: task_id.clone(),
timestamp: chrono::Utc::now(),
})?;
let call_result =
function.call1(py, (py_context.clone(), py_handle.clone_ref(py)));
let recovered = py_handle.borrow_mut(py).inner.take();
(call_result, recovered)
} else {
let call_result = function.call1(py, (py_context.clone(),));
(call_result, None)
};
match result {
Ok(returned) => {
let final_context = if returned.is_none(py) {
let mut new_context = crate::Context::new();
for (key, value) in original_data.iter() {
new_context.insert(key.clone(), value.clone()).unwrap();
}
new_context
} else {
let returned_context: PyContext =
returned.extract(py).map_err(|e| {
crate::TaskError::ExecutionFailed {
message: format!("Python task execution failed: {}", e),
task_id: task_id.clone(),
timestamp: chrono::Utc::now(),
}
})?;
returned_context.into_inner()
};
if let Some(callback) = on_success {
let cloned_data = final_context.data().clone();
let mut callback_ctx = crate::Context::new();
for (key, value) in cloned_data.iter() {
callback_ctx.insert(key.clone(), value.clone()).ok();
}
let callback_context = PyContext::from_rust_context(callback_ctx);
if let Err(e) = callback.call1(py, (&task_id, callback_context)) {
eprintln!(
"[cloaca] on_success callback failed for task '{}': {}",
task_id, e
);
}
}
Ok((final_context, recovered_handle))
}
Err(e) => {
let error_message = format!("Python task execution failed: {}", e);
if let Some(callback) = on_failure {
if let Err(callback_err) =
callback.call1(py, (&task_id, &error_message, py_context))
{
eprintln!(
"[cloaca] on_failure callback failed for task '{}': {}",
task_id, callback_err
);
}
}
Err(crate::TaskError::ExecutionFailed {
message: error_message,
task_id: task_id.clone(),
timestamp: chrono::Utc::now(),
})
}
}
})
})
.await
.map_err(|e| crate::TaskError::ExecutionFailed {
message: format!("Task execution panicked: {}", e),
task_id: task_id_for_error.clone(),
timestamp: chrono::Utc::now(),
})??;
if let Some(handle) = returned_handle {
crate::return_task_handle(handle);
}
Ok(context_result)
}
fn id(&self) -> &str {
&self.id
}
fn dependencies(&self) -> &[crate::TaskNamespace] {
&self.dependencies
}
fn retry_policy(&self) -> crate::retry::RetryPolicy {
self.retry_policy.clone()
}
fn requires_handle(&self) -> bool {
self.requires_handle
}
fn checkpoint(
&self,
_context: &crate::Context<serde_json::Value>,
) -> Result<(), crate::CheckpointError> {
Ok(())
}
fn trigger_rules(&self) -> serde_json::Value {
serde_json::json!({"type": "Always"})
}
fn code_fingerprint(&self) -> Option<String> {
None
}
}
fn build_retry_policy(
retry_attempts: Option<usize>,
retry_backoff: Option<String>,
retry_delay_ms: Option<u64>,
retry_max_delay_ms: Option<u64>,
retry_condition: Option<String>,
retry_jitter: Option<bool>,
) -> crate::retry::RetryPolicy {
use crate::retry::*;
use std::time::Duration;
let mut builder = RetryPolicy::builder();
if let Some(attempts) = retry_attempts {
builder = builder.max_attempts(attempts as i32);
}
if let Some(backoff) = retry_backoff {
let strategy = match backoff.as_str() {
"fixed" => BackoffStrategy::Fixed,
"linear" => BackoffStrategy::Linear { multiplier: 1.0 },
"exponential" => BackoffStrategy::Exponential {
base: 2.0,
multiplier: 1.0,
},
_ => BackoffStrategy::Fixed,
};
builder = builder.backoff_strategy(strategy);
}
if let Some(delay) = retry_delay_ms {
builder = builder.initial_delay(Duration::from_millis(delay));
}
if let Some(max_delay) = retry_max_delay_ms {
builder = builder.max_delay(Duration::from_millis(max_delay));
}
if let Some(condition) = retry_condition {
let retry_cond = match condition.as_str() {
"never" => RetryCondition::Never,
"transient" => RetryCondition::TransientOnly,
"all" => RetryCondition::AllErrors,
_ => RetryCondition::AllErrors,
};
builder = builder.retry_condition(retry_cond);
}
if let Some(jitter) = retry_jitter {
builder = builder.with_jitter(jitter);
}
builder.build()
}
#[pyclass]
pub struct TaskDecorator {
id: Option<String>,
dependencies: Vec<PyObject>,
retry_policy: crate::retry::RetryPolicy,
on_success: Option<PyObject>,
on_failure: Option<PyObject>,
}
#[pymethods]
impl TaskDecorator {
pub fn __call__(&self, py: Python, func: PyObject) -> PyResult<PyObject> {
let context = current_workflow_context()?;
let task_id = if let Some(id) = &self.id {
id.clone()
} else {
func.getattr(py, "__name__")?.extract::<String>(py)?
};
let has_handle = {
let code = func.getattr(py, "__code__")?;
let argcount: usize = code.getattr(py, "co_argcount")?.extract(py)?;
if argcount >= 2 {
let varnames: Vec<String> = code.getattr(py, "co_varnames")?.extract(py)?;
matches!(
varnames.get(1).map(|s| s.as_str()),
Some("handle" | "task_handle")
)
} else {
false
}
};
let deps = match self.convert_dependencies_to_namespaces(py, &context) {
Ok(deps) => deps,
Err(e) => {
eprintln!("Error converting dependencies: {}", e);
return Err(e);
}
};
let policy = self.retry_policy.clone();
let function = func.clone_ref(py);
let on_success_cb = self.on_success.as_ref().map(|f| f.clone_ref(py));
let on_failure_cb = self.on_failure.as_ref().map(|f| f.clone_ref(py));
let shared_function = Arc::new(function);
let shared_on_success = on_success_cb.map(Arc::new);
let shared_on_failure = on_failure_cb.map(Arc::new);
let (tenant_id, package_name, workflow_id) = context.as_components();
let namespace = crate::TaskNamespace::new(tenant_id, package_name, workflow_id, &task_id);
py.allow_threads(|| {
crate::register_task_constructor(namespace.clone(), {
let task_id_clone = task_id.clone();
let deps_clone = deps.clone();
let policy_clone = policy.clone();
let function_arc = shared_function.clone();
let on_success_arc = shared_on_success.clone();
let on_failure_arc = shared_on_failure.clone();
move || {
let function_clone = Python::with_gil(|py| function_arc.clone_ref(py));
let on_success_clone =
Python::with_gil(|py| on_success_arc.as_ref().map(|f| f.clone_ref(py)));
let on_failure_clone =
Python::with_gil(|py| on_failure_arc.as_ref().map(|f| f.clone_ref(py)));
Arc::new(PythonTaskWrapper {
id: task_id_clone.clone(),
dependencies: deps_clone.clone(),
retry_policy: policy_clone.clone(),
python_function: function_clone,
on_success_callback: on_success_clone,
on_failure_callback: on_failure_clone,
requires_handle: has_handle,
}) as Arc<dyn crate::Task>
}
});
});
Ok(func)
}
}
impl TaskDecorator {
fn convert_dependencies_to_namespaces(
&self,
py: Python,
context: &PyWorkflowContext,
) -> PyResult<Vec<crate::TaskNamespace>> {
let mut namespace_deps = Vec::new();
for (i, dep) in self.dependencies.iter().enumerate() {
let task_name = if let Ok(string_dep) = dep.extract::<String>(py) {
string_dep
} else {
match dep.bind(py).hasattr("__name__") {
Ok(true) => match dep.getattr(py, "__name__") {
Ok(name_obj) => match name_obj.extract::<String>(py) {
Ok(func_name) => func_name,
Err(e) => {
return Err(PyValueError::new_err(format!(
"Dependency {} has __name__ but it's not a string: {}",
i, e
)));
}
},
Err(e) => {
return Err(PyValueError::new_err(format!(
"Failed to get __name__ from dependency {}: {}",
i, e
)));
}
},
Ok(false) => {
return Err(PyValueError::new_err(format!(
"Dependency {} must be either a string or a function object with __name__ attribute",
i
)));
}
Err(e) => {
return Err(PyValueError::new_err(format!(
"Failed to check if dependency {} has __name__ attribute: {}",
i, e
)));
}
}
};
let (tenant_id, package_name, workflow_id) = context.as_components();
namespace_deps.push(crate::TaskNamespace::new(
tenant_id,
package_name,
workflow_id,
&task_name,
));
}
Ok(namespace_deps)
}
}
#[pyfunction]
#[pyo3(signature = (
*,
id = None,
dependencies = None,
retry_attempts = None,
retry_backoff = None,
retry_delay_ms = None,
retry_max_delay_ms = None,
retry_condition = None,
retry_jitter = None,
on_success = None,
on_failure = None
))]
#[allow(clippy::too_many_arguments)]
pub fn task(
id: Option<String>,
dependencies: Option<Vec<PyObject>>,
retry_attempts: Option<usize>,
retry_backoff: Option<String>,
retry_delay_ms: Option<u64>,
retry_max_delay_ms: Option<u64>,
retry_condition: Option<String>,
retry_jitter: Option<bool>,
on_success: Option<PyObject>,
on_failure: Option<PyObject>,
) -> PyResult<TaskDecorator> {
let retry_policy = build_retry_policy(
retry_attempts,
retry_backoff,
retry_delay_ms,
retry_max_delay_ms,
retry_condition,
retry_jitter,
);
Ok(TaskDecorator {
id,
dependencies: dependencies.unwrap_or_default(),
retry_policy,
on_success,
on_failure,
})
}