1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::time::timeout;
11
12use crate::ir::ActivityIR;
13
14#[async_trait]
16pub trait Activity: Send + Sync {
17 async fn execute(&self, inputs: HashMap<String, serde_json::Value>) -> Result<HashMap<String, serde_json::Value>, ActivityError>;
19
20 fn name(&self) -> &str;
22
23 fn timeout(&self) -> Option<Duration> { None }
25
26 fn retry_policy(&self) -> Option<RetryPolicy> { None }
28}
29
30#[derive(Debug, thiserror::Error)]
32pub enum ActivityError {
33 #[error("Activity execution failed: {0}")]
34 ExecutionFailed(String),
35 #[error("Invalid input: {0}")]
36 InvalidInput(String),
37 #[error("Timeout exceeded")]
38 Timeout,
39 #[error("Activity not found: {0}")]
40 NotFound(String),
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct RetryPolicy {
46 pub initial_interval: Duration,
47 pub backoff_coefficient: f64,
48 pub maximum_interval: Option<Duration>,
49 pub maximum_attempts: u32,
50 pub non_retryable_errors: Vec<String>,
51}
52
53#[derive(Debug, Clone)]
55pub struct ActivityResult {
56 pub activity_name: String,
57 pub status: ActivityStatus,
58 pub outputs: Option<HashMap<String, serde_json::Value>>,
59 pub error: Option<String>,
60 pub execution_time: Duration,
61 pub attempt_count: u32,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum ActivityStatus {
67 Scheduled,
68 Started,
69 Completed,
70 Failed,
71 Cancelled,
72 TimedOut,
73}
74
75pub struct ActivityRegistry {
77 activities: tokio::sync::RwLock<HashMap<String, Arc<dyn Activity>>>,
78}
79
80impl ActivityRegistry {
81 pub fn new() -> Self {
82 Self {
83 activities: tokio::sync::RwLock::new(HashMap::new()),
84 }
85 }
86
87 pub async fn register(&self, activity: Arc<dyn Activity>) {
89 let mut activities = self.activities.write().await;
90 activities.insert(activity.name().to_string(), activity);
91 }
92
93 pub async fn get(&self, name: &str) -> Option<Arc<dyn Activity>> {
95 let activities = self.activities.read().await;
96 activities.get(name).cloned()
97 }
98
99 pub async fn execute(
101 &self,
102 name: &str,
103 inputs: HashMap<String, serde_json::Value>,
104 ) -> Result<ActivityResult, ActivityError> {
105 let start_time = std::time::Instant::now();
106 let activity = self.get(name).await
107 .ok_or(ActivityError::NotFound(name.to_string()))?;
108
109 let mut attempt_count = 0;
110 let retry_policy = activity.retry_policy();
111
112 if let Some(retry_policy) = retry_policy {
114 self.execute_with_retry(&*activity, inputs, retry_policy, start_time).await
115 } else {
116 self.execute_once(&*activity, inputs, start_time, 1).await
117 }
118 }
119
120 async fn execute_with_retry(
121 &self,
122 activity: &dyn Activity,
123 inputs: HashMap<String, serde_json::Value>,
124 retry_policy: RetryPolicy,
125 start_time: std::time::Instant,
126 ) -> Result<ActivityResult, ActivityError> {
127 let mut attempt_count = 0;
128 let mut current_interval = retry_policy.initial_interval;
129
130 loop {
131 attempt_count += 1;
132
133 match self.execute_once(activity, inputs.clone(), start_time, attempt_count).await {
134 Ok(result) => return Ok(result),
135 Err(e) => {
136 if retry_policy.non_retryable_errors.iter().any(|err| e.to_string().contains(err)) {
138 return Err(e);
139 }
140
141 if attempt_count >= retry_policy.maximum_attempts {
143 return Err(e);
144 }
145
146 tokio::time::sleep(current_interval).await;
148
149 current_interval = std::cmp::min(
151 current_interval.mul_f64(retry_policy.backoff_coefficient),
152 retry_policy.maximum_interval.unwrap_or(Duration::from_secs(300)),
153 );
154 }
155 }
156 }
157 }
158
159 async fn execute_once(
160 &self,
161 activity: &dyn Activity,
162 inputs: HashMap<String, serde_json::Value>,
163 start_time: std::time::Instant,
164 attempt_count: u32,
165 ) -> Result<ActivityResult, ActivityError> {
166 let result = if let Some(timeout_duration) = activity.timeout() {
168 match timeout(timeout_duration, activity.execute(inputs)).await {
169 Ok(result) => result,
170 Err(_) => return Err(ActivityError::Timeout),
171 }
172 } else {
173 activity.execute(inputs).await
174 };
175
176 let execution_time = start_time.elapsed();
177
178 match result {
179 Ok(outputs) => Ok(ActivityResult {
180 activity_name: activity.name().to_string(),
181 status: ActivityStatus::Completed,
182 outputs: Some(outputs),
183 error: None,
184 execution_time,
185 attempt_count,
186 }),
187 Err(e) => Ok(ActivityResult {
188 activity_name: activity.name().to_string(),
189 status: ActivityStatus::Failed,
190 outputs: None,
191 error: Some(e.to_string()),
192 execution_time,
193 attempt_count,
194 }),
195 }
196 }
197
198 pub async fn list_activities(&self) -> Vec<String> {
200 let activities = self.activities.read().await;
201 activities.keys().cloned().collect()
202 }
203}
204
205pub struct HttpActivity {
209 name: String,
210 url: String,
211 method: String,
212 headers: HashMap<String, String>,
213 timeout: Option<Duration>,
214}
215
216impl HttpActivity {
217 pub fn new(name: &str, url: &str, method: &str) -> Self {
218 Self {
219 name: name.to_string(),
220 url: url.to_string(),
221 method: method.to_string(),
222 headers: HashMap::new(),
223 timeout: Some(Duration::from_secs(30)),
224 }
225 }
226
227 pub fn with_header(mut self, key: &str, value: &str) -> Self {
228 self.headers.insert(key.to_string(), value.to_string());
229 self
230 }
231
232 pub fn with_timeout(mut self, timeout: Duration) -> Self {
233 self.timeout = Some(timeout);
234 self
235 }
236}
237
238#[async_trait]
239impl Activity for HttpActivity {
240 async fn execute(&self, inputs: HashMap<String, serde_json::Value>) -> Result<HashMap<String, serde_json::Value>, ActivityError> {
241 println!("Executing HTTP activity: {} {} -> {}", self.method, self.url, self.name);
244
245 let mut outputs = HashMap::new();
247 outputs.insert("status".to_string(), serde_json::json!(200));
248 outputs.insert("response".to_string(), serde_json::json!({"ok": true}));
249
250 Ok(outputs)
251 }
252
253 fn name(&self) -> &str {
254 &self.name
255 }
256
257 fn timeout(&self) -> Option<Duration> {
258 self.timeout
259 }
260}
261
262pub struct DatabaseActivity {
264 name: String,
265 query: String,
266 timeout: Option<Duration>,
267}
268
269impl DatabaseActivity {
270 pub fn new(name: &str, query: &str) -> Self {
271 Self {
272 name: name.to_string(),
273 query: query.to_string(),
274 timeout: Some(Duration::from_secs(30)),
275 }
276 }
277}
278
279#[async_trait]
280impl Activity for DatabaseActivity {
281 async fn execute(&self, inputs: HashMap<String, serde_json::Value>) -> Result<HashMap<String, serde_json::Value>, ActivityError> {
282 println!("Executing DB activity: {} -> {}", self.query, self.name);
284
285 let mut outputs = HashMap::new();
287 outputs.insert("rows_affected".to_string(), serde_json::json!(1));
288 outputs.insert("result".to_string(), serde_json::json!({"success": true}));
289
290 Ok(outputs)
291 }
292
293 fn name(&self) -> &str {
294 &self.name
295 }
296
297 fn timeout(&self) -> Option<Duration> {
298 self.timeout
299 }
300}
301
302pub struct FunctionActivity {
304 name: String,
305 function: Arc<dyn Fn(HashMap<String, serde_json::Value>) -> Result<HashMap<String, serde_json::Value>, ActivityError> + Send + Sync>,
306 timeout: Option<Duration>,
307}
308
309impl FunctionActivity {
310 pub fn new<F>(name: &str, function: F) -> Self
311 where
312 F: Fn(HashMap<String, serde_json::Value>) -> Result<HashMap<String, serde_json::Value>, ActivityError> + Send + Sync + 'static,
313 {
314 Self {
315 name: name.to_string(),
316 function: Arc::new(function),
317 timeout: None,
318 }
319 }
320
321 pub fn with_timeout(mut self, timeout: Duration) -> Self {
322 self.timeout = Some(timeout);
323 self
324 }
325}
326
327#[async_trait]
328impl Activity for FunctionActivity {
329 async fn execute(&self, inputs: HashMap<String, serde_json::Value>) -> Result<HashMap<String, serde_json::Value>, ActivityError> {
330 let function = Arc::clone(&self.function);
332 let inputs_clone = inputs.clone();
333 tokio::task::spawn_blocking(move || {
334 function(inputs_clone)
335 })
336 .await
337 .map_err(|e| ActivityError::ExecutionFailed(format!("Task join error: {}", e)))?
338 }
339
340 fn name(&self) -> &str {
341 &self.name
342 }
343
344 fn timeout(&self) -> Option<Duration> {
345 self.timeout
346 }
347}
348
349pub struct ActivityBuilder {
351 name: String,
352 activity_type: ActivityType,
353}
354
355#[derive(Debug)]
356pub enum ActivityType {
357 Http { url: String, method: String },
358 Database { query: String },
359 Function { function_name: String },
360}
361
362impl ActivityBuilder {
363 pub fn new(name: &str) -> Self {
364 Self {
365 name: name.to_string(),
366 activity_type: ActivityType::Http { url: String::new(), method: "GET".to_string() },
367 }
368 }
369
370 pub fn http(mut self, url: &str, method: &str) -> Self {
371 self.activity_type = ActivityType::Http {
372 url: url.to_string(),
373 method: method.to_string(),
374 };
375 self
376 }
377
378 pub fn database(mut self, query: &str) -> Self {
379 self.activity_type = ActivityType::Database {
380 query: query.to_string(),
381 };
382 self
383 }
384
385 pub fn function(mut self, function_name: &str) -> Self {
386 self.activity_type = ActivityType::Function {
387 function_name: function_name.to_string(),
388 };
389 self
390 }
391
392 pub fn build(self) -> Arc<dyn Activity> {
393 match self.activity_type {
394 ActivityType::Http { url, method } => Arc::new(HttpActivity::new(&self.name, &url, &method)),
395 ActivityType::Database { query } => Arc::new(DatabaseActivity::new(&self.name, &query)),
396 ActivityType::Function { function_name: _ } => {
397 panic!("Function activities not yet implemented")
399 }
400 }
401 }
402}
403
404pub mod prelude {
406 pub use super::{
407 Activity, ActivityRegistry, ActivityResult, ActivityStatus, ActivityError,
408 HttpActivity, DatabaseActivity, FunctionActivity, ActivityBuilder,
409 RetryPolicy,
410 };
411}