1use std::io::Write;
6
7use alopex_embedded::{Database, TxnMode};
8use serde::{Deserialize, Serialize};
9
10use crate::batch::BatchMode;
11use crate::cli::VectorCommand;
12use crate::client::http::{ClientError, HttpClient};
13use crate::error::{CliError, Result};
14use crate::models::{Column, DataType, Row, Value};
15use crate::output::formatter::Formatter;
16use crate::progress::ProgressIndicator;
17use crate::streaming::{StreamingWriter, WriteStatus};
18
19#[derive(Debug, Serialize)]
20struct RemoteVectorSearchRequest {
21 index: String,
22 query: Vec<f32>,
23 k: usize,
24}
25
26#[derive(Debug, Serialize)]
27struct RemoteVectorUpsertRequest {
28 index: String,
29 key: Vec<u8>,
30 vector: Vec<f32>,
31}
32
33#[derive(Debug, Serialize)]
34struct RemoteVectorDeleteRequest {
35 index: String,
36 key: Vec<u8>,
37}
38
39#[derive(Debug, Deserialize)]
40struct RemoteVectorSearchResult {
41 key: Vec<u8>,
42 distance: f32,
43 #[allow(dead_code)]
44 metadata: Vec<u8>,
45}
46
47#[derive(Debug, Deserialize)]
48struct RemoteVectorSearchResponse {
49 results: Vec<RemoteVectorSearchResult>,
50}
51
52#[derive(Debug, Deserialize)]
53struct RemoteVectorStatusResponse {
54 success: bool,
55}
56
57pub fn execute<W: Write>(
65 db: &Database,
66 cmd: VectorCommand,
67 batch_mode: &BatchMode,
68 writer: &mut StreamingWriter<W>,
69) -> Result<()> {
70 match cmd {
71 VectorCommand::Search {
72 index,
73 query,
74 k,
75 progress,
76 } => execute_search(db, &index, &query, k, progress, batch_mode, writer),
77 VectorCommand::Upsert { index, key, vector } => {
78 execute_upsert(db, &index, &key, &vector, writer)
79 }
80 VectorCommand::Delete { index, key } => execute_delete(db, &index, &key, writer),
81 }
82}
83
84pub async fn execute_remote_with_formatter<W: Write>(
86 client: &HttpClient,
87 cmd: &VectorCommand,
88 batch_mode: &BatchMode,
89 writer: &mut W,
90 formatter: Box<dyn Formatter>,
91 limit: Option<usize>,
92 quiet: bool,
93) -> Result<()> {
94 match cmd {
95 VectorCommand::Search {
96 index,
97 query,
98 k,
99 progress,
100 } => {
101 execute_remote_search(
102 client, index, query, *k, *progress, batch_mode, writer, formatter, limit, quiet,
103 )
104 .await
105 }
106 VectorCommand::Upsert { index, key, vector } => {
107 execute_remote_upsert(client, index, key, vector, writer, formatter, limit, quiet).await
108 }
109 VectorCommand::Delete { index, key } => {
110 execute_remote_delete(client, index, key, writer, formatter, limit, quiet).await
111 }
112 }
113}
114
115#[allow(clippy::too_many_arguments)]
116async fn execute_remote_search<W: Write>(
117 client: &HttpClient,
118 index: &str,
119 query_json: &str,
120 k: usize,
121 progress: bool,
122 batch_mode: &BatchMode,
123 writer: &mut W,
124 formatter: Box<dyn Formatter>,
125 limit: Option<usize>,
126 quiet: bool,
127) -> Result<()> {
128 let query_vector: Vec<f32> = serde_json::from_str(query_json)
129 .map_err(|e| CliError::InvalidArgument(format!("Invalid vector JSON: {}", e)))?;
130
131 let mut progress_indicator = ProgressIndicator::new(
132 batch_mode,
133 progress,
134 quiet,
135 format!("Searching index '{}' for {} nearest neighbors...", index, k),
136 );
137
138 let request = RemoteVectorSearchRequest {
139 index: index.to_string(),
140 query: query_vector,
141 k,
142 };
143 let response: RemoteVectorSearchResponse = client
144 .post_json("hnsw/search", &request)
145 .await
146 .map_err(map_client_error)?;
147
148 progress_indicator.finish_with_message(format!("found {} results.", response.results.len()));
149
150 let columns = vector_search_columns();
151 let mut streaming_writer =
152 StreamingWriter::new(writer, formatter, columns, limit).with_quiet(quiet);
153 streaming_writer.prepare(Some(response.results.len()))?;
154 for result in response.results {
155 let key_display = match std::str::from_utf8(&result.key) {
156 Ok(s) => Value::Text(s.to_string()),
157 Err(_) => Value::Bytes(result.key),
158 };
159 let row = Row::new(vec![key_display, Value::Float(result.distance as f64)]);
160 match streaming_writer.write_row(row)? {
161 WriteStatus::LimitReached => break,
162 WriteStatus::Continue => {}
163 }
164 }
165 streaming_writer.finish()
166}
167
168#[allow(clippy::too_many_arguments)]
169async fn execute_remote_upsert<W: Write>(
170 client: &HttpClient,
171 index: &str,
172 key: &str,
173 vector_json: &str,
174 writer: &mut W,
175 formatter: Box<dyn Formatter>,
176 limit: Option<usize>,
177 quiet: bool,
178) -> Result<()> {
179 let vector: Vec<f32> = serde_json::from_str(vector_json)
180 .map_err(|e| CliError::InvalidArgument(format!("Invalid vector JSON: {}", e)))?;
181 let request = RemoteVectorUpsertRequest {
182 index: index.to_string(),
183 key: key.as_bytes().to_vec(),
184 vector,
185 };
186 let response: RemoteVectorStatusResponse = client
187 .post_json("hnsw/upsert", &request)
188 .await
189 .map_err(map_client_error)?;
190 if response.success {
191 if quiet {
192 return Ok(());
193 }
194 let columns = vector_status_columns();
195 let mut streaming_writer =
196 StreamingWriter::new(writer, formatter, columns, limit).with_quiet(quiet);
197 streaming_writer.prepare(Some(1))?;
198 let row = Row::new(vec![
199 Value::Text("OK".to_string()),
200 Value::Text(format!("Vector '{}' upserted", key)),
201 ]);
202 streaming_writer.write_row(row)?;
203 streaming_writer.finish()
204 } else {
205 Err(CliError::InvalidArgument(
206 "Failed to upsert vector".to_string(),
207 ))
208 }
209}
210
211#[allow(clippy::too_many_arguments)]
212async fn execute_remote_delete<W: Write>(
213 client: &HttpClient,
214 index: &str,
215 key: &str,
216 writer: &mut W,
217 formatter: Box<dyn Formatter>,
218 limit: Option<usize>,
219 quiet: bool,
220) -> Result<()> {
221 let request = RemoteVectorDeleteRequest {
222 index: index.to_string(),
223 key: key.as_bytes().to_vec(),
224 };
225 let response: RemoteVectorStatusResponse = client
226 .post_json("hnsw/delete", &request)
227 .await
228 .map_err(map_client_error)?;
229 if response.success {
230 if quiet {
231 return Ok(());
232 }
233 let columns = vector_status_columns();
234 let mut streaming_writer =
235 StreamingWriter::new(writer, formatter, columns, limit).with_quiet(quiet);
236 streaming_writer.prepare(Some(1))?;
237 let row = Row::new(vec![
238 Value::Text("OK".to_string()),
239 Value::Text(format!("Vector '{}' deleted", key)),
240 ]);
241 streaming_writer.write_row(row)?;
242 streaming_writer.finish()
243 } else {
244 Err(CliError::InvalidArgument(
245 "Failed to delete vector".to_string(),
246 ))
247 }
248}
249
250fn map_client_error(err: ClientError) -> CliError {
251 match err {
252 ClientError::Request { source, .. } => {
253 CliError::ServerConnection(format!("request failed: {source}"))
254 }
255 ClientError::InvalidUrl(message) => CliError::InvalidArgument(message),
256 ClientError::Build(message) => CliError::InvalidArgument(message),
257 ClientError::Auth(err) => CliError::InvalidArgument(err.to_string()),
258 ClientError::HttpStatus { status, body } => {
259 CliError::InvalidArgument(format!("Server error: HTTP {} - {}", status.as_u16(), body))
260 }
261 }
262}
263
264fn execute_search<W: Write>(
266 db: &Database,
267 index: &str,
268 query_json: &str,
269 k: usize,
270 progress: bool,
271 batch_mode: &BatchMode,
272 writer: &mut StreamingWriter<W>,
273) -> Result<()> {
274 let query_vector: Vec<f32> = serde_json::from_str(query_json)
276 .map_err(|e| CliError::InvalidArgument(format!("Invalid vector JSON: {}", e)))?;
277
278 let mut progress_indicator = ProgressIndicator::new(
279 batch_mode,
280 progress,
281 writer.is_quiet(),
282 format!("Searching index '{}' for {} nearest neighbors...", index, k),
283 );
284
285 let (results, _stats) = db.search_hnsw(index, &query_vector, k, None)?;
287
288 progress_indicator.finish_with_message(format!("found {} results.", results.len()));
289
290 writer.prepare(Some(results.len()))?;
292
293 for result in results {
294 let key_display = match std::str::from_utf8(&result.key) {
296 Ok(s) => Value::Text(s.to_string()),
297 Err(_) => Value::Bytes(result.key),
298 };
299
300 let row = Row::new(vec![key_display, Value::Float(result.distance as f64)]);
301
302 match writer.write_row(row)? {
303 WriteStatus::LimitReached => break,
304 WriteStatus::Continue => {}
305 }
306 }
307
308 writer.finish()?;
309 Ok(())
310}
311
312fn execute_upsert<W: Write>(
314 db: &Database,
315 index: &str,
316 key: &str,
317 vector_json: &str,
318 writer: &mut StreamingWriter<W>,
319) -> Result<()> {
320 let vector: Vec<f32> = serde_json::from_str(vector_json)
322 .map_err(|e| CliError::InvalidArgument(format!("Invalid vector JSON: {}", e)))?;
323
324 let mut txn = db.begin(TxnMode::ReadWrite)?;
326
327 txn.upsert_to_hnsw(index, key.as_bytes(), &vector, b"")?;
328
329 txn.commit()?;
330
331 if !writer.is_quiet() {
333 writer.prepare(Some(1))?;
334 let row = Row::new(vec![
335 Value::Text("OK".to_string()),
336 Value::Text(format!("Vector '{}' upserted", key)),
337 ]);
338 writer.write_row(row)?;
339 writer.finish()?;
340 }
341
342 Ok(())
343}
344
345fn execute_delete<W: Write>(
347 db: &Database,
348 index: &str,
349 key: &str,
350 writer: &mut StreamingWriter<W>,
351) -> Result<()> {
352 let mut txn = db.begin(TxnMode::ReadWrite)?;
354
355 txn.delete_from_hnsw(index, key.as_bytes())?;
356
357 txn.commit()?;
358
359 if !writer.is_quiet() {
361 writer.prepare(Some(1))?;
362 let row = Row::new(vec![
363 Value::Text("OK".to_string()),
364 Value::Text(format!("Vector '{}' deleted", key)),
365 ]);
366 writer.write_row(row)?;
367 writer.finish()?;
368 }
369
370 Ok(())
371}
372
373pub fn vector_search_columns() -> Vec<Column> {
375 vec![
376 Column::new("id", DataType::Text),
377 Column::new("distance", DataType::Float),
378 ]
379}
380
381pub fn vector_status_columns() -> Vec<Column> {
383 vec![
384 Column::new("status", DataType::Text),
385 Column::new("message", DataType::Text),
386 ]
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::batch::BatchModeSource;
393 use crate::output::jsonl::JsonlFormatter;
394 use alopex_embedded::HnswConfig;
395
396 fn create_test_db() -> Database {
397 Database::new()
398 }
399
400 fn create_search_writer(output: &mut Vec<u8>) -> StreamingWriter<&mut Vec<u8>> {
401 let formatter = Box::new(JsonlFormatter::new());
402 let columns = vector_search_columns();
403 StreamingWriter::new(output, formatter, columns, None)
404 }
405
406 fn default_batch_mode() -> BatchMode {
407 BatchMode {
408 is_batch: false,
409 is_tty: true,
410 source: BatchModeSource::Default,
411 }
412 }
413
414 fn create_status_writer(output: &mut Vec<u8>) -> StreamingWriter<&mut Vec<u8>> {
415 let formatter = Box::new(JsonlFormatter::new());
416 let columns = vector_status_columns();
417 StreamingWriter::new(output, formatter, columns, None)
418 }
419
420 fn setup_hnsw_index(db: &Database, name: &str) {
421 let config = HnswConfig::default()
422 .with_dimension(3)
423 .with_metric(alopex_embedded::Metric::L2)
424 .with_m(8)
425 .with_ef_construction(32);
426 db.create_hnsw_index(name, config).unwrap();
427 }
428
429 #[test]
430 fn test_upsert_single_vector() {
431 let db = create_test_db();
432 setup_hnsw_index(&db, "test_index");
433
434 let mut output = Vec::new();
435 {
436 let mut writer = create_status_writer(&mut output);
437 execute_upsert(&db, "test_index", "v1", "[1.0, 0.0, 0.0]", &mut writer).unwrap();
438 }
439
440 let result = String::from_utf8(output).unwrap();
441 assert!(result.contains("OK"));
442 assert!(result.contains("Vector 'v1' upserted"));
443 }
444
445 #[test]
446 fn test_search_vectors() {
447 use alopex_embedded::TxnMode;
448
449 let db = create_test_db();
450 setup_hnsw_index(&db, "search_test");
451
452 {
455 let mut txn = db.begin(TxnMode::ReadWrite).unwrap();
456 txn.upsert_to_hnsw("search_test", b"v1", &[1.0_f32, 0.0, 0.0], b"")
457 .unwrap();
458 txn.upsert_to_hnsw("search_test", b"v2", &[0.0_f32, 1.0, 0.0], b"")
459 .unwrap();
460 txn.upsert_to_hnsw("search_test", b"v3", &[0.0_f32, 0.0, 1.0], b"")
461 .unwrap();
462 txn.commit().unwrap();
463 }
464
465 let query = "[1.0, 0.0, 0.0]";
467 let mut output = Vec::new();
468 {
469 let mut writer = create_search_writer(&mut output);
470 execute_search(
471 &db,
472 "search_test",
473 query,
474 2,
475 false,
476 &default_batch_mode(),
477 &mut writer,
478 )
479 .unwrap();
480 }
481
482 let result = String::from_utf8(output).unwrap();
483 assert!(result.contains("v1")); }
485
486 #[test]
487 fn test_delete_single_vector() {
488 use alopex_embedded::TxnMode;
489
490 let db = create_test_db();
491 setup_hnsw_index(&db, "delete_test");
492
493 {
496 let mut txn = db.begin(TxnMode::ReadWrite).unwrap();
497 txn.upsert_to_hnsw("delete_test", b"v1", &[1.0_f32, 0.0, 0.0], b"")
498 .unwrap();
499 txn.upsert_to_hnsw("delete_test", b"v2", &[0.0_f32, 1.0, 0.0], b"")
500 .unwrap();
501 txn.commit().unwrap();
502 }
503
504 let mut output = Vec::new();
506 {
507 let mut writer = create_status_writer(&mut output);
508 execute_delete(&db, "delete_test", "v1", &mut writer).unwrap();
509 }
510
511 let result = String::from_utf8(output).unwrap();
512 assert!(result.contains("OK"));
513 assert!(result.contains("Vector 'v1' deleted"));
514 }
515
516 #[test]
517 fn test_invalid_vector_json() {
518 let db = create_test_db();
519 setup_hnsw_index(&db, "invalid_test");
520
521 let mut output = Vec::new();
522 let mut writer = create_status_writer(&mut output);
523
524 let result = execute_upsert(&db, "invalid_test", "v1", "not valid json", &mut writer);
525 assert!(result.is_err());
526 assert!(matches!(result.unwrap_err(), CliError::InvalidArgument(_)));
527 }
528
529 #[test]
532 fn test_direct_multi_txn_hnsw() {
533 use alopex_embedded::TxnMode;
534
535 let db = Database::new();
536 let config = HnswConfig::default()
538 .with_dimension(2)
539 .with_metric(alopex_embedded::Metric::L2)
540 .with_m(8)
541 .with_ef_construction(32);
542 db.create_hnsw_index("direct_test", config).unwrap();
543
544 let mut txn = db.begin(TxnMode::ReadWrite).unwrap();
546 txn.upsert_to_hnsw("direct_test", b"a", &[0.0_f32, 0.0], b"ma")
547 .unwrap();
548 txn.upsert_to_hnsw("direct_test", b"b", &[1.0_f32, 0.0], b"mb")
549 .unwrap();
550 txn.commit().unwrap();
551
552 let (results, _) = db
554 .search_hnsw("direct_test", &[0.1_f32, 0.0], 1, None)
555 .unwrap();
556 assert_eq!(results.len(), 1);
557 assert_eq!(results[0].key, b"a");
558 }
559}