Skip to main content

vm/
lib.rs

1//使用 cranelift 作为后端 直接 jit 解释脚本
2mod binary;
3mod native;
4pub use native::{ANY, STD};
5
6mod fns;
7use anyhow::{Result, anyhow};
8pub use fns::{FnInfo, FnVariant};
9mod context;
10use context::BuildContext;
11
12mod rt;
13use cranelift::prelude::types;
14use dynamic::Type;
15pub use rt::JITRunTime;
16use smol_str::SmolStr;
17mod http_module;
18mod llm_module;
19mod root_module;
20
21use std::sync::{OnceLock, RwLock};
22static PTR_TYPE: OnceLock<types::Type> = OnceLock::new();
23pub fn ptr_type() -> types::Type {
24    PTR_TYPE.get().cloned().unwrap()
25}
26
27pub fn get_type(ty: &Type) -> Result<types::Type> {
28    if ty.is_f64() {
29        Ok(types::F64)
30    } else if ty.is_f32() {
31        Ok(types::F32)
32    } else if ty.is_int() | ty.is_uint() {
33        match ty.width() {
34            1 => Ok(types::I8),
35            2 => Ok(types::I16),
36            4 => Ok(types::I32),
37            8 => Ok(types::I64),
38            _ => Err(anyhow!("非法类型 {:?}", ty)),
39        }
40    } else if let Type::Bool = ty {
41        Ok(types::I8)
42    } else {
43        Ok(ptr_type())
44    }
45}
46
47use compiler::Symbol;
48use cranelift::prelude::*;
49
50pub fn init_jit(mut jit: JITRunTime) -> Result<JITRunTime> {
51    jit.compiler.symbols.add_module("std".into()); //开始导入标准库,可以直接访问
52    for std in STD {
53        jit.add_native(std.0, std.0, std.1, std.2)?;
54    }
55
56    let mut fields = Vec::new();
57    for (name, arg_tys, ret_ty, _) in ANY {
58        let id = jit.add_native(name, name, arg_tys, ret_ty)?;
59        let (_, field_name) = name.split_once("::").unwrap();
60        fields.push((field_name.into(), Type::Symbol { id, params: Vec::new() }));
61    }
62    jit.compiler.add_symbol("Any", Symbol::Struct(Type::Struct { params: Vec::new(), fields }, true));
63
64    jit.compiler.add_symbol("Vec", Symbol::Struct(Type::Struct { params: Vec::new(), fields: Vec::new() }, true));
65    let vec_def = Type::Symbol { id: jit.get_id("Vec")?, params: Vec::new() };
66    jit.add_inline("Vec::swap", vec![vec_def.clone(), Type::I64, Type::I64], Type::Void, |ctx: Option<&mut BuildContext>, args: Vec<Value>| {
67        if let Some(ctx) = ctx {
68            let width = ctx.builder.ins().iconst(types::I64, 4);
69            let offset_val = ctx.builder.ins().imul(args[1], width); // i * 4 i32大小四字节
70            let final_addr = ctx.builder.ins().iadd(args[0], offset_val); // base + (i*4)
71            let dest = ctx.builder.ins().imul(args[2], width);
72            let dest_addr = ctx.builder.ins().iadd(args[0], dest); // base + (i*4)
73            let dest_val = ctx.builder.ins().load(types::I32, MemFlags::trusted(), dest_addr, 0);
74            let v = ctx.builder.ins().load(types::I32, MemFlags::trusted(), final_addr, 0);
75            ctx.builder.ins().store(MemFlags::trusted(), v, dest_addr, 0);
76            ctx.builder.ins().store(MemFlags::trusted(), dest_val, final_addr, 0);
77        }
78        Err(anyhow!("无返回值"))
79    })?;
80
81    jit.add_inline("Vec::get_idx", vec![vec_def.clone(), Type::I64], Type::I32, |ctx: Option<&mut BuildContext>, args: Vec<Value>| {
82        if let Some(ctx) = ctx {
83            let width = ctx.builder.ins().iconst(types::I64, 4);
84            let offset_val = ctx.builder.ins().imul(args[1], width); // i * 4 i32大小四字节
85            let final_addr = ctx.builder.ins().iadd(args[0], offset_val);
86            Ok((Some(ctx.builder.ins().load(types::I32, MemFlags::trusted(), final_addr, 0)), Type::I32))
87        } else {
88            Ok((None, Type::I32))
89        }
90    })?;
91    Ok(jit)
92}
93
94use std::sync::Arc;
95
96use std::sync::LazyLock;
97unsafe impl Send for JITRunTime {}
98unsafe impl Sync for JITRunTime {}
99
100//直接在这里增加一行 就可以导入一个模块
101static mut MODULES: &[(&str, &[(&str, &[Type], Type, *const u8)])] = &[("llm", &llm_module::LLM_NATIVE), ("root", &root_module::ROOT_NATIVE), ("http", &http_module::HTTP_NATIVE)];
102
103pub static JIT: LazyLock<Arc<RwLock<JITRunTime>>> = LazyLock::new(|| {
104    let jit = JITRunTime::new(|b| {
105        //这里注册所有的外部符号
106        for (name, _, _, fn_ptr) in STD {
107            b.symbol(name, fn_ptr);
108        }
109        for (name, _, _, fn_ptr) in ANY {
110            b.symbol(name, fn_ptr);
111        }
112        for (name, fns) in unsafe { MODULES.into_iter() } {
113            for (fn_name, _, _, fn_ptr) in *fns {
114                let full_name = format!("{}::{}", *name, *fn_name);
115                b.symbol(&full_name, *fn_ptr);
116            }
117        }
118    });
119    let mut jit = init_jit(jit).unwrap();
120    for (name, fns) in unsafe { MODULES.into_iter() } {
121        jit.compiler.symbols.add_module((*name).into());
122        for r in fns.into_iter() {
123            let full_name = format!("{}::{}", *name, r.0);
124            jit.add_native(&&full_name, r.0, r.1, r.2.clone()).unwrap();
125        }
126        jit.compiler.symbols.pop_module();
127    }
128    Arc::new(RwLock::new(jit))
129});
130
131pub fn import_code(name: &str, code: Vec<u8>) -> Result<()> {
132    JIT.write().unwrap().import_code(name, code)
133}
134
135pub fn import(name: &str, path: &str) -> Result<()> {
136    if root::contains(path) {
137        //优先从 root 文件系统里面 import
138        let code = root::get(path).unwrap();
139        if code.is_str() {
140            JIT.write().unwrap().import_code(name, code.as_str().as_bytes().to_vec())
141        } else {
142            JIT.write().unwrap().import_code(name, code.get_dynamic("code").ok_or(anyhow!("{:?} 没有 code 成员", code))?.as_str().as_bytes().to_vec())
143        }
144    } else {
145        JIT.write().unwrap().compiler.import_file(name, path)?;
146        Ok(())
147    }
148}
149
150pub fn infer(name: &str, arg_tys: &[Type]) -> Result<Type> {
151    JIT.write().unwrap().get_type(name, arg_tys)
152}
153
154pub fn get_fn(name: &str, arg_tys: &[Type]) -> Result<(*const u8, Type)> {
155    JIT.write().unwrap().get_fn_ptr(name, arg_tys)
156}
157
158pub fn load(code: Vec<u8>, arg_name: SmolStr) -> Result<(i64, Type)> {
159    JIT.write().unwrap().load(code, arg_name)
160}
161
162pub fn get_symbol(name: &str, params: Vec<Type>) -> Result<Type> {
163    Ok(Type::Symbol { id: JIT.read().unwrap().get_id(name)?, params })
164}
165
166pub fn disassemble(name: &str) -> Result<String> {
167    JIT.read().unwrap().compiler.symbols.disassemble(name)
168}
169
170#[cfg(feature = "ir-disassembly")]
171pub fn disassemble_ir(name: &str) -> Result<String> {
172    JIT.write().unwrap().disassemble_ir(name)
173}
174
175#[cfg(test)]
176mod tests {
177    use super::{get_fn, import_code};
178    use dynamic::{Dynamic, ToJson, Type};
179
180    #[test]
181    fn compares_any_with_string_literal_as_string() -> anyhow::Result<()> {
182        import_code(
183            "vm_string_compare_any",
184            br#"
185            pub fn any_ne_empty(chat_path) {
186                chat_path != ""
187            }
188            "#
189            .to_vec(),
190        )?;
191
192        let (fn_ptr, ret_ty) = get_fn("vm_string_compare_any::any_ne_empty", &[Type::Any])?;
193        assert_eq!(ret_ty, Type::Bool);
194
195        let any_ne_empty: extern "C" fn(*const Dynamic) -> bool = unsafe { std::mem::transmute(fn_ptr) };
196        let empty = Dynamic::from("");
197        let non_empty = Dynamic::from("chat");
198
199        assert!(!any_ne_empty(&empty));
200        assert!(any_ne_empty(&non_empty));
201        Ok(())
202    }
203
204    #[test]
205    fn compares_concrete_value_with_string_literal_as_string() -> anyhow::Result<()> {
206        import_code(
207            "vm_string_compare_imm",
208            br#"
209            pub fn int_eq_str(value: i64) {
210                value == "42"
211            }
212
213            pub fn int_to_str(value: i64) {
214                value + ""
215            }
216            "#
217            .to_vec(),
218        )?;
219
220        let (fn_ptr, ret_ty) = get_fn("vm_string_compare_imm::int_eq_str", &[Type::I64])?;
221        assert_eq!(ret_ty, Type::Bool);
222
223        let int_eq_str: extern "C" fn(i64) -> bool = unsafe { std::mem::transmute(fn_ptr) };
224
225        let (fn_ptr, ret_ty) = get_fn("vm_string_compare_imm::int_to_str", &[Type::I64])?;
226        assert_eq!(ret_ty, Type::Any);
227        let int_to_str: extern "C" fn(i64) -> *const Dynamic = unsafe { std::mem::transmute(fn_ptr) };
228        let text = int_to_str(42);
229        assert_eq!(unsafe { &*text }.as_str(), "42");
230
231        assert!(int_eq_str(42));
232        assert!(!int_eq_str(7));
233        Ok(())
234    }
235
236    #[test]
237    fn root_native_calls_do_not_take_ownership_of_dynamic_args() -> anyhow::Result<()> {
238        import_code(
239            "vm_root_clone_bridge",
240            br#"
241            pub fn add_then_reuse(arg) {
242                let user = {
243                    address: "test-wallet",
244                    points: 20
245                };
246                root::add("local/root-clone-bridge-user", user);
247                user.points = user.points - 7;
248                root::add("local/root-clone-bridge-user", user);
249                {
250                    user: user,
251                    points: user.points
252                }
253            }
254            "#
255            .to_vec(),
256        )?;
257
258        let (fn_ptr, ret_ty) = get_fn("vm_root_clone_bridge::add_then_reuse", &[Type::Any])?;
259        assert_eq!(ret_ty, Type::Any);
260        let add_then_reuse: extern "C" fn(*const Dynamic) -> *const Dynamic = unsafe { std::mem::transmute(fn_ptr) };
261        let arg = Dynamic::Null;
262        let result = add_then_reuse(&arg);
263        let result = unsafe { &*result };
264
265        assert_eq!(result.get_dynamic("points").and_then(|value| value.as_int()), Some(13));
266        let mut json = String::new();
267        result.to_json(&mut json);
268        assert!(json.contains("\"points\": 13"));
269        Ok(())
270    }
271}