1use std::sync::Arc;
8use std::time::Instant;
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio::sync::Mutex;
11use tracing::{debug, info, warn};
12
13use cp_graph::GraphStore;
14
15use crate::error::{Result, TorError};
16use crate::keys::verify_signature;
17use crate::rate_limit::RateLimiter;
18use crate::types::{
19 RemoteSearchResult, SearchRequest, SearchResponse, SearchStatus, MAX_RESULTS, RRF_K,
20};
21use crate::wire;
22
23pub struct ServerConfig {
25 pub identity_secret: [u8; 32],
27 pub identity_public: [u8; 32],
29 pub model_hash: [u8; 32],
31 pub max_concurrent: u8,
33 pub embedding_dim: usize,
35}
36
37pub async fn handle_connection<S>(
46 stream: &mut S,
47 graph: &Arc<Mutex<GraphStore>>,
48 rate_limiter: &Arc<Mutex<RateLimiter>>,
49 config: &ServerConfig,
50) -> Result<()>
51where
52 S: AsyncRead + AsyncWrite + Unpin,
53{
54 let (mut reader, mut writer) = tokio::io::split(stream);
55
56 let request: SearchRequest = match wire::read_message(&mut reader).await {
58 Ok(req) => req,
59 Err(TorError::Keepalive) => {
60 debug!("Received keepalive probe");
61 return Ok(());
62 }
63 Err(e) => return Err(e),
64 };
65
66 debug!(
67 "Received search request {} ({} dims, max_results={})",
68 hex::encode(&request.request_id[..4]),
69 request.query_embedding.len(),
70 request.max_results
71 );
72
73 let signing_bytes = request.signing_bytes();
75 if let Err(e) = verify_signature(&request.public_key, &signing_bytes, &request.signature) {
76 warn!("Invalid request signature: {}", e);
77 let response = build_error_response(&request, SearchStatus::InvalidRequest, config);
78 wire::write_message(&mut writer, &response).await?;
79 return Ok(());
80 }
81
82 let now_ms = std::time::SystemTime::now()
84 .duration_since(std::time::UNIX_EPOCH)
85 .unwrap()
86 .as_millis() as i64;
87 if (now_ms - request.timestamp).abs() > 30_000 {
88 warn!("Request timestamp outside 30s window");
89 let response = build_error_response(&request, SearchStatus::InvalidRequest, config);
90 wire::write_message(&mut writer, &response).await?;
91 return Ok(());
92 }
93
94 let rate_allowed = {
96 let mut rl = rate_limiter.lock().await;
97 rl.cleanup(now_ms, 120_000);
100 rl.check(&request.public_key, now_ms)
101 };
102
103 if !rate_allowed {
104 debug!(
105 "Rate limited requester {}",
106 hex::encode(&request.public_key[..4])
107 );
108 let response = build_error_response(&request, SearchStatus::Overloaded, config);
109 wire::write_message(&mut writer, &response).await?;
110 return Ok(());
111 }
112
113 if request.model_hash != config.model_hash {
115 debug!("Model mismatch from requester");
116 let response = build_error_response(&request, SearchStatus::ModelMismatch, config);
117 wire::write_message(&mut writer, &response).await?;
118 return Ok(());
119 }
120
121 if request.query_embedding.len() != config.embedding_dim {
123 warn!(
124 "Embedding dimension mismatch: got {}, expected {}",
125 request.query_embedding.len(),
126 config.embedding_dim
127 );
128 let response = build_error_response(&request, SearchStatus::InvalidRequest, config);
129 wire::write_message(&mut writer, &response).await?;
130 return Ok(());
131 }
132
133 let search_start = Instant::now();
135 let max_results = request.max_results.min(MAX_RESULTS) as usize;
136
137 let (results, state_root) = {
138 let store = graph.lock().await;
139
140 let query_f32: Vec<f32> = request
144 .query_embedding
145 .iter()
146 .map(|&v| f32::from(v) / 32767.0)
147 .collect();
148
149 let search_hits = store
153 .search(&query_f32, max_results)
154 .map_err(|e| TorError::InvalidRequest(format!("Search failed: {e}")))?;
155
156 let all_chunk_ids = store
160 .get_all_chunk_ids()
161 .map_err(|e| TorError::InvalidRequest(format!("Failed to get chunk IDs: {e}")))?;
162
163 let sorted_leaves: Vec<[u8; 32]> = {
164 let mut ids = all_chunk_ids.clone();
165 ids.sort();
166 ids.iter()
167 .map(|id| *blake3::hash(id.as_bytes()).as_bytes())
168 .collect()
169 };
170 let chunk_tree_root = canonical_merkle_root(&sorted_leaves);
171
172 let mut chunk_id_to_leaf_index: std::collections::HashMap<uuid::Uuid, usize> =
174 std::collections::HashMap::new();
175 {
176 let mut sorted_ids = all_chunk_ids;
177 sorted_ids.sort();
178 for (i, id) in sorted_ids.iter().enumerate() {
179 chunk_id_to_leaf_index.insert(*id, i);
180 }
181 }
182
183 let mut remote_results = Vec::with_capacity(search_hits.len());
185 for (rank, (embedding_id, _similarity)) in search_hits.into_iter().enumerate() {
186 let Ok(Some(chunk_id)) = store.get_chunk_id_for_embedding(embedding_id) else {
188 continue;
189 };
190
191 let Ok(Some(chunk)) = store.get_chunk(chunk_id) else {
192 continue;
193 };
194
195 let doc_path = match store.get_document(chunk.doc_id) {
196 Ok(Some(doc)) => doc.path.to_string_lossy().to_string(),
197 _ => String::new(),
198 };
199
200 let proof = chunk_id_to_leaf_index
202 .get(&chunk_id)
203 .and_then(|&idx| canonical_merkle_proof(&sorted_leaves, idx));
204
205 remote_results.push(RemoteSearchResult {
206 chunk_id: *chunk_id.as_bytes(),
207 chunk_text: chunk.text,
208 document_path: doc_path,
209 score: rrf_score(rank),
210 merkle_proof: proof,
211 });
212 }
213
214 (remote_results, chunk_tree_root)
215 };
216
217 let search_latency = search_start.elapsed().as_millis() as u16;
218
219 let response = build_ok_response(&request, results, state_root, search_latency, config);
221
222 wire::write_message(&mut writer, &response).await?;
223
224 info!(
225 "Responded to search {} with {} results in {}ms",
226 hex::encode(&request.request_id[..4]),
227 response.results.len(),
228 search_latency
229 );
230
231 Ok(())
232}
233
234pub async fn handle_connection_loop<S>(
240 stream: &mut S,
241 graph: &Arc<Mutex<GraphStore>>,
242 rate_limiter: &Arc<Mutex<RateLimiter>>,
243 config: &ServerConfig,
244) -> Result<()>
245where
246 S: AsyncRead + AsyncWrite + Unpin,
247{
248 loop {
249 match handle_connection(stream, graph, rate_limiter, config).await {
250 Ok(()) => {}
251 Err(TorError::Io(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
252 debug!("Peer disconnected");
253 return Ok(());
254 }
255 Err(TorError::Io(ref e)) if e.kind() == std::io::ErrorKind::ConnectionReset => {
256 debug!("Peer connection reset");
257 return Ok(());
258 }
259 Err(e) => return Err(e),
260 }
261 }
262}
263
264fn rrf_score(rank: usize) -> u32 {
266 (1_000_000.0 / (f64::from(RRF_K) + rank as f64 + 1.0)) as u32
269}
270
271fn canonical_merkle_root(leaves: &[[u8; 32]]) -> [u8; 32] {
276 if leaves.is_empty() {
277 return [0u8; 32];
278 }
279 if leaves.len() == 1 {
280 return leaves[0];
281 }
282 canonical_merkle_root_recursive(leaves)
283}
284
285fn canonical_merkle_root_recursive(hashes: &[[u8; 32]]) -> [u8; 32] {
286 if hashes.len() == 1 {
287 return hashes[0];
288 }
289
290 let mut next_level = Vec::with_capacity(hashes.len().div_ceil(2));
291 for chunk in hashes.chunks(2) {
292 let mut hasher = blake3::Hasher::new();
293 if chunk.len() > 1 {
294 if chunk[0] <= chunk[1] {
295 hasher.update(&chunk[0]);
296 hasher.update(&chunk[1]);
297 } else {
298 hasher.update(&chunk[1]);
299 hasher.update(&chunk[0]);
300 }
301 } else {
302 hasher.update(&chunk[0]);
303 hasher.update(&chunk[0]);
304 }
305 next_level.push(*hasher.finalize().as_bytes());
306 }
307
308 canonical_merkle_root_recursive(&next_level)
309}
310
311fn canonical_merkle_proof(leaves: &[[u8; 32]], leaf_index: usize) -> Option<Vec<[u8; 32]>> {
316 if leaf_index >= leaves.len() || leaves.is_empty() {
317 return None;
318 }
319 if leaves.len() == 1 {
320 return Some(Vec::new());
321 }
322
323 let mut proof = Vec::new();
324 let mut level = leaves.to_vec();
325 let mut index = leaf_index;
326
327 while level.len() > 1 {
328 let sibling_index = if index.is_multiple_of(2) {
329 if index + 1 < level.len() {
330 index + 1
331 } else {
332 index
333 }
334 } else {
335 index - 1
336 };
337
338 proof.push(level[sibling_index]);
339
340 let mut next_level = Vec::with_capacity(level.len().div_ceil(2));
342 for chunk in level.chunks(2) {
343 let mut hasher = blake3::Hasher::new();
344 if chunk.len() > 1 {
345 if chunk[0] <= chunk[1] {
346 hasher.update(&chunk[0]);
347 hasher.update(&chunk[1]);
348 } else {
349 hasher.update(&chunk[1]);
350 hasher.update(&chunk[0]);
351 }
352 } else {
353 hasher.update(&chunk[0]);
354 hasher.update(&chunk[0]);
355 }
356 next_level.push(*hasher.finalize().as_bytes());
357 }
358
359 index /= 2;
360 level = next_level;
361 }
362
363 Some(proof)
364}
365
366fn build_error_response(
367 request: &SearchRequest,
368 status: SearchStatus,
369 config: &ServerConfig,
370) -> SearchResponse {
371 let now_ms = std::time::SystemTime::now()
372 .duration_since(std::time::UNIX_EPOCH)
373 .unwrap()
374 .as_millis() as i64;
375
376 let mut response = SearchResponse {
377 request_id: request.request_id,
378 status,
379 results: Vec::new(),
380 peer_state_root: [0u8; 32],
381 search_latency_ms: 0,
382 timestamp: now_ms,
383 signature: [0u8; 64],
384 };
385
386 let signing_bytes = response.signing_bytes();
387 let signing_key = ed25519_dalek::SigningKey::from_bytes(&config.identity_secret);
388 response.signature = ed25519_dalek::Signer::sign(&signing_key, &signing_bytes).to_bytes();
389
390 response
391}
392
393fn build_ok_response(
394 request: &SearchRequest,
395 results: Vec<RemoteSearchResult>,
396 state_root: [u8; 32],
397 search_latency_ms: u16,
398 config: &ServerConfig,
399) -> SearchResponse {
400 let now_ms = std::time::SystemTime::now()
401 .duration_since(std::time::UNIX_EPOCH)
402 .unwrap()
403 .as_millis() as i64;
404
405 let mut response = SearchResponse {
406 request_id: request.request_id,
407 status: SearchStatus::Ok,
408 results,
409 peer_state_root: state_root,
410 search_latency_ms,
411 timestamp: now_ms,
412 signature: [0u8; 64],
413 };
414
415 let signing_bytes = response.signing_bytes();
416 let signing_key = ed25519_dalek::SigningKey::from_bytes(&config.identity_secret);
417 response.signature = ed25519_dalek::Signer::sign(&signing_key, &signing_bytes).to_bytes();
418
419 response
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_rrf_score_decreasing() {
428 let scores: Vec<u32> = (0..10).map(rrf_score).collect();
429 for w in scores.windows(2) {
430 assert!(w[0] > w[1], "RRF scores should decrease with rank");
431 }
432 }
433
434 #[test]
435 fn test_rrf_score_rank_0() {
436 let score = rrf_score(0);
437 assert_eq!(score, 16393);
439 }
440
441 #[test]
442 fn test_rrf_score_rank_19() {
443 let score = rrf_score(19);
444 assert_eq!(score, 12500);
446 }
447}