Skip to main content

dioxus_ai/
hooks.rs

1//! Reactive hooks for Dioxus AI
2
3use dioxus::prelude::*;
4use cortexai_llm_client::{Message, Provider, RequestBuilder, ResponseParser};
5use wasm_bindgen::prelude::*;
6use wasm_bindgen_futures::JsFuture;
7use web_sys::{Request, RequestInit, RequestMode, Response};
8
9use crate::{ChatMessage, ChatOptions, CompletionOptions, DioxusAiError, Result};
10
11/// State for the chat hook
12#[derive(Clone)]
13pub struct UseChatState {
14    messages: Signal<Vec<ChatMessage>>,
15    is_loading: Signal<bool>,
16    error: Signal<Option<String>>,
17    streaming_content: Signal<String>,
18    should_stop: Signal<bool>,
19    options: ChatOptions,
20}
21
22impl UseChatState {
23    /// Get current messages
24    pub fn messages(&self) -> Vec<ChatMessage> {
25        (self.messages)()
26    }
27
28    /// Check if loading
29    pub fn is_loading(&self) -> bool {
30        (self.is_loading)()
31    }
32
33    /// Get current error
34    pub fn error(&self) -> Option<String> {
35        (self.error)()
36    }
37
38    /// Get streaming content
39    pub fn streaming_content(&self) -> String {
40        (self.streaming_content)()
41    }
42
43    /// Send a message
44    pub fn send(&mut self, message: &str) {
45        let user_msg = ChatMessage::user(message);
46        let mut messages = self.messages;
47        let mut is_loading = self.is_loading;
48        let mut error = self.error;
49        let mut streaming_content = self.streaming_content;
50        let mut should_stop = self.should_stop;
51        let options = self.options.clone();
52
53        // Add user message
54        messages.write().push(user_msg);
55
56        // Reset state
57        is_loading.set(true);
58        error.set(None);
59        streaming_content.set(String::new());
60        should_stop.set(false);
61
62        let current_messages = messages();
63
64        // Spawn async request
65        spawn(async move {
66            let result = if options.stream {
67                send_streaming_request(&options, current_messages, streaming_content, should_stop)
68                    .await
69            } else {
70                send_request(&options, current_messages).await
71            };
72
73            match result {
74                Ok(content) => {
75                    let assistant_msg = ChatMessage::assistant(content);
76                    messages.write().push(assistant_msg);
77                }
78                Err(e) => {
79                    error.set(Some(e.to_string()));
80                }
81            }
82
83            is_loading.set(false);
84            streaming_content.set(String::new());
85        });
86    }
87
88    /// Clear all messages
89    pub fn clear(&mut self) {
90        self.messages.write().clear();
91        self.error.set(None);
92        self.streaming_content.set(String::new());
93    }
94
95    /// Stop generation
96    pub fn stop(&mut self) {
97        self.should_stop.set(true);
98    }
99}
100
101/// Create a reactive chat interface
102///
103/// # Example
104///
105/// ```rust,ignore
106/// let mut chat = use_chat(ChatOptions {
107///     provider: "openai".to_string(),
108///     api_key: "sk-...".to_string(),
109///     model: "gpt-4o-mini".to_string(),
110///     ..Default::default()
111/// });
112///
113/// chat.send("Hello!");
114/// ```
115pub fn use_chat(options: ChatOptions) -> UseChatState {
116    let messages = use_signal(|| options.initial_messages.clone());
117    let is_loading = use_signal(|| false);
118    let error = use_signal(|| None::<String>);
119    let streaming_content = use_signal(String::new);
120    let should_stop = use_signal(|| false);
121
122    UseChatState {
123        messages,
124        is_loading,
125        error,
126        streaming_content,
127        should_stop,
128        options,
129    }
130}
131
132/// State for the completion hook
133#[derive(Clone)]
134pub struct UseCompletionState {
135    completion: Signal<Option<String>>,
136    is_loading: Signal<bool>,
137    error: Signal<Option<String>>,
138    options: CompletionOptions,
139}
140
141impl UseCompletionState {
142    /// Get completion result
143    pub fn completion(&self) -> Option<String> {
144        (self.completion)()
145    }
146
147    /// Check if loading
148    pub fn is_loading(&self) -> bool {
149        (self.is_loading)()
150    }
151
152    /// Get current error
153    pub fn error(&self) -> Option<String> {
154        (self.error)()
155    }
156
157    /// Request a completion
158    pub fn complete(&mut self, prompt: &str) {
159        let mut completion = self.completion;
160        let mut is_loading = self.is_loading;
161        let mut error = self.error;
162        let options = self.options.clone();
163        let prompt = prompt.to_string();
164
165        is_loading.set(true);
166        error.set(None);
167        completion.set(None);
168
169        spawn(async move {
170            let messages = vec![ChatMessage::user(&prompt)];
171
172            let chat_opts = ChatOptions {
173                provider: options.provider,
174                api_key: options.api_key,
175                model: options.model,
176                system_prompt: options.system_prompt,
177                temperature: options.temperature,
178                max_tokens: options.max_tokens,
179                stream: false,
180                initial_messages: Vec::new(),
181            };
182
183            match send_request(&chat_opts, messages).await {
184                Ok(content) => {
185                    completion.set(Some(content));
186                }
187                Err(e) => {
188                    error.set(Some(e.to_string()));
189                }
190            }
191
192            is_loading.set(false);
193        });
194    }
195}
196
197/// Create a completion interface
198///
199/// # Example
200///
201/// ```rust,ignore
202/// let mut completion = use_completion(CompletionOptions {
203///     provider: "openai".to_string(),
204///     api_key: "sk-...".to_string(),
205///     model: "gpt-4o-mini".to_string(),
206///     ..Default::default()
207/// });
208///
209/// completion.complete("Translate 'hello' to Spanish");
210/// ```
211pub fn use_completion(options: CompletionOptions) -> UseCompletionState {
212    let completion = use_signal(|| None::<String>);
213    let is_loading = use_signal(|| false);
214    let error = use_signal(|| None::<String>);
215
216    UseCompletionState {
217        completion,
218        is_loading,
219        error,
220        options,
221    }
222}
223
224/// Send a non-streaming request
225async fn send_request(options: &ChatOptions, messages: Vec<ChatMessage>) -> Result<String> {
226    let provider: Provider = options
227        .provider
228        .parse()
229        .map_err(|_| DioxusAiError::InvalidProvider(options.provider.clone()))?;
230
231    let mut llm_messages: Vec<Message> = Vec::new();
232
233    if let Some(ref system) = options.system_prompt {
234        llm_messages.push(Message::system(system));
235    }
236
237    for msg in &messages {
238        match msg.role.as_str() {
239            "user" => llm_messages.push(Message::user(&msg.content)),
240            "assistant" => llm_messages.push(Message::assistant(&msg.content)),
241            "system" => llm_messages.push(Message::system(&msg.content)),
242            _ => {}
243        }
244    }
245
246    let http_request = RequestBuilder::new(provider)
247        .model(&options.model)
248        .api_key(&options.api_key)
249        .messages(&llm_messages)
250        .temperature(options.temperature)
251        .max_tokens(options.max_tokens)
252        .stream(false)
253        .build()?;
254
255    let response = fetch(&http_request.url, &http_request.headers, &http_request.body).await?;
256    let llm_response = ResponseParser::parse(provider, &response)?;
257
258    Ok(llm_response.content)
259}
260
261/// Send a streaming request
262async fn send_streaming_request(
263    options: &ChatOptions,
264    messages: Vec<ChatMessage>,
265    mut streaming_content: Signal<String>,
266    should_stop: Signal<bool>,
267) -> Result<String> {
268    let provider: Provider = options
269        .provider
270        .parse()
271        .map_err(|_| DioxusAiError::InvalidProvider(options.provider.clone()))?;
272
273    let mut llm_messages: Vec<Message> = Vec::new();
274
275    if let Some(ref system) = options.system_prompt {
276        llm_messages.push(Message::system(system));
277    }
278
279    for msg in &messages {
280        match msg.role.as_str() {
281            "user" => llm_messages.push(Message::user(&msg.content)),
282            "assistant" => llm_messages.push(Message::assistant(&msg.content)),
283            "system" => llm_messages.push(Message::system(&msg.content)),
284            _ => {}
285        }
286    }
287
288    let http_request = RequestBuilder::new(provider)
289        .model(&options.model)
290        .api_key(&options.api_key)
291        .messages(&llm_messages)
292        .temperature(options.temperature)
293        .max_tokens(options.max_tokens)
294        .stream(true)
295        .build()?;
296
297    let response =
298        fetch_stream(&http_request.url, &http_request.headers, &http_request.body).await?;
299
300    let mut full_content = String::new();
301    let reader = response
302        .body()
303        .ok_or_else(|| DioxusAiError::StreamError("No response body".to_string()))?
304        .get_reader();
305
306    let reader: web_sys::ReadableStreamDefaultReader = reader.unchecked_into();
307
308    loop {
309        if should_stop() {
310            break;
311        }
312
313        let result = JsFuture::from(reader.read()).await;
314        let result = result.map_err(|e| DioxusAiError::StreamError(format!("{:?}", e)))?;
315
316        let done = js_sys::Reflect::get(&result, &JsValue::from_str("done"))
317            .map_err(|e| DioxusAiError::StreamError(format!("{:?}", e)))?
318            .as_bool()
319            .unwrap_or(true);
320
321        if done {
322            break;
323        }
324
325        let value = js_sys::Reflect::get(&result, &JsValue::from_str("value"))
326            .map_err(|e| DioxusAiError::StreamError(format!("{:?}", e)))?;
327
328        let array = js_sys::Uint8Array::new(&value);
329        let bytes = array.to_vec();
330        let text = String::from_utf8_lossy(&bytes);
331
332        for line in text.lines() {
333            if let Ok(Some(chunk)) = ResponseParser::parse_stream_line(provider, line) {
334                if let Some(content) = chunk.content {
335                    full_content.push_str(&content);
336                    streaming_content.set(full_content.clone());
337                }
338                if chunk.done {
339                    break;
340                }
341            }
342        }
343    }
344
345    Ok(full_content)
346}
347
348async fn fetch(url: &str, headers: &[(String, String)], body: &str) -> Result<String> {
349    let opts = RequestInit::new();
350    opts.set_method("POST");
351    opts.set_mode(RequestMode::Cors);
352    opts.set_body(&JsValue::from_str(body));
353
354    let js_headers =
355        web_sys::Headers::new().map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
356
357    for (key, value) in headers {
358        js_headers
359            .set(key, value)
360            .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
361    }
362    opts.set_headers(&js_headers);
363
364    let request = Request::new_with_str_and_init(url, &opts)
365        .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
366
367    let window =
368        web_sys::window().ok_or_else(|| DioxusAiError::RequestFailed("No window".to_string()))?;
369
370    let resp_value = JsFuture::from(window.fetch_with_request(&request))
371        .await
372        .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
373
374    let resp: Response = resp_value
375        .dyn_into()
376        .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
377
378    if !resp.ok() {
379        let status = resp.status();
380        let text = JsFuture::from(
381            resp.text()
382                .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?,
383        )
384        .await
385        .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?
386        .as_string()
387        .unwrap_or_default();
388        return Err(DioxusAiError::ApiError(format!(
389            "HTTP {}: {}",
390            status, text
391        )));
392    }
393
394    let text = JsFuture::from(
395        resp.text()
396            .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?,
397    )
398    .await
399    .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?
400    .as_string()
401    .unwrap_or_default();
402
403    Ok(text)
404}
405
406async fn fetch_stream(url: &str, headers: &[(String, String)], body: &str) -> Result<Response> {
407    let opts = RequestInit::new();
408    opts.set_method("POST");
409    opts.set_mode(RequestMode::Cors);
410    opts.set_body(&JsValue::from_str(body));
411
412    let js_headers =
413        web_sys::Headers::new().map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
414
415    for (key, value) in headers {
416        js_headers
417            .set(key, value)
418            .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
419    }
420    opts.set_headers(&js_headers);
421
422    let request = Request::new_with_str_and_init(url, &opts)
423        .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
424
425    let window =
426        web_sys::window().ok_or_else(|| DioxusAiError::RequestFailed("No window".to_string()))?;
427
428    let resp_value = JsFuture::from(window.fetch_with_request(&request))
429        .await
430        .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
431
432    let resp: Response = resp_value
433        .dyn_into()
434        .map_err(|e| DioxusAiError::RequestFailed(format!("{:?}", e)))?;
435
436    if !resp.ok() {
437        let status = resp.status();
438        return Err(DioxusAiError::ApiError(format!("HTTP {}", status)));
439    }
440
441    Ok(resp)
442}