1use leptos::prelude::*;
4use cortexai_llm_client::{Message, Provider, RequestBuilder, ResponseParser};
5use wasm_bindgen::prelude::*;
6use wasm_bindgen_futures::JsFuture;
7use web_sys::{Request, RequestInit, RequestMode, Response};
8
9use crate::{ChatMessage, ChatOptions, CompletionOptions, LeptosAiError, Result};
10
11#[derive(Clone, Copy)]
13pub struct UseChat {
14 pub messages: RwSignal<Vec<ChatMessage>>,
16 pub is_loading: RwSignal<bool>,
18 pub error: RwSignal<Option<String>>,
20 pub streaming_content: RwSignal<String>,
22 options: RwSignal<ChatOptions>,
24 should_stop: RwSignal<bool>,
26}
27
28impl UseChat {
29 pub fn send(&self, message: &str) {
31 let opts = self.options.get_untracked();
32 let messages_signal = self.messages;
33 let is_loading = self.is_loading;
34 let error = self.error;
35 let streaming_content = self.streaming_content;
36 let should_stop = self.should_stop;
37
38 let user_msg = ChatMessage::user(message);
40 messages_signal.update(|msgs| msgs.push(user_msg));
41
42 is_loading.set(true);
44 error.set(None);
45 streaming_content.set(String::new());
46 should_stop.set(false);
47
48 let current_messages = messages_signal.get_untracked();
49
50 leptos::task::spawn_local(async move {
52 let result = if opts.stream {
53 send_streaming_request(&opts, current_messages, streaming_content, should_stop)
54 .await
55 } else {
56 send_request(&opts, current_messages).await
57 };
58
59 match result {
60 Ok(content) => {
61 let assistant_msg = ChatMessage::assistant(content);
62 messages_signal.update(|msgs| msgs.push(assistant_msg));
63 }
64 Err(e) => {
65 error.set(Some(e.to_string()));
66 }
67 }
68
69 is_loading.set(false);
70 streaming_content.set(String::new());
71 });
72 }
73
74 pub fn clear(&self) {
76 self.messages.set(Vec::new());
77 self.error.set(None);
78 self.streaming_content.set(String::new());
79 }
80
81 pub fn stop(&self) {
83 self.should_stop.set(true);
84 }
85}
86
87pub fn use_chat(options: ChatOptions) -> UseChat {
106 let messages = RwSignal::new(options.initial_messages.clone());
107 let is_loading = RwSignal::new(false);
108 let error = RwSignal::new(None::<String>);
109 let streaming_content = RwSignal::new(String::new());
110 let should_stop = RwSignal::new(false);
111 let options_signal = RwSignal::new(options);
112
113 UseChat {
114 messages,
115 is_loading,
116 error,
117 streaming_content,
118 options: options_signal,
119 should_stop,
120 }
121}
122
123#[derive(Clone, Copy)]
125pub struct UseCompletion {
126 pub completion: RwSignal<Option<String>>,
128 pub is_loading: RwSignal<bool>,
130 pub error: RwSignal<Option<String>>,
132 options: RwSignal<CompletionOptions>,
134}
135
136impl UseCompletion {
137 pub fn complete(&self, prompt: &str) {
139 let opts = self.options.get_untracked();
140 let completion = self.completion;
141 let is_loading = self.is_loading;
142 let error = self.error;
143 let prompt = prompt.to_string();
144
145 is_loading.set(true);
146 error.set(None);
147 completion.set(None);
148
149 leptos::task::spawn_local(async move {
150 let messages = vec![ChatMessage::user(&prompt)];
151
152 let chat_opts = ChatOptions {
153 provider: opts.provider,
154 api_key: opts.api_key,
155 model: opts.model,
156 system_prompt: opts.system_prompt,
157 temperature: opts.temperature,
158 max_tokens: opts.max_tokens,
159 stream: false,
160 initial_messages: Vec::new(),
161 };
162
163 match send_request(&chat_opts, messages).await {
164 Ok(content) => {
165 completion.set(Some(content));
166 }
167 Err(e) => {
168 error.set(Some(e.to_string()));
169 }
170 }
171
172 is_loading.set(false);
173 });
174 }
175}
176
177pub fn use_completion(options: CompletionOptions) -> UseCompletion {
198 let completion = RwSignal::new(None::<String>);
199 let is_loading = RwSignal::new(false);
200 let error = RwSignal::new(None::<String>);
201 let options_signal = RwSignal::new(options);
202
203 UseCompletion {
204 completion,
205 is_loading,
206 error,
207 options: options_signal,
208 }
209}
210
211async fn send_request(options: &ChatOptions, messages: Vec<ChatMessage>) -> Result<String> {
213 let provider: Provider = options
214 .provider
215 .parse()
216 .map_err(|_| LeptosAiError::InvalidProvider(options.provider.clone()))?;
217
218 let mut llm_messages: Vec<Message> = Vec::new();
220
221 if let Some(ref system) = options.system_prompt {
222 llm_messages.push(Message::system(system));
223 }
224
225 for msg in &messages {
226 match msg.role.as_str() {
227 "user" => llm_messages.push(Message::user(&msg.content)),
228 "assistant" => llm_messages.push(Message::assistant(&msg.content)),
229 "system" => llm_messages.push(Message::system(&msg.content)),
230 _ => {}
231 }
232 }
233
234 let http_request = RequestBuilder::new(provider)
236 .model(&options.model)
237 .api_key(&options.api_key)
238 .messages(&llm_messages)
239 .temperature(options.temperature)
240 .max_tokens(options.max_tokens)
241 .stream(false)
242 .build()?;
243
244 let response = fetch(&http_request.url, &http_request.headers, &http_request.body).await?;
246
247 let llm_response = ResponseParser::parse(provider, &response)?;
249
250 Ok(llm_response.content)
251}
252
253async fn send_streaming_request(
255 options: &ChatOptions,
256 messages: Vec<ChatMessage>,
257 streaming_content: RwSignal<String>,
258 should_stop: RwSignal<bool>,
259) -> Result<String> {
260 let provider: Provider = options
261 .provider
262 .parse()
263 .map_err(|_| LeptosAiError::InvalidProvider(options.provider.clone()))?;
264
265 let mut llm_messages: Vec<Message> = Vec::new();
267
268 if let Some(ref system) = options.system_prompt {
269 llm_messages.push(Message::system(system));
270 }
271
272 for msg in &messages {
273 match msg.role.as_str() {
274 "user" => llm_messages.push(Message::user(&msg.content)),
275 "assistant" => llm_messages.push(Message::assistant(&msg.content)),
276 "system" => llm_messages.push(Message::system(&msg.content)),
277 _ => {}
278 }
279 }
280
281 let http_request = RequestBuilder::new(provider)
283 .model(&options.model)
284 .api_key(&options.api_key)
285 .messages(&llm_messages)
286 .temperature(options.temperature)
287 .max_tokens(options.max_tokens)
288 .stream(true)
289 .build()?;
290
291 let response =
293 fetch_stream(&http_request.url, &http_request.headers, &http_request.body).await?;
294
295 let mut full_content = String::new();
297 let reader = response
298 .body()
299 .ok_or_else(|| LeptosAiError::StreamError("No response body".to_string()))?
300 .get_reader();
301
302 let reader: web_sys::ReadableStreamDefaultReader = reader.unchecked_into();
303
304 loop {
305 if should_stop.get_untracked() {
306 break;
307 }
308
309 let result = JsFuture::from(reader.read()).await;
310 let result = result.map_err(|e| LeptosAiError::StreamError(format!("{:?}", e)))?;
311
312 let done = js_sys::Reflect::get(&result, &JsValue::from_str("done"))
313 .map_err(|e| LeptosAiError::StreamError(format!("{:?}", e)))?
314 .as_bool()
315 .unwrap_or(true);
316
317 if done {
318 break;
319 }
320
321 let value = js_sys::Reflect::get(&result, &JsValue::from_str("value"))
322 .map_err(|e| LeptosAiError::StreamError(format!("{:?}", e)))?;
323
324 let array = js_sys::Uint8Array::new(&value);
325 let bytes = array.to_vec();
326 let text = String::from_utf8_lossy(&bytes);
327
328 for line in text.lines() {
330 if let Ok(Some(chunk)) = ResponseParser::parse_stream_line(provider, line) {
331 if let Some(content) = chunk.content {
332 full_content.push_str(&content);
333 streaming_content.set(full_content.clone());
334 }
335 if chunk.done {
336 break;
337 }
338 }
339 }
340 }
341
342 Ok(full_content)
343}
344
345async fn fetch(url: &str, headers: &[(String, String)], body: &str) -> Result<String> {
347 let opts = RequestInit::new();
348 opts.set_method("POST");
349 opts.set_mode(RequestMode::Cors);
350 opts.set_body(&JsValue::from_str(body));
351
352 let js_headers =
353 web_sys::Headers::new().map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
354
355 for (key, value) in headers {
356 js_headers
357 .set(key, value)
358 .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
359 }
360 opts.set_headers(&js_headers);
361
362 let request = Request::new_with_str_and_init(url, &opts)
363 .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
364
365 let window =
366 web_sys::window().ok_or_else(|| LeptosAiError::RequestFailed("No window".to_string()))?;
367
368 let resp_value = JsFuture::from(window.fetch_with_request(&request))
369 .await
370 .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
371
372 let resp: Response = resp_value
373 .dyn_into()
374 .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
375
376 if !resp.ok() {
377 let status = resp.status();
378 let text = JsFuture::from(
379 resp.text()
380 .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?,
381 )
382 .await
383 .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?
384 .as_string()
385 .unwrap_or_default();
386 return Err(LeptosAiError::ApiError(format!(
387 "HTTP {}: {}",
388 status, text
389 )));
390 }
391
392 let text = JsFuture::from(
393 resp.text()
394 .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?,
395 )
396 .await
397 .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?
398 .as_string()
399 .unwrap_or_default();
400
401 Ok(text)
402}
403
404async fn fetch_stream(url: &str, headers: &[(String, String)], body: &str) -> Result<Response> {
406 let opts = RequestInit::new();
407 opts.set_method("POST");
408 opts.set_mode(RequestMode::Cors);
409 opts.set_body(&JsValue::from_str(body));
410
411 let js_headers =
412 web_sys::Headers::new().map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
413
414 for (key, value) in headers {
415 js_headers
416 .set(key, value)
417 .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
418 }
419 opts.set_headers(&js_headers);
420
421 let request = Request::new_with_str_and_init(url, &opts)
422 .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
423
424 let window =
425 web_sys::window().ok_or_else(|| LeptosAiError::RequestFailed("No window".to_string()))?;
426
427 let resp_value = JsFuture::from(window.fetch_with_request(&request))
428 .await
429 .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
430
431 let resp: Response = resp_value
432 .dyn_into()
433 .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
434
435 if !resp.ok() {
436 let status = resp.status();
437 return Err(LeptosAiError::ApiError(format!("HTTP {}", status)));
438 }
439
440 Ok(resp)
441}