use crate::error::{IoError, Result};
use chrono::{DateTime, Datelike, Duration, Utc};
use std::collections::{HashMap, HashSet};
use super::types::{Task, Workflow, WorkflowStatus};
pub mod engines {
use super::*;
pub trait WorkflowEngineAdapter: Send + Sync {
fn exportworkflow(&self, workflow: &Workflow) -> Result<String>;
fn importworkflow(&self, definition: &str) -> Result<Workflow>;
fn submit(&self, workflow: &Workflow) -> Result<String>;
fn get_status(&self, executionid: &str) -> Result<WorkflowStatus>;
fn cancel(&self, executionid: &str) -> Result<()>;
}
pub struct AirflowAdapter {
api_url: String,
auth_token: Option<String>,
}
impl AirflowAdapter {
pub fn new(api_url: impl Into<String>) -> Self {
Self {
api_url: api_url.into(),
auth_token: None,
}
}
pub fn with_auth(mut self, token: impl Into<String>) -> Self {
self.auth_token = Some(token.into());
self
}
}
impl WorkflowEngineAdapter for AirflowAdapter {
fn exportworkflow(&self, workflow: &Workflow) -> Result<String> {
let mut dag_code = String::new();
dag_code.push_str("from airflow import DAG\n");
dag_code.push_str("from airflow.operators.python import PythonOperator\n");
dag_code.push_str("from datetime import datetime, timedelta\n\n");
dag_code.push_str("dag = DAG(\n");
dag_code.push_str(&format!(" '{}',\n", workflow.id));
dag_code.push_str(&format!(
" description='{}',\n",
workflow.description.as_deref().unwrap_or("")
));
dag_code.push_str(" default_args={\n");
dag_code.push_str(" 'owner': 'scirs2',\n");
dag_code.push_str(" 'retries': 3,\n");
dag_code.push_str(" 'retry_delay': timedelta(minutes=5),\n");
dag_code.push_str(" },\n");
dag_code.push_str(" schedule_interval=None,\n");
dag_code.push_str(" start_date=datetime(2024, 1, 1),\n");
dag_code.push_str(" catchup=False,\n");
dag_code.push_str(")\n\n");
for task in &workflow.tasks {
dag_code.push_str(&format!("{} = PythonOperator(\n", task.id));
dag_code.push_str(&format!(" task_id='{}',\n", task.id));
dag_code.push_str(&format!(
" python_callable=lambda: print('{}'),\n",
task.name
));
dag_code.push_str(" dag=dag,\n");
dag_code.push_str(")\n\n");
}
for (task_id, deps) in &workflow.dependencies {
for dep in deps {
dag_code.push_str(&format!("{dep} >> {task_id}\n"));
}
}
Ok(dag_code)
}
fn importworkflow(&self, definition: &str) -> Result<Workflow> {
Err(IoError::UnsupportedFormat(
"Airflow import not yet implemented".to_string(),
))
}
fn submit(&self, workflow: &Workflow) -> Result<String> {
let executionid = format!("{}_run_{}", workflow.id, Utc::now().timestamp());
Ok(executionid)
}
fn get_status(&self, _executionid: &str) -> Result<WorkflowStatus> {
Ok(WorkflowStatus::Running)
}
fn cancel(&self, _executionid: &str) -> Result<()> {
Ok(())
}
}
pub struct PrefectAdapter {
api_url: String,
project_name: String,
}
impl PrefectAdapter {
pub fn new(api_url: impl Into<String>, project: impl Into<String>) -> Self {
Self {
api_url: api_url.into(),
project_name: project.into(),
}
}
}
impl WorkflowEngineAdapter for PrefectAdapter {
fn exportworkflow(&self, workflow: &Workflow) -> Result<String> {
let mut flow_code = String::new();
flow_code.push_str("from prefect import flow, task\n");
flow_code.push_str("from prefect.task_runners import SequentialTaskRunner\n\n");
for task in &workflow.tasks {
flow_code.push_str(&format!("@task(name='{}')\n", task.name));
flow_code.push_str(&format!("def {}():\n", task.id));
flow_code.push_str(&format!(" print('Executing {}')\n", task.name));
flow_code.push_str(" return True\n\n");
}
flow_code.push_str(&format!(
"@flow(name='{}', task_runner=SequentialTaskRunner())\n",
workflow.name
));
flow_code.push_str("def workflow_flow():\n");
let mut executed = HashSet::new();
let mut to_execute: Vec<_> = workflow.tasks.iter().map(|t| &t.id).collect();
while !to_execute.is_empty() {
let mut progress = false;
to_execute.retain(|task_id| {
let deps = workflow.dependencies.get(*task_id);
let can_execute =
deps.is_none_or(|d| d.iter().all(|dep| executed.contains(dep)));
if can_execute {
flow_code.push_str(&format!(" {task_id}()\n"));
executed.insert((*task_id).clone());
progress = true;
false
} else {
true
}
});
if !progress && !to_execute.is_empty() {
return Err(IoError::Other("Circular dependency detected".to_string()));
}
}
flow_code.push_str("\nif __name__ == '__main__':\n");
flow_code.push_str(" workflow_flow()\n");
Ok(flow_code)
}
fn importworkflow(&self, definition: &str) -> Result<Workflow> {
Err(IoError::UnsupportedFormat(
"Prefect import not yet implemented".to_string(),
))
}
fn submit(&self, workflow: &Workflow) -> Result<String> {
let flow_run_id = uuid::Uuid::new_v4().to_string();
Ok(flow_run_id)
}
fn get_status(&self, _executionid: &str) -> Result<WorkflowStatus> {
Ok(WorkflowStatus::Running)
}
fn cancel(&self, _executionid: &str) -> Result<()> {
Ok(())
}
}
pub struct DagsterAdapter {
repository_url: String,
}
impl WorkflowEngineAdapter for DagsterAdapter {
fn exportworkflow(&self, workflow: &Workflow) -> Result<String> {
let mut job_code = String::new();
job_code.push_str("from dagster import job, op, Config\n\n");
for task in &workflow.tasks {
job_code.push_str(&format!("@op(name='{}')\n", task.id));
job_code.push_str(&format!("def {}(context):\n", task.id));
job_code.push_str(&format!(
" context.log.info('Executing {}')\n",
task.name
));
job_code.push_str(" return True\n\n");
}
job_code.push_str(&format!("@job(name='{}')\n", workflow.id));
job_code.push_str("def workflow_job():\n");
for task in &workflow.tasks {
if let Some(deps) = workflow.dependencies.get(&task.id) {
let deps_str = deps.join(", ");
job_code.push_str(&format!(" {}({}())\n", task.id, deps_str));
} else {
job_code.push_str(&format!(" {}()\n", task.id));
}
}
Ok(job_code)
}
fn importworkflow(&self, definition: &str) -> Result<Workflow> {
Err(IoError::UnsupportedFormat(
"Dagster import not yet implemented".to_string(),
))
}
fn submit(&self, workflow: &Workflow) -> Result<String> {
Ok(uuid::Uuid::new_v4().to_string())
}
fn get_status(&self, _executionid: &str) -> Result<WorkflowStatus> {
Ok(WorkflowStatus::Running)
}
fn cancel(&self, _executionid: &str) -> Result<()> {
Ok(())
}
}
}
pub mod dynamic {
use super::*;
pub struct DynamicWorkflowGenerator {
templates: HashMap<String, WorkflowTemplate>,
}
#[derive(Debug, Clone)]
pub struct WorkflowTemplate {
pub baseworkflow: Workflow,
pub parameters: Vec<ParameterDef>,
pub generators: Vec<TaskGenerator>,
}
#[derive(Debug, Clone)]
pub struct ParameterDef {
pub name: String,
pub param_type: ParameterType,
pub required: bool,
pub default: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub enum ParameterType {
String,
Integer,
Float,
Boolean,
List(Box<ParameterType>),
Object,
}
#[derive(Debug, Clone)]
pub enum TaskGenerator {
ForEach {
parameter: String,
task_template: Task,
},
Conditional {
condition: String,
true_tasks: Vec<Task>,
false_tasks: Vec<Task>,
},
Repeat {
count_param: String,
task_template: Task,
},
}
impl Default for DynamicWorkflowGenerator {
fn default() -> Self {
Self::new()
}
}
impl DynamicWorkflowGenerator {
pub fn new() -> Self {
Self {
templates: HashMap::new(),
}
}
pub fn register_template(&mut self, name: impl Into<String>, template: WorkflowTemplate) {
self.templates.insert(name.into(), template);
}
pub fn generate(
&self,
template_name: &str,
params: HashMap<String, serde_json::Value>,
) -> Result<Workflow> {
let template = self.templates.get(template_name).ok_or_else(|| {
IoError::NotFound(format!("Template '{template_name}' not found"))
})?;
for param_def in &template.parameters {
if param_def.required && !params.contains_key(¶m_def.name) {
return Err(IoError::ValidationError(format!(
"Required parameter '{}' not provided",
param_def.name
)));
}
}
let mut workflow = template.baseworkflow.clone();
workflow.id = format!("{}_{}", workflow.id, Utc::now().timestamp());
for generator in &template.generators {
self.apply_generator(&mut workflow, generator, ¶ms)?;
}
Ok(workflow)
}
fn apply_generator(
&self,
workflow: &mut Workflow,
generator: &TaskGenerator,
params: &HashMap<String, serde_json::Value>,
) -> Result<()> {
match generator {
TaskGenerator::ForEach {
parameter,
task_template,
} => {
if let Some(serde_json::Value::Array(items)) = params.get(parameter) {
for (i, item) in items.iter().enumerate() {
let mut task = task_template.clone();
task.id = format!("{}_{}", task.id, i);
task.name = format!("{} [{}]", task.name, i);
if let serde_json::Value::Object(mut config) = task.config.clone() {
config.insert("item".to_string(), item.clone());
task.config = serde_json::Value::Object(config);
}
workflow.tasks.push(task);
}
}
}
TaskGenerator::Conditional {
condition,
true_tasks,
false_tasks,
} => {
let condition_result = self.evaluate_condition(condition, params)?;
if condition_result {
workflow.tasks.extend(true_tasks.iter().cloned());
} else {
workflow.tasks.extend(false_tasks.iter().cloned());
}
}
TaskGenerator::Repeat {
count_param,
task_template,
} => {
if let Some(serde_json::Value::Number(n)) = params.get(count_param) {
if let Some(count) = n.as_u64() {
for i in 0..count {
let mut task = task_template.clone();
task.id = format!("{}_{}", task.id, i);
task.name = format!("{} [{}]", task.name, i);
workflow.tasks.push(task);
}
}
}
}
}
Ok(())
}
fn evaluate_condition(
&self,
condition: &str,
params: &HashMap<String, serde_json::Value>,
) -> Result<bool> {
if let Some((param, value)) = condition.split_once("==") {
let param = param.trim();
let value = value.trim().trim_matches('"');
if let Some(serde_json::Value::String(s)) = params.get(param) {
return Ok(s == value);
}
}
Ok(false)
}
}
}