1use dioxus::prelude::*;
4use cortexai_llm_client::{Message, Provider, RequestBuilder, ResponseParser};
5use wasm_bindgen::prelude::*;
6use wasm_bindgen_futures::JsFuture;
7use web_sys::{Request, RequestInit, RequestMode, Response};
8
9use crate::{ChatMessage, ChatOptions, CompletionOptions, DioxusAiError, Result};
10
11#[derive(Clone)]
13pub struct UseChatState {
14 messages: Signal<Vec<ChatMessage>>,
15 is_loading: Signal<bool>,
16 error: Signal<Option<String>>,
17 streaming_content: Signal<String>,
18 should_stop: Signal<bool>,
19 options: ChatOptions,
20}
21
22impl UseChatState {
23 pub fn messages(&self) -> Vec<ChatMessage> {
25 (self.messages)()
26 }
27
28 pub fn is_loading(&self) -> bool {
30 (self.is_loading)()
31 }
32
33 pub fn error(&self) -> Option<String> {
35 (self.error)()
36 }
37
38 pub fn streaming_content(&self) -> String {
40 (self.streaming_content)()
41 }
42
43 pub fn send(&mut self, message: &str) {
45 let user_msg = ChatMessage::user(message);
46 let mut messages = self.messages;
47 let mut is_loading = self.is_loading;
48 let mut error = self.error;
49 let mut streaming_content = self.streaming_content;
50 let mut should_stop = self.should_stop;
51 let options = self.options.clone();
52
53 messages.write().push(user_msg);
55
56 is_loading.set(true);
58 error.set(None);
59 streaming_content.set(String::new());
60 should_stop.set(false);
61
62 let current_messages = messages();
63
64 spawn(async move {
66 let result = if options.stream {
67 send_streaming_request(&options, current_messages, streaming_content, should_stop)
68 .await
69 } else {
70 send_request(&options, current_messages).await
71 };
72
73 match result {
74 Ok(content) => {
75 let assistant_msg = ChatMessage::assistant(content);
76 messages.write().push(assistant_msg);
77 }
78 Err(e) => {
79 error.set(Some(e.to_string()));
80 }
81 }
82
83 is_loading.set(false);
84 streaming_content.set(String::new());
85 });
86 }
87
88 pub fn clear(&mut self) {
90 self.messages.write().clear();
91 self.error.set(None);
92 self.streaming_content.set(String::new());
93 }
94
95 pub fn stop(&mut self) {
97 self.should_stop.set(true);
98 }
99}
100
101pub fn use_chat(options: ChatOptions) -> UseChatState {
116 let messages = use_signal(|| options.initial_messages.clone());
117 let is_loading = use_signal(|| false);
118 let error = use_signal(|| None::<String>);
119 let streaming_content = use_signal(String::new);
120 let should_stop = use_signal(|| false);
121
122 UseChatState {
123 messages,
124 is_loading,
125 error,
126 streaming_content,
127 should_stop,
128 options,
129 }
130}
131
132#[derive(Clone)]
134pub struct UseCompletionState {
135 completion: Signal<Option<String>>,
136 is_loading: Signal<bool>,
137 error: Signal<Option<String>>,
138 options: CompletionOptions,
139}
140
141impl UseCompletionState {
142 pub fn completion(&self) -> Option<String> {
144 (self.completion)()
145 }
146
147 pub fn is_loading(&self) -> bool {
149 (self.is_loading)()
150 }
151
152 pub fn error(&self) -> Option<String> {
154 (self.error)()
155 }
156
157 pub fn complete(&mut self, prompt: &str) {
159 let mut completion = self.completion;
160 let mut is_loading = self.is_loading;
161 let mut error = self.error;
162 let options = self.options.clone();
163 let prompt = prompt.to_string();
164
165 is_loading.set(true);
166 error.set(None);
167 completion.set(None);
168
169 spawn(async move {
170 let messages = vec![ChatMessage::user(&prompt)];
171
172 let chat_opts = ChatOptions {
173 provider: options.provider,
174 api_key: options.api_key,
175 model: options.model,
176 system_prompt: options.system_prompt,
177 temperature: options.temperature,
178 max_tokens: options.max_tokens,
179 stream: false,
180 initial_messages: Vec::new(),
181 };
182
183 match send_request(&chat_opts, messages).await {
184 Ok(content) => {
185 completion.set(Some(content));
186 }
187 Err(e) => {
188 error.set(Some(e.to_string()));
189 }
190 }
191
192 is_loading.set(false);
193 });
194 }
195}
196
197pub fn use_completion(options: CompletionOptions) -> UseCompletionState {
212 let completion = use_signal(|| None::<String>);
213 let is_loading = use_signal(|| false);
214 let error = use_signal(|| None::<String>);
215
216 UseCompletionState {
217 completion,
218 is_loading,
219 error,
220 options,
221 }
222}
223
224async fn send_request(options: &ChatOptions, messages: Vec<ChatMessage>) -> Result<String> {
226 let provider: Provider = options
227 .provider
228 .parse()
229 .map_err(|_| DioxusAiError::InvalidProvider(options.provider.clone()))?;
230
231 let mut llm_messages: Vec<Message> = Vec::new();
232
233 if let Some(ref system) = options.system_prompt {
234 llm_messages.push(Message::system(system));
235 }
236
237 for msg in &messages {
238 match msg.role.as_str() {
239 "user" => llm_messages.push(Message::user(&msg.content)),
240 "assistant" => llm_messages.push(Message::assistant(&msg.content)),
241 "system" => llm_messages.push(Message::system(&msg.content)),
242 _ => {}
243 }
244 }
245
246 let http_request = RequestBuilder::new(provider)
247 .model(&options.model)
248 .api_key(&options.api_key)
249 .messages(&llm_messages)
250 .temperature(options.temperature)
251 .max_tokens(options.max_tokens)
252 .stream(false)
253 .build()?;
254
255 let response = fetch(&http_request.url, &http_request.headers, &http_request.body).await?;
256 let llm_response = ResponseParser::parse(provider, &response)?;
257
258 Ok(llm_response.content)
259}
260
261async fn send_streaming_request(
263 options: &ChatOptions,
264 messages: Vec<ChatMessage>,
265 mut streaming_content: Signal<String>,
266 should_stop: Signal<bool>,
267) -> Result<String> {
268 let provider: Provider = options
269 .provider
270 .parse()
271 .map_err(|_| DioxusAiError::InvalidProvider(options.provider.clone()))?;
272
273 let mut llm_messages: Vec<Message> = Vec::new();
274
275 if let Some(ref system) = options.system_prompt {
276 llm_messages.push(Message::system(system));
277 }
278
279 for msg in &messages {
280 match msg.role.as_str() {
281 "user" => llm_messages.push(Message::user(&msg.content)),
282 "assistant" => llm_messages.push(Message::assistant(&msg.content)),
283 "system" => llm_messages.push(Message::system(&msg.content)),
284 _ => {}
285 }
286 }
287
288 let http_request = RequestBuilder::new(provider)
289 .model(&options.model)
290 .api_key(&options.api_key)
291 .messages(&llm_messages)
292 .temperature(options.temperature)
293 .max_tokens(options.max_tokens)
294 .stream(true)
295 .build()?;
296
297 let response =
298 fetch_stream(&http_request.url, &http_request.headers, &http_request.body).await?;
299
300 let mut full_content = String::new();
301 let reader = response
302 .body()
303 .ok_or_else(|| DioxusAiError::StreamError("No response body".to_string()))?
304 .get_reader();
305
306 let reader: web_sys::ReadableStreamDefaultReader = reader.unchecked_into();
307
308 loop {
309 if should_stop() {
310 break;
311 }
312
313 let result = JsFuture::from(reader.read()).await;
314 let result = result.map_err(|e| DioxusAiError::StreamError(format!("{:?}", e)))?;
315
316 let done = js_sys::Reflect::get(&result, &JsValue::from_str("done"))
317 .map_err(|e| DioxusAiError::StreamError(format!("{:?}", e)))?
318 .as_bool()
319 .unwrap_or(true);
320
321 if done {
322 break;
323 }
324
325 let value = js_sys::Reflect::get(&result, &JsValue::from_str("value"))
326 .map_err(|e| DioxusAiError::StreamError(format!("{:?}", e)))?;
327
328 let array = js_sys::Uint8Array::new(&value);
329 let bytes = array.to_vec();
330 let text = String::from_utf8_lossy(&bytes);
331
332 for line in text.lines() {
333 if let Ok(Some(chunk)) = ResponseParser::parse_stream_line(provider, line) {
334 if let Some(content) = chunk.content {
335 full_content.push_str(&content);
336 streaming_content.set(full_content.clone());
337 }
338 if chunk.done {
339 break;
340 }
341 }
342 }
343 }
344
345 Ok(full_content)
346}
347
348async fn fetch(url: &str, headers: &[(String, String)], body: &str) -> Result<String> {
349 let opts = RequestInit::new();
350 opts.set_method("POST");
351 opts.set_mode(RequestMode::Cors);
352 opts.set_body(&JsValue::from_str(body));
353
354 let js_headers =
355 web_sys::Headers::new().map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
356
357 for (key, value) in headers {
358 js_headers
359 .set(key, value)
360 .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
361 }
362 opts.set_headers(&js_headers);
363
364 let request = Request::new_with_str_and_init(url, &opts)
365 .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
366
367 let window =
368 web_sys::window().ok_or_else(|| DioxusAiError::RequestFailed("No window".to_string()))?;
369
370 let resp_value = JsFuture::from(window.fetch_with_request(&request))
371 .await
372 .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
373
374 let resp: Response = resp_value
375 .dyn_into()
376 .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
377
378 if !resp.ok() {
379 let status = resp.status();
380 let text = JsFuture::from(
381 resp.text()
382 .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?,
383 )
384 .await
385 .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?
386 .as_string()
387 .unwrap_or_default();
388 return Err(DioxusAiError::ApiError(format!(
389 "HTTP {}: {}",
390 status, text
391 )));
392 }
393
394 let text = JsFuture::from(
395 resp.text()
396 .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?,
397 )
398 .await
399 .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?
400 .as_string()
401 .unwrap_or_default();
402
403 Ok(text)
404}
405
406async fn fetch_stream(url: &str, headers: &[(String, String)], body: &str) -> Result<Response> {
407 let opts = RequestInit::new();
408 opts.set_method("POST");
409 opts.set_mode(RequestMode::Cors);
410 opts.set_body(&JsValue::from_str(body));
411
412 let js_headers =
413 web_sys::Headers::new().map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
414
415 for (key, value) in headers {
416 js_headers
417 .set(key, value)
418 .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
419 }
420 opts.set_headers(&js_headers);
421
422 let request = Request::new_with_str_and_init(url, &opts)
423 .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
424
425 let window =
426 web_sys::window().ok_or_else(|| DioxusAiError::RequestFailed("No window".to_string()))?;
427
428 let resp_value = JsFuture::from(window.fetch_with_request(&request))
429 .await
430 .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
431
432 let resp: Response = resp_value
433 .dyn_into()
434 .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
435
436 if !resp.ok() {
437 let status = resp.status();
438 return Err(DioxusAiError::ApiError(format!("HTTP {}", status)));
439 }
440
441 Ok(resp)
442}