Skip to main content

cp_tor/
server.rs

1//! Search server: handles incoming search requests from remote peers.
2//!
3//! Reads length-prefixed CBOR `SearchRequests` from any AsyncRead+AsyncWrite
4//! stream, performs vector search via the local graph store, and returns
5//! signed `SearchResponses`.
6
7use std::sync::Arc;
8use std::time::Instant;
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio::sync::Mutex;
11use tracing::{debug, info, warn};
12
13use cp_graph::GraphStore;
14
15use crate::error::{Result, TorError};
16use crate::keys::verify_signature;
17use crate::rate_limit::RateLimiter;
18use crate::types::{
19    RemoteSearchResult, SearchRequest, SearchResponse, SearchStatus, MAX_RESULTS, RRF_K,
20};
21use crate::wire;
22
23/// Configuration for the search server.
24pub struct ServerConfig {
25    /// The node's Ed25519 signing key seed (for signing responses).
26    pub identity_secret: [u8; 32],
27    /// The node's Ed25519 public key.
28    pub identity_public: [u8; 32],
29    /// BLAKE3 hash of the embedding model in use.
30    pub model_hash: [u8; 32],
31    /// Maximum concurrent queries.
32    pub max_concurrent: u8,
33    /// Expected embedding dimension (e.g. 1536).
34    pub embedding_dim: usize,
35}
36
37/// Handle a single incoming search connection.
38///
39/// Reads a `SearchRequest`, verifies it, performs the search, builds a
40/// `SearchResponse`, signs it, and sends it back. This function works
41/// with any `AsyncRead` + `AsyncWrite` stream (TCP, Tor `DataStream`, etc.).
42///
43/// Handles exactly one request-response cycle. For persistent connections
44/// that handle multiple messages, use `handle_connection_loop`.
45pub async fn handle_connection<S>(
46    stream: &mut S,
47    graph: &Arc<Mutex<GraphStore>>,
48    rate_limiter: &Arc<Mutex<RateLimiter>>,
49    config: &ServerConfig,
50) -> Result<()>
51where
52    S: AsyncRead + AsyncWrite + Unpin,
53{
54    let (mut reader, mut writer) = tokio::io::split(stream);
55
56    // Read the request
57    let request: SearchRequest = match wire::read_message(&mut reader).await {
58        Ok(req) => req,
59        Err(TorError::Keepalive) => {
60            debug!("Received keepalive probe");
61            return Ok(());
62        }
63        Err(e) => return Err(e),
64    };
65
66    debug!(
67        "Received search request {} ({} dims, max_results={})",
68        hex::encode(&request.request_id[..4]),
69        request.query_embedding.len(),
70        request.max_results
71    );
72
73    // Verify request signature
74    let signing_bytes = request.signing_bytes();
75    if let Err(e) = verify_signature(&request.public_key, &signing_bytes, &request.signature) {
76        warn!("Invalid request signature: {}", e);
77        let response = build_error_response(&request, SearchStatus::InvalidRequest, config);
78        wire::write_message(&mut writer, &response).await?;
79        return Ok(());
80    }
81
82    // Validate request timestamp (replay protection: reject if outside 30s window)
83    let now_ms = std::time::SystemTime::now()
84        .duration_since(std::time::UNIX_EPOCH)
85        .unwrap()
86        .as_millis() as i64;
87    if (now_ms - request.timestamp).abs() > 30_000 {
88        warn!("Request timestamp outside 30s window");
89        let response = build_error_response(&request, SearchStatus::InvalidRequest, config);
90        wire::write_message(&mut writer, &response).await?;
91        return Ok(());
92    }
93
94    // Check rate limit
95    let rate_allowed = {
96        let mut rl = rate_limiter.lock().await;
97        // Periodic cleanup of stale rate limiter entries to prevent unbounded growth
98        // from ephemeral session keys. Runs every time we process a request.
99        rl.cleanup(now_ms, 120_000);
100        rl.check(&request.public_key, now_ms)
101    };
102
103    if !rate_allowed {
104        debug!(
105            "Rate limited requester {}",
106            hex::encode(&request.public_key[..4])
107        );
108        let response = build_error_response(&request, SearchStatus::Overloaded, config);
109        wire::write_message(&mut writer, &response).await?;
110        return Ok(());
111    }
112
113    // Check model compatibility
114    if request.model_hash != config.model_hash {
115        debug!("Model mismatch from requester");
116        let response = build_error_response(&request, SearchStatus::ModelMismatch, config);
117        wire::write_message(&mut writer, &response).await?;
118        return Ok(());
119    }
120
121    // Validate embedding dimensions
122    if request.query_embedding.len() != config.embedding_dim {
123        warn!(
124            "Embedding dimension mismatch: got {}, expected {}",
125            request.query_embedding.len(),
126            config.embedding_dim
127        );
128        let response = build_error_response(&request, SearchStatus::InvalidRequest, config);
129        wire::write_message(&mut writer, &response).await?;
130        return Ok(());
131    }
132
133    // Perform search
134    let search_start = Instant::now();
135    let max_results = request.max_results.min(MAX_RESULTS) as usize;
136
137    let (results, state_root) = {
138        let store = graph.lock().await;
139
140        // Convert i16 query embedding to f32 for the HNSW index.
141        // Embeddings are quantized with scale 32767 (per CP-010), so we
142        // reverse that to get the unit-normalized f32 vector back.
143        let query_f32: Vec<f32> = request
144            .query_embedding
145            .iter()
146            .map(|&v| f32::from(v) / 32767.0)
147            .collect();
148
149        // Vector search returns Vec<(Uuid, f32)> of (embedding_id, score).
150        // The HNSW index stores embedding IDs, so we need to resolve each
151        // embedding ID → chunk ID → chunk text and document path.
152        let search_hits = store
153            .search(&query_f32, max_results)
154            .map_err(|e| TorError::InvalidRequest(format!("Search failed: {e}")))?;
155
156        // Build a canonical Merkle tree from all chunk IDs for proof generation.
157        // Uses canonical ordering (smaller hash first) so proofs can be verified
158        // without directional flags.
159        let all_chunk_ids = store
160            .get_all_chunk_ids()
161            .map_err(|e| TorError::InvalidRequest(format!("Failed to get chunk IDs: {e}")))?;
162
163        let sorted_leaves: Vec<[u8; 32]> = {
164            let mut ids = all_chunk_ids.clone();
165            ids.sort();
166            ids.iter()
167                .map(|id| *blake3::hash(id.as_bytes()).as_bytes())
168                .collect()
169        };
170        let chunk_tree_root = canonical_merkle_root(&sorted_leaves);
171
172        // Build index for fast lookup: chunk_id → leaf position
173        let mut chunk_id_to_leaf_index: std::collections::HashMap<uuid::Uuid, usize> =
174            std::collections::HashMap::new();
175        {
176            let mut sorted_ids = all_chunk_ids;
177            sorted_ids.sort();
178            for (i, id) in sorted_ids.iter().enumerate() {
179                chunk_id_to_leaf_index.insert(*id, i);
180            }
181        }
182
183        // Resolve embedding IDs to chunks and document paths
184        let mut remote_results = Vec::with_capacity(search_hits.len());
185        for (rank, (embedding_id, _similarity)) in search_hits.into_iter().enumerate() {
186            // Resolve embedding → chunk
187            let Ok(Some(chunk_id)) = store.get_chunk_id_for_embedding(embedding_id) else {
188                continue;
189            };
190
191            let Ok(Some(chunk)) = store.get_chunk(chunk_id) else {
192                continue;
193            };
194
195            let doc_path = match store.get_document(chunk.doc_id) {
196                Ok(Some(doc)) => doc.path.to_string_lossy().to_string(),
197                _ => String::new(),
198            };
199
200            // Generate Merkle proof for this chunk
201            let proof = chunk_id_to_leaf_index
202                .get(&chunk_id)
203                .and_then(|&idx| canonical_merkle_proof(&sorted_leaves, idx));
204
205            remote_results.push(RemoteSearchResult {
206                chunk_id: *chunk_id.as_bytes(),
207                chunk_text: chunk.text,
208                document_path: doc_path,
209                score: rrf_score(rank),
210                merkle_proof: proof,
211            });
212        }
213
214        (remote_results, chunk_tree_root)
215    };
216
217    let search_latency = search_start.elapsed().as_millis() as u16;
218
219    // Build and sign response
220    let response = build_ok_response(&request, results, state_root, search_latency, config);
221
222    wire::write_message(&mut writer, &response).await?;
223
224    info!(
225        "Responded to search {} with {} results in {}ms",
226        hex::encode(&request.request_id[..4]),
227        response.results.len(),
228        search_latency
229    );
230
231    Ok(())
232}
233
234/// Handle a persistent connection that may carry multiple messages.
235///
236/// Loops reading messages from the stream until the remote side disconnects
237/// or an unrecoverable error occurs. Keepalive probes are silently handled.
238/// This is used by the onion service accept loop for long-lived connections.
239pub async fn handle_connection_loop<S>(
240    stream: &mut S,
241    graph: &Arc<Mutex<GraphStore>>,
242    rate_limiter: &Arc<Mutex<RateLimiter>>,
243    config: &ServerConfig,
244) -> Result<()>
245where
246    S: AsyncRead + AsyncWrite + Unpin,
247{
248    loop {
249        match handle_connection(stream, graph, rate_limiter, config).await {
250            Ok(()) => {}
251            Err(TorError::Io(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
252                debug!("Peer disconnected");
253                return Ok(());
254            }
255            Err(TorError::Io(ref e)) if e.kind() == std::io::ErrorKind::ConnectionReset => {
256                debug!("Peer connection reset");
257                return Ok(());
258            }
259            Err(e) => return Err(e),
260        }
261    }
262}
263
264/// Compute RRF score for a given rank (0-indexed).
265fn rrf_score(rank: usize) -> u32 {
266    // RRF formula: 1_000_000 / (K + rank + 1)
267    // Scale up by 1M for integer precision
268    (1_000_000.0 / (f64::from(RRF_K) + rank as f64 + 1.0)) as u32
269}
270
271/// Compute a canonical Merkle root from leaf hashes.
272///
273/// Uses canonical ordering: at each level, the smaller hash goes first.
274/// This allows proof verification without directional flags.
275fn canonical_merkle_root(leaves: &[[u8; 32]]) -> [u8; 32] {
276    if leaves.is_empty() {
277        return [0u8; 32];
278    }
279    if leaves.len() == 1 {
280        return leaves[0];
281    }
282    canonical_merkle_root_recursive(leaves)
283}
284
285fn canonical_merkle_root_recursive(hashes: &[[u8; 32]]) -> [u8; 32] {
286    if hashes.len() == 1 {
287        return hashes[0];
288    }
289
290    let mut next_level = Vec::with_capacity(hashes.len().div_ceil(2));
291    for chunk in hashes.chunks(2) {
292        let mut hasher = blake3::Hasher::new();
293        if chunk.len() > 1 {
294            if chunk[0] <= chunk[1] {
295                hasher.update(&chunk[0]);
296                hasher.update(&chunk[1]);
297            } else {
298                hasher.update(&chunk[1]);
299                hasher.update(&chunk[0]);
300            }
301        } else {
302            hasher.update(&chunk[0]);
303            hasher.update(&chunk[0]);
304        }
305        next_level.push(*hasher.finalize().as_bytes());
306    }
307
308    canonical_merkle_root_recursive(&next_level)
309}
310
311/// Generate a canonical Merkle proof for the leaf at `leaf_index`.
312///
313/// Returns sibling hashes from leaf to root. The verifier uses canonical
314/// ordering (min first) so direction flags are not needed.
315fn canonical_merkle_proof(leaves: &[[u8; 32]], leaf_index: usize) -> Option<Vec<[u8; 32]>> {
316    if leaf_index >= leaves.len() || leaves.is_empty() {
317        return None;
318    }
319    if leaves.len() == 1 {
320        return Some(Vec::new());
321    }
322
323    let mut proof = Vec::new();
324    let mut level = leaves.to_vec();
325    let mut index = leaf_index;
326
327    while level.len() > 1 {
328        let sibling_index = if index.is_multiple_of(2) {
329            if index + 1 < level.len() {
330                index + 1
331            } else {
332                index
333            }
334        } else {
335            index - 1
336        };
337
338        proof.push(level[sibling_index]);
339
340        // Build next level with canonical ordering
341        let mut next_level = Vec::with_capacity(level.len().div_ceil(2));
342        for chunk in level.chunks(2) {
343            let mut hasher = blake3::Hasher::new();
344            if chunk.len() > 1 {
345                if chunk[0] <= chunk[1] {
346                    hasher.update(&chunk[0]);
347                    hasher.update(&chunk[1]);
348                } else {
349                    hasher.update(&chunk[1]);
350                    hasher.update(&chunk[0]);
351                }
352            } else {
353                hasher.update(&chunk[0]);
354                hasher.update(&chunk[0]);
355            }
356            next_level.push(*hasher.finalize().as_bytes());
357        }
358
359        index /= 2;
360        level = next_level;
361    }
362
363    Some(proof)
364}
365
366fn build_error_response(
367    request: &SearchRequest,
368    status: SearchStatus,
369    config: &ServerConfig,
370) -> SearchResponse {
371    let now_ms = std::time::SystemTime::now()
372        .duration_since(std::time::UNIX_EPOCH)
373        .unwrap()
374        .as_millis() as i64;
375
376    let mut response = SearchResponse {
377        request_id: request.request_id,
378        status,
379        results: Vec::new(),
380        peer_state_root: [0u8; 32],
381        search_latency_ms: 0,
382        timestamp: now_ms,
383        signature: [0u8; 64],
384    };
385
386    let signing_bytes = response.signing_bytes();
387    let signing_key = ed25519_dalek::SigningKey::from_bytes(&config.identity_secret);
388    response.signature = ed25519_dalek::Signer::sign(&signing_key, &signing_bytes).to_bytes();
389
390    response
391}
392
393fn build_ok_response(
394    request: &SearchRequest,
395    results: Vec<RemoteSearchResult>,
396    state_root: [u8; 32],
397    search_latency_ms: u16,
398    config: &ServerConfig,
399) -> SearchResponse {
400    let now_ms = std::time::SystemTime::now()
401        .duration_since(std::time::UNIX_EPOCH)
402        .unwrap()
403        .as_millis() as i64;
404
405    let mut response = SearchResponse {
406        request_id: request.request_id,
407        status: SearchStatus::Ok,
408        results,
409        peer_state_root: state_root,
410        search_latency_ms,
411        timestamp: now_ms,
412        signature: [0u8; 64],
413    };
414
415    let signing_bytes = response.signing_bytes();
416    let signing_key = ed25519_dalek::SigningKey::from_bytes(&config.identity_secret);
417    response.signature = ed25519_dalek::Signer::sign(&signing_key, &signing_bytes).to_bytes();
418
419    response
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    #[test]
427    fn test_rrf_score_decreasing() {
428        let scores: Vec<u32> = (0..10).map(rrf_score).collect();
429        for w in scores.windows(2) {
430            assert!(w[0] > w[1], "RRF scores should decrease with rank");
431        }
432    }
433
434    #[test]
435    fn test_rrf_score_rank_0() {
436        let score = rrf_score(0);
437        // 1_000_000 / (60 + 0 + 1) = 1_000_000 / 61 ≈ 16393
438        assert_eq!(score, 16393);
439    }
440
441    #[test]
442    fn test_rrf_score_rank_19() {
443        let score = rrf_score(19);
444        // 1_000_000 / (60 + 19 + 1) = 1_000_000 / 80 = 12500
445        assert_eq!(score, 12500);
446    }
447}