Skip to main content

offline_intelligence/api/
stream_api.rs

1//! Streaming chat endpoint — the core 1-hop architecture handler.
2//!
3//! Flow: Client POST → SharedState (session + cache lookup) → LLM Worker (HTTP to llama-server) → SSE stream back
4//! All state access is in-process via Arc/shared memory. The only network hop is to localhost llama-server.
5
6use axum::{
7    extract::State,
8    response::{
9        sse::{Event, Sse},
10        IntoResponse, Response,
11    },
12    http::StatusCode,
13    Json,
14};
15use futures_util::StreamExt;
16use serde::Deserialize;
17use std::convert::Infallible;
18use tracing::{info, error, debug};
19
20use crate::memory::Message;
21use crate::memory_db::schema::Embedding;
22use crate::shared_state::UnifiedAppState;
23
24/// Request body matching what the frontend sends
25#[derive(Debug, Deserialize)]
26pub struct StreamChatRequest {
27    pub model: Option<String>,
28    pub messages: Vec<Message>,
29    pub session_id: String,
30    #[serde(default = "default_max_tokens")]
31    pub max_tokens: u32,
32    #[serde(default = "default_temperature")]
33    pub temperature: f32,
34    #[serde(default = "default_stream")]
35    pub stream: bool,
36}
37
38fn default_max_tokens() -> u32 { 2000 }
39fn default_temperature() -> f32 { 0.7 }
40fn default_stream() -> bool { true }
41
42/// POST /generate/stream — Main streaming chat endpoint
43///
44/// 1. Validates request and gets/creates session in shared memory
45/// 2. Persists user message to database
46/// 3. Streams LLM response back via SSE
47/// 4. Persists assistant response to database after completion
48pub async fn generate_stream(
49    State(state): State<UnifiedAppState>,
50    Json(req): Json<StreamChatRequest>,
51) -> Response {
52    let request_num = state.shared_state.counters.inc_total_requests();
53    info!("Stream request #{} for session: {}", request_num, req.session_id);
54
55    if req.messages.is_empty() {
56        return (StatusCode::BAD_REQUEST, "Messages array cannot be empty").into_response();
57    }
58
59    let session_id = req.session_id.clone();
60
61    // 1. Get or create session in shared memory (zero-cost Arc lookup)
62    let session = state.shared_state.get_or_create_session(&session_id).await;
63
64    // 2. Update in-memory session with the incoming messages
65    {
66        if let Ok(mut session_data) = session.write() {
67            session_data.last_accessed = std::time::Instant::now();
68            session_data.messages = req.messages.clone();
69        }
70    }
71
72    // 3. Ensure session exists in database and persist user message
73    //    Also capture the user message content for embedding generation later
74    let user_msg_content = req.messages.iter().rev().find(|m| m.role == "user").map(|m| m.content.clone());
75    if let Some(ref content) = user_msg_content {
76        let db = state.shared_state.database_pool.clone();
77        let sid = session_id.clone();
78        let content = content.clone();
79        let msg_count = req.messages.len() as i32;
80        tokio::spawn(async move {
81            // Ensure session exists in DB (ignore error if already exists)
82            let _ = db.conversations.create_session_with_id(&sid, None);
83            // Persist user message via batch API
84            if let Err(e) = db.conversations.store_messages_batch(
85                &sid,
86                &[("user".to_string(), content, msg_count - 1, 0, 0.5)],
87            ) {
88                error!("Failed to persist user message: {}", e);
89            }
90        });
91    }
92
93    // 4. Context Engine: Retrieve past context via semantic search when KV cache misses.
94    //    Always let the retrieval planner decide — even a brand-new session can trigger
95    //    cross-session search if the user asks "what did we discuss yesterday?".
96    //    The planner + orchestrator handle the "nothing to search" case internally
97    //    (checks has_embeddings > 0 before hitting llama-server, returns early if no past refs).
98    let context_messages = {
99        let orchestrator_guard = state.context_orchestrator.read().await;
100        if let Some(ref orchestrator) = *orchestrator_guard {
101            let user_query = user_msg_content.as_deref();
102            match orchestrator.process_conversation(&session_id, &req.messages, user_query).await {
103                Ok(optimized) => {
104                    if optimized.len() != req.messages.len() {
105                        info!("Context engine optimized: {} → {} messages (retrieved past context)",
106                            req.messages.len(), optimized.len());
107                    }
108                    optimized
109                }
110                Err(e) => {
111                    error!("Context engine error (falling back to raw messages): {}", e);
112                    req.messages.clone()
113                }
114            }
115        } else {
116            debug!("Context orchestrator not initialized, using raw messages");
117            req.messages.clone()
118        }
119    };
120
121    // 5. Stream from LLM worker (single network hop to localhost llama-server)
122    let llm_worker = state.llm_worker.clone();
123    let max_tokens = req.max_tokens;
124    let temperature = req.temperature;
125    let db_for_persist = state.shared_state.database_pool.clone();
126    let session_id_for_persist = session_id.clone();
127    let msg_index = req.messages.len() as i32;
128
129    // Clones for background embedding generation after stream completes
130    let llm_worker_for_embed = state.llm_worker.clone();
131    let db_for_embed_persist = state.shared_state.database_pool.clone();
132    let session_id_for_embed = session_id.clone();
133    let user_msg_for_embed = user_msg_content.clone();
134
135    match llm_worker.stream_response(context_messages, max_tokens, temperature).await {
136        Ok(llm_stream) => {
137            // Wrap the LLM stream to collect the full response for DB persistence
138            let output_stream = async_stream::stream! {
139                let mut full_response = String::new();
140
141                futures_util::pin_mut!(llm_stream);
142
143                while let Some(item) = llm_stream.next().await {
144                    match item {
145                        Ok(sse_line) => {
146                            // Extract content from SSE data for persistence
147                            if sse_line.starts_with("data: ") && !sse_line.contains("[DONE]") {
148                                if let Ok(chunk) = serde_json::from_str::<serde_json::Value>(&sse_line[6..].trim()) {
149                                    if let Some(content) = chunk
150                                        .get("choices")
151                                        .and_then(|c| c.get(0))
152                                        .and_then(|c| c.get("delta"))
153                                        .and_then(|d| d.get("content"))
154                                        .and_then(|c| c.as_str())
155                                    {
156                                        full_response.push_str(content);
157                                    }
158                                }
159                            }
160
161                            // Yield SSE event to client
162                            let data = sse_line.trim_start_matches("data: ").trim_end().to_string();
163                            yield Ok::<_, Infallible>(Event::default().data(data));
164                        }
165                        Err(e) => {
166                            error!("Stream error: {}", e);
167                            yield Ok(Event::default().data(
168                                format!("{{\"error\": \"{}\"}}", e)
169                            ));
170                            break;
171                        }
172                    }
173                }
174
175                // Persist assistant response to database after stream completes
176                if !full_response.is_empty() {
177                    match db_for_persist.conversations.store_messages_batch(
178                        &session_id_for_persist,
179                        &[("assistant".to_string(), full_response.clone(), msg_index, 0, 0.5)],
180                    ) {
181                        Ok(stored_msgs) => {
182                            debug!("Persisted assistant response ({} chars) for session {}",
183                                full_response.len(), session_id_for_persist);
184
185                            // Background: Generate and store embeddings for the new messages
186                            // This captures the vectors llama.cpp computes via /v1/embeddings
187                            // enabling semantic search for future KV cache misses.
188                            let llm_for_embed = llm_worker_for_embed.clone();
189                            let db_for_embed = db_for_embed_persist.clone();
190                            let assistant_content = full_response.clone();
191                            let user_content_for_embed = user_msg_for_embed.clone();
192                            let stored = stored_msgs;
193
194                            tokio::spawn(async move {
195                                // Collect texts + their message IDs for embedding
196                                let mut texts = Vec::new();
197                                let mut message_ids = Vec::new();
198
199                                // User message embedding (get ID from DB)
200                                if let Some(ref user_text) = user_content_for_embed {
201                                    // The user message was stored one index before the assistant
202                                    // We need its DB ID — query by session + content
203                                    if let Ok(msgs) = db_for_embed.search_messages_by_keywords(
204                                        &session_id_for_embed,
205                                        &[user_text.clone()],
206                                        1,
207                                    ).await {
208                                        if let Some(user_stored) = msgs.first() {
209                                            texts.push(user_text.clone());
210                                            message_ids.push(user_stored.id);
211                                        }
212                                    }
213                                }
214
215                                // Assistant message embedding
216                                if let Some(assistant_stored) = stored.first() {
217                                    texts.push(assistant_content);
218                                    message_ids.push(assistant_stored.id);
219                                }
220
221                                if texts.is_empty() {
222                                    return;
223                                }
224
225                                // Call llama-server /v1/embeddings
226                                match llm_for_embed.generate_embeddings(texts).await {
227                                    Ok(embeddings) => {
228                                        let now = chrono::Utc::now();
229                                        for (embedding_vec, msg_id) in embeddings.into_iter().zip(message_ids.iter()) {
230                                            let emb = Embedding {
231                                                id: 0, // auto-assigned by DB
232                                                message_id: *msg_id,
233                                                embedding: embedding_vec,
234                                                embedding_model: "llama-server".to_string(),
235                                                generated_at: now,
236                                            };
237                                            if let Err(e) = db_for_embed.embeddings.store_embedding(&emb) {
238                                                debug!("Failed to store embedding for msg {}: {}", msg_id, e);
239                                            }
240                                        }
241                                        // Mark messages as having embeddings
242                                        for msg_id in &message_ids {
243                                            let _ = db_for_embed.conversations.mark_embedding_generated(*msg_id);
244                                        }
245                                        debug!("Stored {} embeddings for session {}", message_ids.len(), session_id_for_embed);
246                                    }
247                                    Err(e) => {
248                                        debug!("Embedding generation skipped (llama-server may not support /v1/embeddings): {}", e);
249                                    }
250                                }
251                            });
252                        }
253                        Err(e) => {
254                            error!("Failed to persist assistant message: {}", e);
255                        }
256                    }
257                }
258            };
259
260            Sse::new(output_stream)
261                .keep_alive(
262                    axum::response::sse::KeepAlive::new()
263                        .interval(std::time::Duration::from_secs(15))
264                )
265                .into_response()
266        }
267        Err(e) => {
268            error!("Failed to start LLM stream: {}", e);
269            (StatusCode::BAD_GATEWAY, format!("LLM backend error: {}", e)).into_response()
270        }
271    }
272}