Skip to main content

leptos_ai/
hooks.rs

1//! Reactive hooks for Leptos AI
2
3use leptos::prelude::*;
4use cortexai_llm_client::{Message, Provider, RequestBuilder, ResponseParser};
5use wasm_bindgen::prelude::*;
6use wasm_bindgen_futures::JsFuture;
7use web_sys::{Request, RequestInit, RequestMode, Response};
8
9use crate::{ChatMessage, ChatOptions, CompletionOptions, LeptosAiError, Result};
10
11/// Return type for use_chat hook
12#[derive(Clone, Copy)]
13pub struct UseChat {
14    /// Reactive list of messages
15    pub messages: RwSignal<Vec<ChatMessage>>,
16    /// Whether a request is in progress
17    pub is_loading: RwSignal<bool>,
18    /// Current error, if any
19    pub error: RwSignal<Option<String>>,
20    /// Current streaming content (updated in real-time)
21    pub streaming_content: RwSignal<String>,
22    /// Internal: options stored
23    options: RwSignal<ChatOptions>,
24    /// Internal: stop flag
25    should_stop: RwSignal<bool>,
26}
27
28impl UseChat {
29    /// Send a message to the chat
30    pub fn send(&self, message: &str) {
31        let opts = self.options.get_untracked();
32        let messages_signal = self.messages;
33        let is_loading = self.is_loading;
34        let error = self.error;
35        let streaming_content = self.streaming_content;
36        let should_stop = self.should_stop;
37
38        // Add user message
39        let user_msg = ChatMessage::user(message);
40        messages_signal.update(|msgs| msgs.push(user_msg));
41
42        // Reset state
43        is_loading.set(true);
44        error.set(None);
45        streaming_content.set(String::new());
46        should_stop.set(false);
47
48        let current_messages = messages_signal.get_untracked();
49
50        // Spawn the async request
51        leptos::task::spawn_local(async move {
52            let result = if opts.stream {
53                send_streaming_request(&opts, current_messages, streaming_content, should_stop)
54                    .await
55            } else {
56                send_request(&opts, current_messages).await
57            };
58
59            match result {
60                Ok(content) => {
61                    let assistant_msg = ChatMessage::assistant(content);
62                    messages_signal.update(|msgs| msgs.push(assistant_msg));
63                }
64                Err(e) => {
65                    error.set(Some(e.to_string()));
66                }
67            }
68
69            is_loading.set(false);
70            streaming_content.set(String::new());
71        });
72    }
73
74    /// Clear all messages
75    pub fn clear(&self) {
76        self.messages.set(Vec::new());
77        self.error.set(None);
78        self.streaming_content.set(String::new());
79    }
80
81    /// Stop the current generation
82    pub fn stop(&self) {
83        self.should_stop.set(true);
84    }
85}
86
87/// Create a reactive chat interface
88///
89/// # Example
90///
91/// ```rust,ignore
92/// let chat = use_chat(ChatOptions {
93///     provider: "openai".to_string(),
94///     api_key: "sk-...".to_string(),
95///     model: "gpt-4o-mini".to_string(),
96///     ..Default::default()
97/// });
98///
99/// // Send a message
100/// chat.send("Hello!");
101///
102/// // Access messages reactively
103/// let messages = chat.messages.get();
104/// ```
105pub fn use_chat(options: ChatOptions) -> UseChat {
106    let messages = RwSignal::new(options.initial_messages.clone());
107    let is_loading = RwSignal::new(false);
108    let error = RwSignal::new(None::<String>);
109    let streaming_content = RwSignal::new(String::new());
110    let should_stop = RwSignal::new(false);
111    let options_signal = RwSignal::new(options);
112
113    UseChat {
114        messages,
115        is_loading,
116        error,
117        streaming_content,
118        options: options_signal,
119        should_stop,
120    }
121}
122
123/// Return type for use_completion hook
124#[derive(Clone, Copy)]
125pub struct UseCompletion {
126    /// The completion result
127    pub completion: RwSignal<Option<String>>,
128    /// Whether a request is in progress
129    pub is_loading: RwSignal<bool>,
130    /// Current error, if any
131    pub error: RwSignal<Option<String>>,
132    /// Internal: options
133    options: RwSignal<CompletionOptions>,
134}
135
136impl UseCompletion {
137    /// Request a completion
138    pub fn complete(&self, prompt: &str) {
139        let opts = self.options.get_untracked();
140        let completion = self.completion;
141        let is_loading = self.is_loading;
142        let error = self.error;
143        let prompt = prompt.to_string();
144
145        is_loading.set(true);
146        error.set(None);
147        completion.set(None);
148
149        leptos::task::spawn_local(async move {
150            let messages = vec![ChatMessage::user(&prompt)];
151
152            let chat_opts = ChatOptions {
153                provider: opts.provider,
154                api_key: opts.api_key,
155                model: opts.model,
156                system_prompt: opts.system_prompt,
157                temperature: opts.temperature,
158                max_tokens: opts.max_tokens,
159                stream: false,
160                initial_messages: Vec::new(),
161            };
162
163            match send_request(&chat_opts, messages).await {
164                Ok(content) => {
165                    completion.set(Some(content));
166                }
167                Err(e) => {
168                    error.set(Some(e.to_string()));
169                }
170            }
171
172            is_loading.set(false);
173        });
174    }
175}
176
177/// Create a completion interface for single requests
178///
179/// # Example
180///
181/// ```rust,ignore
182/// let completion = use_completion(CompletionOptions {
183///     provider: "openai".to_string(),
184///     api_key: "sk-...".to_string(),
185///     model: "gpt-4o-mini".to_string(),
186///     ..Default::default()
187/// });
188///
189/// // Request a completion
190/// completion.complete("Translate 'hello' to Spanish");
191///
192/// // Access result reactively
193/// if let Some(result) = completion.completion.get() {
194///     println!("{}", result);
195/// }
196/// ```
197pub fn use_completion(options: CompletionOptions) -> UseCompletion {
198    let completion = RwSignal::new(None::<String>);
199    let is_loading = RwSignal::new(false);
200    let error = RwSignal::new(None::<String>);
201    let options_signal = RwSignal::new(options);
202
203    UseCompletion {
204        completion,
205        is_loading,
206        error,
207        options: options_signal,
208    }
209}
210
211/// Send a non-streaming request
212async fn send_request(options: &ChatOptions, messages: Vec<ChatMessage>) -> Result<String> {
213    let provider: Provider = options
214        .provider
215        .parse()
216        .map_err(|_| LeptosAiError::InvalidProvider(options.provider.clone()))?;
217
218    // Convert ChatMessage to llm-client Message
219    let mut llm_messages: Vec<Message> = Vec::new();
220
221    if let Some(ref system) = options.system_prompt {
222        llm_messages.push(Message::system(system));
223    }
224
225    for msg in &messages {
226        match msg.role.as_str() {
227            "user" => llm_messages.push(Message::user(&msg.content)),
228            "assistant" => llm_messages.push(Message::assistant(&msg.content)),
229            "system" => llm_messages.push(Message::system(&msg.content)),
230            _ => {}
231        }
232    }
233
234    // Build request using llm-client
235    let http_request = RequestBuilder::new(provider)
236        .model(&options.model)
237        .api_key(&options.api_key)
238        .messages(&llm_messages)
239        .temperature(options.temperature)
240        .max_tokens(options.max_tokens)
241        .stream(false)
242        .build()?;
243
244    // Send via web-sys fetch
245    let response = fetch(&http_request.url, &http_request.headers, &http_request.body).await?;
246
247    // Parse response using llm-client
248    let llm_response = ResponseParser::parse(provider, &response)?;
249
250    Ok(llm_response.content)
251}
252
253/// Send a streaming request
254async fn send_streaming_request(
255    options: &ChatOptions,
256    messages: Vec<ChatMessage>,
257    streaming_content: RwSignal<String>,
258    should_stop: RwSignal<bool>,
259) -> Result<String> {
260    let provider: Provider = options
261        .provider
262        .parse()
263        .map_err(|_| LeptosAiError::InvalidProvider(options.provider.clone()))?;
264
265    // Convert ChatMessage to llm-client Message
266    let mut llm_messages: Vec<Message> = Vec::new();
267
268    if let Some(ref system) = options.system_prompt {
269        llm_messages.push(Message::system(system));
270    }
271
272    for msg in &messages {
273        match msg.role.as_str() {
274            "user" => llm_messages.push(Message::user(&msg.content)),
275            "assistant" => llm_messages.push(Message::assistant(&msg.content)),
276            "system" => llm_messages.push(Message::system(&msg.content)),
277            _ => {}
278        }
279    }
280
281    // Build request using llm-client
282    let http_request = RequestBuilder::new(provider)
283        .model(&options.model)
284        .api_key(&options.api_key)
285        .messages(&llm_messages)
286        .temperature(options.temperature)
287        .max_tokens(options.max_tokens)
288        .stream(true)
289        .build()?;
290
291    // Send via web-sys fetch and get streaming response
292    let response =
293        fetch_stream(&http_request.url, &http_request.headers, &http_request.body).await?;
294
295    // Process SSE stream
296    let mut full_content = String::new();
297    let reader = response
298        .body()
299        .ok_or_else(|| LeptosAiError::StreamError("No response body".to_string()))?
300        .get_reader();
301
302    let reader: web_sys::ReadableStreamDefaultReader = reader.unchecked_into();
303
304    loop {
305        if should_stop.get_untracked() {
306            break;
307        }
308
309        let result = JsFuture::from(reader.read()).await;
310        let result = result.map_err(|e| LeptosAiError::StreamError(format!("{:?}", e)))?;
311
312        let done = js_sys::Reflect::get(&result, &JsValue::from_str("done"))
313            .map_err(|e| LeptosAiError::StreamError(format!("{:?}", e)))?
314            .as_bool()
315            .unwrap_or(true);
316
317        if done {
318            break;
319        }
320
321        let value = js_sys::Reflect::get(&result, &JsValue::from_str("value"))
322            .map_err(|e| LeptosAiError::StreamError(format!("{:?}", e)))?;
323
324        let array = js_sys::Uint8Array::new(&value);
325        let bytes = array.to_vec();
326        let text = String::from_utf8_lossy(&bytes);
327
328        // Parse SSE lines
329        for line in text.lines() {
330            if let Ok(Some(chunk)) = ResponseParser::parse_stream_line(provider, line) {
331                if let Some(content) = chunk.content {
332                    full_content.push_str(&content);
333                    streaming_content.set(full_content.clone());
334                }
335                if chunk.done {
336                    break;
337                }
338            }
339        }
340    }
341
342    Ok(full_content)
343}
344
345/// Fetch helper using web-sys
346async fn fetch(url: &str, headers: &[(String, String)], body: &str) -> Result<String> {
347    let opts = RequestInit::new();
348    opts.set_method("POST");
349    opts.set_mode(RequestMode::Cors);
350    opts.set_body(&JsValue::from_str(body));
351
352    let js_headers =
353        web_sys::Headers::new().map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
354
355    for (key, value) in headers {
356        js_headers
357            .set(key, value)
358            .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
359    }
360    opts.set_headers(&js_headers);
361
362    let request = Request::new_with_str_and_init(url, &opts)
363        .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
364
365    let window =
366        web_sys::window().ok_or_else(|| LeptosAiError::RequestFailed("No window".to_string()))?;
367
368    let resp_value = JsFuture::from(window.fetch_with_request(&request))
369        .await
370        .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
371
372    let resp: Response = resp_value
373        .dyn_into()
374        .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
375
376    if !resp.ok() {
377        let status = resp.status();
378        let text = JsFuture::from(
379            resp.text()
380                .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?,
381        )
382        .await
383        .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?
384        .as_string()
385        .unwrap_or_default();
386        return Err(LeptosAiError::ApiError(format!(
387            "HTTP {}: {}",
388            status, text
389        )));
390    }
391
392    let text = JsFuture::from(
393        resp.text()
394            .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?,
395    )
396    .await
397    .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?
398    .as_string()
399    .unwrap_or_default();
400
401    Ok(text)
402}
403
404/// Fetch helper for streaming responses
405async fn fetch_stream(url: &str, headers: &[(String, String)], body: &str) -> Result<Response> {
406    let opts = RequestInit::new();
407    opts.set_method("POST");
408    opts.set_mode(RequestMode::Cors);
409    opts.set_body(&JsValue::from_str(body));
410
411    let js_headers =
412        web_sys::Headers::new().map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
413
414    for (key, value) in headers {
415        js_headers
416            .set(key, value)
417            .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
418    }
419    opts.set_headers(&js_headers);
420
421    let request = Request::new_with_str_and_init(url, &opts)
422        .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
423
424    let window =
425        web_sys::window().ok_or_else(|| LeptosAiError::RequestFailed("No window".to_string()))?;
426
427    let resp_value = JsFuture::from(window.fetch_with_request(&request))
428        .await
429        .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
430
431    let resp: Response = resp_value
432        .dyn_into()
433        .map_err(|e| LeptosAiError::RequestFailed(format!("{:?}", e)))?;
434
435    if !resp.ok() {
436        let status = resp.status();
437        return Err(LeptosAiError::ApiError(format!("HTTP {}", status)));
438    }
439
440    Ok(resp)
441}