Skip to main content

cortexai_wasm/
lib.rs

1//! # Cortex WebAssembly
2//!
3//! This crate provides two complementary capabilities:
4//!
5//! ## Browser feature (default)
6//! WebAssembly bindings for running AI agents in the browser.
7//!
8//! ## Sandbox feature
9//! WASM sandbox for executing untrusted tool code on the host using wasmtime.
10
11// Browser-target modules (wasm-bindgen based)
12#[cfg(feature = "browser")]
13mod error;
14#[cfg(feature = "browser")]
15mod provider;
16#[cfg(feature = "browser")]
17mod streaming;
18#[cfg(feature = "browser")]
19mod types;
20
21#[cfg(feature = "browser")]
22pub use error::*;
23#[cfg(feature = "browser")]
24pub use provider::*;
25#[cfg(feature = "browser")]
26pub use streaming::*;
27#[cfg(feature = "browser")]
28pub use types::*;
29
30// Sandbox modules (wasmtime-based, host-only)
31#[cfg(feature = "sandbox")]
32pub mod sandbox;
33#[cfg(feature = "sandbox")]
34pub mod sandbox_error;
35
36// === Browser-only code below ===
37
38#[cfg(feature = "browser")]
39mod browser {
40    use serde::{Deserialize, Serialize};
41    use wasm_bindgen::prelude::*;
42    use wasm_bindgen_futures::JsFuture;
43    use web_sys::{Request, RequestInit, RequestMode, Response};
44
45    use crate::{WasmResponse, WasmToolCall, WasmUsage};
46
47    /// Type alias for HTTP request components: (URL, Headers, Body)
48    type HttpRequestParts = (String, Vec<(String, String)>, String);
49
50    /// Initialize the WASM module
51    #[wasm_bindgen(start)]
52    pub fn init() {
53        #[cfg(feature = "console_error_panic_hook")]
54        console_error_panic_hook::set_once();
55        tracing_wasm::set_as_global_default();
56    }
57
58    /// Agent configuration for WASM
59    #[derive(Debug, Clone, Serialize, Deserialize)]
60    #[wasm_bindgen]
61    pub struct WasmAgentConfig {
62        #[wasm_bindgen(skip)]
63        pub provider: String,
64        #[wasm_bindgen(skip)]
65        pub api_key: String,
66        #[wasm_bindgen(skip)]
67        pub model: String,
68        #[wasm_bindgen(skip)]
69        pub system_prompt: Option<String>,
70        #[wasm_bindgen(skip)]
71        pub temperature: f32,
72        #[wasm_bindgen(skip)]
73        pub max_tokens: u32,
74    }
75
76    #[wasm_bindgen]
77    impl WasmAgentConfig {
78        #[wasm_bindgen(constructor)]
79        pub fn new(provider: String, api_key: String, model: String) -> Self {
80            Self {
81                provider,
82                api_key,
83                model,
84                system_prompt: None,
85                temperature: 0.7,
86                max_tokens: 4096,
87            }
88        }
89
90        #[wasm_bindgen(setter)]
91        pub fn set_system_prompt(&mut self, prompt: String) {
92            self.system_prompt = Some(prompt);
93        }
94
95        #[wasm_bindgen(setter)]
96        pub fn set_temperature(&mut self, temp: f32) {
97            self.temperature = temp;
98        }
99
100        #[wasm_bindgen(setter)]
101        pub fn set_max_tokens(&mut self, tokens: u32) {
102            self.max_tokens = tokens;
103        }
104    }
105
106    /// WASM-compatible AI Agent
107    #[wasm_bindgen]
108    pub struct WasmAgent {
109        config: WasmAgentConfig,
110        messages: Vec<WasmMessage>,
111    }
112
113    #[wasm_bindgen]
114    impl WasmAgent {
115        #[wasm_bindgen(constructor)]
116        pub fn new(config: WasmAgentConfig) -> Self {
117            Self {
118                config,
119                messages: Vec::new(),
120            }
121        }
122
123        #[wasm_bindgen(js_name = fromObject)]
124        pub fn from_object(obj: JsValue) -> Result<WasmAgent, JsValue> {
125            let config: WasmAgentConfig = serde_wasm_bindgen::from_value(obj)
126                .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
127            Ok(Self::new(config))
128        }
129
130        #[wasm_bindgen]
131        pub async fn chat(&mut self, message: &str) -> Result<JsValue, JsValue> {
132            self.messages.push(WasmMessage::user(message.to_string()));
133            let response = self.call_llm(false).await?;
134            let content = self.extract_content(&response)?;
135            self.messages.push(WasmMessage::assistant(content.clone()));
136
137            Ok(serde_wasm_bindgen::to_value(&WasmResponse {
138                content,
139                tool_calls: None,
140                usage: None,
141            })?)
142        }
143
144        #[wasm_bindgen(js_name = chatStream)]
145        pub async fn chat_stream(&mut self, message: &str) -> Result<JsValue, JsValue> {
146            self.messages.push(WasmMessage::user(message.to_string()));
147            let response = self.call_llm(true).await?;
148            let body = Response::from(response)
149                .body()
150                .ok_or_else(|| JsValue::from_str("No response body"))?;
151            Ok(body.into())
152        }
153
154        #[wasm_bindgen]
155        pub fn clear(&mut self) {
156            self.messages.clear();
157        }
158
159        #[wasm_bindgen(js_name = getHistory)]
160        pub fn get_history(&self) -> Result<JsValue, JsValue> {
161            serde_wasm_bindgen::to_value(&self.messages)
162                .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
163        }
164
165        async fn call_llm(&self, stream: bool) -> Result<JsValue, JsValue> {
166            let (url, headers, body) = self.build_request(stream)?;
167            let opts = RequestInit::new();
168            opts.set_method("POST");
169            opts.set_mode(RequestMode::Cors);
170
171            let js_headers = web_sys::Headers::new()?;
172            for (key, value) in headers {
173                js_headers.set(&key, &value)?;
174            }
175            opts.set_headers(&js_headers);
176            opts.set_body(&JsValue::from_str(&body));
177
178            let request = Request::new_with_str_and_init(&url, &opts)?;
179            let window = web_sys::window().ok_or_else(|| JsValue::from_str("No window"))?;
180            let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;
181            let resp: Response = resp_value.dyn_into()?;
182
183            if !resp.ok() {
184                let status = resp.status();
185                let text = JsFuture::from(resp.text()?).await?;
186                return Err(JsValue::from_str(&format!(
187                    "HTTP {}: {}",
188                    status,
189                    text.as_string().unwrap_or_default()
190                )));
191            }
192
193            if stream {
194                Ok(resp.into())
195            } else {
196                let json = JsFuture::from(resp.json()?).await?;
197                Ok(json)
198            }
199        }
200
201        fn build_request(&self, stream: bool) -> Result<HttpRequestParts, JsValue> {
202            match self.config.provider.as_str() {
203                "openai" => self.build_openai_request(stream),
204                "anthropic" => self.build_anthropic_request(stream),
205                "openrouter" => self.build_openrouter_request(stream),
206                _ => Err(JsValue::from_str(&format!(
207                    "Unknown provider: {}",
208                    self.config.provider
209                ))),
210            }
211        }
212
213        fn build_openai_request(&self, stream: bool) -> Result<HttpRequestParts, JsValue> {
214            let url = "https://api.openai.com/v1/chat/completions".to_string();
215            let headers = vec![
216                ("Content-Type".to_string(), "application/json".to_string()),
217                (
218                    "Authorization".to_string(),
219                    format!("Bearer {}", self.config.api_key),
220                ),
221            ];
222            let messages: Vec<serde_json::Value> = self.build_messages();
223            let body = serde_json::json!({
224                "model": self.config.model,
225                "messages": messages,
226                "temperature": self.config.temperature,
227                "max_tokens": self.config.max_tokens,
228                "stream": stream
229            });
230            Ok((url, headers, body.to_string()))
231        }
232
233        fn build_anthropic_request(&self, stream: bool) -> Result<HttpRequestParts, JsValue> {
234            let url = "https://api.anthropic.com/v1/messages".to_string();
235            let headers = vec![
236                ("Content-Type".to_string(), "application/json".to_string()),
237                ("x-api-key".to_string(), self.config.api_key.clone()),
238                ("anthropic-version".to_string(), "2023-06-01".to_string()),
239                (
240                    "anthropic-dangerous-direct-browser-access".to_string(),
241                    "true".to_string(),
242                ),
243            ];
244            let (system, messages) = self.build_anthropic_messages();
245            let mut body = serde_json::json!({
246                "model": self.config.model,
247                "messages": messages,
248                "max_tokens": self.config.max_tokens,
249                "stream": stream
250            });
251            if let Some(sys) = system {
252                body["system"] = serde_json::json!(sys);
253            }
254            Ok((url, headers, body.to_string()))
255        }
256
257        fn build_openrouter_request(&self, stream: bool) -> Result<HttpRequestParts, JsValue> {
258            let url = "https://openrouter.ai/api/v1/chat/completions".to_string();
259            let headers = vec![
260                ("Content-Type".to_string(), "application/json".to_string()),
261                (
262                    "Authorization".to_string(),
263                    format!("Bearer {}", self.config.api_key),
264                ),
265                (
266                    "HTTP-Referer".to_string(),
267                    "https://cortex.dev".to_string(),
268                ),
269            ];
270            let messages: Vec<serde_json::Value> = self.build_messages();
271            let body = serde_json::json!({
272                "model": self.config.model,
273                "messages": messages,
274                "temperature": self.config.temperature,
275                "max_tokens": self.config.max_tokens,
276                "stream": stream
277            });
278            Ok((url, headers, body.to_string()))
279        }
280
281        fn build_messages(&self) -> Vec<serde_json::Value> {
282            let mut messages = Vec::new();
283            if let Some(ref system) = self.config.system_prompt {
284                messages.push(serde_json::json!({
285                    "role": "system",
286                    "content": system
287                }));
288            }
289            for msg in &self.messages {
290                messages.push(serde_json::json!({
291                    "role": msg.role,
292                    "content": msg.content
293                }));
294            }
295            messages
296        }
297
298        fn build_anthropic_messages(&self) -> (Option<String>, Vec<serde_json::Value>) {
299            let system = self.config.system_prompt.clone();
300            let messages: Vec<serde_json::Value> = self
301                .messages
302                .iter()
303                .map(|msg| {
304                    serde_json::json!({
305                        "role": if msg.role == "user" { "user" } else { "assistant" },
306                        "content": msg.content
307                    })
308                })
309                .collect();
310            (system, messages)
311        }
312
313        fn extract_content(&self, response: &JsValue) -> Result<String, JsValue> {
314            let obj: serde_json::Value = serde_wasm_bindgen::from_value(response.clone())
315                .map_err(|e| JsValue::from_str(&format!("Parse error: {}", e)))?;
316            match self.config.provider.as_str() {
317                "openai" | "openrouter" => obj["choices"][0]["message"]["content"]
318                    .as_str()
319                    .map(|s| s.to_string())
320                    .ok_or_else(|| JsValue::from_str("No content in response")),
321                "anthropic" => obj["content"][0]["text"]
322                    .as_str()
323                    .map(|s| s.to_string())
324                    .ok_or_else(|| JsValue::from_str("No content in response")),
325                _ => Err(JsValue::from_str("Unknown provider")),
326            }
327        }
328    }
329
330    /// Message in conversation
331    #[derive(Debug, Clone, Serialize, Deserialize)]
332    #[wasm_bindgen]
333    pub struct WasmMessage {
334        #[wasm_bindgen(skip)]
335        pub role: String,
336        #[wasm_bindgen(skip)]
337        pub content: String,
338    }
339
340    #[wasm_bindgen]
341    impl WasmMessage {
342        #[wasm_bindgen(constructor)]
343        pub fn new(role: String, content: String) -> Self {
344            Self { role, content }
345        }
346
347        #[wasm_bindgen(getter)]
348        pub fn role(&self) -> String {
349            self.role.clone()
350        }
351
352        #[wasm_bindgen(getter)]
353        pub fn content(&self) -> String {
354            self.content.clone()
355        }
356
357        pub fn user(content: String) -> Self {
358            Self {
359                role: "user".to_string(),
360                content,
361            }
362        }
363
364        pub fn assistant(content: String) -> Self {
365            Self {
366                role: "assistant".to_string(),
367                content,
368            }
369        }
370    }
371
372    #[cfg(test)]
373    mod tests {
374        use super::*;
375        use wasm_bindgen_test::*;
376
377        wasm_bindgen_test_configure!(run_in_browser);
378
379        #[wasm_bindgen_test]
380        fn test_config_creation() {
381            let config = WasmAgentConfig::new(
382                "openai".to_string(),
383                "test-key".to_string(),
384                "gpt-4".to_string(),
385            );
386            assert_eq!(config.provider, "openai");
387            assert_eq!(config.model, "gpt-4");
388            assert_eq!(config.temperature, 0.7);
389        }
390
391        #[wasm_bindgen_test]
392        fn test_agent_creation() {
393            let config = WasmAgentConfig::new(
394                "openai".to_string(),
395                "test-key".to_string(),
396                "gpt-4".to_string(),
397            );
398            let agent = WasmAgent::new(config);
399            assert!(agent.messages.is_empty());
400        }
401
402        #[wasm_bindgen_test]
403        fn test_message_creation() {
404            let msg = WasmMessage::user("Hello".to_string());
405            assert_eq!(msg.role, "user");
406            assert_eq!(msg.content, "Hello");
407            let msg = WasmMessage::assistant("Hi there!".to_string());
408            assert_eq!(msg.role, "assistant");
409        }
410    }
411}
412
413#[cfg(feature = "browser")]
414pub use browser::*;