alopex_cli/commands/
vector.rs

1//! Vector Command - Vector similarity operations
2//!
3//! Supports: search, upsert, delete (single key/vector operations)
4
5use std::io::Write;
6
7use alopex_embedded::{Database, TxnMode};
8use serde::{Deserialize, Serialize};
9
10use crate::batch::BatchMode;
11use crate::cli::VectorCommand;
12use crate::client::http::{ClientError, HttpClient};
13use crate::error::{CliError, Result};
14use crate::models::{Column, DataType, Row, Value};
15use crate::output::formatter::Formatter;
16use crate::progress::ProgressIndicator;
17use crate::streaming::{StreamingWriter, WriteStatus};
18
19#[derive(Debug, Serialize)]
20struct RemoteVectorSearchRequest {
21    index: String,
22    query: Vec<f32>,
23    k: usize,
24}
25
26#[derive(Debug, Serialize)]
27struct RemoteVectorUpsertRequest {
28    index: String,
29    key: Vec<u8>,
30    vector: Vec<f32>,
31}
32
33#[derive(Debug, Serialize)]
34struct RemoteVectorDeleteRequest {
35    index: String,
36    key: Vec<u8>,
37}
38
39#[derive(Debug, Deserialize)]
40struct RemoteVectorSearchResult {
41    key: Vec<u8>,
42    distance: f32,
43    #[allow(dead_code)]
44    metadata: Vec<u8>,
45}
46
47#[derive(Debug, Deserialize)]
48struct RemoteVectorSearchResponse {
49    results: Vec<RemoteVectorSearchResult>,
50}
51
52#[derive(Debug, Deserialize)]
53struct RemoteVectorStatusResponse {
54    success: bool,
55}
56
57/// Execute a Vector command.
58///
59/// # Arguments
60///
61/// * `db` - The database instance.
62/// * `cmd` - The Vector subcommand to execute.
63/// * `writer` - The streaming writer for output.
64pub fn execute<W: Write>(
65    db: &Database,
66    cmd: VectorCommand,
67    batch_mode: &BatchMode,
68    writer: &mut StreamingWriter<W>,
69) -> Result<()> {
70    match cmd {
71        VectorCommand::Search {
72            index,
73            query,
74            k,
75            progress,
76        } => execute_search(db, &index, &query, k, progress, batch_mode, writer),
77        VectorCommand::Upsert { index, key, vector } => {
78            execute_upsert(db, &index, &key, &vector, writer)
79        }
80        VectorCommand::Delete { index, key } => execute_delete(db, &index, &key, writer),
81    }
82}
83
84/// Execute a Vector command against a remote server.
85pub async fn execute_remote_with_formatter<W: Write>(
86    client: &HttpClient,
87    cmd: &VectorCommand,
88    batch_mode: &BatchMode,
89    writer: &mut W,
90    formatter: Box<dyn Formatter>,
91    limit: Option<usize>,
92    quiet: bool,
93) -> Result<()> {
94    match cmd {
95        VectorCommand::Search {
96            index,
97            query,
98            k,
99            progress,
100        } => {
101            execute_remote_search(
102                client, index, query, *k, *progress, batch_mode, writer, formatter, limit, quiet,
103            )
104            .await
105        }
106        VectorCommand::Upsert { index, key, vector } => {
107            execute_remote_upsert(client, index, key, vector, writer, formatter, limit, quiet).await
108        }
109        VectorCommand::Delete { index, key } => {
110            execute_remote_delete(client, index, key, writer, formatter, limit, quiet).await
111        }
112    }
113}
114
115#[allow(clippy::too_many_arguments)]
116async fn execute_remote_search<W: Write>(
117    client: &HttpClient,
118    index: &str,
119    query_json: &str,
120    k: usize,
121    progress: bool,
122    batch_mode: &BatchMode,
123    writer: &mut W,
124    formatter: Box<dyn Formatter>,
125    limit: Option<usize>,
126    quiet: bool,
127) -> Result<()> {
128    let query_vector: Vec<f32> = serde_json::from_str(query_json)
129        .map_err(|e| CliError::InvalidArgument(format!("Invalid vector JSON: {}", e)))?;
130
131    let mut progress_indicator = ProgressIndicator::new(
132        batch_mode,
133        progress,
134        quiet,
135        format!("Searching index '{}' for {} nearest neighbors...", index, k),
136    );
137
138    let request = RemoteVectorSearchRequest {
139        index: index.to_string(),
140        query: query_vector,
141        k,
142    };
143    let response: RemoteVectorSearchResponse = client
144        .post_json("hnsw/search", &request)
145        .await
146        .map_err(map_client_error)?;
147
148    progress_indicator.finish_with_message(format!("found {} results.", response.results.len()));
149
150    let columns = vector_search_columns();
151    let mut streaming_writer =
152        StreamingWriter::new(writer, formatter, columns, limit).with_quiet(quiet);
153    streaming_writer.prepare(Some(response.results.len()))?;
154    for result in response.results {
155        let key_display = match std::str::from_utf8(&result.key) {
156            Ok(s) => Value::Text(s.to_string()),
157            Err(_) => Value::Bytes(result.key),
158        };
159        let row = Row::new(vec![key_display, Value::Float(result.distance as f64)]);
160        match streaming_writer.write_row(row)? {
161            WriteStatus::LimitReached => break,
162            WriteStatus::Continue => {}
163        }
164    }
165    streaming_writer.finish()
166}
167
168#[allow(clippy::too_many_arguments)]
169async fn execute_remote_upsert<W: Write>(
170    client: &HttpClient,
171    index: &str,
172    key: &str,
173    vector_json: &str,
174    writer: &mut W,
175    formatter: Box<dyn Formatter>,
176    limit: Option<usize>,
177    quiet: bool,
178) -> Result<()> {
179    let vector: Vec<f32> = serde_json::from_str(vector_json)
180        .map_err(|e| CliError::InvalidArgument(format!("Invalid vector JSON: {}", e)))?;
181    let request = RemoteVectorUpsertRequest {
182        index: index.to_string(),
183        key: key.as_bytes().to_vec(),
184        vector,
185    };
186    let response: RemoteVectorStatusResponse = client
187        .post_json("hnsw/upsert", &request)
188        .await
189        .map_err(map_client_error)?;
190    if response.success {
191        if quiet {
192            return Ok(());
193        }
194        let columns = vector_status_columns();
195        let mut streaming_writer =
196            StreamingWriter::new(writer, formatter, columns, limit).with_quiet(quiet);
197        streaming_writer.prepare(Some(1))?;
198        let row = Row::new(vec![
199            Value::Text("OK".to_string()),
200            Value::Text(format!("Vector '{}' upserted", key)),
201        ]);
202        streaming_writer.write_row(row)?;
203        streaming_writer.finish()
204    } else {
205        Err(CliError::InvalidArgument(
206            "Failed to upsert vector".to_string(),
207        ))
208    }
209}
210
211#[allow(clippy::too_many_arguments)]
212async fn execute_remote_delete<W: Write>(
213    client: &HttpClient,
214    index: &str,
215    key: &str,
216    writer: &mut W,
217    formatter: Box<dyn Formatter>,
218    limit: Option<usize>,
219    quiet: bool,
220) -> Result<()> {
221    let request = RemoteVectorDeleteRequest {
222        index: index.to_string(),
223        key: key.as_bytes().to_vec(),
224    };
225    let response: RemoteVectorStatusResponse = client
226        .post_json("hnsw/delete", &request)
227        .await
228        .map_err(map_client_error)?;
229    if response.success {
230        if quiet {
231            return Ok(());
232        }
233        let columns = vector_status_columns();
234        let mut streaming_writer =
235            StreamingWriter::new(writer, formatter, columns, limit).with_quiet(quiet);
236        streaming_writer.prepare(Some(1))?;
237        let row = Row::new(vec![
238            Value::Text("OK".to_string()),
239            Value::Text(format!("Vector '{}' deleted", key)),
240        ]);
241        streaming_writer.write_row(row)?;
242        streaming_writer.finish()
243    } else {
244        Err(CliError::InvalidArgument(
245            "Failed to delete vector".to_string(),
246        ))
247    }
248}
249
250fn map_client_error(err: ClientError) -> CliError {
251    match err {
252        ClientError::Request { source, .. } => {
253            CliError::ServerConnection(format!("request failed: {source}"))
254        }
255        ClientError::InvalidUrl(message) => CliError::InvalidArgument(message),
256        ClientError::Build(message) => CliError::InvalidArgument(message),
257        ClientError::Auth(err) => CliError::InvalidArgument(err.to_string()),
258        ClientError::HttpStatus { status, body } => {
259            CliError::InvalidArgument(format!("Server error: HTTP {} - {}", status.as_u16(), body))
260        }
261    }
262}
263
264/// Execute a vector search command.
265fn execute_search<W: Write>(
266    db: &Database,
267    index: &str,
268    query_json: &str,
269    k: usize,
270    progress: bool,
271    batch_mode: &BatchMode,
272    writer: &mut StreamingWriter<W>,
273) -> Result<()> {
274    // Parse vector from JSON array
275    let query_vector: Vec<f32> = serde_json::from_str(query_json)
276        .map_err(|e| CliError::InvalidArgument(format!("Invalid vector JSON: {}", e)))?;
277
278    let mut progress_indicator = ProgressIndicator::new(
279        batch_mode,
280        progress,
281        writer.is_quiet(),
282        format!("Searching index '{}' for {} nearest neighbors...", index, k),
283    );
284
285    // Perform search
286    let (results, _stats) = db.search_hnsw(index, &query_vector, k, None)?;
287
288    progress_indicator.finish_with_message(format!("found {} results.", results.len()));
289
290    // Prepare writer with hint
291    writer.prepare(Some(results.len()))?;
292
293    for result in results {
294        // Convert key to displayable string
295        let key_display = match std::str::from_utf8(&result.key) {
296            Ok(s) => Value::Text(s.to_string()),
297            Err(_) => Value::Bytes(result.key),
298        };
299
300        let row = Row::new(vec![key_display, Value::Float(result.distance as f64)]);
301
302        match writer.write_row(row)? {
303            WriteStatus::LimitReached => break,
304            WriteStatus::Continue => {}
305        }
306    }
307
308    writer.finish()?;
309    Ok(())
310}
311
312/// Execute a single vector upsert command.
313fn execute_upsert<W: Write>(
314    db: &Database,
315    index: &str,
316    key: &str,
317    vector_json: &str,
318    writer: &mut StreamingWriter<W>,
319) -> Result<()> {
320    // Parse vector from JSON array
321    let vector: Vec<f32> = serde_json::from_str(vector_json)
322        .map_err(|e| CliError::InvalidArgument(format!("Invalid vector JSON: {}", e)))?;
323
324    // Begin transaction
325    let mut txn = db.begin(TxnMode::ReadWrite)?;
326
327    txn.upsert_to_hnsw(index, key.as_bytes(), &vector, b"")?;
328
329    txn.commit()?;
330
331    // Suppress status output in quiet mode
332    if !writer.is_quiet() {
333        writer.prepare(Some(1))?;
334        let row = Row::new(vec![
335            Value::Text("OK".to_string()),
336            Value::Text(format!("Vector '{}' upserted", key)),
337        ]);
338        writer.write_row(row)?;
339        writer.finish()?;
340    }
341
342    Ok(())
343}
344
345/// Execute a single vector delete command.
346fn execute_delete<W: Write>(
347    db: &Database,
348    index: &str,
349    key: &str,
350    writer: &mut StreamingWriter<W>,
351) -> Result<()> {
352    // Begin transaction
353    let mut txn = db.begin(TxnMode::ReadWrite)?;
354
355    txn.delete_from_hnsw(index, key.as_bytes())?;
356
357    txn.commit()?;
358
359    // Suppress status output in quiet mode
360    if !writer.is_quiet() {
361        writer.prepare(Some(1))?;
362        let row = Row::new(vec![
363            Value::Text("OK".to_string()),
364            Value::Text(format!("Vector '{}' deleted", key)),
365        ]);
366        writer.write_row(row)?;
367        writer.finish()?;
368    }
369
370    Ok(())
371}
372
373/// Create columns for vector search output.
374pub fn vector_search_columns() -> Vec<Column> {
375    vec![
376        Column::new("id", DataType::Text),
377        Column::new("distance", DataType::Float),
378    ]
379}
380
381/// Create columns for vector status output.
382pub fn vector_status_columns() -> Vec<Column> {
383    vec![
384        Column::new("status", DataType::Text),
385        Column::new("message", DataType::Text),
386    ]
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use crate::batch::BatchModeSource;
393    use crate::output::jsonl::JsonlFormatter;
394    use alopex_embedded::HnswConfig;
395
396    fn create_test_db() -> Database {
397        Database::new()
398    }
399
400    fn create_search_writer(output: &mut Vec<u8>) -> StreamingWriter<&mut Vec<u8>> {
401        let formatter = Box::new(JsonlFormatter::new());
402        let columns = vector_search_columns();
403        StreamingWriter::new(output, formatter, columns, None)
404    }
405
406    fn default_batch_mode() -> BatchMode {
407        BatchMode {
408            is_batch: false,
409            is_tty: true,
410            source: BatchModeSource::Default,
411        }
412    }
413
414    fn create_status_writer(output: &mut Vec<u8>) -> StreamingWriter<&mut Vec<u8>> {
415        let formatter = Box::new(JsonlFormatter::new());
416        let columns = vector_status_columns();
417        StreamingWriter::new(output, formatter, columns, None)
418    }
419
420    fn setup_hnsw_index(db: &Database, name: &str) {
421        let config = HnswConfig::default()
422            .with_dimension(3)
423            .with_metric(alopex_embedded::Metric::L2)
424            .with_m(8)
425            .with_ef_construction(32);
426        db.create_hnsw_index(name, config).unwrap();
427    }
428
429    #[test]
430    fn test_upsert_single_vector() {
431        let db = create_test_db();
432        setup_hnsw_index(&db, "test_index");
433
434        let mut output = Vec::new();
435        {
436            let mut writer = create_status_writer(&mut output);
437            execute_upsert(&db, "test_index", "v1", "[1.0, 0.0, 0.0]", &mut writer).unwrap();
438        }
439
440        let result = String::from_utf8(output).unwrap();
441        assert!(result.contains("OK"));
442        assert!(result.contains("Vector 'v1' upserted"));
443    }
444
445    #[test]
446    fn test_search_vectors() {
447        use alopex_embedded::TxnMode;
448
449        let db = create_test_db();
450        setup_hnsw_index(&db, "search_test");
451
452        // Upsert vectors in a single transaction
453        // (multiple sequential transactions have a known bug with checksum mismatch)
454        {
455            let mut txn = db.begin(TxnMode::ReadWrite).unwrap();
456            txn.upsert_to_hnsw("search_test", b"v1", &[1.0_f32, 0.0, 0.0], b"")
457                .unwrap();
458            txn.upsert_to_hnsw("search_test", b"v2", &[0.0_f32, 1.0, 0.0], b"")
459                .unwrap();
460            txn.upsert_to_hnsw("search_test", b"v3", &[0.0_f32, 0.0, 1.0], b"")
461                .unwrap();
462            txn.commit().unwrap();
463        }
464
465        // Search for similar vectors
466        let query = "[1.0, 0.0, 0.0]";
467        let mut output = Vec::new();
468        {
469            let mut writer = create_search_writer(&mut output);
470            execute_search(
471                &db,
472                "search_test",
473                query,
474                2,
475                false,
476                &default_batch_mode(),
477                &mut writer,
478            )
479            .unwrap();
480        }
481
482        let result = String::from_utf8(output).unwrap();
483        assert!(result.contains("v1")); // Should find v1 as most similar
484    }
485
486    #[test]
487    fn test_delete_single_vector() {
488        use alopex_embedded::TxnMode;
489
490        let db = create_test_db();
491        setup_hnsw_index(&db, "delete_test");
492
493        // Upsert vectors in a single transaction
494        // (multiple sequential transactions have a known bug with checksum mismatch)
495        {
496            let mut txn = db.begin(TxnMode::ReadWrite).unwrap();
497            txn.upsert_to_hnsw("delete_test", b"v1", &[1.0_f32, 0.0, 0.0], b"")
498                .unwrap();
499            txn.upsert_to_hnsw("delete_test", b"v2", &[0.0_f32, 1.0, 0.0], b"")
500                .unwrap();
501            txn.commit().unwrap();
502        }
503
504        // Delete one vector via CLI command
505        let mut output = Vec::new();
506        {
507            let mut writer = create_status_writer(&mut output);
508            execute_delete(&db, "delete_test", "v1", &mut writer).unwrap();
509        }
510
511        let result = String::from_utf8(output).unwrap();
512        assert!(result.contains("OK"));
513        assert!(result.contains("Vector 'v1' deleted"));
514    }
515
516    #[test]
517    fn test_invalid_vector_json() {
518        let db = create_test_db();
519        setup_hnsw_index(&db, "invalid_test");
520
521        let mut output = Vec::new();
522        let mut writer = create_status_writer(&mut output);
523
524        let result = execute_upsert(&db, "invalid_test", "v1", "not valid json", &mut writer);
525        assert!(result.is_err());
526        assert!(matches!(result.unwrap_err(), CliError::InvalidArgument(_)));
527    }
528
529    /// Direct test using EXACTLY the same pattern as hnsw_integration_tests.rs
530    /// to verify that multiple upserts in a single transaction work correctly.
531    #[test]
532    fn test_direct_multi_txn_hnsw() {
533        use alopex_embedded::TxnMode;
534
535        let db = Database::new();
536        // Use same config as hnsw_integration_tests.rs: dimension 2, L2, m=8, ef=32
537        let config = HnswConfig::default()
538            .with_dimension(2)
539            .with_metric(alopex_embedded::Metric::L2)
540            .with_m(8)
541            .with_ef_construction(32);
542        db.create_hnsw_index("direct_test", config).unwrap();
543
544        // Single transaction with multiple upserts (exactly like hnsw_integration_tests.rs)
545        let mut txn = db.begin(TxnMode::ReadWrite).unwrap();
546        txn.upsert_to_hnsw("direct_test", b"a", &[0.0_f32, 0.0], b"ma")
547            .unwrap();
548        txn.upsert_to_hnsw("direct_test", b"b", &[1.0_f32, 0.0], b"mb")
549            .unwrap();
550        txn.commit().unwrap();
551
552        // Search should find "a" as nearest to [0.1, 0]
553        let (results, _) = db
554            .search_hnsw("direct_test", &[0.1_f32, 0.0], 1, None)
555            .unwrap();
556        assert_eq!(results.len(), 1);
557        assert_eq!(results[0].key, b"a");
558    }
559}