offline_intelligence/api/
search_api.rs1use axum::{
10 extract::State,
11 http::StatusCode,
12 response::IntoResponse,
13 Json,
14};
15use std::sync::Arc;
16use serde::{Deserialize, Serialize};
17use tracing::{info, debug};
18
19use crate::shared_state::SharedState;
20
21#[derive(Debug, Deserialize)]
23pub struct SearchRequest {
24 pub query: String,
25 pub session_id: Option<String>,
26 pub limit: Option<i32>,
27 pub similarity_threshold: Option<f32>,
29}
30
31#[derive(Debug, Serialize)]
33pub struct SearchResponse {
34 pub results: Vec<SearchResult>,
35 pub total: usize,
36 pub search_type: String, }
38
39#[derive(Debug, Serialize, Clone)]
41pub struct SearchResult {
42 pub session_id: String,
43 pub message_id: i64,
44 pub content: String,
45 pub role: String,
46 pub relevance_score: f32,
47 pub search_source: String, }
49
50pub async fn search(
52 State(shared_state): State<Arc<SharedState>>,
53 Json(payload): Json<SearchRequest>,
54) -> Result<impl IntoResponse, (StatusCode, String)> {
55 info!("Search request: query='{}', session={:?}, limit={:?}",
56 payload.query, payload.session_id, payload.limit);
57
58 if payload.query.trim().is_empty() {
60 return Err((StatusCode::BAD_REQUEST, "Query cannot be empty".to_string()));
61 }
62
63 let limit = payload.limit.unwrap_or(10).clamp(1, 100) as usize;
64 let similarity_threshold = payload.similarity_threshold.unwrap_or(0.3);
65
66 let mut all_results: Vec<SearchResult> = Vec::new();
67 let mut search_type = String::from("keyword"); let llm_worker = &shared_state.llm_worker;
71 let db = &shared_state.database_pool;
72
73 match llm_worker.generate_embeddings(vec![payload.query.clone()]).await {
75 Ok(query_embeddings) if !query_embeddings.is_empty() => {
76 let query_vec = &query_embeddings[0];
77
78 match db.embeddings.find_similar_embeddings(
80 query_vec,
81 "llama-server",
82 (limit * 2) as i32, similarity_threshold,
84 ) {
85 Ok(similar_ids) if !similar_ids.is_empty() => {
86 search_type = "semantic".to_string();
87 debug!("Semantic search found {} candidates", similar_ids.len());
88
89 for (message_id, similarity) in &similar_ids {
91 if let Ok(Some(session_id_filter)) = get_message_session_id(db, *message_id) {
93 if let Some(ref filter_sid) = payload.session_id {
95 if &session_id_filter != filter_sid {
96 continue;
97 }
98 }
99 if let Ok(msg) = get_message_by_id(db, *message_id) {
100 all_results.push(SearchResult {
101 session_id: session_id_filter,
102 message_id: *message_id,
103 content: msg.content,
104 role: msg.role,
105 relevance_score: *similarity,
106 search_source: "semantic".to_string(),
107 });
108 }
109 }
110 }
111 }
112 Ok(_) => {
113 debug!("Semantic search returned no results above threshold {}", similarity_threshold);
114 }
115 Err(e) => {
116 debug!("Semantic search failed (falling back to keyword): {}", e);
117 }
118 }
119 }
120 Ok(_) => {
121 debug!("Empty embedding response, falling back to keyword search");
122 }
123 Err(e) => {
124 debug!("Embedding generation unavailable ({}), using keyword search only", e);
125 }
126 }
127
128 let keywords: Vec<String> = payload.query
130 .split_whitespace()
131 .filter(|word| word.len() > 2)
132 .map(|s| s.to_lowercase())
133 .collect();
134
135 if !keywords.is_empty() {
136 let orchestrator_guard = shared_state.context_orchestrator.read().await;
137 if let Some(orchestrator) = &*orchestrator_guard {
138 if let Ok(stored_messages) = orchestrator.search_messages(
139 payload.session_id.as_deref(),
140 &keywords,
141 limit,
142 ).await {
143 let semantic_ids: std::collections::HashSet<i64> = all_results.iter()
144 .map(|r| r.message_id)
145 .collect();
146
147 for msg in stored_messages {
148 if semantic_ids.contains(&msg.id) {
150 continue;
151 }
152
153 let keyword_score = calculate_relevance(&msg.content, &keywords);
154 all_results.push(SearchResult {
155 session_id: msg.session_id,
156 message_id: msg.id,
157 content: msg.content,
158 role: msg.role,
159 relevance_score: keyword_score,
160 search_source: "keyword".to_string(),
161 });
162 }
163
164 if search_type == "semantic" && all_results.iter().any(|r| r.search_source == "keyword") {
165 search_type = "hybrid".to_string();
166 }
167 }
168 }
169 }
170
171 all_results.sort_by(|a, b| b.relevance_score.partial_cmp(&a.relevance_score).unwrap_or(std::cmp::Ordering::Equal));
173 all_results.truncate(limit);
174
175 let total = all_results.len();
176 info!("Search completed: {} results ({})", total, search_type);
177
178 Ok(Json(SearchResponse {
179 results: all_results,
180 total,
181 search_type,
182 }))
183}
184
185fn calculate_relevance(content: &str, keywords: &[String]) -> f32 {
187 let content_lower = content.to_lowercase();
188 let mut score = 0.0;
189
190 for keyword in keywords {
191 let matches = content_lower.matches(keyword).count();
192 if matches > 0 {
193 score += matches as f32 * (keyword.len() as f32 / content.len().max(1) as f32);
194 }
195 }
196
197 score.min(1.0)
198}
199
200fn get_message_session_id(
202 db: &crate::memory_db::MemoryDatabase,
203 message_id: i64,
204) -> anyhow::Result<Option<String>> {
205 let pool_conn = db.conversations.get_conn_public()?;
206 let mut stmt = pool_conn.prepare(
207 "SELECT session_id FROM messages WHERE id = ?1"
208 )?;
209 let mut rows = stmt.query([message_id])?;
210 if let Some(row) = rows.next()? {
211 let sid: String = row.get(0)?;
212 Ok(Some(sid))
213 } else {
214 Ok(None)
215 }
216}
217
218struct MinimalMessage {
220 content: String,
221 role: String,
222}
223
224fn get_message_by_id(
225 db: &crate::memory_db::MemoryDatabase,
226 message_id: i64,
227) -> anyhow::Result<MinimalMessage> {
228 let conn = db.conversations.get_conn_public()?;
229 let mut stmt = conn.prepare(
230 "SELECT content, role FROM messages WHERE id = ?1"
231 )?;
232 let mut rows = stmt.query([message_id])?;
233 if let Some(row) = rows.next()? {
234 Ok(MinimalMessage {
235 content: row.get(0)?,
236 role: row.get(1)?,
237 })
238 } else {
239 Err(anyhow::anyhow!("Message {} not found", message_id))
240 }
241}