custom_function/
custom_function.rs

1use async_trait::async_trait;
2use dataflow_rs::{
3    engine::{
4        error::{DataflowError, Result},
5        message::{Change, Message},
6        AsyncFunctionHandler,
7    },
8    Engine, Workflow,
9};
10use serde_json::{json, Value};
11use std::collections::HashMap;
12
13/// Custom function that calculates statistics from numeric data
14pub struct StatisticsFunction;
15
16#[async_trait]
17impl AsyncFunctionHandler for StatisticsFunction {
18    async fn execute(&self, message: &mut Message, input: &Value) -> Result<(usize, Vec<Change>)> {
19        // Extract the data path to analyze
20        let data_path = input
21            .get("data_path")
22            .and_then(Value::as_str)
23            .unwrap_or("data.numbers");
24
25        // Extract the output path where to store results
26        let output_path = input
27            .get("output_path")
28            .and_then(Value::as_str)
29            .unwrap_or("data.statistics");
30
31        // Get the numbers from the specified path
32        let numbers = self.extract_numbers_from_path(message, data_path)?;
33
34        if numbers.is_empty() {
35            return Err(DataflowError::Validation(
36                "No numeric data found to analyze".to_string(),
37            ));
38        }
39
40        // Calculate statistics
41        let stats = self.calculate_statistics(&numbers);
42
43        // Store the results in the message
44        self.set_value_at_path(message, output_path, stats.clone())?;
45
46        // Return success with changes
47        Ok((
48            200,
49            vec![Change {
50                path: output_path.to_string(),
51                old_value: Value::Null,
52                new_value: stats,
53            }],
54        ))
55    }
56}
57
58impl Default for StatisticsFunction {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl StatisticsFunction {
65    pub fn new() -> Self {
66        Self
67    }
68
69    fn extract_numbers_from_path(&self, message: &Message, path: &str) -> Result<Vec<f64>> {
70        let parts: Vec<&str> = path.split('.').collect();
71        let mut current = if parts[0] == "data" {
72            &message.data
73        } else if parts[0] == "temp_data" {
74            &message.temp_data
75        } else if parts[0] == "metadata" {
76            &message.metadata
77        } else {
78            &message.data
79        };
80
81        // Navigate to the target location
82        for part in &parts[1..] {
83            current = current.get(part).unwrap_or(&Value::Null);
84        }
85
86        // Extract numbers from the value
87        match current {
88            Value::Array(arr) => {
89                let mut numbers = Vec::new();
90                for item in arr {
91                    if let Some(num) = item.as_f64() {
92                        numbers.push(num);
93                    } else if let Some(num) = item.as_i64() {
94                        numbers.push(num as f64);
95                    }
96                }
97                Ok(numbers)
98            }
99            Value::Number(num) => {
100                if let Some(f) = num.as_f64() {
101                    Ok(vec![f])
102                } else {
103                    Ok(vec![])
104                }
105            }
106            _ => Ok(vec![]),
107        }
108    }
109
110    fn calculate_statistics(&self, numbers: &[f64]) -> Value {
111        let count = numbers.len();
112        let sum: f64 = numbers.iter().sum();
113        let mean = sum / count as f64;
114
115        let mut sorted = numbers.to_vec();
116        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
117
118        let median = if count % 2 == 0 {
119            (sorted[count / 2 - 1] + sorted[count / 2]) / 2.0
120        } else {
121            sorted[count / 2]
122        };
123
124        let variance: f64 = numbers.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / count as f64;
125        let std_dev = variance.sqrt();
126
127        json!({
128            "count": count,
129            "sum": sum,
130            "mean": mean,
131            "median": median,
132            "min": sorted[0],
133            "max": sorted[count - 1],
134            "variance": variance,
135            "std_dev": std_dev
136        })
137    }
138
139    fn set_value_at_path(&self, message: &mut Message, path: &str, value: Value) -> Result<()> {
140        let parts: Vec<&str> = path.split('.').collect();
141        let target = if parts[0] == "data" {
142            &mut message.data
143        } else if parts[0] == "temp_data" {
144            &mut message.temp_data
145        } else if parts[0] == "metadata" {
146            &mut message.metadata
147        } else {
148            &mut message.data
149        };
150
151        // Navigate and set the value
152        let mut current = target;
153        for (i, part) in parts[1..].iter().enumerate() {
154            if i == parts.len() - 2 {
155                // Last part, set the value
156                if current.is_null() {
157                    *current = json!({});
158                }
159                if let Value::Object(map) = current {
160                    map.insert(part.to_string(), value.clone());
161                }
162                break;
163            } else {
164                // Navigate deeper
165                if current.is_null() {
166                    *current = json!({});
167                }
168                if let Value::Object(map) = current {
169                    if !map.contains_key(*part) {
170                        map.insert(part.to_string(), json!({}));
171                    }
172                    current = map.get_mut(*part).unwrap();
173                }
174            }
175        }
176
177        Ok(())
178    }
179}
180
181/// Custom function that enriches data with external information
182pub struct DataEnrichmentFunction {
183    enrichment_data: HashMap<String, Value>,
184}
185
186#[async_trait]
187impl AsyncFunctionHandler for DataEnrichmentFunction {
188    async fn execute(&self, message: &mut Message, input: &Value) -> Result<(usize, Vec<Change>)> {
189        // Extract lookup key and field
190        let lookup_field = input
191            .get("lookup_field")
192            .and_then(Value::as_str)
193            .ok_or_else(|| DataflowError::Validation("Missing lookup_field".to_string()))?;
194
195        let lookup_value = input
196            .get("lookup_value")
197            .and_then(Value::as_str)
198            .ok_or_else(|| DataflowError::Validation("Missing lookup_value".to_string()))?;
199
200        let output_path = input
201            .get("output_path")
202            .and_then(Value::as_str)
203            .unwrap_or("data.enrichment");
204
205        // Simulate async operation (e.g., database lookup, API call)
206        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
207
208        // Look up enrichment data
209        let enrichment = if let Some(data) = self.enrichment_data.get(lookup_value) {
210            data.clone()
211        } else {
212            json!({
213                "status": "not_found",
214                "message": format!("No enrichment data found for {}: {}", lookup_field, lookup_value)
215            })
216        };
217
218        // Store enrichment data
219        self.set_value_at_path(message, output_path, enrichment.clone())?;
220
221        Ok((
222            200,
223            vec![Change {
224                path: output_path.to_string(),
225                old_value: Value::Null,
226                new_value: enrichment,
227            }],
228        ))
229    }
230}
231
232impl Default for DataEnrichmentFunction {
233    fn default() -> Self {
234        Self::new()
235    }
236}
237
238impl DataEnrichmentFunction {
239    pub fn new() -> Self {
240        let mut enrichment_data = HashMap::new();
241
242        // Sample enrichment data
243        enrichment_data.insert(
244            "user_123".to_string(),
245            json!({
246                "department": "Engineering",
247                "location": "San Francisco",
248                "manager": "Alice Johnson",
249                "start_date": "2022-01-15",
250                "security_clearance": "Level 2"
251            }),
252        );
253
254        enrichment_data.insert(
255            "user_456".to_string(),
256            json!({
257                "department": "Marketing",
258                "location": "New York",
259                "manager": "Bob Smith",
260                "start_date": "2021-06-01",
261                "security_clearance": "Level 1"
262            }),
263        );
264
265        Self { enrichment_data }
266    }
267
268    fn set_value_at_path(&self, message: &mut Message, path: &str, value: Value) -> Result<()> {
269        let parts: Vec<&str> = path.split('.').collect();
270        let target = if parts[0] == "data" {
271            &mut message.data
272        } else if parts[0] == "temp_data" {
273            &mut message.temp_data
274        } else if parts[0] == "metadata" {
275            &mut message.metadata
276        } else {
277            &mut message.data
278        };
279
280        let mut current = target;
281        for (i, part) in parts[1..].iter().enumerate() {
282            if i == parts.len() - 2 {
283                if current.is_null() {
284                    *current = json!({});
285                }
286                if let Value::Object(map) = current {
287                    map.insert(part.to_string(), value.clone());
288                }
289                break;
290            } else {
291                if current.is_null() {
292                    *current = json!({});
293                }
294                if let Value::Object(map) = current {
295                    if !map.contains_key(*part) {
296                        map.insert(part.to_string(), json!({}));
297                    }
298                    current = map.get_mut(*part).unwrap();
299                }
300            }
301        }
302        Ok(())
303    }
304}
305
306#[tokio::main]
307async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
308    println!("=== Custom Function Example ===\n");
309
310    // Create engine without built-in functions to demonstrate custom ones
311    let mut engine = Engine::new_empty();
312
313    // Register our custom functions
314    engine.register_task_function(
315        "statistics".to_string(),
316        Box::new(StatisticsFunction::new()),
317    );
318
319    engine.register_task_function(
320        "enrich_data".to_string(),
321        Box::new(DataEnrichmentFunction::new()),
322    );
323
324    // Also register built-in map function for data preparation
325    engine.register_task_function(
326        "map".to_string(),
327        Box::new(dataflow_rs::engine::functions::MapFunction::new()),
328    );
329
330    // Define a workflow that uses our custom functions
331    let workflow_json = r#"
332    {
333        "id": "custom_function_demo",
334        "name": "Custom Function Demo",
335        "description": "Demonstrates custom async functions in workflow",
336        "condition": { "==": [true, true] },
337        "tasks": [
338            {
339                "id": "prepare_data",
340                "name": "Prepare Data",
341                "description": "Extract and prepare data for analysis",
342                "function": {
343                    "name": "map",
344                    "input": {
345                        "mappings": [
346                            {
347                                "path": "data.numbers",
348                                "logic": { "var": "temp_data.measurements" }
349                            },
350                            {
351                                "path": "data.user_id",
352                                "logic": { "var": "temp_data.user_id" }
353                            }
354                        ]
355                    }
356                }
357            },
358            {
359                "id": "calculate_stats",
360                "name": "Calculate Statistics",
361                "description": "Calculate statistical measures from numeric data",
362                "function": {
363                    "name": "statistics",
364                    "input": {
365                        "data_path": "data.numbers",
366                        "output_path": "data.stats"
367                    }
368                }
369            },
370            {
371                "id": "enrich_user_data",
372                "name": "Enrich User Data",
373                "description": "Add additional user information",
374                "function": {
375                    "name": "enrich_data",
376                    "input": {
377                        "lookup_field": "user_id",
378                        "lookup_value": "user_123",
379                        "output_path": "data.user_info"
380                    }
381                }
382            }
383        ]
384    }
385    "#;
386
387    // Parse and add the workflow
388    let workflow = Workflow::from_json(workflow_json)?;
389    engine.add_workflow(&workflow);
390
391    // Create sample data
392    let sample_data = json!({
393        "measurements": [10.5, 15.2, 8.7, 22.1, 18.9, 12.3, 25.6, 14.8, 19.4, 16.7],
394        "user_id": "user_123",
395        "timestamp": "2024-01-15T10:30:00Z"
396    });
397
398    // Create and process message
399    let mut message = dataflow_rs::engine::message::Message::new(&json!({}));
400    message.temp_data = sample_data;
401    message.data = json!({});
402
403    println!("Processing message with custom functions...\n");
404
405    // Process the message through our custom workflow
406    match engine.process_message(&mut message).await {
407        Ok(_) => {
408            println!("āœ… Message processed successfully!\n");
409
410            println!("šŸ“Š Final Results:");
411            println!("{}\n", serde_json::to_string_pretty(&message.data)?);
412
413            println!("šŸ“‹ Audit Trail:");
414            for (i, audit) in message.audit_trail.iter().enumerate() {
415                println!(
416                    "{}. Task: {} (Status: {})",
417                    i + 1,
418                    audit.task_id,
419                    audit.status_code
420                );
421                println!("   Timestamp: {}", audit.timestamp);
422                println!("   Changes: {} field(s) modified", audit.changes.len());
423            }
424
425            if message.has_errors() {
426                println!("\nāš ļø  Errors encountered:");
427                for error in &message.errors {
428                    println!(
429                        "   - {}: {:?}",
430                        error.task_id.as_ref().unwrap_or(&"unknown".to_string()),
431                        error.error_message
432                    );
433                }
434            }
435        }
436        Err(e) => {
437            println!("āŒ Error processing message: {e:?}");
438        }
439    }
440
441    // Demonstrate another example with different data
442    let separator = "=".repeat(50);
443    println!("\n{separator}");
444    println!("=== Second Example with Different User ===\n");
445
446    let mut message2 = dataflow_rs::engine::message::Message::new(&json!({}));
447    message2.temp_data = json!({
448        "measurements": [5.1, 7.3, 9.8, 6.2, 8.5],
449        "user_id": "user_456",
450        "timestamp": "2024-01-15T11:00:00Z"
451    });
452    message2.data = json!({});
453
454    // Create a workflow for the second user
455    let workflow2_json = r#"
456    {
457        "id": "custom_function_demo_2",
458        "name": "Custom Function Demo 2",
459        "description": "Second demo with different user",
460        "condition": { "==": [true, true] },
461        "tasks": [
462            {
463                "id": "prepare_data",
464                "name": "Prepare Data",
465                "function": {
466                    "name": "map",
467                    "input": {
468                        "mappings": [
469                            {
470                                "path": "data.numbers",
471                                "logic": { "var": "temp_data.measurements" }
472                            },
473                            {
474                                "path": "data.user_id",
475                                "logic": { "var": "temp_data.user_id" }
476                            }
477                        ]
478                    }
479                }
480            },
481            {
482                "id": "calculate_stats",
483                "name": "Calculate Statistics",
484                "function": {
485                    "name": "statistics",
486                    "input": {
487                        "data_path": "data.numbers",
488                        "output_path": "data.analysis"
489                    }
490                }
491            },
492            {
493                "id": "enrich_user_data",
494                "name": "Enrich User Data",
495                "function": {
496                    "name": "enrich_data",
497                    "input": {
498                        "lookup_field": "user_id",
499                        "lookup_value": "user_456",
500                        "output_path": "data.employee_details"
501                    }
502                }
503            }
504        ]
505    }
506    "#;
507
508    let workflow2 = Workflow::from_json(workflow2_json)?;
509    engine.add_workflow(&workflow2);
510
511    match engine.process_message(&mut message2).await {
512        Ok(_) => {
513            println!("āœ… Second message processed successfully!\n");
514            println!("šŸ“Š Results for user_456:");
515            println!("{}", serde_json::to_string_pretty(&message2.data)?);
516        }
517        Err(e) => {
518            println!("āŒ Error processing second message: {e:?}");
519        }
520    }
521
522    println!("\nšŸŽ‰ Custom function examples completed!");
523
524    Ok(())
525}