1use axum::{
7 extract::State,
8 response::{
9 sse::{Event, Sse},
10 IntoResponse, Response,
11 },
12 http::StatusCode,
13 Json,
14};
15use serde::Deserialize;
16use std::convert::Infallible;
17use tracing::{info, error, debug, warn};
18use serde_json::Value;
19use reqwest;
20use std::sync::Arc;
21
22use crate::memory::Message;
23use crate::memory_db::schema::Embedding;
24use crate::shared_state::UnifiedAppState;
25use crate::utils::{extract_content_from_bytes, estimate_tokens, truncate_to_budget, is_extraction_sentinel};
26use crate::cache_management::cache_scorer::score_message_importance;
27use regex::Regex;
28
29lazy_static::lazy_static! {
30 static ref ATTACHED_RE: Regex = Regex::new(r"\[Attached: ([^\]]+)\]").unwrap();
32 static ref AT_FILE_RE: Regex = Regex::new(r"@(\S+\.\w+)").unwrap();
35}
36
37#[derive(Debug, Clone, Deserialize)]
49pub struct ChatAttachment {
50 pub name: String,
51 pub source: String,
53 #[serde(default)]
55 pub file_path: Option<String>,
56 #[serde(default)]
58 pub all_files_id: Option<i64>,
59 #[serde(default)]
60 pub size_bytes: Option<i64>,
61}
62
63#[derive(Debug, Deserialize)]
65pub struct StreamChatRequest {
66 pub model: Option<String>,
67 pub model_source: Option<String>, pub messages: Vec<Message>,
69 pub session_id: String,
70 #[serde(default = "default_max_tokens")]
71 pub max_tokens: u32,
72 #[serde(default = "default_temperature")]
73 pub temperature: f32,
74 #[serde(default = "default_stream")]
75 pub stream: bool,
76 #[serde(default)]
78 pub attachments: Option<Vec<ChatAttachment>>,
79 #[serde(default)]
81 pub api_key: Option<String>,
82}
83
84fn default_max_tokens() -> u32 { 2000 }
85fn default_temperature() -> f32 { 0.7 }
86fn default_stream() -> bool { true }
87
88const MAX_ATTACHMENT_TOKENS: usize = 64_000;
91
92
93async fn try_extract_attachment(
100 attach: &ChatAttachment,
101 state: &UnifiedAppState,
102) -> Result<(String, String), String> {
103 match attach.source.as_str() {
104 "inline" => {
105 let path = attach.file_path.as_deref().ok_or_else(|| {
106 format!("'{}': no file path provided. Use the paperclip button to attach files.", attach.name)
107 })?;
108
109 info!("Reading inline file: {} ({})", attach.name, path);
110 let bytes = tokio::fs::read(path).await.map_err(|e| {
111 format!(
112 "Could not read '{}': {}.\n\nMake sure the file exists and is not stored only in the cloud (OneDrive, iCloud, etc.).",
113 attach.name, e
114 )
115 })?;
116
117 info!("Read {} bytes from '{}'", bytes.len(), attach.name);
118 let content = extract_content_from_bytes(&bytes, &attach.name)
119 .await
120 .map_err(|e| format!("Could not parse '{}': {}", attach.name, e))?;
121
122 if is_extraction_sentinel(&content) {
125 return Err(if attach.name.to_lowercase().ends_with(".pdf") {
126 format!(
127 "Could not extract text from '{}'.\n\nThe PDF is likely image-based (scanned) or password-protected. \
128 Try one of:\n • Export/re-save as a text-based PDF\n • Attach a DOCX version\n • Paste the text directly into the chat",
129 attach.name
130 )
131 } else {
132 format!(
133 "Could not extract text from '{}'.\n\nThe file may be corrupted or in an unsupported format. \
134 Try a different format, or paste the content directly into the chat.",
135 attach.name
136 )
137 });
138 }
139
140 if content.trim().is_empty() {
141 return Err(format!("'{}' appears to be empty — no text content found.", attach.name));
142 }
143
144 info!("Extracted {} chars from '{}'", content.len(), attach.name);
145 Ok((attach.name.clone(), content))
146 }
147
148 "local_storage" => {
149 let id = attach.all_files_id.ok_or_else(|| {
150 format!("'{}': no database ID provided for local storage attachment.", attach.name)
151 })?;
152
153 info!("Reading local_storage file: {} (id={})", attach.name, id);
154 let all_files = &state.shared_state.database_pool.all_files;
155
156 let bytes = all_files.get_file_bytes(id).map_err(|e| {
159 format!(
160 "Could not read '{}' from local storage: {}.\n\nTry re-adding the file through the local storage panel.",
161 attach.name, e
162 )
163 })?;
164
165 info!("Read {} bytes from local_storage '{}'", bytes.len(), attach.name);
166
167 let content = extract_content_from_bytes(&bytes, &attach.name)
168 .await
169 .map_err(|e| format!("Could not parse '{}': {}", attach.name, e))?;
170
171 if is_extraction_sentinel(&content) {
173 return Err(if attach.name.to_lowercase().ends_with(".pdf") {
174 format!(
175 "Could not extract text from '{}'.\n\nThe PDF is likely image-based (scanned) or password-protected. \
176 Try one of:\n • Export/re-save as a text-based PDF\n • Attach a DOCX version\n • Paste the text directly into the chat",
177 attach.name
178 )
179 } else {
180 format!(
181 "Could not extract text from '{}'.\n\nThe file may be corrupted or in an unsupported format. \
182 Try a different format, or paste the content directly into the chat.",
183 attach.name
184 )
185 });
186 }
187
188 let _ = all_files.record_access(id);
189
190 if content.trim().is_empty() {
191 return Err(format!("'{}' from local storage appears to be empty.", attach.name));
192 }
193
194 info!("Extracted {} chars from local_storage '{}'", content.len(), attach.name);
195 Ok((attach.name.clone(), content))
196 }
197
198 other => Err(format!(
199 "'{}': unknown attachment source '{}'. Use the paperclip button (inline) or the local storage panel to attach files.",
200 attach.name, other
201 )),
202 }
203}
204
205fn inject_attachment_contents(messages: &mut Vec<Message>, contents: Vec<(String, String)>) {
208 let total_tokens: usize = contents.iter().map(|(_, c)| estimate_tokens(c)).sum();
209 info!("Attachment total: {} tokens across {} file(s)", total_tokens, contents.len());
210
211 let final_contents: Vec<(String, String)> = if total_tokens > MAX_ATTACHMENT_TOKENS {
212 let budget_per_file = MAX_ATTACHMENT_TOKENS / contents.len().max(1);
213 info!("Applying 64k budget: {} tokens/file", budget_per_file);
214 contents.into_iter().map(|(name, content)| {
215 let (truncated, was_cut) = truncate_to_budget(&content, budget_per_file);
216 let final_content = if was_cut {
217 let original_tokens = estimate_tokens(&content);
218 format!(
219 "{}\n[File truncated: showing first ~{} tokens of ~{} total]",
220 truncated, budget_per_file, original_tokens
221 )
222 } else {
223 truncated
224 };
225 (name, final_content)
226 }).collect()
227 } else {
228 contents
229 };
230
231 let mut block = String::new();
232 for (name, content) in &final_contents {
233 block.push_str(&format!(
234 "\n--- Content of attached file: {} ---\n{}\n--- End of file ---\n",
235 name, content
236 ));
237 }
238
239 if let Some(last_user) = messages.iter_mut().rev().find(|m| m.role == "user") {
240 info!("Injecting {} chars of attachment content into user message", block.len());
241 last_user.content = format!("{}\n{}", last_user.content, block);
242 } else {
243 error!("No user message found to inject attachment content into!");
244 }
245}
246
247async fn process_file_attachments(
250 messages: &mut Vec<Message>,
251 state: &UnifiedAppState,
252) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
253 let local_files = &state.shared_state.database_pool.local_files;
254
255 for msg in messages.iter_mut() {
256 if msg.role == "user" {
257 let original = msg.content.clone();
260 let mut updated_content = original.clone();
261
262 for cap in ATTACHED_RE.captures_iter(&original) {
264 if let Some(m) = cap.get(1) {
265 let filename = m.as_str();
266 updated_content = replace_file_reference(
267 &updated_content,
268 &format!("[Attached: {}]", filename),
269 filename,
270 local_files,
271 ).await;
272 }
273 }
274
275 for cap in AT_FILE_RE.captures_iter(&original) {
277 if let Some(m) = cap.get(1) {
278 let filename = m.as_str();
279 updated_content = replace_file_reference(
280 &updated_content,
281 &format!("@{}", filename),
282 filename,
283 local_files,
284 ).await;
285 }
286 }
287
288 msg.content = updated_content;
289 }
290 }
291
292 Ok(())
293}
294
295async fn replace_file_reference(
300 content: &str,
301 marker: &str,
302 filename: &str,
303 local_files: &crate::memory_db::LocalFilesStore,
304) -> String {
305 match local_files.get_file_by_name(filename) {
307 Ok(file) => {
308 match local_files.get_file_content(file.id) {
310 Ok(bytes) => {
311 match extract_content_from_bytes(&bytes, filename).await {
312 Ok(file_content) if !file_content.trim().is_empty() => {
313 let attachment_text = format!(
314 "\n--- Content of file: {} ---\n{}\n--- End of file ---\n",
315 filename, file_content
316 );
317 content.replace(marker, &attachment_text)
318 }
319 _ => {
320 let error_text = format!(
321 "\n[Note: Could not extract text from '{}'. The file may be in an unsupported format.]",
322 filename
323 );
324 content.replace(marker, &error_text)
325 }
326 }
327 }
328 Err(_) => {
329 let error_text = format!(
330 "\n[Note: Could not read file '{}'. File may be missing or corrupted.]",
331 filename
332 );
333 content.replace(marker, &error_text)
334 }
335 }
336 }
337 Err(_) => {
338 let app_data_dir = dirs::data_dir()
340 .unwrap_or_else(|| std::path::PathBuf::from("."))
341 .join("Aud.io");
342 let file_path = app_data_dir.join(filename);
343
344 match crate::utils::extract_file_content(&file_path).await {
345 Ok(file_content) => {
346 let attachment_text = format!(
347 "\n--- Content of file: {} ---\n{}\n--- End of file ---\n",
348 filename, file_content
349 );
350 content.replace(marker, &attachment_text)
351 }
352 Err(_) => {
353 let error_text = format!(
354 "\n[Note: File '{}' not found in local files. Upload it first or check the filename.]",
355 filename
356 );
357 content.replace(marker, &error_text)
358 }
359 }
360 }
361 }
362}
363
364pub async fn generate_stream(
371 State(state): State<UnifiedAppState>,
372 Json(req): Json<StreamChatRequest>,
373) -> Response {
374 let request_num = state.shared_state.counters.inc_total_requests();
375 info!("Stream request #{} for session: {}", request_num, req.session_id);
376
377 if let Some(ref attachments) = req.attachments {
379 info!("Request has {} attachment(s)", attachments.len());
380 for (i, att) in attachments.iter().enumerate() {
381 info!(
382 " Attachment {}: name={}, source={}, file_path={}, all_files_id={}",
383 i, att.name,
384 att.source,
385 att.file_path.as_deref().unwrap_or("(none)"),
386 att.all_files_id.map(|id| id.to_string()).unwrap_or_else(|| "(none)".to_string()),
387 );
388 }
389 } else {
390 debug!("Request has no attachments");
391 }
392
393 if req.messages.is_empty() {
394 return (StatusCode::BAD_REQUEST, "Messages array cannot be empty").into_response();
395 }
396
397 let session_id = req.session_id.clone();
398
399 let mut processed_messages = req.messages.clone();
401
402 if let Err(e) = process_file_attachments(&mut processed_messages, &state).await {
407 error!("Error processing legacy file text references: {}", e);
408 }
410
411 if let Some(ref attachments) = req.attachments {
416 if !attachments.is_empty() {
417 let mut extracted: Vec<(String, String)> = Vec::with_capacity(attachments.len());
418 let mut errors: Vec<String> = Vec::new();
419
420 for attach in attachments {
421 let cache_key = crate::api::attachment_api::attachment_cache_key(attach);
426 if let Some((_, cached)) = state.shared_state.attachment_cache.remove(&cache_key) {
427 if !cached.is_stale(crate::api::attachment_api::CACHE_TTL_SECS) {
428 if is_extraction_sentinel(&cached.text) {
432 info!("Cached sentinel for '{}' — treating as miss, re-extracting", attach.name);
433 } else {
434 info!("Attachment cache hit for '{}' — skipping extraction", attach.name);
435 extracted.push((attach.name.clone(), cached.text));
436 continue;
437 }
438 } else {
439 info!("Stale cache entry for '{}' — re-extracting", attach.name);
440 }
441 }
442
443 match try_extract_attachment(attach, &state).await {
446 Ok(content) => extracted.push(content),
447 Err(e) => {
448 warn!("Attachment extraction failed for '{}': {}", attach.name, e);
449 errors.push(e);
450 }
451 }
452 }
453
454 if !errors.is_empty() {
455 let error_msg = errors.join("\n\n");
456 error!("Rejecting request — {} attachment(s) could not be processed", errors.len());
457 return (StatusCode::UNPROCESSABLE_ENTITY, error_msg).into_response();
458 }
459
460 inject_attachment_contents(&mut processed_messages, extracted);
461 }
462 }
463
464 {
468 let db = &state.shared_state.database_pool;
469
470 if let Some(ref attachments) = req.attachments {
472 if !attachments.is_empty() {
473 let refs: Vec<crate::memory_db::AttachmentRef<'_>> = attachments
474 .iter()
475 .map(|a| crate::memory_db::AttachmentRef {
476 name: &a.name,
477 source: &a.source,
478 file_path: a.file_path.as_deref(),
479 all_files_id: a.all_files_id,
480 size_bytes: a.size_bytes,
481 })
482 .collect();
483 if let Err(e) = db.session_file_contexts.store_attachments(&session_id, &refs) {
484 warn!("Failed to persist session file context references: {}", e);
485 }
486 }
487 }
488
489 match db.session_file_contexts.get_for_session(&session_id) {
491 Ok(historical) if !historical.is_empty() => {
492 let current_names: std::collections::HashSet<&str> = req
494 .attachments
495 .as_ref()
496 .map(|a| a.iter().map(|att| att.name.as_str()).collect())
497 .unwrap_or_default();
498
499 let prior: Vec<_> = historical
500 .iter()
501 .filter(|h| !current_names.contains(h.file_name.as_str()))
502 .collect();
503
504 if !prior.is_empty() {
505 info!(
506 "Re-injecting {} historical file(s) as persistent context for session {}",
507 prior.len(),
508 session_id
509 );
510
511 const HIST_BUDGET: usize = 32_000;
513 let budget_per_file = HIST_BUDGET / prior.len().max(1);
514
515 let mut context_block = String::from(
516 "Files previously shared in this conversation (always available as context):\n",
517 );
518
519 for hist in &prior {
520 let chat_att = ChatAttachment {
521 name: hist.file_name.clone(),
522 source: hist.source.clone(),
523 file_path: hist.file_path.clone(),
524 all_files_id: hist.all_files_id,
525 size_bytes: hist.size_bytes,
526 };
527 match try_extract_attachment(&chat_att, &state).await {
528 Ok((name, content)) => {
529 let (truncated, was_cut) =
530 truncate_to_budget(&content, budget_per_file);
531 context_block.push_str(&format!(
532 "\n--- {} ---\n{}{}\n--- end of {} ---\n",
533 name,
534 truncated,
535 if was_cut { "\n[... file truncated for context ...]" } else { "" },
536 name
537 ));
538 }
539 Err(e) => {
540 warn!(
542 "Could not re-read historical attachment '{}': {}",
543 hist.file_name, e
544 );
545 }
546 }
547 }
548
549 if context_block.len() > 80 {
550 if let Some(first) = processed_messages.first_mut() {
552 if first.role == "system" {
553 first.content.push_str(&format!("\n\n{}", context_block));
554 } else {
555 processed_messages.insert(
556 0,
557 crate::memory::Message {
558 role: "system".to_string(),
559 content: context_block.clone(),
560 },
561 );
562 }
563 } else {
564 processed_messages.insert(
565 0,
566 crate::memory::Message {
567 role: "system".to_string(),
568 content: context_block.clone(),
569 },
570 );
571 }
572 info!(
573 "Injected {} chars of persistent file context for session {}",
574 context_block.len(),
575 session_id
576 );
577 }
578 }
579 }
580 Ok(_) => {} Err(e) => {
582 warn!("Could not retrieve session file contexts: {}", e);
583 }
584 }
585 }
586
587 let session = state.shared_state.get_or_create_session(&session_id).await;
589
590 {
592 if let Ok(mut session_data) = session.write() {
593 session_data.last_accessed = std::time::Instant::now();
594 session_data.messages = processed_messages.clone();
595 }
596 }
597
598 let user_msg_content = req.messages.iter().rev().find(|m| m.role == "user").map(|m| m.content.clone());
606 if let Some(ref content) = user_msg_content {
607 let db = state.shared_state.database_pool.clone();
608 let sid = session_id.clone();
609 let content = content.clone();
610 let msg_count = processed_messages.len() as i32;
611
612 if let Err(e) = db.conversations.create_session_with_id(&sid, None) {
615 debug!("Session creation result (may already exist): {}", e);
617 }
618
619 tokio::spawn(async move {
621 if let Err(e) = db.conversations.store_messages_batch(
622 &sid,
623 &[("user".to_string(), content.clone(), msg_count - 1, 0, score_message_importance("user", &content))],
624 ) {
625 error!("Failed to persist user message: {}", e);
626 }
627 });
628 }
629
630 let context_messages = {
636 let orchestrator_guard = state.context_orchestrator.read().await;
637 if let Some(ref orchestrator) = *orchestrator_guard {
638 let user_query = user_msg_content.as_deref();
639 match orchestrator.process_conversation(&session_id, &processed_messages, user_query).await {
640 Ok(optimized) => {
641 if optimized.len() != processed_messages.len() {
642 info!("Context engine optimized: {} → {} messages (retrieved past context)",
643 processed_messages.len(), optimized.len());
644 }
645 optimized
646 }
647 Err(e) => {
648 error!("Context engine error (falling back to raw messages): {}", e);
649 processed_messages.clone()
650 }
651 }
652 } else {
653 debug!("Context orchestrator not initialized, using raw messages");
654 processed_messages.clone()
655 }
656 };
657
658 let max_tokens = req.max_tokens;
660 let temperature = req.temperature;
661 let db_for_persist = state.shared_state.database_pool.clone();
662 let session_id_for_persist = session_id.clone();
663 let msg_index = req.messages.len() as i32;
664
665 let db_for_embed_persist = state.shared_state.database_pool.clone();
667 let session_id_for_embed = session_id.clone();
668 let user_msg_for_embed = user_msg_content.clone();
669
670 let is_online_model = req.model_source.as_deref() == Some("openrouter");
672
673 if is_online_model {
674 let api_key = req.api_key.clone().unwrap_or_else(|| {
677 std::env::var("OPENROUTER_API_KEY").unwrap_or_else(|_| {
678 state.shared_state.config.openrouter_api_key.clone()
680 })
681 });
682
683 if api_key.is_empty() {
684 return (StatusCode::UNAUTHORIZED, "OpenRouter API key not configured").into_response();
685 }
686
687 let model_id = req.model.unwrap_or_else(|| "openrouter/auto".to_string());
688 let openrouter_messages = context_messages.iter().map(|m| {
689 serde_json::json!({
690 "role": m.role,
691 "content": m.content
692 })
693 }).collect::<Vec<_>>();
694
695 let openrouter_request = serde_json::json!({
696 "model": model_id,
697 "messages": openrouter_messages,
698 "max_tokens": max_tokens,
699 "temperature": temperature,
700 "stream": true,
701 });
702
703 match stream_openrouter_response(api_key, openrouter_request, session_id_for_persist.clone(), db_for_persist.clone(), context_messages.clone(), user_msg_for_embed.clone(), db_for_embed_persist.clone(), session_id_for_embed.clone(), state.http_client.clone()).await {
704 Ok(openrouter_stream) => {
705 let output_stream = async_stream::stream! {
707 let mut full_response = String::new();
708
709 futures_util::pin_mut!(openrouter_stream);
710
711 while let Some(item) = tokio_stream::StreamExt::next(&mut openrouter_stream).await {
712 match item {
713 Ok(sse_line) => {
714 if sse_line.starts_with("data: ") && !sse_line.contains("[DONE]") {
716 if let Ok(chunk) = serde_json::from_str::<serde_json::Value>(&sse_line[6..].trim()) {
717 if let Some(content) = chunk
718 .get("choices")
719 .and_then(|c| c.get(0))
720 .and_then(|c| c.get("delta"))
721 .and_then(|d| d.get("content"))
722 .and_then(|c| c.as_str())
723 {
724 full_response.push_str(content);
725 }
726 }
727 }
728
729 let data = sse_line.trim_start_matches("data: ").trim_end().to_string();
731 yield Ok::<_, Infallible>(Event::default().data(data));
732 }
733 Err(e) => {
734 error!("OpenRouter stream error: {}", e);
735 yield Ok(Event::default().data(
736 format!("{{\"error\": \"{}\"}}", e)
737 ));
738 break;
739 }
740 }
741 }
742
743 if !full_response.is_empty() {
745 let importance = score_message_importance("assistant", &full_response);
746 match db_for_persist.conversations.store_messages_batch(
747 &session_id_for_persist,
748 &[("assistant".to_string(), full_response.clone(), msg_index, 0, importance)],
749 ) {
750 Ok(stored_msgs) => {
751 debug!("Persisted assistant response ({} chars) for session {}",
752 full_response.len(), session_id_for_persist);
753 }
754 Err(e) => {
755 error!("Failed to persist assistant message: {}", e);
756 }
757 }
758 }
759 };
760
761 Sse::new(output_stream)
762 .keep_alive(
763 axum::response::sse::KeepAlive::new()
764 .interval(std::time::Duration::from_secs(15))
765 )
766 .into_response()
767 }
768 Err(e) => {
769 error!("Failed to start OpenRouter stream: {}", e);
770 let json_body = build_openrouter_error_json(&e.to_string());
771 (StatusCode::BAD_GATEWAY, axum::Json(json_body)).into_response()
772 }
773 }
774 } else {
775 let runtime_ready = state.llm_worker.is_runtime_ready().await;
778 info!("Offline mode: runtime_ready check = {}", runtime_ready);
779
780 if !runtime_ready {
781 info!("Model not ready - returning error");
782 return (StatusCode::SERVICE_UNAVAILABLE,
783 "Model Not Ready: No local model is currently loaded. Please go to the Models page and activate a model by clicking \"Active Model\".").into_response();
784 }
785
786 let llm_worker = state.llm_worker.clone();
787 let llm_worker_for_embed = state.llm_worker.clone();
788
789 match llm_worker.stream_response(context_messages, max_tokens, temperature).await {
790 Ok(llm_stream) => {
791 let output_stream = async_stream::stream! {
793 let mut full_response = String::new();
794
795 futures_util::pin_mut!(llm_stream);
796
797 while let Some(item) = tokio_stream::StreamExt::next(&mut llm_stream).await {
798 match item {
799 Ok(sse_line) => {
800 if sse_line.starts_with("data: ") && !sse_line.contains("[DONE]") {
802 if let Ok(chunk) = serde_json::from_str::<serde_json::Value>(&sse_line[6..].trim()) {
803 if let Some(content) = chunk
804 .get("choices")
805 .and_then(|c| c.get(0))
806 .and_then(|c| c.get("delta"))
807 .and_then(|d| d.get("content"))
808 .and_then(|c| c.as_str())
809 {
810 full_response.push_str(content);
811 }
812 }
813 }
814
815 let data = sse_line.trim_start_matches("data: ").trim_end().to_string();
817 yield Ok::<_, Infallible>(Event::default().data(data));
818 }
819 Err(e) => {
820 error!("Stream error: {}", e);
821 yield Ok(Event::default().data(
822 format!("{{\"error\": \"{}\"}}", e)
823 ));
824 break;
825 }
826 }
827 }
828
829 if !full_response.is_empty() {
831 let importance = score_message_importance("assistant", &full_response);
832 match db_for_persist.conversations.store_messages_batch(
833 &session_id_for_persist,
834 &[("assistant".to_string(), full_response.clone(), msg_index, 0, importance)],
835 ) {
836 Ok(_stored_msgs) => {
837 debug!("Persisted assistant response ({} chars) for session {}",
838 full_response.len(), session_id_for_persist);
839
840 let llm_for_embed = llm_worker_for_embed.clone();
844 let db_for_embed = db_for_embed_persist.clone();
845 let assistant_content = full_response.clone();
846 let user_content_for_embed = user_msg_for_embed.clone();
847 let stored = _stored_msgs;
848
849 tokio::spawn(async move {
850 let mut texts = Vec::new();
852 let mut message_ids = Vec::new();
853
854 if let Some(ref user_text) = user_content_for_embed {
856 if let Ok(msgs) = db_for_embed.search_messages_by_keywords(
859 &session_id_for_embed,
860 &[user_text.clone()],
861 1,
862 ).await {
863 if let Some(user_stored) = msgs.first() {
864 texts.push(user_text.clone());
865 message_ids.push(user_stored.id);
866 }
867 }
868 }
869
870 if let Some(assistant_stored) = stored.first() {
872 texts.push(assistant_content);
873 message_ids.push(assistant_stored.id);
874 }
875
876 if texts.is_empty() {
877 return;
878 }
879
880 match llm_for_embed.generate_embeddings(texts).await {
882 Ok(embeddings) => {
883 let now = chrono::Utc::now();
884 for (embedding_vec, msg_id) in embeddings.into_iter().zip(message_ids.iter()) {
885 let emb = Embedding {
886 id: 0, message_id: *msg_id,
888 embedding: embedding_vec,
889 embedding_model: "llama-server".to_string(),
890 generated_at: now,
891 };
892 if let Err(e) = db_for_embed.embeddings.store_embedding(&emb) {
893 debug!("Failed to store embedding for msg {}: {}", msg_id, e);
894 }
895 }
896 for msg_id in &message_ids {
898 let _ = db_for_embed.conversations.mark_embedding_generated(*msg_id);
899 }
900 debug!("Stored {} embeddings for session {}", message_ids.len(), session_id_for_embed);
901 }
902 Err(e) => {
903 debug!("Embedding generation skipped (llama-server may not support /v1/embeddings): {}", e);
904 }
905 }
906 });
907 }
908 Err(e) => {
909 error!("Failed to persist assistant message: {}", e);
910 }
911 }
912 }
913 };
914
915 return Sse::new(output_stream)
916 .keep_alive(
917 axum::response::sse::KeepAlive::new()
918 .interval(std::time::Duration::from_secs(15))
919 )
920 .into_response();
921 }
922 Err(e) => {
923 let error_msg = format!("{}", e);
924 error!("Failed to start LLM stream: {}", error_msg);
925
926 let (status_code, user_message) = if error_msg.contains("Cannot connect") || error_msg.contains("Connection refused") {
928 (
929 StatusCode::SERVICE_UNAVAILABLE,
930 "Local LLM server is not running. Please ensure:\\n\\n1. An engine is installed (Settings > Engines)\\n2. A model is downloaded and loaded (Settings > Models)\\n3. The engine has finished initializing".to_string()
931 )
932 } else if error_msg.contains("not found") || error_msg.contains("No such file") {
933 (
934 StatusCode::NOT_FOUND,
935 "Model or engine binary not found. Please:\\n\\n1. Download an engine (Settings > Engines)\\n2. Download a model (Settings > Models)\\n3. Wait for initialization to complete".to_string()
936 )
937 } else if error_msg.contains("timeout") || error_msg.contains("timed out") {
938 (
939 StatusCode::GATEWAY_TIMEOUT,
940 "LLM server connection timed out. The engine may be still initializing. Please wait a moment and try again.".to_string()
941 )
942 } else {
943 (
944 StatusCode::BAD_GATEWAY,
945 format!("LLM backend error: {}\\n\\nPlease check that:\\n1. Engine is installed\\n2. Model is loaded\\n3. Engine is running", error_msg)
946 )
947 };
948
949 return (status_code, user_message).into_response();
950 }
951 }
952 }
953}
954
955async fn stream_openrouter_response(
957 api_key: String,
958 request_body: Value,
959 session_id: String,
960 _db_for_persist: Arc<crate::memory_db::MemoryDatabase>,
961 _context_messages: Vec<crate::memory::Message>,
962 _user_msg_for_embed: Option<String>,
963 _db_for_embed_persist: Arc<crate::memory_db::MemoryDatabase>,
964 _session_id_for_embed: String,
965 client: reqwest::Client,
966) -> Result<
967 std::pin::Pin<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send>>,
968 anyhow::Error
969> {
970
971 let response = client
972 .post("https://openrouter.ai/api/v1/chat/completions")
973 .header("Authorization", format!("Bearer {}", api_key))
974 .header("Content-Type", "application/json")
975 .header("HTTP-Referer", "https://aud.io")
976 .header("X-Title", "Aud.io")
977 .json(&request_body)
978 .send()
979 .await
980 .map_err(|e| anyhow::anyhow!("OpenRouter request failed: {}", e))?;
981
982 if !response.status().is_success() {
983 let status = response.status();
984 let body = response.text().await.unwrap_or_default();
985 return Err(anyhow::anyhow!("OpenRouter returned {}: {}", status, body));
986 }
987
988 let byte_stream = response.bytes_stream();
989
990 let sse_stream = async_stream::try_stream! {
991 let mut buffer = String::new();
992
993 futures_util::pin_mut!(byte_stream);
994
995 while let Some(chunk_result) = tokio_stream::StreamExt::next(&mut byte_stream).await {
996 let chunk: bytes::Bytes = chunk_result
997 .map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
998
999 buffer.push_str(&String::from_utf8_lossy(&chunk));
1000
1001 while let Some(newline_pos) = buffer.find('\n') {
1002 let line = buffer[..newline_pos].trim().to_string();
1003 buffer.drain(..=newline_pos);
1006
1007 if line.is_empty() {
1008 continue;
1009 }
1010
1011 if line.starts_with("data: ") {
1012 let data = &line[6..];
1013
1014 if data == "[DONE]" {
1015 yield "data: [DONE]\n\n".to_string();
1016 return;
1017 }
1018
1019 match serde_json::from_str::<Value>(data) {
1020 Ok(chunk) => {
1021 let finished = chunk
1026 .get("choices")
1027 .and_then(|c| c.as_array())
1028 .map(|arr| arr.iter().any(|choice| {
1029 choice.get("finish_reason")
1030 .and_then(|fr| fr.as_str())
1031 .map(|fr| !fr.is_empty())
1032 .unwrap_or(false)
1033 }))
1034 .unwrap_or(false);
1035
1036 yield format!("data: {}\n\n", data);
1037
1038 if finished {
1039 yield "data: [DONE]\n\n".to_string();
1040 return;
1041 }
1042 }
1043 Err(_) => {
1044 yield format!("data: {}\n\n", data);
1045 }
1046 }
1047 }
1048 }
1049 }
1050
1051 yield "data: [DONE]\n\n".to_string();
1052 };
1053
1054 Ok(Box::pin(sse_stream))
1055}
1056
1057fn build_openrouter_error_json(err_str: &str) -> serde_json::Value {
1061 if let Some(brace_pos) = err_str.find('{') {
1063 let raw_body = &err_str[brace_pos..];
1064 if let Ok(v) = serde_json::from_str::<serde_json::Value>(raw_body) {
1065 let msg = v.get("error")
1066 .and_then(|e| e.get("message"))
1067 .and_then(|m| m.as_str())
1068 .unwrap_or("OpenRouter returned an error");
1069 let code = v.get("error")
1070 .and_then(|e| e.get("code"))
1071 .and_then(|c| c.as_u64())
1072 .unwrap_or(0) as u16;
1073 let (error_type, user_message) = classify_openrouter_error(code, msg);
1074 return serde_json::json!({
1075 "error_type": error_type,
1076 "message": user_message,
1077 });
1078 }
1079 }
1080 serde_json::json!({
1082 "error_type": "generic",
1083 "message": "OpenRouter returned an error. Please try again or switch to a different model.",
1084 })
1085}
1086
1087fn classify_openrouter_error(code: u16, msg: &str) -> (&'static str, String) {
1090 let m = msg.to_lowercase();
1091 if code == 402
1092 || m.contains("credit")
1093 || m.contains("insufficient")
1094 || m.contains("balance")
1095 || m.contains("billing")
1096 || m.contains("quota")
1097 {
1098 (
1099 "insufficient_credits",
1100 "Your OpenRouter account has insufficient credits to process this request.".to_string(),
1101 )
1102 } else if (code == 400 || code == 413)
1103 && (m.contains("context")
1104 || m.contains("too long")
1105 || m.contains("token")
1106 || m.contains("length"))
1107 {
1108 (
1109 "context_exceeded",
1110 "This conversation exceeds the model's context limit. Try a shorter message or switch to a model with a larger context window.".to_string(),
1111 )
1112 } else if code == 429
1113 || m.contains("rate limit")
1114 || m.contains("rate_limit")
1115 || m.contains("too many request")
1116 {
1117 (
1118 "rate_limit",
1119 "Rate limit exceeded for this model. Please wait a moment and try again, or switch to a different model.".to_string(),
1120 )
1121 } else if code == 401
1122 || (m.contains("invalid") && (m.contains("key") || m.contains("api")))
1123 || m.contains("unauthorized")
1124 || m.contains("authentication")
1125 {
1126 (
1127 "invalid_key",
1128 "Your OpenRouter API key is invalid or expired. Please update it in the Models page.".to_string(),
1129 )
1130 } else if m.contains("not enabled")
1131 || m.contains("developer instruction")
1132 || m.contains("not supported")
1133 || (m.contains("invalid request") && (m.contains("model") || m.contains("instruction")))
1134 {
1135 (
1136 "model_restriction",
1137 "This model has a restriction that prevents it from being used with this request. Try switching to a different model.".to_string(),
1138 )
1139 } else {
1140 ("generic", format!("OpenRouter error: {}", msg))
1141 }
1142}