Skip to main content

a3s_flow/
service.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use serde_json::Value;
5use tokio::sync::broadcast;
6use uuid::Uuid;
7
8use crate::{
9    ExecutionState, FlowCapabilities, FlowEngine, FlowError, FlowEvent, Node, NodeDescriptor,
10    ValidationIssue,
11};
12
13pub type NodeFactory = Arc<dyn Fn() -> Arc<dyn Node> + Send + Sync>;
14
15#[derive(Clone)]
16pub struct FlowService {
17    engine: Arc<FlowEngine>,
18    node_factories: Arc<HashMap<String, NodeFactory>>,
19}
20
21impl FlowService {
22    pub fn new(engine: Arc<FlowEngine>) -> Self {
23        Self::with_factories(engine, HashMap::new())
24    }
25
26    pub fn with_factories(
27        engine: Arc<FlowEngine>,
28        node_factories: HashMap<String, NodeFactory>,
29    ) -> Self {
30        Self {
31            engine,
32            node_factories: Arc::new(node_factories),
33        }
34    }
35
36    pub fn engine(&self) -> Arc<FlowEngine> {
37        Arc::clone(&self.engine)
38    }
39
40    pub fn capabilities(&self) -> FlowCapabilities {
41        self.engine.capabilities()
42    }
43
44    pub fn node_types(&self) -> Vec<String> {
45        self.engine.node_types()
46    }
47
48    pub fn node_descriptors(&self) -> Vec<NodeDescriptor> {
49        self.engine.node_descriptors()
50    }
51
52    pub fn validate(&self, definition: &Value) -> Vec<ValidationIssue> {
53        self.engine.validate(definition)
54    }
55
56    pub async fn start_execution(
57        &self,
58        definition: &Value,
59        variables: HashMap<String, Value>,
60    ) -> crate::Result<Uuid> {
61        self.engine.start(definition, variables).await
62    }
63
64    pub async fn get_execution(&self, id: Uuid) -> crate::Result<ExecutionState> {
65        self.engine.state(id).await
66    }
67
68    pub async fn subscribe(&self, id: Uuid) -> crate::Result<broadcast::Receiver<FlowEvent>> {
69        self.engine.subscribe(id).await
70    }
71
72    pub async fn pause_execution(&self, id: Uuid) -> crate::Result<ExecutionState> {
73        self.engine.pause(id).await?;
74        self.engine.state(id).await
75    }
76
77    pub async fn resume_execution(&self, id: Uuid) -> crate::Result<ExecutionState> {
78        self.engine.resume(id).await?;
79        self.engine.state(id).await
80    }
81
82    pub async fn terminate_execution(&self, id: Uuid) -> crate::Result<()> {
83        self.engine.terminate(id).await
84    }
85
86    pub async fn get_context(&self, id: Uuid) -> crate::Result<HashMap<String, Value>> {
87        self.engine.get_context(id).await
88    }
89
90    pub async fn set_context_entry(
91        &self,
92        id: Uuid,
93        key: String,
94        value: Value,
95    ) -> crate::Result<()> {
96        self.engine.set_context_entry(id, key, value).await
97    }
98
99    pub async fn delete_context_entry(&self, id: Uuid, key: &str) -> crate::Result<bool> {
100        self.engine.delete_context_entry(id, key).await
101    }
102
103    pub async fn run_named_flow(
104        &self,
105        name: &str,
106        variables: HashMap<String, Value>,
107    ) -> crate::Result<Uuid> {
108        self.engine.start_named(name, variables).await
109    }
110
111    pub fn register_node_type(
112        &self,
113        factory_name: &str,
114        descriptor: Option<NodeDescriptor>,
115    ) -> crate::Result<(String, bool)> {
116        let factory = self
117            .node_factories
118            .get(factory_name)
119            .cloned()
120            .ok_or_else(|| {
121                FlowError::InvalidDefinition(format!("unknown node factory: {factory_name}"))
122            })?;
123        let node = factory();
124        let node_type = node.node_type().to_string();
125        let replaced = self.engine.node_types().contains(&node_type);
126        match descriptor {
127            Some(descriptor) => self
128                .engine
129                .register_node_type_with_descriptor(node, descriptor),
130            None => self.engine.register_node_type(node),
131        }
132        Ok((node_type, replaced))
133    }
134
135    pub fn unregister_node_type(&self, node_type: &str) -> crate::Result<bool> {
136        self.engine.unregister_node_type(node_type)
137    }
138}
139
140#[allow(dead_code)]
141fn _assert_send_sync<T: Send + Sync>() {}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use crate::{FlowStore, MemoryFlowStore, NodeRegistry};
147    use async_trait::async_trait;
148    use serde_json::json;
149
150    struct SlowNode;
151
152    #[async_trait]
153    impl Node for SlowNode {
154        fn node_type(&self) -> &str {
155            "slow"
156        }
157
158        async fn execute(&self, _ctx: crate::ExecContext) -> crate::Result<Value> {
159            Ok(json!({ "ok": true }))
160        }
161    }
162
163    #[tokio::test]
164    async fn register_node_type_uses_factory_registry() {
165        let engine = Arc::new(FlowEngine::new(NodeRegistry::with_defaults()));
166        let mut factories: HashMap<String, NodeFactory> = HashMap::new();
167        factories.insert("slow-test-node".into(), Arc::new(|| Arc::new(SlowNode)));
168        let service = FlowService::with_factories(engine, factories);
169
170        let (node_type, replaced) = service.register_node_type("slow-test-node", None).unwrap();
171        assert_eq!(node_type, "slow");
172        assert!(!replaced);
173        assert!(service.node_types().contains(&"slow".to_string()));
174    }
175
176    #[tokio::test]
177    async fn run_named_flow_uses_engine_store() {
178        let flow_store = Arc::new(MemoryFlowStore::new());
179        flow_store
180            .save(
181                "hello",
182                &json!({
183                    "nodes": [{ "id": "a", "type": "noop" }],
184                    "edges": []
185                }),
186            )
187            .await
188            .unwrap();
189        let engine = Arc::new(
190            FlowEngine::new(NodeRegistry::with_defaults())
191                .with_flow_store(flow_store as Arc<dyn FlowStore>),
192        );
193        let service = FlowService::new(engine);
194
195        let id = service
196            .run_named_flow("hello", HashMap::new())
197            .await
198            .unwrap();
199        let state = service.get_execution(id).await.unwrap();
200        assert!(matches!(
201            state,
202            ExecutionState::Running | ExecutionState::Completed(_)
203        ));
204    }
205}