Skip to main content

mlua_flow_ir/
lib.rs

1#![deny(unsafe_code)]
2//! flow.ir async runtime + mlua binding.
3//!
4//! Layer 3 of the 4-layer flow.ir stack:
5//!
6//! 1. `flow-ir-lua` — Pure Lua DSL (separate repo, ecosystem-neutral)
7//! 2. `flow-ir-core` — Pure Rust schema + sync interpreter (no mlua, no async)
8//! 3. `mlua-flow-ir` — **this crate**: re-export of `flow-ir-core` +
9//!    `AsyncDispatcher` + `eval_async` + `fanout_eval` + Lua `module()` binding
10//! 4. `mlua-swarm-engine` — host concerns (Spawner / Worker / Loop /
11//!    AuthzPolicy / cp_state persist)
12//!
13//! All schema types (`Node` / `Expr` / `JoinMode` / `EvalError` / `Dispatcher`)
14//! are re-exported verbatim from `flow-ir-core` so callers can keep a single
15//! import path:
16//!
17//! ```
18//! use mlua_flow_ir::{eval, eval_async, AsyncDispatcher, Dispatcher, EvalError, Expr, Node};
19//! ```
20
21// ──────────────────────────────────────────────────────────────────────────
22// Re-export Pure Rust core (flow-ir-core)
23// ──────────────────────────────────────────────────────────────────────────
24
25pub use flow_ir_core::{
26    eval, eval_expr, eval_with_storage, is_truthy, read_path, write_path, CtxStorage, Dispatcher,
27    EvalError, Expr, JoinMode, MemoryCtx, Node,
28};
29
30use serde_json::Value;
31use std::sync::Arc;
32
33// ══════════════════════════════════════════════════════════════════════════
34// v0.0.2 — Async core (eval_async + AsyncDispatcher trait)
35// ══════════════════════════════════════════════════════════════════════════
36
37use async_recursion::async_recursion;
38use async_trait::async_trait;
39
40/// Async dispatcher trait — async 版 `Dispatcher`。
41///
42/// `async_trait` macro 経由 (= Rust 2021 互換 + dyn safe)。 Host crate
43/// (e.g. mlua-swarm-engine `AsyncSpawner`) が impl する。 substrate には
44/// tokio dep 入れない (= Pure 維持)、 executor は caller (host) 責務。
45#[async_trait]
46pub trait AsyncDispatcher: Send + Sync {
47    async fn dispatch(&self, ref_: &str, input: Value) -> Result<Value, EvalError>;
48}
49
50/// Evaluate a `Node` against a context value asynchronously,
51/// using the given async dispatcher for `Step` resolution.
52///
53/// `eval` (sync) と同型 logic、 dispatch を `.await` に置き換え。 Seq / Branch
54/// は recursive async fn (= `async_recursion` macro で `Pin<Box>` wrap)。
55///
56/// # Quick start
57///
58/// ```
59/// use async_trait::async_trait;
60/// use mlua_flow_ir::{eval_async, AsyncDispatcher, EvalError, Expr, Node};
61/// use serde_json::{json, Value};
62///
63/// struct Fixture;
64///
65/// #[async_trait]
66/// impl AsyncDispatcher for Fixture {
67///     async fn dispatch(&self, _r: &str, input: Value) -> Result<Value, EvalError> {
68///         if let Value::String(s) = input {
69///             Ok(Value::String(s.to_uppercase()))
70///         } else {
71///             Ok(input)
72///         }
73///     }
74/// }
75///
76/// let rt = tokio::runtime::Runtime::new().unwrap();
77/// rt.block_on(async {
78///     let node = Node::Step {
79///         ref_: "up".into(),
80///         in_: Expr::Path { at: "$.input".into() },
81///         out: Expr::Path { at: "$.output".into() },
82///     };
83///     let out = eval_async(&node, json!({ "input": "hello" }), &Fixture).await.unwrap();
84///     assert_eq!(out, json!({ "input": "hello", "output": "HELLO" }));
85/// });
86/// ```
87/// Storage-backed async evaluator — canonical entry.
88///
89/// `Arc<dyn CtxStorage>` 経由で ctx を共有することで、 dispatch().await suspend
90/// 中に外部 task が同じ ctx に `write` できる (= dynamic State injection 経路)。
91/// Step 評価の境界で `ctx.snapshot()` を取って Expr eval に渡す。
92#[async_recursion]
93pub async fn eval_async_with_storage<D>(
94    node: &Node,
95    ctx: Arc<dyn CtxStorage>,
96    dispatcher: &D,
97) -> Result<(), EvalError>
98where
99    D: AsyncDispatcher + ?Sized,
100{
101    match node {
102        Node::Step { ref_, in_, out } => {
103            // snap は dispatch() **呼出し前** の view。 dispatch().await 中に
104            // 外部 task が ctx.write しても、 ここで取った snap は影響を受けず
105            // input の値は確定。 write_target の `out` path への write は
106            // dispatch 完了後に共有 ctx を直接更新。
107            let snap = ctx.snapshot();
108            let input = eval_expr(in_, &snap)?;
109            let output =
110                dispatcher
111                    .dispatch(ref_, input)
112                    .await
113                    .map_err(|e| EvalError::DispatcherError {
114                        ref_: ref_.clone(),
115                        msg: e.to_string(),
116                    })?;
117            ctx.write(path_str_async(out)?, output)
118        }
119        Node::Seq { children } => {
120            for child in children {
121                eval_async_with_storage(child, ctx.clone(), dispatcher).await?;
122            }
123            Ok(())
124        }
125        Node::Branch { cond, then_, else_ } => {
126            let snap = ctx.snapshot();
127            match eval_expr(cond, &snap)? {
128                Value::Bool(true) => eval_async_with_storage(then_, ctx, dispatcher).await,
129                Value::Bool(false) => eval_async_with_storage(else_, ctx, dispatcher).await,
130                other => Err(EvalError::NonBoolCond(other)),
131            }
132        }
133        Node::Fanout {
134            items,
135            bind,
136            body,
137            join,
138            out,
139        } => fanout_eval(items, bind, body, *join, out, ctx, dispatcher).await,
140        Node::Loop {
141            counter,
142            cond,
143            body,
144            max,
145        } => {
146            let counter_path = path_str_async(counter)?.to_string();
147            ctx.write(&counter_path, Value::Number(serde_json::Number::from(0u32)))?;
148            let mut n: u32 = 0;
149            loop {
150                if n >= *max {
151                    break;
152                }
153                let snap = ctx.snapshot();
154                if !is_truthy(&eval_expr(cond, &snap)?) {
155                    break;
156                }
157                eval_async_with_storage(body, ctx.clone(), dispatcher).await?;
158                n += 1;
159                ctx.write(&counter_path, Value::Number(serde_json::Number::from(n)))?;
160            }
161            Ok(())
162        }
163        Node::Try {
164            body,
165            catch,
166            err_at,
167        } => {
168            let snap_before = ctx.snapshot();
169            match eval_async_with_storage(body, ctx.clone(), dispatcher).await {
170                Ok(()) => Ok(()),
171                Err(e) => {
172                    ctx.replace(snap_before);
173                    if let Some(at) = err_at {
174                        ctx.write(path_str_async(at)?, Value::String(e.to_string()))?;
175                    }
176                    eval_async_with_storage(catch, ctx, dispatcher).await
177                }
178            }
179        }
180        Node::Assign { at, value } => {
181            let snap = ctx.snapshot();
182            let v = eval_expr(value, &snap)?;
183            ctx.write(path_str_async(at)?, v)
184        }
185    }
186}
187
188/// Legacy Value-passing async evaluator — backward compat wrapper around
189/// `eval_async_with_storage` + `MemoryCtx`. 既存 caller (= dynamic injection
190/// を要求しない用途) は引き続きこの API で OK。
191pub async fn eval_async<D>(node: &Node, ctx: Value, dispatcher: &D) -> Result<Value, EvalError>
192where
193    D: AsyncDispatcher + ?Sized,
194{
195    let storage: Arc<dyn CtxStorage> = MemoryCtx::shared(ctx);
196    eval_async_with_storage(node, storage.clone(), dispatcher).await?;
197    Ok(storage.snapshot())
198}
199
200/// Resolve `Path` Expr to its literal `$.a.b.c` string (async eval 側 helper).
201fn path_str_async(expr: &Expr) -> Result<&str, EvalError> {
202    match expr {
203        Expr::Path { at } => Ok(at.as_str()),
204        _ => Err(EvalError::InvalidPath(
205            "expected Path expr for write target".into(),
206        )),
207    }
208}
209
210/// Fanout 並列 evaluator (storage-backed)。 各 branch は disjoint MemoryCtx
211/// を持ち、 branch 内で write しても共有 ctx には影響しない (= snapshot 切り出し
212/// semantic)。 集約結果は最後に共有 ctx の `out` path に write。
213#[async_recursion]
214async fn fanout_eval<D>(
215    items: &Expr,
216    bind: &Expr,
217    body: &Node,
218    join: JoinMode,
219    out: &Expr,
220    ctx: Arc<dyn CtxStorage>,
221    dispatcher: &D,
222) -> Result<(), EvalError>
223where
224    D: AsyncDispatcher + ?Sized,
225{
226    use futures::future::{join_all, select_ok, FutureExt};
227
228    let snap = ctx.snapshot();
229    let items_val = eval_expr(items, &snap)?;
230    let items_arr = match items_val {
231        Value::Array(a) => a,
232        other => {
233            return Err(EvalError::DispatcherError {
234                ref_: "fanout.items".into(),
235                msg: format!("expected array, got {other:?}"),
236            })
237        }
238    };
239
240    // branch storage を pre-allocate して、 各 branch future と pair で持つ。
241    // 集約時に同じ storage の snapshot を取って結果にする。
242    let branches: Vec<Arc<dyn CtxStorage>> = items_arr
243        .into_iter()
244        .map(|item| -> Result<Arc<dyn CtxStorage>, EvalError> {
245            let branch_ctx = write_path(bind, snap.clone(), item)?;
246            Ok(MemoryCtx::shared(branch_ctx))
247        })
248        .collect::<Result<_, _>>()?;
249
250    // 各 branch を `(idx, future)` で wrap。 future は branch storage と body を
251    // 共有して走る。
252    let branch_futs: Vec<_> = branches
253        .iter()
254        .map(|b| eval_async_with_storage(body, b.clone(), dispatcher))
255        .collect();
256
257    let joined: Value = match join {
258        JoinMode::All => {
259            futures::future::try_join_all(branch_futs).await?;
260            Value::Array(branches.iter().map(|b| b.snapshot()).collect())
261        }
262        JoinMode::Any => {
263            if branch_futs.is_empty() {
264                Value::Array(vec![])
265            } else {
266                let mapped: Vec<_> = branch_futs
267                    .into_iter()
268                    .enumerate()
269                    .map(|(i, f)| f.map(move |r| r.map(|()| i)).boxed())
270                    .collect();
271                let (winner_idx, _rest) = select_ok(mapped).await?;
272                branches[winner_idx].snapshot()
273            }
274        }
275        JoinMode::Race => {
276            if branch_futs.is_empty() {
277                Value::Array(vec![])
278            } else {
279                let mapped: Vec<_> = branch_futs
280                    .into_iter()
281                    .enumerate()
282                    .map(|(i, f)| f.map(move |r| r.map(|()| i)).boxed())
283                    .collect();
284                let (first, _idx, _rest) = futures::future::select_all(mapped).await;
285                let winner_idx = first?;
286                branches[winner_idx].snapshot()
287            }
288        }
289        JoinMode::AllSettled => {
290            let results = join_all(branch_futs).await;
291            let records: Vec<Value> = results
292                .into_iter()
293                .zip(branches.iter())
294                .map(|(r, b)| match r {
295                    Ok(()) => serde_json::json!({"status": "fulfilled", "value": b.snapshot()}),
296                    Err(e) => serde_json::json!({"status": "rejected", "reason": e.to_string()}),
297                })
298                .collect();
299            Value::Array(records)
300        }
301    };
302
303    ctx.write(path_str_async(out)?, joined)
304}
305
306// ══════════════════════════════════════════════════════════════════════════
307// v0.0.3 — mlua bridge full
308// ══════════════════════════════════════════════════════════════════════════
309
310use mlua::LuaSerdeExt;
311
312/// Lua function を Rust `Dispatcher` trait に wrap した adapter。
313///
314/// Lua 側 dispatcher function `function(ref, input) return ... end` を受けて、
315/// Rust `eval(node, ctx, &lua_dispatcher)` から呼び出せるようにする。
316/// 内部で serde Value ↔ Lua value 変換 (= mlua serde feature) を経由。
317struct LuaDispatcher<'a> {
318    lua: &'a mlua::Lua,
319    func: mlua::Function,
320}
321
322impl<'a> Dispatcher for LuaDispatcher<'a> {
323    fn dispatch(&self, ref_: &str, input: Value) -> Result<Value, EvalError> {
324        let lua_input = self
325            .lua
326            .to_value(&input)
327            .map_err(|e| EvalError::DispatcherError {
328                ref_: ref_.into(),
329                msg: format!("to_value: {}", e),
330            })?;
331        let result: mlua::Value = self.func.call((ref_.to_string(), lua_input)).map_err(|e| {
332            EvalError::DispatcherError {
333                ref_: ref_.into(),
334                msg: format!("lua call: {}", e),
335            }
336        })?;
337        let value: Value = self
338            .lua
339            .from_value(result)
340            .map_err(|e| EvalError::DispatcherError {
341                ref_: ref_.into(),
342                msg: format!("from_value: {}", e),
343            })?;
344        Ok(value)
345    }
346}
347
348/// Register the flow module table with Lua.
349///
350/// v0.0.3 full impl — exposes:
351///
352/// - `flow.version` (= string): crate version
353/// - `flow.eval(node_table, ctx_table, dispatcher_fn) -> result_table`:
354///   Lua-side entry to evaluate a flow.ir BluePrint with a Lua dispatcher fn
355///
356/// # Lua usage
357///
358/// ```lua
359/// local flow = require("flow")  -- or set via lua.globals():set("flow", module(lua))
360///
361/// local node = {
362///   kind = "step",
363///   ref = "uppercase",
364///   ["in"] = { op = "path", at = "$.input" },
365///   out = { op = "path", at = "$.output" },
366/// }
367///
368/// local function dispatcher(ref, input)
369///   if ref == "uppercase" then
370///     return string.upper(input)
371///   end
372/// end
373///
374/// local result = flow.eval(node, { input = "hello" }, dispatcher)
375/// assert(result.output == "HELLO")
376/// ```
377pub fn module(lua: &mlua::Lua) -> mlua::Result<mlua::Table> {
378    let t = lua.create_table()?;
379    t.set("version", env!("CARGO_PKG_VERSION"))?;
380
381    let eval_fn = lua.create_function(
382        |lua_inner: &mlua::Lua,
383         (node_val, ctx_val, dispatcher_fn): (mlua::Value, mlua::Value, mlua::Function)| {
384            let node: Node = lua_inner
385                .from_value(node_val)
386                .map_err(|e| mlua::Error::external(format!("node parse: {}", e)))?;
387            let ctx: Value = lua_inner
388                .from_value(ctx_val)
389                .map_err(|e| mlua::Error::external(format!("ctx parse: {}", e)))?;
390
391            let dispatcher = LuaDispatcher {
392                lua: lua_inner,
393                func: dispatcher_fn,
394            };
395            let result = eval(&node, ctx, &dispatcher)
396                .map_err(|e| mlua::Error::external(format!("eval: {}", e)))?;
397            lua_inner.to_value(&result)
398        },
399    )?;
400    t.set("eval", eval_fn)?;
401
402    Ok(t)
403}