use wasm_bindgen::prelude::*;
use serde::{Serialize, Deserialize};
use serde_json::Value;
use reqwest::Client;
use futures::StreamExt;
use serde_json::json;
#[wasm_bindgen(start)]
pub fn init() {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}
#[derive(Serialize, Deserialize)]
pub struct LLMConfig {
api_key: String,
model: String,
base_url: Option<String>,
provider: Option<String>,
api_version: Option<String>,
organization: Option<String>,
max_retries: Option<i32>,
timeout: Option<f64>,
debug: Option<bool>,
max_tokens: Option<i32>,
buffer_size: Option<usize>,
}
#[wasm_bindgen]
pub struct RathCore {
config: LLMConfig,
client: Client,
debug: bool,
}
#[wasm_bindgen]
impl RathCore {
#[wasm_bindgen(constructor)]
pub fn new(config_js: JsValue) -> Result<RathCore, JsValue> {
let config: LLMConfig = serde_wasm_bindgen::from_value(config_js)?;
let debug = config.debug.unwrap_or(false);
let client = Client::new();
Ok(RathCore {
config,
client,
debug,
})
}
pub async fn completion(&self, messages: JsValue, stream: Option<bool>) -> Result<JsValue, JsValue> {
if self.debug {
web_sys::console::log_1(&"Starting completion request".into());
}
let messages: Vec<Value> = serde_wasm_bindgen::from_value(messages)?;
let stream = stream.unwrap_or(false);
let mut request_body = serde_json::json!({
"model": self.config.model,
"messages": messages,
"stream": stream,
});
if let Some(max_tokens) = self.config.max_tokens {
request_body["max_tokens"] = json!(max_tokens);
}
let base_url = self.config.base_url.as_deref().unwrap_or_else(|| {
if self.config.model.starts_with("claude-") {
"https://api.anthropic.com/v1"
} else if self.config.model.starts_with("gemini-") {
"https://generativelanguage.googleapis.com/v1/models"
} else {
"https://api.openai.com/v1"
}
});
let url = format!("{}/chat/completions", base_url);
let auth_header = if self.config.model.starts_with("claude-") {
format!("{}", self.config.api_key)
} else {
format!("Bearer {}", self.config.api_key)
};
let response = self.client
.post(&url)
.header("Authorization", auth_header)
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.map_err(|e| JsValue::from_str(&format!("{:?}", e)))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await
.map_err(|e| JsValue::from_str(&format!("{:?}", e)))?;
return Err(JsValue::from_str(&format!(
"HTTP error {}: {}",
status,
text
)));
}
if stream {
let mut stream = response.bytes_stream();
let buffer_size = self.config.buffer_size.unwrap_or(16384);
let mut buffer = Vec::with_capacity(buffer_size);
let mut total_bytes = 0;
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| JsValue::from_str(&format!("{:?}", e)))?;
total_bytes += chunk.len();
if buffer.len() + chunk.len() > buffer.capacity() {
let new_capacity = (buffer.capacity() * 3) / 2;
buffer.reserve(new_capacity - buffer.capacity());
}
buffer.extend_from_slice(&chunk);
}
if self.debug {
web_sys::console::log_1(&format!("Total bytes received: {}, Buffer size: {}", total_bytes, buffer_size).into());
}
let array = js_sys::Uint8Array::from(&buffer[..]);
Ok(array.into())
} else {
let text = response.text().await
.map_err(|e| JsValue::from_str(&format!("{:?}", e)))?;
if self.debug {
web_sys::console::log_1(&format!("Raw response: {}", text).into());
}
Ok(JsValue::from_str(&text))
}
}
}