offline_intelligence/api/
search_api.rs1use axum::{
10 extract::State,
11 http::StatusCode,
12 response::IntoResponse,
13 Json,
14};
15use std::sync::Arc;
16use serde::{Deserialize, Serialize};
17use tracing::{info, warn, debug};
18
19use crate::shared_state::SharedState;
20use crate::worker_threads::LLMWorker;
21
22#[derive(Debug, Deserialize)]
24pub struct SearchRequest {
25 pub query: String,
26 pub session_id: Option<String>,
27 pub limit: Option<i32>,
28 pub similarity_threshold: Option<f32>,
30}
31
32#[derive(Debug, Serialize)]
34pub struct SearchResponse {
35 pub results: Vec<SearchResult>,
36 pub total: usize,
37 pub search_type: String, }
39
40#[derive(Debug, Serialize, Clone)]
42pub struct SearchResult {
43 pub session_id: String,
44 pub message_id: i64,
45 pub content: String,
46 pub role: String,
47 pub relevance_score: f32,
48 pub search_source: String, }
50
51pub async fn search(
53 State(shared_state): State<Arc<SharedState>>,
54 Json(payload): Json<SearchRequest>,
55) -> Result<impl IntoResponse, (StatusCode, String)> {
56 info!("Search request: query='{}', session={:?}, limit={:?}",
57 payload.query, payload.session_id, payload.limit);
58
59 if payload.query.trim().is_empty() {
61 return Err((StatusCode::BAD_REQUEST, "Query cannot be empty".to_string()));
62 }
63
64 let limit = payload.limit.unwrap_or(10).clamp(1, 100) as usize;
65 let similarity_threshold = payload.similarity_threshold.unwrap_or(0.3);
66
67 let mut all_results: Vec<SearchResult> = Vec::new();
68 let mut search_type = String::from("keyword"); let llm_worker = &shared_state.llm_worker;
72 let db = &shared_state.database_pool;
73
74 match llm_worker.generate_embeddings(vec![payload.query.clone()]).await {
76 Ok(query_embeddings) if !query_embeddings.is_empty() => {
77 let query_vec = &query_embeddings[0];
78
79 match db.embeddings.find_similar_embeddings(
81 query_vec,
82 "llama-server",
83 (limit * 2) as i32, similarity_threshold,
85 ) {
86 Ok(similar_ids) if !similar_ids.is_empty() => {
87 let similar_ids: Vec<(i64, f32)> = similar_ids;
88 search_type = "semantic".to_string();
89 debug!("Semantic search found {} candidates", similar_ids.len());
90
91 for (message_id, similarity) in &similar_ids {
93 if let Ok(Some(session_id_filter)) = get_message_session_id(db, *message_id) {
95 if let Some(ref filter_sid) = payload.session_id {
97 if &session_id_filter != filter_sid {
98 continue;
99 }
100 }
101 if let Ok(msg) = get_message_by_id(db, *message_id) {
102 all_results.push(SearchResult {
103 session_id: session_id_filter,
104 message_id: *message_id,
105 content: msg.content,
106 role: msg.role,
107 relevance_score: *similarity,
108 search_source: "semantic".to_string(),
109 });
110 }
111 }
112 }
113 }
114 Ok(_) => {
115 debug!("Semantic search returned no results above threshold {}", similarity_threshold);
116 }
117 Err(e) => {
118 debug!("Semantic search failed (falling back to keyword): {}", e);
119 }
120 }
121 }
122 Ok(_) => {
123 debug!("Empty embedding response, falling back to keyword search");
124 }
125 Err(e) => {
126 debug!("Embedding generation unavailable ({}), using keyword search only", e);
127 }
128 }
129
130 let keywords: Vec<String> = payload.query
132 .split_whitespace()
133 .filter(|word| word.len() > 2)
134 .map(|s| s.to_lowercase())
135 .collect();
136
137 if !keywords.is_empty() {
138 let orchestrator_guard = shared_state.context_orchestrator.read().await;
139 if let Some(orchestrator) = &*orchestrator_guard {
140 if let Ok(stored_messages) = orchestrator.search_messages(
141 payload.session_id.as_deref(),
142 &keywords,
143 limit,
144 ).await {
145 let stored_messages: Vec<crate::memory_db::StoredMessage> = stored_messages;
146 let semantic_ids: std::collections::HashSet<i64> = all_results.iter()
147 .map(|r| r.message_id)
148 .collect();
149
150 for msg in stored_messages {
151 if semantic_ids.contains(&msg.id) {
153 continue;
154 }
155
156 let keyword_score = calculate_relevance(&msg.content, &keywords);
157 all_results.push(SearchResult {
158 session_id: msg.session_id,
159 message_id: msg.id,
160 content: msg.content,
161 role: msg.role,
162 relevance_score: keyword_score,
163 search_source: "keyword".to_string(),
164 });
165 }
166
167 if search_type == "semantic" && all_results.iter().any(|r| r.search_source == "keyword") {
168 search_type = "hybrid".to_string();
169 }
170 }
171 }
172 }
173
174 all_results.sort_by(|a, b| b.relevance_score.partial_cmp(&a.relevance_score).unwrap_or(std::cmp::Ordering::Equal));
176 all_results.truncate(limit);
177
178 let total = all_results.len();
179 info!("Search completed: {} results ({})", total, search_type);
180
181 Ok(Json(SearchResponse {
182 results: all_results,
183 total,
184 search_type,
185 }))
186}
187
188fn calculate_relevance(content: &str, keywords: &[String]) -> f32 {
190 let content_lower = content.to_lowercase();
191 let mut score = 0.0;
192
193 for keyword in keywords {
194 let matches = content_lower.matches(keyword).count();
195 if matches > 0 {
196 score += matches as f32 * (keyword.len() as f32 / content.len().max(1) as f32);
197 }
198 }
199
200 score.min(1.0)
201}
202
203fn get_message_session_id(
205 db: &crate::memory_db::MemoryDatabase,
206 message_id: i64,
207) -> anyhow::Result<Option<String>> {
208 let pool_conn = db.conversations.get_conn_public()?;
209 let mut stmt = pool_conn.prepare(
210 "SELECT session_id FROM messages WHERE id = ?1"
211 )?;
212 let mut rows = stmt.query([message_id])?;
213 if let Some(row) = rows.next()? {
214 let sid: String = row.get::<usize, String>(0)?;
215 Ok(Some(sid))
216 } else {
217 Ok(None)
218 }
219}
220
221struct MinimalMessage {
223 content: String,
224 role: String,
225}
226
227fn get_message_by_id(
228 db: &crate::memory_db::MemoryDatabase,
229 message_id: i64,
230) -> anyhow::Result<MinimalMessage> {
231 let conn = db.conversations.get_conn_public()?;
232 let mut stmt = conn.prepare(
233 "SELECT content, role FROM messages WHERE id = ?1"
234 )?;
235 let mut rows = stmt.query([message_id])?;
236 if let Some(row) = rows.next()? {
237 Ok(MinimalMessage {
238 content: row.get::<usize, String>(0)?,
239 role: row.get::<usize, String>(1)?,
240 })
241 } else {
242 Err(anyhow::anyhow!("Message {} not found", message_id))
243 }
244}