Skip to main content

offline_intelligence/api/
search_api.rs

1//! Search API endpoints — hybrid semantic + keyword search across conversations.
2//!
3//! When a query comes in:
4//! 1. Generate an embedding for the query via llama-server /v1/embeddings
5//! 2. Search HNSW index for semantically similar messages (cosine similarity)
6//! 3. Fall back to keyword search if embeddings are unavailable
7//! 4. Merge and rank results by combined relevance score
8
9use axum::{
10    extract::State,
11    http::StatusCode,
12    response::IntoResponse,
13    Json,
14};
15use std::sync::Arc;
16use serde::{Deserialize, Serialize};
17use tracing::{info, debug};
18
19use crate::shared_state::SharedState;
20
21/// Search request payload
22#[derive(Debug, Deserialize)]
23pub struct SearchRequest {
24    pub query: String,
25    pub session_id: Option<String>,
26    pub limit: Option<i32>,
27    /// Minimum similarity threshold for semantic results (0.0 - 1.0, default 0.3)
28    pub similarity_threshold: Option<f32>,
29}
30
31/// Search response
32#[derive(Debug, Serialize)]
33pub struct SearchResponse {
34    pub results: Vec<SearchResult>,
35    pub total: usize,
36    pub search_type: String, // "semantic", "keyword", or "hybrid"
37}
38
39/// Individual search result
40#[derive(Debug, Serialize, Clone)]
41pub struct SearchResult {
42    pub session_id: String,
43    pub message_id: i64,
44    pub content: String,
45    pub role: String,
46    pub relevance_score: f32,
47    pub search_source: String, // "semantic" or "keyword"
48}
49
50/// Search endpoint handler — hybrid semantic + keyword search
51pub async fn search(
52    State(shared_state): State<Arc<SharedState>>,
53    Json(payload): Json<SearchRequest>,
54) -> Result<impl IntoResponse, (StatusCode, String)> {
55    info!("Search request: query='{}', session={:?}, limit={:?}",
56          payload.query, payload.session_id, payload.limit);
57
58    // Validate input
59    if payload.query.trim().is_empty() {
60        return Err((StatusCode::BAD_REQUEST, "Query cannot be empty".to_string()));
61    }
62
63    let limit = payload.limit.unwrap_or(10).clamp(1, 100) as usize;
64    let similarity_threshold = payload.similarity_threshold.unwrap_or(0.3);
65
66    let mut all_results: Vec<SearchResult> = Vec::new();
67    let mut search_type = String::from("keyword"); // default
68
69    // ── Phase 1: Semantic search via embeddings ──
70    let llm_worker = &shared_state.llm_worker;
71    let db = &shared_state.database_pool;
72
73    // Try to generate query embedding
74    match llm_worker.generate_embeddings(vec![payload.query.clone()]).await {
75        Ok(query_embeddings) if !query_embeddings.is_empty() => {
76            let query_vec = &query_embeddings[0];
77
78            // Search HNSW index (or linear fallback) for similar message embeddings
79            match db.embeddings.find_similar_embeddings(
80                query_vec,
81                "llama-server",
82                (limit * 2) as i32, // fetch extra, we'll filter
83                similarity_threshold,
84            ) {
85                Ok(similar_ids) if !similar_ids.is_empty() => {
86                    search_type = "semantic".to_string();
87                    debug!("Semantic search found {} candidates", similar_ids.len());
88
89                    // Fetch the actual messages for each matching embedding
90                    for (message_id, similarity) in &similar_ids {
91                        // Get the message content from DB
92                        if let Ok(Some(session_id_filter)) = get_message_session_id(db, *message_id) {
93                            // If session filter is set, skip messages from other sessions
94                            if let Some(ref filter_sid) = payload.session_id {
95                                if &session_id_filter != filter_sid {
96                                    continue;
97                                }
98                            }
99                            if let Ok(msg) = get_message_by_id(db, *message_id) {
100                                all_results.push(SearchResult {
101                                    session_id: session_id_filter,
102                                    message_id: *message_id,
103                                    content: msg.content,
104                                    role: msg.role,
105                                    relevance_score: *similarity,
106                                    search_source: "semantic".to_string(),
107                                });
108                            }
109                        }
110                    }
111                }
112                Ok(_) => {
113                    debug!("Semantic search returned no results above threshold {}", similarity_threshold);
114                }
115                Err(e) => {
116                    debug!("Semantic search failed (falling back to keyword): {}", e);
117                }
118            }
119        }
120        Ok(_) => {
121            debug!("Empty embedding response, falling back to keyword search");
122        }
123        Err(e) => {
124            debug!("Embedding generation unavailable ({}), using keyword search only", e);
125        }
126    }
127
128    // ── Phase 2: Keyword search (always runs as fallback/supplement) ──
129    let keywords: Vec<String> = payload.query
130        .split_whitespace()
131        .filter(|word| word.len() > 2)
132        .map(|s| s.to_lowercase())
133        .collect();
134
135    if !keywords.is_empty() {
136        let orchestrator_guard = shared_state.context_orchestrator.read().await;
137        if let Some(orchestrator) = &*orchestrator_guard {
138            if let Ok(stored_messages) = orchestrator.search_messages(
139                payload.session_id.as_deref(),
140                &keywords,
141                limit,
142            ).await {
143                let semantic_ids: std::collections::HashSet<i64> = all_results.iter()
144                    .map(|r| r.message_id)
145                    .collect();
146
147                for msg in stored_messages {
148                    // Skip duplicates already found by semantic search
149                    if semantic_ids.contains(&msg.id) {
150                        continue;
151                    }
152
153                    let keyword_score = calculate_relevance(&msg.content, &keywords);
154                    all_results.push(SearchResult {
155                        session_id: msg.session_id,
156                        message_id: msg.id,
157                        content: msg.content,
158                        role: msg.role,
159                        relevance_score: keyword_score,
160                        search_source: "keyword".to_string(),
161                    });
162                }
163
164                if search_type == "semantic" && all_results.iter().any(|r| r.search_source == "keyword") {
165                    search_type = "hybrid".to_string();
166                }
167            }
168        }
169    }
170
171    // ── Phase 3: Sort by relevance and truncate ──
172    all_results.sort_by(|a, b| b.relevance_score.partial_cmp(&a.relevance_score).unwrap_or(std::cmp::Ordering::Equal));
173    all_results.truncate(limit);
174
175    let total = all_results.len();
176    info!("Search completed: {} results ({})", total, search_type);
177
178    Ok(Json(SearchResponse {
179        results: all_results,
180        total,
181        search_type,
182    }))
183}
184
185/// Calculate keyword relevance score
186fn calculate_relevance(content: &str, keywords: &[String]) -> f32 {
187    let content_lower = content.to_lowercase();
188    let mut score = 0.0;
189
190    for keyword in keywords {
191        let matches = content_lower.matches(keyword).count();
192        if matches > 0 {
193            score += matches as f32 * (keyword.len() as f32 / content.len().max(1) as f32);
194        }
195    }
196
197    score.min(1.0)
198}
199
200/// Helper: get the session_id for a message by its ID
201fn get_message_session_id(
202    db: &crate::memory_db::MemoryDatabase,
203    message_id: i64,
204) -> anyhow::Result<Option<String>> {
205    let pool_conn = db.conversations.get_conn_public()?;
206    let mut stmt = pool_conn.prepare(
207        "SELECT session_id FROM messages WHERE id = ?1"
208    )?;
209    let mut rows = stmt.query([message_id])?;
210    if let Some(row) = rows.next()? {
211        let sid: String = row.get(0)?;
212        Ok(Some(sid))
213    } else {
214        Ok(None)
215    }
216}
217
218/// Helper: minimal message data by ID
219struct MinimalMessage {
220    content: String,
221    role: String,
222}
223
224fn get_message_by_id(
225    db: &crate::memory_db::MemoryDatabase,
226    message_id: i64,
227) -> anyhow::Result<MinimalMessage> {
228    let conn = db.conversations.get_conn_public()?;
229    let mut stmt = conn.prepare(
230        "SELECT content, role FROM messages WHERE id = ?1"
231    )?;
232    let mut rows = stmt.query([message_id])?;
233    if let Some(row) = rows.next()? {
234        Ok(MinimalMessage {
235            content: row.get(0)?,
236            role: row.get(1)?,
237        })
238    } else {
239        Err(anyhow::anyhow!("Message {} not found", message_id))
240    }
241}