1use serde::{Serialize, Deserialize, de::DeserializeOwned};
2use serde_json::{Value, Error as SerdeError};
3use std::collections::HashMap;
4use std::alloc::{alloc as std_alloc, dealloc as std_dealloc, Layout};
5
6
7pub type FnResult<T> = Result<T, ModuleFnErr>;
8
9#[unsafe(no_mangle)]
10pub unsafe extern "C" fn guest_alloc(len: usize) -> *mut u8 {
11 let align = std::mem::align_of::<usize>();
12 let layout = Layout::from_size_align(len, align).unwrap();
13 unsafe { std_alloc(layout) }
14}
15
16#[unsafe(no_mangle)]
17pub unsafe extern "C" fn guest_dealloc(ptr: *mut u8, len: usize) {
18 let align = std::mem::align_of::<usize>();
19 let layout = Layout::from_size_align(len, align).unwrap();
20 unsafe { std_dealloc(ptr, layout); }
21}
22
23pub fn pack_ptr(ptr: *const u8, len: usize) -> u64 {
24 let ptr = ptr as u32;
25 let len = len as u32;
26 ((ptr as u64) << 32) | (len as u64)
27}
28
29pub fn unpack_ptr(packed: u64) -> (*const u8, usize) {
30 let ptr = (packed >> 32) as *const u8;
31 let len = (packed & 0xFFFFFFFF) as usize;
32 (ptr, len)
33}
34
35pub fn serialize_to_ptr<T: Serialize>(result: T) -> Result<u64, SerdeError> {
36 let bytes = serde_json::to_vec(&result)?;
37 let ptr = pack_ptr(bytes.as_ptr(), bytes.len());
38 std::mem::forget(bytes);
40
41 Ok(ptr)
42}
43
44pub fn deserialize_from_ptr<T: DeserializeOwned>(input_ptr: u32, input_len: u32) -> Result<T, SerdeError> {
45 let input_bytes = unsafe { std::slice::from_raw_parts(input_ptr as *const u8, input_len as usize) };
46 serde_json::from_slice(input_bytes)
47}
48
49#[derive(Serialize, Deserialize, Debug, Clone)]
50pub struct ModuleFnInput {
51 pub args: Option<Vec<Value>>,
52 pub kwargs: Option<HashMap<String, Value>>,
53}
54
55impl ModuleFnInput {
56 pub fn new() -> Self {
57 Self {
58 args: None,
59 kwargs: None,
60 }
61 }
62
63 pub fn get_arg<T: DeserializeOwned>(&self, index: usize, name: &str) -> Result<T, String> {
64 if let Some(kwargs) = &self.kwargs {
65 if let Some(value) = kwargs.get(name) {
66 return serde_json::from_value(value.clone())
67 .map_err(|e| format!("Failed to parse argument {}: {}", name, e));
68 }
69 }
70
71 if let Some(args) = &self.args {
72 if index < args.len() {
73 return serde_json::from_value(args[index].clone())
74 .map_err(|e| format!("Failed to parse argument {}: {}", index, e));
75 }
76 }
77
78 Err(format!("Missing argument: {} in position {}", name, index))
79 }
80
81 pub fn get_args<T: DeserializeOwned>(&self) -> Result<Vec<T>, String> {
82 if let Some(args) = &self.args {
83 return serde_json::from_value(Value::Array(args.clone()))
84 .map_err(|e| format!("Failed to parse arguments: {}", e));
85 }
86 Err("No arguments provided".into())
87 }
88
89 pub fn get_kwargs<T: DeserializeOwned>(&self) -> Result<HashMap<String, T>, String> {
90 if let Some(kwargs) = &self.kwargs {
91 return serde_json::from_value(Value::Object(
92 kwargs
93 .iter()
94 .map(|(k, v)| (k.clone(), v.clone()))
95 .collect(),
96 ))
97 .map_err(|e| format!("Failed to parse keyword arguments: {}", e));
98 }
99 Err("No keyword arguments provided".into())
100 }
101
102 pub fn add_arg<T: Serialize>(&mut self, arg: T) -> Result<(), String> {
103 if let Some(args) = &mut self.args {
104 args.push(
105 serde_json::to_value(arg)
106 .map_err(|e| format!("Failed to serialize argument: {}", e))?,
107 );
108 } else {
109 self.args = Some(vec![
110 serde_json::to_value(arg)
111 .map_err(|e| format!("Failed to serialize argument: {}", e))?,
112 ]);
113 }
114
115 Ok(())
116 }
117
118 pub fn add_kwarg<T: Serialize>(&mut self, name: String, arg: T) -> Result<(), String> {
119 if let Some(kwargs) = &mut self.kwargs {
120 kwargs.insert(
121 name,
122 serde_json::to_value(arg)
123 .map_err(|e| format!("Failed to serialize keyword argument: {}", e))?,
124 );
125 } else {
126 let mut kwargs = HashMap::new();
127
128 kwargs.insert(
129 name,
130 serde_json::to_value(arg)
131 .map_err(|e| format!("Failed to serialize keyword argument: {}", e))?,
132 );
133 self.kwargs = Some(kwargs);
134 }
135
136 Ok(())
137 }
138}
139
140impl Default for ModuleFnInput {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146#[derive(Serialize, Deserialize, Debug, Clone)]
147pub struct ModuleFnErr {
148 #[serde(rename = "type")]
149 pub error_type: String,
150 pub message: String,
151}
152
153#[derive(Serialize, Deserialize, Debug, Clone)]
154pub struct ModuleFnReturn<T: Serialize> {
155 pub value: Option<T>,
156}
157
158impl<T: Serialize> ModuleFnReturn<T> {
159 pub fn new(value: T) -> Self {
160 Self { value: Some(value) }
161 }
162
163 pub fn empty() -> Self {
164 Self { value: None }
165 }
166}
167
168impl ModuleFnReturn<serde_json::Value> {
169 pub fn new_serialized<T: Serialize>(value: T) -> Result<Self, SerdeError> {
170 Ok(ModuleFnReturn::new(serde_json::to_value(value)?))
171 }
172}
173
174
175#[derive(Serialize, Deserialize, Debug, Clone)]
176#[serde(tag = "object")]
177pub enum ModuleFnResult<T: Serialize> {
178 #[serde(rename = "data")]
179 Data(ModuleFnReturn<T>),
180 #[serde(rename = "error")]
181 Error(ModuleFnErr),
182}
183