use crate::python::context::PyContext;
use crate::trigger::{Trigger, TriggerError, TriggerResult};
use crate::Context;
use async_trait::async_trait;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use serde_json::Value;
use std::sync::Arc;
use std::time::Duration;
#[pyclass(name = "TriggerResult")]
pub struct PyTriggerResult {
is_fire: bool,
data: Option<std::collections::HashMap<String, Value>>,
}
#[pymethods]
impl PyTriggerResult {
#[staticmethod]
fn skip() -> Self {
PyTriggerResult {
is_fire: false,
data: None,
}
}
#[staticmethod]
#[pyo3(signature = (context=None))]
fn fire(context: Option<&PyContext>) -> Self {
let data = context.map(|c| c.get_data_clone());
PyTriggerResult {
is_fire: true,
data,
}
}
fn __repr__(&self) -> String {
if !self.is_fire {
"TriggerResult.Skip".to_string()
} else if self.data.is_none() {
"TriggerResult.Fire(None)".to_string()
} else {
"TriggerResult.Fire(<context>)".to_string()
}
}
fn is_fire_result(&self) -> bool {
self.is_fire
}
fn is_skip_result(&self) -> bool {
!self.is_fire
}
}
impl PyTriggerResult {
pub fn into_rust(self) -> TriggerResult {
if !self.is_fire {
TriggerResult::Skip
} else {
let ctx = self.data.map(|d| {
let mut context = Context::new();
for (key, value) in d {
context.insert(key, value).ok();
}
context
});
TriggerResult::Fire(ctx)
}
}
}
pub struct PythonTriggerWrapper {
name: String,
workflow_name: String,
poll_interval: Duration,
allow_concurrent: bool,
python_function: PyObject,
}
impl std::fmt::Debug for PythonTriggerWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PythonTriggerWrapper")
.field("name", &self.name)
.field("workflow_name", &self.workflow_name)
.field("poll_interval", &self.poll_interval)
.field("allow_concurrent", &self.allow_concurrent)
.field("python_function", &"<PyObject>")
.finish()
}
}
unsafe impl Send for PythonTriggerWrapper {}
unsafe impl Sync for PythonTriggerWrapper {}
#[async_trait]
impl Trigger for PythonTriggerWrapper {
fn name(&self) -> &str {
&self.name
}
fn poll_interval(&self) -> Duration {
self.poll_interval
}
fn allow_concurrent(&self) -> bool {
self.allow_concurrent
}
async fn poll(&self) -> Result<TriggerResult, TriggerError> {
let function = Python::with_gil(|py| self.python_function.clone_ref(py));
let trigger_name = self.name.clone();
tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let result = function.call0(py).map_err(|e| TriggerError::PollError {
message: format!("Python trigger poll failed: {}", e),
})?;
let bound_result = result.bind(py);
let py_result_bound = bound_result.downcast::<PyTriggerResult>().map_err(|e| {
TriggerError::PollError {
message: format!(
"Trigger '{}' must return TriggerResult, got: {}",
trigger_name, e
),
}
})?;
let py_result = py_result_bound.borrow();
let rust_result = if !py_result.is_fire {
TriggerResult::Skip
} else {
let ctx = py_result.data.clone().map(|d| {
let mut context = Context::new();
for (key, value) in d {
context.insert(key, value).ok();
}
context
});
TriggerResult::Fire(ctx)
};
Ok(rust_result)
})
})
.await
.map_err(|e| TriggerError::PollError {
message: format!("Trigger execution panicked: {}", e),
})?
}
}
impl PythonTriggerWrapper {
pub fn workflow_name(&self) -> &str {
&self.workflow_name
}
}
fn parse_duration(s: &str) -> Result<Duration, String> {
let s = s.trim();
if let Some(stripped) = s.strip_suffix("ms") {
let num: u64 = stripped
.parse()
.map_err(|_| format!("Invalid duration: {}", s))?;
Ok(Duration::from_millis(num))
} else if let Some(stripped) = s.strip_suffix('s') {
let num: u64 = stripped
.parse()
.map_err(|_| format!("Invalid duration: {}", s))?;
Ok(Duration::from_secs(num))
} else if let Some(stripped) = s.strip_suffix('m') {
let num: u64 = stripped
.parse()
.map_err(|_| format!("Invalid duration: {}", s))?;
Ok(Duration::from_secs(num * 60))
} else {
let num: u64 = s.parse().map_err(|_| format!("Invalid duration: {}", s))?;
Ok(Duration::from_secs(num))
}
}
#[pyclass]
pub struct TriggerDecorator {
name: Option<String>,
workflow: String,
poll_interval: Duration,
allow_concurrent: bool,
}
#[pymethods]
impl TriggerDecorator {
pub fn __call__(&self, py: Python, func: PyObject) -> PyResult<PyObject> {
let trigger_name = if let Some(name) = &self.name {
name.clone()
} else {
func.getattr(py, "__name__")?.extract::<String>(py)?
};
let workflow_name = self.workflow.clone();
let poll_interval = self.poll_interval;
let allow_concurrent = self.allow_concurrent;
let name_for_constructor = trigger_name.clone();
let shared_function = Arc::new(func.clone_ref(py));
crate::trigger::register_trigger_constructor(trigger_name.clone(), move || {
let function_clone = Python::with_gil(|py| (*shared_function).clone_ref(py));
Arc::new(PythonTriggerWrapper {
name: name_for_constructor.clone(),
workflow_name: workflow_name.clone(),
poll_interval,
allow_concurrent,
python_function: function_clone,
}) as Arc<dyn Trigger>
});
tracing::info!(
trigger_name = %trigger_name,
workflow = %self.workflow,
poll_interval_ms = %self.poll_interval.as_millis(),
"Registered Python trigger"
);
Ok(func)
}
}
#[pyfunction]
#[pyo3(signature = (
workflow,
*,
name = None,
poll_interval = "5s",
allow_concurrent = false
))]
#[allow(dead_code)] pub fn trigger(
workflow: String,
name: Option<String>,
poll_interval: &str,
allow_concurrent: bool,
) -> PyResult<TriggerDecorator> {
let duration = parse_duration(poll_interval)
.map_err(|e| PyValueError::new_err(format!("Invalid poll_interval: {}", e)))?;
Ok(TriggerDecorator {
name,
workflow,
poll_interval: duration,
allow_concurrent,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_duration() {
assert_eq!(parse_duration("5s").unwrap(), Duration::from_secs(5));
assert_eq!(parse_duration("100ms").unwrap(), Duration::from_millis(100));
assert_eq!(parse_duration("2m").unwrap(), Duration::from_secs(120));
assert_eq!(parse_duration("10").unwrap(), Duration::from_secs(10));
}
}