1use std::io::Write;
6use std::path::PathBuf;
7
8use alopex_embedded::{Database, TxnMode};
9use serde::{Deserialize, Serialize};
10
11use crate::batch::BatchMode;
12use crate::cli::{OutputFormat, VectorCommand};
13use crate::client::http::{ClientError, HttpClient};
14use crate::error::{CliError, Result};
15use crate::models::{Column, DataType, Row, Value};
16use crate::output::formatter::Formatter;
17use crate::output::RowCollector;
18use crate::progress::ProgressIndicator;
19use crate::streaming::{StreamingWriter, WriteStatus};
20use crate::tui::admin::{AdminBackend, AdminContext, AdminTarget, AuthCapabilities};
21use crate::tui::renderer::render_output;
22
23#[derive(Debug, Serialize)]
24struct RemoteVectorSearchRequest {
25 index: String,
26 query: Vec<f32>,
27 k: usize,
28}
29
30#[derive(Debug, Serialize)]
31struct RemoteVectorUpsertRequest {
32 index: String,
33 key: Vec<u8>,
34 vector: Vec<f32>,
35}
36
37#[derive(Debug, Serialize)]
38struct RemoteVectorDeleteRequest {
39 index: String,
40 key: Vec<u8>,
41}
42
43#[derive(Debug, Deserialize)]
44struct RemoteVectorSearchResult {
45 key: Vec<u8>,
46 distance: f32,
47 #[allow(dead_code)]
48 metadata: Vec<u8>,
49}
50
51#[derive(Debug, Deserialize)]
52struct RemoteVectorSearchResponse {
53 results: Vec<RemoteVectorSearchResult>,
54}
55
56#[derive(Debug, Deserialize)]
57struct RemoteVectorStatusResponse {
58 success: bool,
59}
60
61pub fn execute<W: Write>(
69 db: &Database,
70 cmd: VectorCommand,
71 batch_mode: &BatchMode,
72 writer: &mut StreamingWriter<W>,
73) -> Result<()> {
74 match cmd {
75 VectorCommand::Search {
76 index,
77 query,
78 k,
79 progress,
80 } => execute_search(db, &index, &query, k, progress, batch_mode, writer),
81 VectorCommand::Upsert { index, key, vector } => {
82 execute_upsert(db, &index, &key, &vector, writer)
83 }
84 VectorCommand::Delete { index, key } => execute_delete(db, &index, &key, writer),
85 }
86}
87
88#[allow(clippy::too_many_arguments)]
89pub fn execute_tui(
90 db: &Database,
91 cmd: VectorCommand,
92 batch_mode: &BatchMode,
93 output_format: OutputFormat,
94 columns: Vec<Column>,
95 limit: Option<usize>,
96 quiet: bool,
97 connection_label: impl Into<String>,
98 data_dir: Option<PathBuf>,
99) -> Result<()> {
100 let connection_label = connection_label.into();
101 let context_message = Some(vector_command_context(&cmd));
102 let admin_label = connection_label.clone();
103 let admin_data_dir = data_dir.clone();
104 let admin_launcher: Option<Box<dyn FnMut() -> Result<()> + '_>> = Some(Box::new(move || {
105 let connection_label = admin_label.clone();
106 let data_dir = admin_data_dir.clone();
107 crate::tui::admin::run_admin_ui(AdminContext {
108 connection_label,
109 auth: AuthCapabilities::full(),
110 backend: AdminBackend::Local {
111 db,
112 batch_mode,
113 output_format,
114 limit,
115 quiet,
116 data_dir,
117 },
118 initial_target: Some(AdminTarget::Vector),
119 })
120 }));
121 let collector = RowCollector::new();
122 let formatter = Box::new(collector.formatter());
123 let mut sink = std::io::sink();
124 let mut writer =
125 StreamingWriter::new(&mut sink, formatter, columns.clone(), limit).with_quiet(quiet);
126 execute(db, cmd, batch_mode, &mut writer)?;
127 let warning = collector.truncation_warning();
128 render_output(
129 columns,
130 collector.rows(),
131 connection_label,
132 context_message,
133 true,
134 warning,
135 output_format,
136 admin_launcher,
137 )
138}
139
140pub async fn execute_remote_with_formatter<W: Write>(
142 client: &HttpClient,
143 cmd: &VectorCommand,
144 batch_mode: &BatchMode,
145 writer: &mut W,
146 formatter: Box<dyn Formatter>,
147 limit: Option<usize>,
148 quiet: bool,
149) -> Result<()> {
150 match cmd {
151 VectorCommand::Search {
152 index,
153 query,
154 k,
155 progress,
156 } => {
157 execute_remote_search(
158 client, index, query, *k, *progress, batch_mode, writer, formatter, limit, quiet,
159 )
160 .await
161 }
162 VectorCommand::Upsert { index, key, vector } => {
163 execute_remote_upsert(client, index, key, vector, writer, formatter, limit, quiet).await
164 }
165 VectorCommand::Delete { index, key } => {
166 execute_remote_delete(client, index, key, writer, formatter, limit, quiet).await
167 }
168 }
169}
170
171#[allow(clippy::too_many_arguments)]
172pub async fn execute_remote_tui<'a>(
173 client: &HttpClient,
174 cmd: &VectorCommand,
175 batch_mode: &BatchMode,
176 columns: Vec<Column>,
177 output_format: OutputFormat,
178 limit: Option<usize>,
179 quiet: bool,
180 connection_label: impl Into<String>,
181 admin_launcher: Option<Box<dyn FnMut() -> Result<()> + 'a>>,
182) -> Result<()> {
183 let collector = RowCollector::new();
184 let formatter = Box::new(collector.formatter());
185 let mut sink = std::io::sink();
186 execute_remote_with_formatter(client, cmd, batch_mode, &mut sink, formatter, limit, quiet)
187 .await?;
188 let warning = collector.truncation_warning();
189 render_output(
190 columns,
191 collector.rows(),
192 connection_label,
193 Some(vector_command_context(cmd)),
194 true,
195 warning,
196 output_format,
197 admin_launcher,
198 )
199}
200
201#[allow(clippy::too_many_arguments)]
202async fn execute_remote_search<W: Write>(
203 client: &HttpClient,
204 index: &str,
205 query_json: &str,
206 k: usize,
207 progress: bool,
208 batch_mode: &BatchMode,
209 writer: &mut W,
210 formatter: Box<dyn Formatter>,
211 limit: Option<usize>,
212 quiet: bool,
213) -> Result<()> {
214 let query_vector: Vec<f32> = serde_json::from_str(query_json)
215 .map_err(|e| CliError::InvalidArgument(format!("Invalid vector JSON: {}", e)))?;
216
217 let mut progress_indicator = ProgressIndicator::new(
218 batch_mode,
219 progress,
220 quiet,
221 format!("Searching index '{}' for {} nearest neighbors...", index, k),
222 );
223
224 let request = RemoteVectorSearchRequest {
225 index: index.to_string(),
226 query: query_vector,
227 k,
228 };
229 let response: RemoteVectorSearchResponse = client
230 .post_json("hnsw/search", &request)
231 .await
232 .map_err(map_client_error)?;
233
234 progress_indicator.finish_with_message(format!("found {} results.", response.results.len()));
235
236 let columns = vector_search_columns();
237 let mut streaming_writer =
238 StreamingWriter::new(writer, formatter, columns, limit).with_quiet(quiet);
239 streaming_writer.prepare(Some(response.results.len()))?;
240 for result in response.results {
241 let key_display = match std::str::from_utf8(&result.key) {
242 Ok(s) => Value::Text(s.to_string()),
243 Err(_) => Value::Bytes(result.key),
244 };
245 let row = Row::new(vec![key_display, Value::Float(result.distance as f64)]);
246 match streaming_writer.write_row(row)? {
247 WriteStatus::LimitReached => break,
248 WriteStatus::Continue => {}
249 }
250 }
251 streaming_writer.finish()
252}
253
254#[allow(clippy::too_many_arguments)]
255async fn execute_remote_upsert<W: Write>(
256 client: &HttpClient,
257 index: &str,
258 key: &str,
259 vector_json: &str,
260 writer: &mut W,
261 formatter: Box<dyn Formatter>,
262 limit: Option<usize>,
263 quiet: bool,
264) -> Result<()> {
265 let vector: Vec<f32> = serde_json::from_str(vector_json)
266 .map_err(|e| CliError::InvalidArgument(format!("Invalid vector JSON: {}", e)))?;
267 let request = RemoteVectorUpsertRequest {
268 index: index.to_string(),
269 key: key.as_bytes().to_vec(),
270 vector,
271 };
272 let response: RemoteVectorStatusResponse = client
273 .post_json("hnsw/upsert", &request)
274 .await
275 .map_err(map_client_error)?;
276 if response.success {
277 if quiet {
278 return Ok(());
279 }
280 let columns = vector_status_columns();
281 let mut streaming_writer =
282 StreamingWriter::new(writer, formatter, columns, limit).with_quiet(quiet);
283 streaming_writer.prepare(Some(1))?;
284 let row = Row::new(vec![
285 Value::Text("OK".to_string()),
286 Value::Text(format!("Vector '{}' upserted", key)),
287 ]);
288 streaming_writer.write_row(row)?;
289 streaming_writer.finish()
290 } else {
291 Err(CliError::InvalidArgument(
292 "Failed to upsert vector".to_string(),
293 ))
294 }
295}
296
297#[allow(clippy::too_many_arguments)]
298async fn execute_remote_delete<W: Write>(
299 client: &HttpClient,
300 index: &str,
301 key: &str,
302 writer: &mut W,
303 formatter: Box<dyn Formatter>,
304 limit: Option<usize>,
305 quiet: bool,
306) -> Result<()> {
307 let request = RemoteVectorDeleteRequest {
308 index: index.to_string(),
309 key: key.as_bytes().to_vec(),
310 };
311 let response: RemoteVectorStatusResponse = client
312 .post_json("hnsw/delete", &request)
313 .await
314 .map_err(map_client_error)?;
315 if response.success {
316 if quiet {
317 return Ok(());
318 }
319 let columns = vector_status_columns();
320 let mut streaming_writer =
321 StreamingWriter::new(writer, formatter, columns, limit).with_quiet(quiet);
322 streaming_writer.prepare(Some(1))?;
323 let row = Row::new(vec![
324 Value::Text("OK".to_string()),
325 Value::Text(format!("Vector '{}' deleted", key)),
326 ]);
327 streaming_writer.write_row(row)?;
328 streaming_writer.finish()
329 } else {
330 Err(CliError::InvalidArgument(
331 "Failed to delete vector".to_string(),
332 ))
333 }
334}
335
336fn map_client_error(err: ClientError) -> CliError {
337 match err {
338 ClientError::Request { source, .. } => {
339 CliError::ServerConnection(format!("request failed: {source}"))
340 }
341 ClientError::InvalidUrl(message) => CliError::InvalidArgument(message),
342 ClientError::Build(message) => CliError::InvalidArgument(message),
343 ClientError::Auth(err) => CliError::InvalidArgument(err.to_string()),
344 ClientError::HttpStatus { status, body } => {
345 CliError::InvalidArgument(format!("Server error: HTTP {} - {}", status.as_u16(), body))
346 }
347 }
348}
349
350fn vector_command_context(cmd: &VectorCommand) -> String {
351 match cmd {
352 VectorCommand::Search { index, k, .. } => format!("vector search --index {index} -k {k}"),
353 VectorCommand::Upsert { index, key, .. } => {
354 format!("vector upsert --index {index} --key {key}")
355 }
356 VectorCommand::Delete { index, key } => {
357 format!("vector delete --index {index} --key {key}")
358 }
359 }
360}
361
362fn execute_search<W: Write>(
364 db: &Database,
365 index: &str,
366 query_json: &str,
367 k: usize,
368 progress: bool,
369 batch_mode: &BatchMode,
370 writer: &mut StreamingWriter<W>,
371) -> Result<()> {
372 let query_vector: Vec<f32> = serde_json::from_str(query_json)
374 .map_err(|e| CliError::InvalidArgument(format!("Invalid vector JSON: {}", e)))?;
375
376 let mut progress_indicator = ProgressIndicator::new(
377 batch_mode,
378 progress,
379 writer.is_quiet(),
380 format!("Searching index '{}' for {} nearest neighbors...", index, k),
381 );
382
383 let (results, _stats) = db.search_hnsw(index, &query_vector, k, None)?;
385
386 progress_indicator.finish_with_message(format!("found {} results.", results.len()));
387
388 writer.prepare(Some(results.len()))?;
390
391 for result in results {
392 let key_display = match std::str::from_utf8(&result.key) {
394 Ok(s) => Value::Text(s.to_string()),
395 Err(_) => Value::Bytes(result.key),
396 };
397
398 let row = Row::new(vec![key_display, Value::Float(result.distance as f64)]);
399
400 match writer.write_row(row)? {
401 WriteStatus::LimitReached => break,
402 WriteStatus::Continue => {}
403 }
404 }
405
406 writer.finish()?;
407 Ok(())
408}
409
410fn execute_upsert<W: Write>(
412 db: &Database,
413 index: &str,
414 key: &str,
415 vector_json: &str,
416 writer: &mut StreamingWriter<W>,
417) -> Result<()> {
418 let vector: Vec<f32> = serde_json::from_str(vector_json)
420 .map_err(|e| CliError::InvalidArgument(format!("Invalid vector JSON: {}", e)))?;
421
422 let mut txn = db.begin(TxnMode::ReadWrite)?;
424
425 txn.upsert_to_hnsw(index, key.as_bytes(), &vector, b"")?;
426
427 txn.commit()?;
428
429 if !writer.is_quiet() {
431 writer.prepare(Some(1))?;
432 let row = Row::new(vec![
433 Value::Text("OK".to_string()),
434 Value::Text(format!("Vector '{}' upserted", key)),
435 ]);
436 writer.write_row(row)?;
437 writer.finish()?;
438 }
439
440 Ok(())
441}
442
443fn execute_delete<W: Write>(
445 db: &Database,
446 index: &str,
447 key: &str,
448 writer: &mut StreamingWriter<W>,
449) -> Result<()> {
450 let mut txn = db.begin(TxnMode::ReadWrite)?;
452
453 txn.delete_from_hnsw(index, key.as_bytes())?;
454
455 txn.commit()?;
456
457 if !writer.is_quiet() {
459 writer.prepare(Some(1))?;
460 let row = Row::new(vec![
461 Value::Text("OK".to_string()),
462 Value::Text(format!("Vector '{}' deleted", key)),
463 ]);
464 writer.write_row(row)?;
465 writer.finish()?;
466 }
467
468 Ok(())
469}
470
471pub fn vector_search_columns() -> Vec<Column> {
473 vec![
474 Column::new("id", DataType::Text),
475 Column::new("distance", DataType::Float),
476 ]
477}
478
479pub fn vector_status_columns() -> Vec<Column> {
481 vec![
482 Column::new("status", DataType::Text),
483 Column::new("message", DataType::Text),
484 ]
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use crate::batch::BatchModeSource;
491 use crate::output::jsonl::JsonlFormatter;
492 use alopex_embedded::HnswConfig;
493
494 fn create_test_db() -> Database {
495 Database::new()
496 }
497
498 fn create_search_writer(output: &mut Vec<u8>) -> StreamingWriter<&mut Vec<u8>> {
499 let formatter = Box::new(JsonlFormatter::new());
500 let columns = vector_search_columns();
501 StreamingWriter::new(output, formatter, columns, None)
502 }
503
504 fn default_batch_mode() -> BatchMode {
505 BatchMode {
506 is_batch: false,
507 is_tty: true,
508 source: BatchModeSource::Default,
509 }
510 }
511
512 fn create_status_writer(output: &mut Vec<u8>) -> StreamingWriter<&mut Vec<u8>> {
513 let formatter = Box::new(JsonlFormatter::new());
514 let columns = vector_status_columns();
515 StreamingWriter::new(output, formatter, columns, None)
516 }
517
518 fn setup_hnsw_index(db: &Database, name: &str) {
519 let config = HnswConfig::default()
520 .with_dimension(3)
521 .with_metric(alopex_embedded::Metric::L2)
522 .with_m(8)
523 .with_ef_construction(32);
524 db.create_hnsw_index(name, config).unwrap();
525 }
526
527 #[test]
528 fn test_upsert_single_vector() {
529 let db = create_test_db();
530 setup_hnsw_index(&db, "test_index");
531
532 let mut output = Vec::new();
533 {
534 let mut writer = create_status_writer(&mut output);
535 execute_upsert(&db, "test_index", "v1", "[1.0, 0.0, 0.0]", &mut writer).unwrap();
536 }
537
538 let result = String::from_utf8(output).unwrap();
539 assert!(result.contains("OK"));
540 assert!(result.contains("Vector 'v1' upserted"));
541 }
542
543 #[test]
544 fn test_search_vectors() {
545 use alopex_embedded::TxnMode;
546
547 let db = create_test_db();
548 setup_hnsw_index(&db, "search_test");
549
550 {
553 let mut txn = db.begin(TxnMode::ReadWrite).unwrap();
554 txn.upsert_to_hnsw("search_test", b"v1", &[1.0_f32, 0.0, 0.0], b"")
555 .unwrap();
556 txn.upsert_to_hnsw("search_test", b"v2", &[0.0_f32, 1.0, 0.0], b"")
557 .unwrap();
558 txn.upsert_to_hnsw("search_test", b"v3", &[0.0_f32, 0.0, 1.0], b"")
559 .unwrap();
560 txn.commit().unwrap();
561 }
562
563 let query = "[1.0, 0.0, 0.0]";
565 let mut output = Vec::new();
566 {
567 let mut writer = create_search_writer(&mut output);
568 execute_search(
569 &db,
570 "search_test",
571 query,
572 2,
573 false,
574 &default_batch_mode(),
575 &mut writer,
576 )
577 .unwrap();
578 }
579
580 let result = String::from_utf8(output).unwrap();
581 assert!(result.contains("v1")); }
583
584 #[test]
585 fn test_delete_single_vector() {
586 use alopex_embedded::TxnMode;
587
588 let db = create_test_db();
589 setup_hnsw_index(&db, "delete_test");
590
591 {
594 let mut txn = db.begin(TxnMode::ReadWrite).unwrap();
595 txn.upsert_to_hnsw("delete_test", b"v1", &[1.0_f32, 0.0, 0.0], b"")
596 .unwrap();
597 txn.upsert_to_hnsw("delete_test", b"v2", &[0.0_f32, 1.0, 0.0], b"")
598 .unwrap();
599 txn.commit().unwrap();
600 }
601
602 let mut output = Vec::new();
604 {
605 let mut writer = create_status_writer(&mut output);
606 execute_delete(&db, "delete_test", "v1", &mut writer).unwrap();
607 }
608
609 let result = String::from_utf8(output).unwrap();
610 assert!(result.contains("OK"));
611 assert!(result.contains("Vector 'v1' deleted"));
612 }
613
614 #[test]
615 fn test_invalid_vector_json() {
616 let db = create_test_db();
617 setup_hnsw_index(&db, "invalid_test");
618
619 let mut output = Vec::new();
620 let mut writer = create_status_writer(&mut output);
621
622 let result = execute_upsert(&db, "invalid_test", "v1", "not valid json", &mut writer);
623 assert!(result.is_err());
624 assert!(matches!(result.unwrap_err(), CliError::InvalidArgument(_)));
625 }
626
627 #[test]
630 fn test_direct_multi_txn_hnsw() {
631 use alopex_embedded::TxnMode;
632
633 let db = Database::new();
634 let config = HnswConfig::default()
636 .with_dimension(2)
637 .with_metric(alopex_embedded::Metric::L2)
638 .with_m(8)
639 .with_ef_construction(32);
640 db.create_hnsw_index("direct_test", config).unwrap();
641
642 let mut txn = db.begin(TxnMode::ReadWrite).unwrap();
644 txn.upsert_to_hnsw("direct_test", b"a", &[0.0_f32, 0.0], b"ma")
645 .unwrap();
646 txn.upsert_to_hnsw("direct_test", b"b", &[1.0_f32, 0.0], b"mb")
647 .unwrap();
648 txn.commit().unwrap();
649
650 let (results, _) = db
652 .search_hnsw("direct_test", &[0.1_f32, 0.0], 1, None)
653 .unwrap();
654 assert_eq!(results.len(), 1);
655 assert_eq!(results[0].key, b"a");
656 }
657}