1use once_cell::sync::OnceCell;
7use pforge_runtime::HandlerRegistry;
8use std::ffi::{CStr, CString};
9use std::os::raw::{c_char, c_int};
10use std::slice;
11use std::sync::Arc;
12use tokio::runtime::Runtime;
13use tokio::sync::RwLock;
14
15static GLOBAL_REGISTRY: OnceCell<Arc<RwLock<HandlerRegistry>>> = OnceCell::new();
17
18static RUNTIME: OnceCell<Runtime> = OnceCell::new();
20
21fn get_runtime() -> &'static Runtime {
23 RUNTIME.get_or_init(|| Runtime::new().expect("Failed to create tokio runtime"))
24}
25
26#[repr(C)]
28pub struct HandlerContext {
29 _private: [u8; 0],
30}
31
32#[repr(C)]
34pub struct FfiResult {
35 pub code: c_int,
37 pub data: *mut u8,
39 pub data_len: usize,
41 pub error: *const c_char,
43}
44
45#[no_mangle]
51pub unsafe extern "C" fn pforge_init() -> c_int {
52 if GLOBAL_REGISTRY.get().is_some() {
53 return -1; }
55
56 let registry = Arc::new(RwLock::new(HandlerRegistry::new()));
57 if GLOBAL_REGISTRY.set(registry).is_err() {
58 return -2; }
60
61 0 }
63
64#[no_mangle]
70pub unsafe extern "C" fn pforge_register_handler(
71 name: *const c_char,
72 _handler_ptr: *mut std::ffi::c_void,
73) -> c_int {
74 if name.is_null() {
75 return -1;
76 }
77
78 let registry = match GLOBAL_REGISTRY.get() {
79 Some(r) => r,
80 None => return -2, };
82
83 let name_str = match unsafe { CStr::from_ptr(name) }.to_str() {
85 Ok(s) => s,
86 Err(_) => return -3, };
88
89 let rt = get_runtime();
92 let _ = rt.block_on(async { registry.read().await });
93
94 eprintln!("Handler '{}' registered via FFI", name_str);
95 0
96}
97
98#[no_mangle]
106pub unsafe extern "C" fn pforge_execute_handler(
107 handler_name: *const c_char,
108 input_json: *const u8,
109 input_len: usize,
110) -> FfiResult {
111 if handler_name.is_null() || input_json.is_null() {
113 return FfiResult {
114 code: -1,
115 data: std::ptr::null_mut(),
116 data_len: 0,
117 error: create_error_string("Null pointer provided"),
118 };
119 }
120
121 let name = match unsafe { CStr::from_ptr(handler_name) }.to_str() {
124 Ok(s) => s,
125 Err(_) => {
126 return FfiResult {
127 code: -2,
128 data: std::ptr::null_mut(),
129 data_len: 0,
130 error: create_error_string("Invalid UTF-8 in handler name"),
131 }
132 }
133 };
134
135 let input = unsafe { slice::from_raw_parts(input_json, input_len) };
138
139 if let Some(registry) = GLOBAL_REGISTRY.get() {
141 let rt = get_runtime();
142 let result = rt.block_on(async {
143 let reg = registry.read().await;
144 reg.dispatch(name, input).await
145 });
146
147 match result {
148 Ok(output) => {
149 let mut boxed = output.into_boxed_slice();
150 let data_ptr = boxed.as_mut_ptr();
151 let data_len = boxed.len();
152 #[allow(clippy::mem_forget)]
154 std::mem::forget(boxed);
155
156 return FfiResult {
157 code: 0,
158 data: data_ptr,
159 data_len,
160 error: std::ptr::null(),
161 };
162 }
163 Err(e) => {
164 let err_str = e.to_string();
166 if err_str.contains("not found") || err_str.contains("ToolNotFound") {
167 } else {
169 return FfiResult {
170 code: -4,
171 data: std::ptr::null_mut(),
172 data_len: 0,
173 error: create_error_string(&format!("Handler error: {}", e)),
174 };
175 }
176 }
177 }
178 }
179
180 let response = serde_json::json!({
182 "handler": name,
183 "input_size": input_len,
184 "status": "ok",
185 "note": "No global registry - using echo fallback"
186 });
187
188 match serde_json::to_vec(&response) {
189 Ok(data) => {
190 let mut boxed = data.into_boxed_slice();
191 let data_ptr = boxed.as_mut_ptr();
192 let data_len = boxed.len();
193 #[allow(clippy::mem_forget)]
195 std::mem::forget(boxed);
196
197 FfiResult {
198 code: 0,
199 data: data_ptr,
200 data_len,
201 error: std::ptr::null(),
202 }
203 }
204 Err(e) => FfiResult {
205 code: -3,
206 data: std::ptr::null_mut(),
207 data_len: 0,
208 error: create_error_string(&format!("Serialization error: {}", e)),
209 },
210 }
211}
212
213#[no_mangle]
219pub unsafe extern "C" fn pforge_free_result(result: FfiResult) {
220 if !result.data.is_null() && result.data_len > 0 {
221 let _ = unsafe { Vec::from_raw_parts(result.data, result.data_len, result.data_len) };
223 }
224 if !result.error.is_null() {
225 let _ = unsafe { CString::from_raw(result.error as *mut c_char) };
227 }
228}
229
230#[no_mangle]
235pub unsafe extern "C" fn pforge_version() -> *const c_char {
236 static VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "\0");
237 VERSION.as_ptr() as *const c_char
238}
239
240#[no_mangle]
245pub extern "C" fn pforge_is_initialized() -> c_int {
246 if GLOBAL_REGISTRY.get().is_some() {
247 1
248 } else {
249 0
250 }
251}
252
253fn create_error_string(msg: &str) -> *const c_char {
256 match CString::new(msg) {
257 Ok(s) => s.into_raw() as *const c_char,
258 Err(_) => std::ptr::null(),
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use std::ffi::CString;
266
267 #[test]
268 fn test_version() {
269 unsafe {
270 let version = pforge_version();
271 assert!(!version.is_null());
272 let version_str = CStr::from_ptr(version).to_str().unwrap();
273 assert!(version_str.starts_with("0.1"));
274 }
275 }
276
277 #[test]
278 fn test_init() {
279 let result = unsafe { pforge_init() };
282 assert!(result == 0 || result == -1);
284 }
285
286 #[test]
287 fn test_is_initialized() {
288 let status = pforge_is_initialized();
291 assert!(
292 status == 0 || status == 1,
293 "Expected 0 or 1, got {}",
294 status
295 );
296
297 if GLOBAL_REGISTRY.get().is_some() {
299 assert_eq!(
300 pforge_is_initialized(),
301 1,
302 "Should return 1 when initialized"
303 );
304 }
305 }
306
307 #[test]
308 fn test_is_initialized_returns_one_after_init() {
309 let _ = unsafe { pforge_init() };
311 assert_eq!(pforge_is_initialized(), 1);
313 }
314
315 #[test]
316 fn test_create_error_string() {
317 let msg = "Test error message";
318 let ptr = create_error_string(msg);
319 assert!(
320 !ptr.is_null(),
321 "create_error_string should return non-null pointer"
322 );
323
324 unsafe {
326 let c_str = CStr::from_ptr(ptr);
327 let str_slice = c_str.to_str().unwrap();
328 assert_eq!(str_slice, msg);
329 let _ = CString::from_raw(ptr as *mut c_char);
331 }
332 }
333
334 #[test]
335 fn test_create_error_string_with_null_byte() {
336 let msg = "Error\0with null";
338 let ptr = create_error_string(msg);
339 assert!(
340 ptr.is_null(),
341 "Should return null for string with embedded null byte"
342 );
343 }
344
345 #[test]
346 fn test_execute_handler_null_safety() {
347 unsafe {
348 let result = pforge_execute_handler(std::ptr::null(), std::ptr::null(), 0);
350 assert_eq!(result.code, -1);
351 pforge_free_result(result);
352 }
353 }
354
355 #[test]
356 fn test_execute_handler_fallback() {
357 unsafe {
358 let handler_name = CString::new("test_handler").unwrap();
359 let input = b"{}";
360
361 let result = pforge_execute_handler(handler_name.as_ptr(), input.as_ptr(), input.len());
362
363 assert_eq!(result.code, 0);
365 assert!(!result.data.is_null());
366 assert!(result.data_len > 0);
367
368 let data_slice = slice::from_raw_parts(result.data, result.data_len);
370 let response: serde_json::Value = serde_json::from_slice(data_slice).unwrap();
371 assert_eq!(response["handler"], "test_handler");
372 assert_eq!(response["status"], "ok");
373
374 pforge_free_result(result);
375 }
376 }
377
378 #[test]
379 fn test_register_handler_null_name() {
380 unsafe {
381 let result = pforge_register_handler(std::ptr::null(), std::ptr::null_mut());
382 assert_eq!(result, -1);
383 }
384 }
385}