Skip to main content

offline_intelligence/api/
search_api.rs

1//! Search API endpoints — hybrid semantic + keyword search across conversations.
2//!
3//! When a query comes in:
4//! 1. Generate an embedding for the query via llama-server /v1/embeddings
5//! 2. Search HNSW index for semantically similar messages (cosine similarity)
6//! 3. Fall back to keyword search if embeddings are unavailable
7//! 4. Merge and rank results by combined relevance score
8
9use axum::{
10    extract::State,
11    http::StatusCode,
12    response::IntoResponse,
13    Json,
14};
15use std::sync::Arc;
16use serde::{Deserialize, Serialize};
17use tracing::{info, warn, debug};
18
19use crate::shared_state::SharedState;
20use crate::worker_threads::LLMWorker;
21
22/// Search request payload
23#[derive(Debug, Deserialize)]
24pub struct SearchRequest {
25    pub query: String,
26    pub session_id: Option<String>,
27    pub limit: Option<i32>,
28    /// Minimum similarity threshold for semantic results (0.0 - 1.0, default 0.3)
29    pub similarity_threshold: Option<f32>,
30}
31
32/// Search response
33#[derive(Debug, Serialize)]
34pub struct SearchResponse {
35    pub results: Vec<SearchResult>,
36    pub total: usize,
37    pub search_type: String, // "semantic", "keyword", or "hybrid"
38}
39
40/// Individual search result
41#[derive(Debug, Serialize, Clone)]
42pub struct SearchResult {
43    pub session_id: String,
44    pub message_id: i64,
45    pub content: String,
46    pub role: String,
47    pub relevance_score: f32,
48    pub search_source: String, // "semantic" or "keyword"
49}
50
51/// Search endpoint handler — hybrid semantic + keyword search
52pub async fn search(
53    State(shared_state): State<Arc<SharedState>>,
54    Json(payload): Json<SearchRequest>,
55) -> Result<impl IntoResponse, (StatusCode, String)> {
56    info!("Search request: query='{}', session={:?}, limit={:?}",
57          payload.query, payload.session_id, payload.limit);
58
59    // Validate input
60    if payload.query.trim().is_empty() {
61        return Err((StatusCode::BAD_REQUEST, "Query cannot be empty".to_string()));
62    }
63
64    let limit = payload.limit.unwrap_or(10).clamp(1, 100) as usize;
65    let similarity_threshold = payload.similarity_threshold.unwrap_or(0.3);
66
67    let mut all_results: Vec<SearchResult> = Vec::new();
68    let mut search_type = String::from("keyword"); // default
69
70    // ── Phase 1: Semantic search via embeddings ──
71    let llm_worker = &shared_state.llm_worker;
72    let db = &shared_state.database_pool;
73
74    // Try to generate query embedding
75    match llm_worker.generate_embeddings(vec![payload.query.clone()]).await {
76        Ok(query_embeddings) if !query_embeddings.is_empty() => {
77            let query_vec = &query_embeddings[0];
78
79            // Search HNSW index (or linear fallback) for similar message embeddings
80            match db.embeddings.find_similar_embeddings(
81                query_vec,
82                "llama-server",
83                (limit * 2) as i32, // fetch extra, we'll filter
84                similarity_threshold,
85            ) {
86                Ok(similar_ids) if !similar_ids.is_empty() => {
87                    let similar_ids: Vec<(i64, f32)> = similar_ids;
88                    search_type = "semantic".to_string();
89                    debug!("Semantic search found {} candidates", similar_ids.len());
90
91                    // Fetch the actual messages for each matching embedding
92                    for (message_id, similarity) in &similar_ids {
93                        // Get the message content from DB
94                        if let Ok(Some(session_id_filter)) = get_message_session_id(db, *message_id) {
95                            // If session filter is set, skip messages from other sessions
96                            if let Some(ref filter_sid) = payload.session_id {
97                                if &session_id_filter != filter_sid {
98                                    continue;
99                                }
100                            }
101                            if let Ok(msg) = get_message_by_id(db, *message_id) {
102                                all_results.push(SearchResult {
103                                    session_id: session_id_filter,
104                                    message_id: *message_id,
105                                    content: msg.content,
106                                    role: msg.role,
107                                    relevance_score: *similarity,
108                                    search_source: "semantic".to_string(),
109                                });
110                            }
111                        }
112                    }
113                }
114                Ok(_) => {
115                    debug!("Semantic search returned no results above threshold {}", similarity_threshold);
116                }
117                Err(e) => {
118                    debug!("Semantic search failed (falling back to keyword): {}", e);
119                }
120            }
121        }
122        Ok(_) => {
123            debug!("Empty embedding response, falling back to keyword search");
124        }
125        Err(e) => {
126            debug!("Embedding generation unavailable ({}), using keyword search only", e);
127        }
128    }
129
130    // ── Phase 2: Keyword search (always runs as fallback/supplement) ──
131    let keywords: Vec<String> = payload.query
132        .split_whitespace()
133        .filter(|word| word.len() > 2)
134        .map(|s| s.to_lowercase())
135        .collect();
136
137    if !keywords.is_empty() {
138        let orchestrator_guard = shared_state.context_orchestrator.read().await;
139        if let Some(orchestrator) = &*orchestrator_guard {
140            if let Ok(stored_messages) = orchestrator.search_messages(
141                payload.session_id.as_deref(),
142                &keywords,
143                limit,
144            ).await {
145                let stored_messages: Vec<crate::memory_db::StoredMessage> = stored_messages;
146                let semantic_ids: std::collections::HashSet<i64> = all_results.iter()
147                    .map(|r| r.message_id)
148                    .collect();
149
150                for msg in stored_messages {
151                    // Skip duplicates already found by semantic search
152                    if semantic_ids.contains(&msg.id) {
153                        continue;
154                    }
155
156                    let keyword_score = calculate_relevance(&msg.content, &keywords);
157                    all_results.push(SearchResult {
158                        session_id: msg.session_id,
159                        message_id: msg.id,
160                        content: msg.content,
161                        role: msg.role,
162                        relevance_score: keyword_score,
163                        search_source: "keyword".to_string(),
164                    });
165                }
166
167                if search_type == "semantic" && all_results.iter().any(|r| r.search_source == "keyword") {
168                    search_type = "hybrid".to_string();
169                }
170            }
171        }
172    }
173
174    // ── Phase 3: Sort by relevance and truncate ──
175    all_results.sort_by(|a, b| b.relevance_score.partial_cmp(&a.relevance_score).unwrap_or(std::cmp::Ordering::Equal));
176    all_results.truncate(limit);
177
178    let total = all_results.len();
179    info!("Search completed: {} results ({})", total, search_type);
180
181    Ok(Json(SearchResponse {
182        results: all_results,
183        total,
184        search_type,
185    }))
186}
187
188/// Calculate keyword relevance score
189fn calculate_relevance(content: &str, keywords: &[String]) -> f32 {
190    let content_lower = content.to_lowercase();
191    let mut score = 0.0;
192
193    for keyword in keywords {
194        let matches = content_lower.matches(keyword).count();
195        if matches > 0 {
196            score += matches as f32 * (keyword.len() as f32 / content.len().max(1) as f32);
197        }
198    }
199
200    score.min(1.0)
201}
202
203/// Helper: get the session_id for a message by its ID
204fn get_message_session_id(
205    db: &crate::memory_db::MemoryDatabase,
206    message_id: i64,
207) -> anyhow::Result<Option<String>> {
208    let pool_conn = db.conversations.get_conn_public()?;
209    let mut stmt = pool_conn.prepare(
210        "SELECT session_id FROM messages WHERE id = ?1"
211    )?;
212    let mut rows = stmt.query([message_id])?;
213    if let Some(row) = rows.next()? {
214        let sid: String = row.get::<usize, String>(0)?;
215        Ok(Some(sid))
216    } else {
217        Ok(None)
218    }
219}
220
221/// Helper: minimal message data by ID
222struct MinimalMessage {
223    content: String,
224    role: String,
225}
226
227fn get_message_by_id(
228    db: &crate::memory_db::MemoryDatabase,
229    message_id: i64,
230) -> anyhow::Result<MinimalMessage> {
231    let conn = db.conversations.get_conn_public()?;
232    let mut stmt = conn.prepare(
233        "SELECT content, role FROM messages WHERE id = ?1"
234    )?;
235    let mut rows = stmt.query([message_id])?;
236    if let Some(row) = rows.next()? {
237        Ok(MinimalMessage {
238            content: row.get::<usize, String>(0)?,
239            role: row.get::<usize, String>(1)?,
240        })
241    } else {
242        Err(anyhow::anyhow!("Message {} not found", message_id))
243    }
244}