Skip to main content

simple_agents_ffi/
lib.rs

1//! C-compatible FFI bindings for SimpleAgents.
2
3use simple_agent_type::message::Message;
4use simple_agent_type::prelude::{ApiKey, CompletionRequest, Provider, Result, SimpleAgentsError};
5use simple_agents_core::{
6    CompletionOptions, CompletionOutcome, SimpleAgentsClient, SimpleAgentsClientBuilder,
7};
8use simple_agents_providers::anthropic::AnthropicProvider;
9use simple_agents_providers::openai::OpenAIProvider;
10use simple_agents_providers::openrouter::OpenRouterProvider;
11use std::cell::RefCell;
12use std::ffi::{CStr, CString};
13use std::os::raw::c_char;
14use std::panic::{catch_unwind, AssertUnwindSafe};
15use std::sync::{Arc, Mutex};
16
17type Runtime = tokio::runtime::Runtime;
18
19struct FfiClient {
20    runtime: Mutex<Runtime>,
21    client: SimpleAgentsClient,
22}
23
24#[repr(C)]
25pub struct SAClient {
26    inner: FfiClient,
27}
28
29thread_local! {
30    static LAST_ERROR: RefCell<Option<String>> = const { RefCell::new(None) };
31}
32
33fn set_last_error(message: impl Into<String>) {
34    LAST_ERROR.with(|slot| {
35        *slot.borrow_mut() = Some(message.into());
36    });
37}
38
39fn clear_last_error() {
40    LAST_ERROR.with(|slot| {
41        *slot.borrow_mut() = None;
42    });
43}
44
45fn take_last_error() -> Option<String> {
46    LAST_ERROR.with(|slot| slot.borrow_mut().take())
47}
48
49fn build_runtime() -> Result<Runtime> {
50    Runtime::new().map_err(|e| SimpleAgentsError::Config(format!("Failed to build runtime: {e}")))
51}
52
53fn provider_from_env(provider_name: &str) -> Result<Arc<dyn Provider>> {
54    match provider_name {
55        "openai" => Ok(Arc::new(OpenAIProvider::from_env()?)),
56        "anthropic" => Ok(Arc::new(AnthropicProvider::from_env()?)),
57        "openrouter" => Ok(Arc::new(openrouter_from_env()?)),
58        _ => Err(SimpleAgentsError::Config(format!(
59            "Unknown provider '{provider_name}'"
60        ))),
61    }
62}
63
64fn openrouter_from_env() -> Result<OpenRouterProvider> {
65    let api_key = std::env::var("OPENROUTER_API_KEY").map_err(|_| {
66        SimpleAgentsError::Config("OPENROUTER_API_KEY environment variable is required".to_string())
67    })?;
68    let api_key = ApiKey::new(api_key)?;
69    let base_url = std::env::var("OPENROUTER_API_BASE")
70        .unwrap_or_else(|_| OpenRouterProvider::DEFAULT_BASE_URL.to_string());
71    OpenRouterProvider::with_base_url(api_key, base_url)
72}
73
74unsafe fn cstr_to_string(ptr: *const c_char, field: &str) -> Result<String> {
75    if ptr.is_null() {
76        return Err(SimpleAgentsError::Config(format!("{field} cannot be null")));
77    }
78
79    let c_str = CStr::from_ptr(ptr);
80    let value = c_str
81        .to_str()
82        .map_err(|_| SimpleAgentsError::Config(format!("{field} must be valid UTF-8")))?;
83    if value.is_empty() {
84        return Err(SimpleAgentsError::Config(format!(
85            "{field} cannot be empty"
86        )));
87    }
88
89    Ok(value.to_string())
90}
91
92fn build_client(provider: Arc<dyn Provider>) -> Result<SimpleAgentsClient> {
93    SimpleAgentsClientBuilder::new()
94        .with_provider(provider)
95        .build()
96}
97
98fn build_request(
99    model: &str,
100    prompt: &str,
101    max_tokens: i32,
102    temperature: f32,
103) -> Result<CompletionRequest> {
104    let mut builder = CompletionRequest::builder()
105        .model(model)
106        .message(Message::user(prompt));
107
108    if max_tokens > 0 {
109        builder = builder.max_tokens(max_tokens as u32);
110    }
111
112    if temperature >= 0.0 {
113        builder = builder.temperature(temperature);
114    }
115
116    builder.build()
117}
118
119fn ffi_result_string(result: Result<String>) -> *mut c_char {
120    match result {
121        Ok(value) => match CString::new(value) {
122            Ok(c_string) => {
123                clear_last_error();
124                c_string.into_raw()
125            }
126            Err(_) => {
127                set_last_error("Response contained an interior null byte".to_string());
128                std::ptr::null_mut()
129            }
130        },
131        Err(error) => {
132            set_last_error(error.to_string());
133            std::ptr::null_mut()
134        }
135    }
136}
137
138fn ffi_guard<T>(action: impl FnOnce() -> Result<T>) -> *mut c_char
139where
140    T: Into<String>,
141{
142    let result = catch_unwind(AssertUnwindSafe(action));
143    match result {
144        Ok(inner) => ffi_result_string(inner.map(Into::into)),
145        Err(_) => {
146            set_last_error("Panic occurred in FFI call".to_string());
147            std::ptr::null_mut()
148        }
149    }
150}
151
152/// Create a client from environment variables for a provider.
153///
154/// `provider_name` must be one of: "openai", "anthropic", "openrouter".
155///
156/// # Safety
157///
158/// The `provider_name` pointer must be a valid null-terminated C string or null.
159/// The returned pointer must be freed with `sa_client_free`.
160#[no_mangle]
161pub unsafe extern "C" fn sa_client_new_from_env(provider_name: *const c_char) -> *mut SAClient {
162    let result = catch_unwind(AssertUnwindSafe(|| -> Result<Box<SAClient>> {
163        let provider = cstr_to_string(provider_name, "provider_name")?;
164        let provider = provider_from_env(&provider)?;
165        let client = build_client(provider)?;
166        let runtime = build_runtime()?;
167
168        Ok(Box::new(SAClient {
169            inner: FfiClient {
170                runtime: Mutex::new(runtime),
171                client,
172            },
173        }))
174    }));
175
176    match result {
177        Ok(Ok(client)) => {
178            clear_last_error();
179            Box::into_raw(client)
180        }
181        Ok(Err(error)) => {
182            set_last_error(error.to_string());
183            std::ptr::null_mut()
184        }
185        Err(_) => {
186            set_last_error("Panic occurred in sa_client_new_from_env".to_string());
187            std::ptr::null_mut()
188        }
189    }
190}
191
192/// Free a client created by `sa_client_new_from_env`.
193///
194/// # Safety
195///
196/// The `client` pointer must be null or a valid pointer returned by `sa_client_new_from_env`.
197/// After calling this function, the pointer is no longer valid and must not be used.
198#[no_mangle]
199pub unsafe extern "C" fn sa_client_free(client: *mut SAClient) {
200    if client.is_null() {
201        return;
202    }
203
204    drop(Box::from_raw(client));
205}
206
207/// Execute a completion request with a single user prompt.
208///
209/// Use `max_tokens <= 0` to omit, and `temperature < 0.0` to omit.
210///
211/// # Safety
212///
213/// The `client` pointer must be a valid pointer returned by `sa_client_new_from_env`.
214/// The `model` and `prompt` pointers must be valid null-terminated C strings.
215/// The returned pointer must be freed with `sa_string_free`.
216#[no_mangle]
217pub unsafe extern "C" fn sa_complete(
218    client: *mut SAClient,
219    model: *const c_char,
220    prompt: *const c_char,
221    max_tokens: i32,
222    temperature: f32,
223) -> *mut c_char {
224    if client.is_null() {
225        set_last_error("client cannot be null".to_string());
226        return std::ptr::null_mut();
227    }
228
229    ffi_guard(|| {
230        let model = cstr_to_string(model, "model")?;
231        let prompt = cstr_to_string(prompt, "prompt")?;
232        let request = build_request(&model, &prompt, max_tokens, temperature)?;
233
234        let client = &(*client).inner;
235        let runtime = client
236            .runtime
237            .lock()
238            .map_err(|_| SimpleAgentsError::Config("runtime lock poisoned".to_string()))?;
239        let outcome = runtime.block_on(
240            client
241                .client
242                .complete(&request, CompletionOptions::default()),
243        )?;
244        let response = match outcome {
245            CompletionOutcome::Response(response) => response,
246            CompletionOutcome::Stream(_) => {
247                return Err(SimpleAgentsError::Config(
248                    "streaming response returned from complete".to_string(),
249                ))
250            }
251            CompletionOutcome::HealedJson(_) => {
252                return Err(SimpleAgentsError::Config(
253                    "healed json response returned from complete".to_string(),
254                ))
255            }
256            CompletionOutcome::CoercedSchema(_) => {
257                return Err(SimpleAgentsError::Config(
258                    "schema response returned from complete".to_string(),
259                ))
260            }
261        };
262
263        Ok(response.content().unwrap_or_default().to_string())
264    })
265}
266
267/// Get the last error message for the current thread.
268///
269/// Returns null if there is no error. Caller must free the string.
270#[no_mangle]
271pub extern "C" fn sa_last_error_message() -> *mut c_char {
272    match take_last_error() {
273        Some(message) => match CString::new(message) {
274            Ok(c_string) => c_string.into_raw(),
275            Err(_) => std::ptr::null_mut(),
276        },
277        None => std::ptr::null_mut(),
278    }
279}
280
281/// Free a string returned by SimpleAgents FFI.
282///
283/// # Safety
284///
285/// The `value` pointer must be null or a valid pointer returned by a SimpleAgents FFI function.
286/// After calling this function, the pointer is no longer valid and must not be used.
287#[no_mangle]
288pub unsafe extern "C" fn sa_string_free(value: *mut c_char) {
289    if value.is_null() {
290        return;
291    }
292
293    drop(CString::from_raw(value));
294}