offline_intelligence/api/
stream_api.rs1use axum::{
7 extract::State,
8 response::{
9 sse::{Event, Sse},
10 IntoResponse, Response,
11 },
12 http::StatusCode,
13 Json,
14};
15use futures_util::StreamExt;
16use serde::Deserialize;
17use std::convert::Infallible;
18use tracing::{info, error, debug};
19
20use crate::memory::Message;
21use crate::memory_db::schema::Embedding;
22use crate::shared_state::UnifiedAppState;
23
24#[derive(Debug, Deserialize)]
26pub struct StreamChatRequest {
27 pub model: Option<String>,
28 pub messages: Vec<Message>,
29 pub session_id: String,
30 #[serde(default = "default_max_tokens")]
31 pub max_tokens: u32,
32 #[serde(default = "default_temperature")]
33 pub temperature: f32,
34 #[serde(default = "default_stream")]
35 pub stream: bool,
36}
37
38fn default_max_tokens() -> u32 { 2000 }
39fn default_temperature() -> f32 { 0.7 }
40fn default_stream() -> bool { true }
41
42pub async fn generate_stream(
49 State(state): State<UnifiedAppState>,
50 Json(req): Json<StreamChatRequest>,
51) -> Response {
52 let request_num = state.shared_state.counters.inc_total_requests();
53 info!("Stream request #{} for session: {}", request_num, req.session_id);
54
55 if req.messages.is_empty() {
56 return (StatusCode::BAD_REQUEST, "Messages array cannot be empty").into_response();
57 }
58
59 let session_id = req.session_id.clone();
60
61 let session = state.shared_state.get_or_create_session(&session_id).await;
63
64 {
66 if let Ok(mut session_data) = session.write() {
67 session_data.last_accessed = std::time::Instant::now();
68 session_data.messages = req.messages.clone();
69 }
70 }
71
72 let user_msg_content = req.messages.iter().rev().find(|m| m.role == "user").map(|m| m.content.clone());
75 if let Some(ref content) = user_msg_content {
76 let db = state.shared_state.database_pool.clone();
77 let sid = session_id.clone();
78 let content = content.clone();
79 let msg_count = req.messages.len() as i32;
80 tokio::spawn(async move {
81 let _ = db.conversations.create_session_with_id(&sid, None);
83 if let Err(e) = db.conversations.store_messages_batch(
85 &sid,
86 &[("user".to_string(), content, msg_count - 1, 0, 0.5)],
87 ) {
88 error!("Failed to persist user message: {}", e);
89 }
90 });
91 }
92
93 let context_messages = {
99 let orchestrator_guard = state.context_orchestrator.read().await;
100 if let Some(ref orchestrator) = *orchestrator_guard {
101 let user_query = user_msg_content.as_deref();
102 match orchestrator.process_conversation(&session_id, &req.messages, user_query).await {
103 Ok(optimized) => {
104 if optimized.len() != req.messages.len() {
105 info!("Context engine optimized: {} → {} messages (retrieved past context)",
106 req.messages.len(), optimized.len());
107 }
108 optimized
109 }
110 Err(e) => {
111 error!("Context engine error (falling back to raw messages): {}", e);
112 req.messages.clone()
113 }
114 }
115 } else {
116 debug!("Context orchestrator not initialized, using raw messages");
117 req.messages.clone()
118 }
119 };
120
121 let llm_worker = state.llm_worker.clone();
123 let max_tokens = req.max_tokens;
124 let temperature = req.temperature;
125 let db_for_persist = state.shared_state.database_pool.clone();
126 let session_id_for_persist = session_id.clone();
127 let msg_index = req.messages.len() as i32;
128
129 let llm_worker_for_embed = state.llm_worker.clone();
131 let db_for_embed_persist = state.shared_state.database_pool.clone();
132 let session_id_for_embed = session_id.clone();
133 let user_msg_for_embed = user_msg_content.clone();
134
135 match llm_worker.stream_response(context_messages, max_tokens, temperature).await {
136 Ok(llm_stream) => {
137 let output_stream = async_stream::stream! {
139 let mut full_response = String::new();
140
141 futures_util::pin_mut!(llm_stream);
142
143 while let Some(item) = llm_stream.next().await {
144 match item {
145 Ok(sse_line) => {
146 if sse_line.starts_with("data: ") && !sse_line.contains("[DONE]") {
148 if let Ok(chunk) = serde_json::from_str::<serde_json::Value>(&sse_line[6..].trim()) {
149 if let Some(content) = chunk
150 .get("choices")
151 .and_then(|c| c.get(0))
152 .and_then(|c| c.get("delta"))
153 .and_then(|d| d.get("content"))
154 .and_then(|c| c.as_str())
155 {
156 full_response.push_str(content);
157 }
158 }
159 }
160
161 let data = sse_line.trim_start_matches("data: ").trim_end().to_string();
163 yield Ok::<_, Infallible>(Event::default().data(data));
164 }
165 Err(e) => {
166 error!("Stream error: {}", e);
167 yield Ok(Event::default().data(
168 format!("{{\"error\": \"{}\"}}", e)
169 ));
170 break;
171 }
172 }
173 }
174
175 if !full_response.is_empty() {
177 match db_for_persist.conversations.store_messages_batch(
178 &session_id_for_persist,
179 &[("assistant".to_string(), full_response.clone(), msg_index, 0, 0.5)],
180 ) {
181 Ok(stored_msgs) => {
182 debug!("Persisted assistant response ({} chars) for session {}",
183 full_response.len(), session_id_for_persist);
184
185 let llm_for_embed = llm_worker_for_embed.clone();
189 let db_for_embed = db_for_embed_persist.clone();
190 let assistant_content = full_response.clone();
191 let user_content_for_embed = user_msg_for_embed.clone();
192 let stored = stored_msgs;
193
194 tokio::spawn(async move {
195 let mut texts = Vec::new();
197 let mut message_ids = Vec::new();
198
199 if let Some(ref user_text) = user_content_for_embed {
201 if let Ok(msgs) = db_for_embed.search_messages_by_keywords(
204 &session_id_for_embed,
205 &[user_text.clone()],
206 1,
207 ).await {
208 if let Some(user_stored) = msgs.first() {
209 texts.push(user_text.clone());
210 message_ids.push(user_stored.id);
211 }
212 }
213 }
214
215 if let Some(assistant_stored) = stored.first() {
217 texts.push(assistant_content);
218 message_ids.push(assistant_stored.id);
219 }
220
221 if texts.is_empty() {
222 return;
223 }
224
225 match llm_for_embed.generate_embeddings(texts).await {
227 Ok(embeddings) => {
228 let now = chrono::Utc::now();
229 for (embedding_vec, msg_id) in embeddings.into_iter().zip(message_ids.iter()) {
230 let emb = Embedding {
231 id: 0, message_id: *msg_id,
233 embedding: embedding_vec,
234 embedding_model: "llama-server".to_string(),
235 generated_at: now,
236 };
237 if let Err(e) = db_for_embed.embeddings.store_embedding(&emb) {
238 debug!("Failed to store embedding for msg {}: {}", msg_id, e);
239 }
240 }
241 for msg_id in &message_ids {
243 let _ = db_for_embed.conversations.mark_embedding_generated(*msg_id);
244 }
245 debug!("Stored {} embeddings for session {}", message_ids.len(), session_id_for_embed);
246 }
247 Err(e) => {
248 debug!("Embedding generation skipped (llama-server may not support /v1/embeddings): {}", e);
249 }
250 }
251 });
252 }
253 Err(e) => {
254 error!("Failed to persist assistant message: {}", e);
255 }
256 }
257 }
258 };
259
260 Sse::new(output_stream)
261 .keep_alive(
262 axum::response::sse::KeepAlive::new()
263 .interval(std::time::Duration::from_secs(15))
264 )
265 .into_response()
266 }
267 Err(e) => {
268 error!("Failed to start LLM stream: {}", e);
269 (StatusCode::BAD_GATEWAY, format!("LLM backend error: {}", e)).into_response()
270 }
271 }
272}