1use crate::convert::{cipher_blob_to_proto, create_version, key_to_proto, query_from_proto};
7use crate::error::{NetError, NetResult};
8use crate::proto::{aql, query};
9use amaters_core::Query;
10use amaters_core::Update as UpdateOp;
11use amaters_core::traits::StorageEngine;
12use amaters_core::types::{CipherBlob, Key};
13use futures::StreamExt;
14use std::sync::Arc;
15use std::time::Instant;
16use tracing::{debug, error, info, warn};
17
18#[cfg(feature = "compute")]
19use amaters_core::compute::{FheExecutor, KeyManager, PredicateCompiler};
20#[cfg(feature = "compute")]
21use std::collections::HashMap;
22
23pub struct AqlServiceImpl<S: StorageEngine> {
27 storage: Arc<S>,
29 start_time: Instant,
31 #[cfg(feature = "compute")]
33 key_manager: Arc<KeyManager>,
34}
35
36impl<S: StorageEngine> AqlServiceImpl<S> {
37 #[cfg(feature = "compute")]
39 pub fn new(storage: Arc<S>) -> Self {
40 Self {
41 storage,
42 start_time: Instant::now(),
43 key_manager: Arc::new(KeyManager::new()),
44 }
45 }
46
47 #[cfg(not(feature = "compute"))]
49 pub fn new(storage: Arc<S>) -> Self {
50 Self {
51 storage,
52 start_time: Instant::now(),
53 }
54 }
55
56 #[cfg(feature = "compute")]
58 pub fn with_key_manager(storage: Arc<S>, key_manager: Arc<KeyManager>) -> Self {
59 Self {
60 storage,
61 start_time: Instant::now(),
62 key_manager,
63 }
64 }
65
66 pub async fn execute_query(&self, request: aql::QueryRequest) -> aql::QueryResponse {
68 let start_time = Instant::now();
69
70 info!(
71 "ExecuteQuery request received: request_id={:?}",
72 request.request_id
73 );
74
75 let proto_query = match request.query {
77 Some(q) => q,
78 None => {
79 let execution_time_ms = start_time.elapsed().as_millis() as u64;
80 return aql::QueryResponse {
81 response: Some(aql::query_response::Response::Error(
82 crate::proto::errors::ErrorResponse {
83 code: crate::proto::errors::ErrorCode::ErrorProtocolMissingField as i32,
84 message: "Missing query in request".to_string(),
85 category: crate::proto::errors::ErrorCategory::CategoryClientError
86 as i32,
87 details: None,
88 retry_after: None,
89 },
90 )),
91 request_id: request.request_id,
92 execution_time_ms,
93 };
94 }
95 };
96
97 let query = match query_from_proto(proto_query) {
98 Ok(q) => q,
99 Err(e) => {
100 error!("Failed to parse query: {}", e);
101 let execution_time_ms = start_time.elapsed().as_millis() as u64;
102 return aql::QueryResponse {
103 response: Some(aql::query_response::Response::Error(
104 crate::proto::errors::ErrorResponse {
105 code: e.error_code() as i32,
106 message: e.to_string(),
107 category: e.error_category() as i32,
108 details: None,
109 retry_after: None,
110 },
111 )),
112 request_id: request.request_id,
113 execution_time_ms,
114 };
115 }
116 };
117
118 let result = self.execute_query_internal(query).await;
120
121 let execution_time_ms = start_time.elapsed().as_millis() as u64;
122
123 match result {
125 Ok(query_result) => aql::QueryResponse {
126 response: Some(aql::query_response::Response::Result(query_result)),
127 request_id: request.request_id,
128 execution_time_ms,
129 },
130 Err(e) => {
131 error!("Query execution failed: {}", e);
132 aql::QueryResponse {
133 response: Some(aql::query_response::Response::Error(
134 crate::proto::errors::ErrorResponse {
135 code: e.error_code() as i32,
136 message: e.to_string(),
137 category: e.error_category() as i32,
138 details: None,
139 retry_after: None,
140 },
141 )),
142 request_id: request.request_id,
143 execution_time_ms,
144 }
145 }
146 }
147 }
148
149 #[doc(hidden)]
154 #[tracing::instrument(skip(self), fields(trace_id = tracing::field::Empty, duration_us = tracing::field::Empty))]
155 pub async fn execute_query_internal(&self, query: Query) -> NetResult<query::QueryResult> {
156 match query {
157 Query::Get { collection, key } => {
158 debug!(
159 "Executing GET query: collection={}, key={:?}",
160 collection, key
161 );
162
163 let result = self.storage.get(&key).await?;
164
165 let result = match result {
166 Some(value) => query::QueryResult {
167 result: Some(query::query_result::Result::Single(query::SingleResult {
168 value: Some(cipher_blob_to_proto(&value)),
169 })),
170 },
171 None => query::QueryResult {
172 result: Some(query::query_result::Result::Single(query::SingleResult {
173 value: None,
174 })),
175 },
176 };
177
178 Ok(result)
179 }
180 Query::Set {
181 collection,
182 key,
183 value,
184 } => {
185 debug!(
186 "Executing SET query: collection={}, key={:?}",
187 collection, key
188 );
189
190 self.storage.put(&key, &value).await?;
191
192 Ok(query::QueryResult {
193 result: Some(query::query_result::Result::Success(query::SuccessResult {
194 affected_rows: 1,
195 })),
196 })
197 }
198 Query::Delete { collection, key } => {
199 debug!(
200 "Executing DELETE query: collection={}, key={:?}",
201 collection, key
202 );
203
204 self.storage.delete(&key).await?;
205
206 Ok(query::QueryResult {
207 result: Some(query::query_result::Result::Success(query::SuccessResult {
208 affected_rows: 1,
209 })),
210 })
211 }
212 Query::Range {
213 collection,
214 start,
215 end,
216 } => {
217 debug!(
218 "Executing RANGE query: collection={}, start={:?}, end={:?}",
219 collection, start, end
220 );
221
222 let results = self.storage.range(&start, &end).await?;
223
224 let values: Vec<query::KeyValue> = results
225 .into_iter()
226 .map(|(k, v)| query::KeyValue {
227 key: Some(key_to_proto(&k)),
228 value: Some(cipher_blob_to_proto(&v)),
229 encrypted_predicate_result: None,
230 })
231 .collect();
232
233 Ok(query::QueryResult {
234 result: Some(query::query_result::Result::Multi(query::MultiResult {
235 values,
236 })),
237 })
238 }
239 Query::Filter {
240 collection,
241 predicate,
242 } => {
243 let min_key = Key::from_slice(&[]);
245 let max_key = Key::from_slice(&[0xFF; 256]);
246
247 let all_rows = match self.storage.range(&min_key, &max_key).await {
248 Ok(rows) => rows,
249 Err(e) => {
250 error!("Failed to retrieve rows for filter: {}", e);
251 return Err(NetError::from(e));
252 }
253 };
254
255 debug!("Filter: retrieved {} candidate rows", all_rows.len());
256
257 if all_rows.len() > 1000 {
258 warn!(
259 "Filter query retrieved {} rows, which may cause performance issues",
260 all_rows.len()
261 );
262 }
263
264 let first_is_plaintext = all_rows
269 .first()
270 .map(|(_, v)| predicate.evaluate_plaintext(v).is_some())
271 .unwrap_or(true); if first_is_plaintext {
274 info!("Executing FILTER query with server-side plaintext predicate evaluation");
275
276 let mut results = Vec::new();
277 let mut excluded: usize = 0;
278
279 for (key, value_blob) in all_rows {
280 match predicate.evaluate_plaintext(&value_blob) {
281 Some(true) => {
282 results.push(query::KeyValue {
283 key: Some(key_to_proto(&key)),
284 value: Some(cipher_blob_to_proto(&value_blob)),
285 encrypted_predicate_result: None,
286 });
287 }
288 Some(false) => {
289 excluded += 1;
291 }
292 None => {
293 warn!(
296 "Plaintext evaluation returned None for key {:?} mid-scan; \
297 including row conservatively",
298 key
299 );
300 results.push(query::KeyValue {
301 key: Some(key_to_proto(&key)),
302 value: Some(cipher_blob_to_proto(&value_blob)),
303 encrypted_predicate_result: None,
304 });
305 }
306 }
307 }
308
309 info!(
310 "FILTER query completed: {} rows matched, {} rows excluded by plaintext predicate",
311 results.len(),
312 excluded
313 );
314
315 return Ok(query::QueryResult {
316 result: Some(query::query_result::Result::Multi(query::MultiResult {
317 values: results,
318 })),
319 });
320 }
321
322 #[cfg(feature = "compute")]
324 {
325 info!("Executing FILTER query with FHE predicate evaluation");
326
327 let mut compiler = PredicateCompiler::new();
335
336 let circuit = match compiler
339 .compile(&predicate, amaters_core::compute::EncryptedType::U8)
340 {
341 Ok(c) => c,
342 Err(e) => {
343 error!("Failed to compile predicate: {}", e);
344 return Err(NetError::ServerInternal(format!(
345 "Predicate compilation failed: {}",
346 e
347 )));
348 }
349 };
350
351 debug!(
352 "Compiled predicate circuit: depth={}, gates={}",
353 circuit.depth, circuit.gate_count
354 );
355
356 let rhs = match PredicateCompiler::extract_rhs_value(&predicate) {
358 Ok(r) => r,
359 Err(e) => {
360 error!("Failed to extract RHS value: {}", e);
361 return Err(NetError::ServerInternal(format!(
362 "RHS extraction failed: {}",
363 e
364 )));
365 }
366 };
367
368 let executor = FheExecutor::new();
370
371 let mut results = Vec::new();
374 let mut execution_errors = 0;
375
376 for (key, value_blob) in all_rows {
377 let mut inputs = HashMap::new();
379 inputs.insert("value".to_string(), value_blob.clone());
380 inputs.insert("rhs".to_string(), rhs.clone());
381
382 match executor.execute(&circuit, &inputs) {
385 Ok(result_blob) => {
386 let result_bytes = result_blob.as_bytes().to_vec();
387
388 debug!(
389 "Executed predicate on key {:?}, result blob size: {}",
390 key,
391 result_bytes.len()
392 );
393
394 results.push(query::KeyValue {
395 key: Some(key_to_proto(&key)),
396 value: Some(cipher_blob_to_proto(&value_blob)),
397 encrypted_predicate_result: Some(result_bytes),
398 });
399 }
400 Err(e) => {
401 execution_errors += 1;
402 warn!("FHE execution failed for key {:?}: {}", key, e);
403 }
405 }
406 }
407
408 if execution_errors > 0 {
409 warn!(
410 "Filter query had {} FHE execution errors out of {} total rows",
411 execution_errors,
412 execution_errors + results.len()
413 );
414 }
415
416 info!(
417 "FILTER query completed, processed {} rows successfully",
418 results.len()
419 );
420
421 Ok(query::QueryResult {
422 result: Some(query::query_result::Result::Multi(query::MultiResult {
423 values: results,
424 })),
425 })
426 }
427
428 #[cfg(not(feature = "compute"))]
429 {
430 let _ = (collection, predicate);
431 warn!("FILTER query reached FHE path but compute feature is disabled");
432 Err(NetError::ServerInternal(
433 "FILTER queries on encrypted values require the compute feature"
434 .to_string(),
435 ))
436 }
437 }
438 Query::Update {
439 collection,
440 predicate,
441 updates,
442 } => {
443 debug!(
444 "Executing UPDATE query: collection={}, updates_count={}",
445 collection,
446 updates.len()
447 );
448
449 #[cfg(feature = "compute")]
450 {
451 let mut compiler = PredicateCompiler::new();
455 let circuit = match compiler
456 .compile(&predicate, amaters_core::compute::EncryptedType::U8)
457 {
458 Ok(c) => c,
459 Err(e) => {
460 error!("Failed to compile update predicate: {}", e);
461 return Err(NetError::ServerInternal(format!(
462 "Update predicate compilation failed: {}",
463 e
464 )));
465 }
466 };
467
468 let rhs = match PredicateCompiler::extract_rhs_value(&predicate) {
469 Ok(r) => r,
470 Err(e) => {
471 error!("Failed to extract RHS value for update predicate: {}", e);
472 return Err(NetError::ServerInternal(format!(
473 "Update RHS extraction failed: {}",
474 e
475 )));
476 }
477 };
478
479 let executor = FheExecutor::new();
480
481 let min_key = Key::from_slice(&[]);
483 let max_key = Key::from_slice(&[0xFF; 256]);
484 let all_rows = self.storage.range(&min_key, &max_key).await?;
485
486 let mut affected_rows: u64 = 0;
487
488 for (key, value_blob) in &all_rows {
489 let mut inputs = HashMap::new();
491 inputs.insert("value".to_string(), value_blob.clone());
492 inputs.insert("rhs".to_string(), rhs.clone());
493
494 let matches = match executor.execute(&circuit, &inputs) {
496 Ok(result_blob) => {
497 result_blob.as_bytes().iter().any(|&b| b != 0)
499 }
500 Err(e) => {
501 warn!("FHE predicate evaluation failed for key {:?}: {}", key, e);
502 continue;
503 }
504 };
505
506 if !matches {
507 continue;
508 }
509
510 let mut current_value = value_blob.clone();
512 for update_op in &updates {
513 current_value = apply_update_operation(¤t_value, update_op);
514 }
515
516 self.storage.put(key, ¤t_value).await?;
517 affected_rows += 1;
518 }
519
520 info!(
521 "UPDATE query completed: {} rows affected out of {} total",
522 affected_rows,
523 all_rows.len()
524 );
525
526 Ok(query::QueryResult {
527 result: Some(query::query_result::Result::Success(query::SuccessResult {
528 affected_rows,
529 })),
530 })
531 }
532
533 #[cfg(not(feature = "compute"))]
534 {
535 let _ = predicate;
539
540 let all_keys = self.storage.keys().await?;
541
542 if all_keys.is_empty() {
543 info!(
544 "UPDATE query on collection '{}': no keys found, 0 rows affected",
545 collection
546 );
547 return Ok(query::QueryResult {
548 result: Some(query::query_result::Result::Success(
549 query::SuccessResult { affected_rows: 0 },
550 )),
551 });
552 }
553
554 let mut affected_rows: u64 = 0;
555
556 for key in &all_keys {
557 let value_opt = self.storage.get(key).await?;
558 let current_value = match value_opt {
559 Some(v) => v,
560 None => continue,
561 };
562
563 let mut updated_value = current_value;
564 for update_op in &updates {
565 updated_value = apply_update_operation(&updated_value, update_op);
566 }
567
568 self.storage.put(key, &updated_value).await?;
569 affected_rows += 1;
570 }
571
572 info!(
573 "UPDATE query completed: {} rows affected in collection '{}'",
574 affected_rows, collection
575 );
576
577 Ok(query::QueryResult {
578 result: Some(query::query_result::Result::Success(query::SuccessResult {
579 affected_rows,
580 })),
581 })
582 }
583 }
584 }
585 }
586
587 #[tracing::instrument(skip(self, request), fields(trace_id = tracing::field::Empty, query_count = request.queries.len(), duration_us = tracing::field::Empty))]
594 pub async fn execute_batch(&self, request: aql::BatchRequest) -> aql::BatchResponse {
595 let start_time = Instant::now();
596
597 info!(
598 "ExecuteBatch request received: request_id={:?}, query_count={}",
599 request.request_id,
600 request.queries.len()
601 );
602
603 if request.queries.is_empty() {
605 let execution_time_ms = start_time.elapsed().as_millis() as u64;
606 return aql::BatchResponse {
607 response: Some(aql::batch_response::Response::Results(aql::BatchResult {
608 results: Vec::new(),
609 })),
610 request_id: request.request_id,
611 execution_time_ms,
612 };
613 }
614
615 let mut results = Vec::with_capacity(request.queries.len());
616 let mut rollback_ops: Vec<RollbackOp> = Vec::new();
617
618 for (idx, proto_query) in request.queries.into_iter().enumerate() {
619 let core_query = match query_from_proto(proto_query) {
621 Ok(q) => q,
622 Err(e) => {
623 error!("Failed to parse query {} in batch: {}", idx, e);
624 self.rollback_operations(&rollback_ops).await;
626 let execution_time_ms = start_time.elapsed().as_millis() as u64;
627 return aql::BatchResponse {
628 response: Some(aql::batch_response::Response::Error(
629 crate::proto::errors::ErrorResponse {
630 code: e.error_code() as i32,
631 message: format!("Query {} in batch failed to parse: {}", idx, e),
632 category: e.error_category() as i32,
633 details: None,
634 retry_after: None,
635 },
636 )),
637 request_id: request.request_id,
638 execution_time_ms,
639 };
640 }
641 };
642
643 let rollback_op = self.build_rollback_op(&core_query).await;
645
646 match self.execute_query_internal(core_query).await {
647 Ok(query_result) => {
648 if let Some(op) = rollback_op {
650 rollback_ops.push(op);
651 }
652 results.push(query_result);
653 }
654 Err(e) => {
655 error!("Query {} in batch failed: {}", idx, e);
656 self.rollback_operations(&rollback_ops).await;
658 let execution_time_ms = start_time.elapsed().as_millis() as u64;
659 return aql::BatchResponse {
660 response: Some(aql::batch_response::Response::Error(
661 crate::proto::errors::ErrorResponse {
662 code: e.error_code() as i32,
663 message: format!("Query {} in batch failed: {}", idx, e),
664 category: e.error_category() as i32,
665 details: None,
666 retry_after: None,
667 },
668 )),
669 request_id: request.request_id,
670 execution_time_ms,
671 };
672 }
673 }
674 }
675
676 let execution_time_ms = start_time.elapsed().as_millis() as u64;
677 info!(
678 "ExecuteBatch completed successfully: {} queries in {}ms",
679 results.len(),
680 execution_time_ms
681 );
682
683 aql::BatchResponse {
684 response: Some(aql::batch_response::Response::Results(aql::BatchResult {
685 results,
686 })),
687 request_id: request.request_id,
688 execution_time_ms,
689 }
690 }
691
692 async fn build_rollback_op(&self, query: &Query) -> Option<RollbackOp> {
699 match query {
700 Query::Set { key, .. } => {
701 let old_value = match self.storage.get(key).await {
703 Ok(v) => v,
704 Err(e) => {
705 warn!("Failed to read old value for rollback tracking: {}", e);
706 None
707 }
708 };
709 Some(RollbackOp::UndoSet {
710 key: key.clone(),
711 old_value,
712 })
713 }
714 Query::Delete { key, .. } => {
715 let old_value = match self.storage.get(key).await {
717 Ok(v) => v,
718 Err(e) => {
719 warn!("Failed to read value for rollback tracking: {}", e);
720 None
721 }
722 };
723 Some(RollbackOp::UndoDelete {
724 key: key.clone(),
725 old_value,
726 })
727 }
728 Query::Update { .. } => {
729 let keys = match self.storage.keys().await {
731 Ok(k) => k,
732 Err(e) => {
733 warn!("Failed to list keys for update rollback tracking: {}", e);
734 return Some(RollbackOp::UndoUpdate {
735 snapshots: Vec::new(),
736 });
737 }
738 };
739 let mut snapshots = Vec::with_capacity(keys.len());
740 for key in &keys {
741 let value = match self.storage.get(key).await {
742 Ok(v) => v,
743 Err(e) => {
744 warn!(
745 "Failed to read value for key {:?} during update rollback tracking: {}",
746 key, e
747 );
748 None
749 }
750 };
751 snapshots.push((key.clone(), value));
752 }
753 Some(RollbackOp::UndoUpdate { snapshots })
754 }
755 Query::Get { .. } | Query::Range { .. } | Query::Filter { .. } => None,
757 }
758 }
759
760 async fn rollback_operations(&self, ops: &[RollbackOp]) {
765 if ops.is_empty() {
766 return;
767 }
768
769 warn!("Rolling back {} operations due to batch failure", ops.len());
770
771 for (idx, op) in ops.iter().rev().enumerate() {
772 match op {
773 RollbackOp::UndoSet { key, old_value } => {
774 match old_value {
775 Some(value) => {
776 if let Err(e) = self.storage.put(key, value).await {
778 error!(
779 "Rollback failed for UndoSet (restore) at index {}: {}",
780 idx, e
781 );
782 } else {
783 debug!("Rolled back Set: restored old value for key {:?}", key);
784 }
785 }
786 None => {
787 if let Err(e) = self.storage.delete(key).await {
789 error!(
790 "Rollback failed for UndoSet (delete) at index {}: {}",
791 idx, e
792 );
793 } else {
794 debug!("Rolled back Set: deleted new key {:?}", key);
795 }
796 }
797 }
798 }
799 RollbackOp::UndoDelete { key, old_value } => {
800 if let Some(value) = old_value {
801 if let Err(e) = self.storage.put(key, value).await {
803 error!("Rollback failed for UndoDelete at index {}: {}", idx, e);
804 } else {
805 debug!("Rolled back Delete: restored value for key {:?}", key);
806 }
807 }
808 }
811 RollbackOp::UndoUpdate { snapshots } => {
812 let current_keys = match self.storage.keys().await {
814 Ok(k) => k,
815 Err(e) => {
816 error!(
817 "Rollback failed for UndoUpdate at index {}: cannot list keys: {}",
818 idx, e
819 );
820 continue;
821 }
822 };
823
824 let snapshot_keys: std::collections::HashSet<&Key> =
826 snapshots.iter().map(|(k, _)| k).collect();
827
828 for key in ¤t_keys {
830 if !snapshot_keys.contains(key) {
831 if let Err(e) = self.storage.delete(key).await {
832 error!(
833 "Rollback failed for UndoUpdate (remove new key) at index {}: {}",
834 idx, e
835 );
836 } else {
837 debug!("Rolled back Update: removed new key {:?}", key);
838 }
839 }
840 }
841
842 for (key, old_value) in snapshots {
844 match old_value {
845 Some(value) => {
846 if let Err(e) = self.storage.put(key, value).await {
847 error!(
848 "Rollback failed for UndoUpdate (restore) at index {}: {}",
849 idx, e
850 );
851 } else {
852 debug!("Rolled back Update: restored value for key {:?}", key);
853 }
854 }
855 None => {
856 if let Err(e) = self.storage.delete(key).await {
858 error!(
859 "Rollback failed for UndoUpdate (delete) at index {}: {}",
860 idx, e
861 );
862 }
863 }
864 }
865 }
866 debug!("Rolled back Update operation at index {}", idx);
867 }
868 }
869 }
870
871 info!("Rollback completed");
872 }
873
874 pub fn execute_stream(
887 &self,
888 request: aql::QueryRequest,
889 config: StreamConfig,
890 ) -> futures::stream::BoxStream<'static, Result<aql::StreamResponse, NetError>> {
891 use futures::StreamExt;
892
893 let storage = self.storage.clone();
894 let request_id = request.request_id.clone();
895
896 let stream = async_stream::stream! {
897 let start_time = Instant::now();
898
899 info!(
900 "ExecuteStream request received: request_id={:?}, chunk_size={}",
901 request_id, config.chunk_size
902 );
903
904 let proto_query = match request.query {
906 Some(q) => q,
907 None => {
908 yield Err(NetError::MissingField("query".to_string()));
909 return;
910 }
911 };
912
913 let core_query = match query_from_proto(proto_query) {
914 Ok(q) => q,
915 Err(e) => {
916 error!("Failed to parse stream query: {}", e);
917 yield Err(e);
918 return;
919 }
920 };
921
922 let results = match core_query {
924 Query::Range { collection, start, end } => {
925 debug!(
926 "Executing streaming RANGE query: collection={}, start={:?}, end={:?}",
927 collection, start, end
928 );
929 match storage.range(&start, &end).await {
930 Ok(rows) => rows,
931 Err(e) => {
932 error!("Storage range query failed: {}", e);
933 yield Err(NetError::from(e));
934 return;
935 }
936 }
937 }
938 Query::Get { collection, key } => {
939 debug!(
940 "Executing streaming GET query: collection={}, key={:?}",
941 collection, key
942 );
943 match storage.get(&key).await {
944 Ok(Some(value)) => vec![(key, value)],
945 Ok(None) => Vec::new(),
946 Err(e) => {
947 error!("Storage get query failed: {}", e);
948 yield Err(NetError::from(e));
949 return;
950 }
951 }
952 }
953 _ => {
954 yield Err(NetError::InvalidRequest(
955 "Only Range and Get queries are supported for streaming".to_string(),
956 ));
957 return;
958 }
959 };
960
961 let results = if let Some(max) = config.max_results {
963 if results.len() > max {
964 results.into_iter().take(max).collect::<Vec<_>>()
965 } else {
966 results
967 }
968 } else {
969 results
970 };
971
972 let total_count = results.len();
973
974 if start_time.elapsed() > config.timeout {
976 yield Err(NetError::Timeout(
977 "Query execution exceeded timeout before streaming began".to_string(),
978 ));
979 return;
980 }
981
982 let mut sequence: u64 = 0;
984 let chunks_iter: Vec<Vec<(Key, CipherBlob)>> = results
985 .chunks(config.chunk_size)
986 .map(|c| c.to_vec())
987 .collect();
988 let total_chunks = chunks_iter.len();
989
990 for (chunk_idx, chunk) in chunks_iter.into_iter().enumerate() {
991 if start_time.elapsed() > config.timeout {
993 yield Err(NetError::Timeout(
994 format!("Streaming timed out at chunk {}/{}", chunk_idx + 1, total_chunks)
995 ));
996 return;
997 }
998
999 let has_more = chunk_idx + 1 < total_chunks;
1000 let values: Vec<query::KeyValue> = chunk
1001 .into_iter()
1002 .map(|(k, v)| query::KeyValue {
1003 key: Some(key_to_proto(&k)),
1004 value: Some(cipher_blob_to_proto(&v)),
1005 encrypted_predicate_result: None,
1006 })
1007 .collect();
1008
1009 yield Ok(aql::StreamResponse {
1010 chunk: Some(aql::stream_response::Chunk::Batch(aql::StreamBatch {
1011 values,
1012 has_more,
1013 })),
1014 sequence,
1015 });
1016
1017 sequence += 1;
1018 }
1019
1020 yield Ok(aql::StreamResponse {
1022 chunk: Some(aql::stream_response::Chunk::End(aql::StreamEnd {
1023 total_count: total_count as u64,
1024 })),
1025 sequence,
1026 });
1027
1028 info!(
1029 "ExecuteStream completed: {} items in {} chunks, {}ms",
1030 total_count,
1031 total_chunks,
1032 start_time.elapsed().as_millis()
1033 );
1034 };
1035
1036 stream.boxed()
1037 }
1038
1039 #[tracing::instrument(skip(self, _request))]
1041 pub async fn health_check(
1042 &self,
1043 _request: aql::HealthCheckRequest,
1044 ) -> aql::HealthCheckResponse {
1045 debug!("HealthCheck request received");
1046
1047 aql::HealthCheckResponse {
1048 status: aql::HealthStatus::HealthServing as i32,
1049 message: Some("Service is healthy".to_string()),
1050 }
1051 }
1052
1053 #[tracing::instrument(skip(self, _request))]
1055 pub async fn get_server_info(
1056 &self,
1057 _request: aql::ServerInfoRequest,
1058 ) -> aql::ServerInfoResponse {
1059 debug!("GetServerInfo request received");
1060
1061 let mut capabilities = vec![
1062 "query.get".to_string(),
1063 "query.set".to_string(),
1064 "query.delete".to_string(),
1065 "query.range".to_string(),
1066 "query.update".to_string(),
1067 ];
1068
1069 #[cfg(feature = "compute")]
1070 capabilities.push("query.filter".to_string());
1071
1072 aql::ServerInfoResponse {
1073 version: Some(create_version()),
1074 supported_versions: vec![create_version()],
1075 capabilities,
1076 uptime_seconds: self.start_time.elapsed().as_secs(),
1077 }
1078 }
1079}
1080
1081pub struct AqlServerBuilder<S: StorageEngine> {
1083 storage: Arc<S>,
1084}
1085
1086impl<S: StorageEngine + Send + Sync + 'static> AqlServerBuilder<S> {
1087 pub fn new(storage: Arc<S>) -> Self {
1089 Self { storage }
1090 }
1091
1092 pub fn build(self) -> AqlServiceImpl<S> {
1094 AqlServiceImpl::new(self.storage)
1095 }
1096
1097 pub fn build_grpc_service(
1102 self,
1103 ) -> crate::proto::aql::aql_service_server::AqlServiceServer<
1104 crate::grpc_service::AqlGrpcService<S>,
1105 > {
1106 use crate::grpc_service::AqlGrpcService;
1107 use crate::proto::aql::aql_service_server::AqlServiceServer;
1108
1109 let service_impl = Arc::new(AqlServiceImpl::new(self.storage));
1110 let grpc_service = AqlGrpcService::new(service_impl);
1111
1112 #[allow(unused_mut)]
1113 let mut server = AqlServiceServer::new(grpc_service);
1114
1115 #[cfg(feature = "compression")]
1116 {
1117 server = server
1118 .accept_compressed(tonic::codec::CompressionEncoding::Gzip)
1119 .send_compressed(tonic::codec::CompressionEncoding::Gzip);
1120 }
1121
1122 server
1123 }
1124}
1125
1126#[derive(Debug, Clone)]
1130pub struct StreamConfig {
1131 pub chunk_size: usize,
1133 pub max_results: Option<usize>,
1135 pub timeout: std::time::Duration,
1137}
1138
1139impl Default for StreamConfig {
1140 fn default() -> Self {
1141 Self {
1142 chunk_size: 100,
1143 max_results: None,
1144 timeout: std::time::Duration::from_secs(30),
1145 }
1146 }
1147}
1148
1149impl StreamConfig {
1150 pub fn with_chunk_size(mut self, chunk_size: usize) -> Self {
1152 self.chunk_size = if chunk_size == 0 { 1 } else { chunk_size };
1153 self
1154 }
1155
1156 pub fn with_max_results(mut self, max_results: usize) -> Self {
1158 self.max_results = Some(max_results);
1159 self
1160 }
1161
1162 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
1164 self.timeout = timeout;
1165 self
1166 }
1167}
1168
1169#[derive(Debug)]
1174#[allow(clippy::enum_variant_names)]
1175enum RollbackOp {
1176 UndoSet {
1178 key: Key,
1179 old_value: Option<CipherBlob>,
1181 },
1182 UndoDelete {
1184 key: Key,
1185 old_value: Option<CipherBlob>,
1187 },
1188 UndoUpdate {
1190 snapshots: Vec<(Key, Option<CipherBlob>)>,
1193 },
1194}
1195
1196fn apply_update_operation(current: &CipherBlob, op: &UpdateOp) -> CipherBlob {
1206 match op {
1207 UpdateOp::Set(_col, blob) => blob.clone(),
1208 UpdateOp::Add(_col, blob) => {
1209 let a = current.as_bytes();
1210 let b = blob.as_bytes();
1211 let len = a.len().max(b.len());
1212 let mut result = Vec::with_capacity(len);
1213 for i in 0..len {
1214 let va = if i < a.len() { a[i] } else { 0 };
1215 let vb = if i < b.len() { b[i] } else { 0 };
1216 result.push(va.wrapping_add(vb));
1217 }
1218 CipherBlob::new(result)
1219 }
1220 UpdateOp::Mul(_col, blob) => {
1221 let a = current.as_bytes();
1222 let b = blob.as_bytes();
1223 let len = a.len().max(b.len());
1224 let mut result = Vec::with_capacity(len);
1225 for i in 0..len {
1226 let va = if i < a.len() { a[i] } else { 1 };
1227 let vb = if i < b.len() { b[i] } else { 1 };
1228 result.push(va.wrapping_mul(vb));
1229 }
1230 CipherBlob::new(result)
1231 }
1232 }
1233}
1234
1235#[cfg(test)]
1236mod tests {
1237 use super::*;
1238 use amaters_core::storage::MemoryStorage;
1239 use amaters_core::types::{CipherBlob, Key};
1240
1241 #[tokio::test]
1242 async fn test_service_creation() {
1243 let storage = Arc::new(MemoryStorage::new());
1244 let service = AqlServiceImpl::new(storage);
1245 assert!(service.start_time.elapsed().as_secs() < 1);
1246 }
1247
1248 #[tokio::test]
1249 async fn test_get_query_execution() {
1250 let storage = Arc::new(MemoryStorage::new());
1251 let key = Key::from_str("test_key");
1252 let value = CipherBlob::new(vec![1, 2, 3, 4, 5]);
1253
1254 storage.put(&key, &value).await.expect("Failed to put");
1255
1256 let service = AqlServiceImpl::new(storage);
1257
1258 let query = Query::Get {
1259 collection: "test".to_string(),
1260 key: key.clone(),
1261 };
1262
1263 let result = service.execute_query_internal(query).await;
1264 assert!(result.is_ok());
1265
1266 let query_result = result.expect("Query failed");
1267 match query_result.result {
1268 Some(query::query_result::Result::Single(single)) => {
1269 assert!(single.value.is_some());
1270 }
1271 _ => panic!("Expected single result"),
1272 }
1273 }
1274
1275 #[tokio::test]
1276 async fn test_set_query_execution() {
1277 let storage = Arc::new(MemoryStorage::new());
1278 let service = AqlServiceImpl::new(storage.clone());
1279
1280 let key = Key::from_str("test_key");
1281 let value = CipherBlob::new(vec![1, 2, 3, 4, 5]);
1282
1283 let query = Query::Set {
1284 collection: "test".to_string(),
1285 key: key.clone(),
1286 value: value.clone(),
1287 };
1288
1289 let result = service.execute_query_internal(query).await;
1290 assert!(result.is_ok());
1291
1292 let stored = storage.get(&key).await.expect("Failed to get");
1294 assert!(stored.is_some());
1295 assert_eq!(stored.expect("No value"), value);
1296 }
1297
1298 #[tokio::test]
1299 async fn test_delete_query_execution() {
1300 let storage = Arc::new(MemoryStorage::new());
1301 let key = Key::from_str("test_key");
1302 let value = CipherBlob::new(vec![1, 2, 3, 4, 5]);
1303
1304 storage.put(&key, &value).await.expect("Failed to put");
1305
1306 let service = AqlServiceImpl::new(storage.clone());
1307
1308 let query = Query::Delete {
1309 collection: "test".to_string(),
1310 key: key.clone(),
1311 };
1312
1313 let result = service.execute_query_internal(query).await;
1314 assert!(result.is_ok());
1315
1316 let stored = storage.get(&key).await.expect("Failed to get");
1318 assert!(stored.is_none());
1319 }
1320
1321 #[tokio::test]
1322 async fn test_range_query_execution() {
1323 let storage = Arc::new(MemoryStorage::new());
1324
1325 for i in 0..10 {
1327 let key = Key::from_str(&format!("key_{:02}", i));
1328 let value = CipherBlob::new(vec![i as u8]);
1329 storage.put(&key, &value).await.expect("Failed to put");
1330 }
1331
1332 let service = AqlServiceImpl::new(storage);
1333
1334 let query = Query::Range {
1335 collection: "test".to_string(),
1336 start: Key::from_str("key_03"),
1337 end: Key::from_str("key_07"),
1338 };
1339
1340 let result = service.execute_query_internal(query).await;
1341 assert!(result.is_ok());
1342
1343 let query_result = result.expect("Query failed");
1344 match query_result.result {
1345 Some(query::query_result::Result::Multi(multi)) => {
1346 assert!(!multi.values.is_empty());
1347 }
1348 _ => panic!("Expected multi result"),
1349 }
1350 }
1351
1352 #[tokio::test]
1353 async fn test_get_nonexistent_key() {
1354 let storage = Arc::new(MemoryStorage::new());
1355 let service = AqlServiceImpl::new(storage);
1356
1357 let query = Query::Get {
1358 collection: "test".to_string(),
1359 key: Key::from_str("nonexistent"),
1360 };
1361
1362 let result = service.execute_query_internal(query).await;
1363 assert!(result.is_ok());
1364
1365 let query_result = result.expect("Query failed");
1366 match query_result.result {
1367 Some(query::query_result::Result::Single(single)) => {
1368 assert!(single.value.is_none());
1369 }
1370 _ => panic!("Expected single result"),
1371 }
1372 }
1373
1374 #[tokio::test]
1375 async fn test_health_check() {
1376 let storage = Arc::new(MemoryStorage::new());
1377 let service = AqlServiceImpl::new(storage);
1378
1379 let request = aql::HealthCheckRequest { service: None };
1380 let response = service.health_check(request).await;
1381
1382 assert_eq!(response.status, aql::HealthStatus::HealthServing as i32);
1383 }
1384
1385 #[tokio::test]
1386 async fn test_server_info() {
1387 let storage = Arc::new(MemoryStorage::new());
1388 let service = AqlServiceImpl::new(storage);
1389
1390 let request = aql::ServerInfoRequest {};
1391 let response = service.get_server_info(request).await;
1392
1393 assert!(response.version.is_some());
1394 assert!(!response.capabilities.is_empty());
1395 assert!(response.capabilities.contains(&"query.get".to_string()));
1396 }
1397
1398 #[cfg(feature = "compute")]
1399 #[tokio::test]
1400 async fn test_server_info_advertises_filter() {
1401 let storage = Arc::new(MemoryStorage::new());
1402 let service = AqlServiceImpl::new(storage);
1403
1404 let request = aql::ServerInfoRequest {};
1405 let response = service.get_server_info(request).await;
1406
1407 assert!(
1408 response.capabilities.contains(&"query.filter".to_string()),
1409 "capabilities should advertise query.filter when compute feature is enabled"
1410 );
1411 }
1412
1413 #[cfg(feature = "compute")]
1414 #[tokio::test]
1415 async fn test_filter_query_execution() {
1416 use amaters_core::{ColumnRef, Predicate};
1417
1418 let storage = Arc::new(MemoryStorage::new());
1419
1420 for i in 0u8..5 {
1423 let key = Key::from_str(&format!("row_{:02}", i));
1424 let value = CipherBlob::new(vec![i]);
1425 storage
1426 .put(&key, &value)
1427 .await
1428 .expect("Failed to insert test data");
1429 }
1430
1431 let service = AqlServiceImpl::new(storage);
1432
1433 let rhs_blob = CipherBlob::new(vec![2]);
1435 let predicate = Predicate::Gt(ColumnRef::new("value".to_string()), rhs_blob);
1436
1437 let filter_query = Query::Filter {
1438 collection: "test".to_string(),
1439 predicate,
1440 };
1441
1442 let result = service
1443 .execute_query_internal(filter_query)
1444 .await
1445 .expect("plaintext filter query should succeed");
1446
1447 match result.result {
1448 Some(query::query_result::Result::Multi(multi)) => {
1449 assert_eq!(
1451 multi.values.len(),
1452 2,
1453 "expected 2 matching rows (values 3 and 4)"
1454 );
1455 for kv in &multi.values {
1457 assert!(
1458 kv.encrypted_predicate_result.is_none(),
1459 "plaintext filter results should not carry encrypted_predicate_result"
1460 );
1461 }
1462 }
1463 other => panic!("Expected Multi result from filter query, got {:?}", other),
1464 }
1465 }
1466
1467 #[cfg(not(feature = "compute"))]
1468 #[tokio::test]
1469 async fn test_filter_query_requires_compute_feature() {
1470 use amaters_core::{ColumnRef, Predicate};
1471
1472 let storage = Arc::new(MemoryStorage::new());
1473 let service = AqlServiceImpl::new(storage);
1474
1475 let rhs_blob = CipherBlob::new(vec![1]);
1476 let predicate = Predicate::Gt(ColumnRef::new("value".to_string()), rhs_blob);
1477
1478 let filter_query = Query::Filter {
1479 collection: "test".to_string(),
1480 predicate,
1481 };
1482
1483 let result = service.execute_query_internal(filter_query).await;
1484 assert!(
1485 result.is_err(),
1486 "Filter should fail without compute feature"
1487 );
1488 let err_msg = result
1489 .as_ref()
1490 .err()
1491 .map(|e| e.to_string())
1492 .unwrap_or_default();
1493 assert!(
1494 err_msg.contains("compute feature"),
1495 "Error should mention compute feature: {}",
1496 err_msg
1497 );
1498 }
1499
1500 #[cfg(not(feature = "compute"))]
1508 fn dummy_predicate() -> amaters_core::Predicate {
1509 amaters_core::Predicate::Eq(
1510 amaters_core::ColumnRef::new("col"),
1511 CipherBlob::new(vec![0]),
1512 )
1513 }
1514
1515 #[cfg(not(feature = "compute"))]
1516 #[tokio::test]
1517 async fn test_update_set_single_key() {
1518 let storage = Arc::new(MemoryStorage::new());
1519 let key = Key::from_str("row_00");
1520 let original = CipherBlob::new(vec![10, 20, 30]);
1521 storage.put(&key, &original).await.expect("Failed to put");
1522
1523 let service = AqlServiceImpl::new(storage.clone());
1524
1525 let new_blob = CipherBlob::new(vec![99, 88, 77]);
1526 let query = Query::Update {
1527 collection: "test".to_string(),
1528 predicate: dummy_predicate(),
1529 updates: vec![amaters_core::Update::Set(
1530 amaters_core::ColumnRef::new("val"),
1531 new_blob.clone(),
1532 )],
1533 };
1534
1535 let result = service
1536 .execute_query_internal(query)
1537 .await
1538 .expect("Update failed");
1539 match result.result {
1540 Some(query::query_result::Result::Success(s)) => {
1541 assert_eq!(s.affected_rows, 1);
1542 }
1543 other => panic!("Expected Success, got {:?}", other),
1544 }
1545
1546 let stored = storage
1547 .get(&key)
1548 .await
1549 .expect("Failed to get")
1550 .expect("Key missing after update");
1551 assert_eq!(stored, new_blob);
1552 }
1553
1554 #[cfg(not(feature = "compute"))]
1555 #[tokio::test]
1556 async fn test_update_set_multiple_keys() {
1557 let storage = Arc::new(MemoryStorage::new());
1558
1559 for i in 0u8..5 {
1560 let key = Key::from_str(&format!("row_{:02}", i));
1561 let value = CipherBlob::new(vec![i]);
1562 storage.put(&key, &value).await.expect("Failed to put");
1563 }
1564
1565 let service = AqlServiceImpl::new(storage.clone());
1566
1567 let replacement = CipherBlob::new(vec![255]);
1568 let query = Query::Update {
1569 collection: "data".to_string(),
1570 predicate: dummy_predicate(),
1571 updates: vec![amaters_core::Update::Set(
1572 amaters_core::ColumnRef::new("v"),
1573 replacement.clone(),
1574 )],
1575 };
1576
1577 let result = service
1578 .execute_query_internal(query)
1579 .await
1580 .expect("Update failed");
1581 match result.result {
1582 Some(query::query_result::Result::Success(s)) => {
1583 assert_eq!(s.affected_rows, 5);
1584 }
1585 other => panic!("Expected Success, got {:?}", other),
1586 }
1587
1588 for i in 0u8..5 {
1590 let key = Key::from_str(&format!("row_{:02}", i));
1591 let stored = storage
1592 .get(&key)
1593 .await
1594 .expect("Failed to get")
1595 .expect("Key missing");
1596 assert_eq!(stored, replacement);
1597 }
1598 }
1599
1600 #[cfg(not(feature = "compute"))]
1601 #[tokio::test]
1602 async fn test_update_nonexistent_collection() {
1603 let storage = Arc::new(MemoryStorage::new());
1605 let service = AqlServiceImpl::new(storage);
1606
1607 let query = Query::Update {
1608 collection: "ghost".to_string(),
1609 predicate: dummy_predicate(),
1610 updates: vec![amaters_core::Update::Set(
1611 amaters_core::ColumnRef::new("x"),
1612 CipherBlob::new(vec![1]),
1613 )],
1614 };
1615
1616 let result = service
1617 .execute_query_internal(query)
1618 .await
1619 .expect("Update on empty storage should not error");
1620 match result.result {
1621 Some(query::query_result::Result::Success(s)) => {
1622 assert_eq!(s.affected_rows, 0);
1623 }
1624 other => panic!("Expected Success with 0 rows, got {:?}", other),
1625 }
1626 }
1627
1628 #[cfg(not(feature = "compute"))]
1629 #[tokio::test]
1630 async fn test_update_add_operation() {
1631 let storage = Arc::new(MemoryStorage::new());
1632 let key = Key::from_str("counter");
1633 let original = CipherBlob::new(vec![10, 20]);
1634 storage.put(&key, &original).await.expect("Failed to put");
1635
1636 let service = AqlServiceImpl::new(storage.clone());
1637
1638 let addend = CipherBlob::new(vec![5, 3]);
1639 let query = Query::Update {
1640 collection: "c".to_string(),
1641 predicate: dummy_predicate(),
1642 updates: vec![amaters_core::Update::Add(
1643 amaters_core::ColumnRef::new("v"),
1644 addend,
1645 )],
1646 };
1647
1648 service
1649 .execute_query_internal(query)
1650 .await
1651 .expect("Update failed");
1652
1653 let stored = storage
1654 .get(&key)
1655 .await
1656 .expect("Failed to get")
1657 .expect("Key missing");
1658 assert_eq!(stored.as_bytes(), &[15, 23]);
1659 }
1660
1661 #[cfg(not(feature = "compute"))]
1662 #[tokio::test]
1663 async fn test_update_mul_operation() {
1664 let storage = Arc::new(MemoryStorage::new());
1665 let key = Key::from_str("product");
1666 let original = CipherBlob::new(vec![3, 4]);
1667 storage.put(&key, &original).await.expect("Failed to put");
1668
1669 let service = AqlServiceImpl::new(storage.clone());
1670
1671 let factor = CipherBlob::new(vec![2, 5]);
1672 let query = Query::Update {
1673 collection: "c".to_string(),
1674 predicate: dummy_predicate(),
1675 updates: vec![amaters_core::Update::Mul(
1676 amaters_core::ColumnRef::new("v"),
1677 factor,
1678 )],
1679 };
1680
1681 service
1682 .execute_query_internal(query)
1683 .await
1684 .expect("Update failed");
1685
1686 let stored = storage
1687 .get(&key)
1688 .await
1689 .expect("Failed to get")
1690 .expect("Key missing");
1691 assert_eq!(stored.as_bytes(), &[6, 20]);
1692 }
1693
1694 #[cfg(not(feature = "compute"))]
1695 #[tokio::test]
1696 async fn test_update_multiple_operations_per_key() {
1697 let storage = Arc::new(MemoryStorage::new());
1698 let key = Key::from_str("multi_op");
1699 let original = CipherBlob::new(vec![2]);
1700 storage.put(&key, &original).await.expect("Failed to put");
1701
1702 let service = AqlServiceImpl::new(storage.clone());
1703
1704 let query = Query::Update {
1706 collection: "c".to_string(),
1707 predicate: dummy_predicate(),
1708 updates: vec![
1709 amaters_core::Update::Add(
1710 amaters_core::ColumnRef::new("v"),
1711 CipherBlob::new(vec![3]),
1712 ),
1713 amaters_core::Update::Mul(
1714 amaters_core::ColumnRef::new("v"),
1715 CipherBlob::new(vec![10]),
1716 ),
1717 ],
1718 };
1719
1720 service
1721 .execute_query_internal(query)
1722 .await
1723 .expect("Update failed");
1724
1725 let stored = storage
1726 .get(&key)
1727 .await
1728 .expect("Failed to get")
1729 .expect("Key missing");
1730 assert_eq!(stored.as_bytes(), &[50]);
1731 }
1732
1733 #[cfg(not(feature = "compute"))]
1734 #[tokio::test]
1735 async fn test_update_returns_affected_count() {
1736 let storage = Arc::new(MemoryStorage::new());
1737
1738 for i in 0u8..7 {
1740 let key = Key::from_str(&format!("k{}", i));
1741 storage
1742 .put(&key, &CipherBlob::new(vec![i]))
1743 .await
1744 .expect("Failed to put");
1745 }
1746
1747 let service = AqlServiceImpl::new(storage);
1748
1749 let query = Query::Update {
1750 collection: "c".to_string(),
1751 predicate: dummy_predicate(),
1752 updates: vec![amaters_core::Update::Set(
1753 amaters_core::ColumnRef::new("v"),
1754 CipherBlob::new(vec![0]),
1755 )],
1756 };
1757
1758 let result = service
1759 .execute_query_internal(query)
1760 .await
1761 .expect("Update failed");
1762 match result.result {
1763 Some(query::query_result::Result::Success(s)) => {
1764 assert_eq!(s.affected_rows, 7);
1765 }
1766 other => panic!("Expected Success with 7 rows, got {:?}", other),
1767 }
1768 }
1769
1770 #[cfg(not(feature = "compute"))]
1771 #[tokio::test]
1772 async fn test_update_preserves_other_collections() {
1773 let storage = Arc::new(MemoryStorage::new());
1776
1777 let key_a = Key::from_str("collA_row1");
1778 let key_b = Key::from_str("collB_row1");
1779 let val_a = CipherBlob::new(vec![1, 2, 3]);
1780 let val_b = CipherBlob::new(vec![4, 5, 6]);
1781
1782 storage.put(&key_a, &val_a).await.expect("Failed to put A");
1783 storage.put(&key_b, &val_b).await.expect("Failed to put B");
1784
1785 let service = AqlServiceImpl::new(storage.clone());
1786
1787 let query = Query::Update {
1789 collection: "collA".to_string(),
1790 predicate: dummy_predicate(),
1791 updates: vec![amaters_core::Update::Set(
1792 amaters_core::ColumnRef::new("v"),
1793 CipherBlob::new(vec![99]),
1794 )],
1795 };
1796
1797 service
1798 .execute_query_internal(query)
1799 .await
1800 .expect("Update failed");
1801
1802 let stored_a = storage.get(&key_a).await.expect("Failed to get A");
1804 assert!(stored_a.is_some(), "key_a should still exist");
1805
1806 let stored_b = storage.get(&key_b).await.expect("Failed to get B");
1807 assert!(stored_b.is_some(), "key_b should still exist");
1808 }
1809
1810 #[cfg(not(feature = "compute"))]
1811 #[tokio::test]
1812 async fn test_update_empty_updates_vec() {
1813 let storage = Arc::new(MemoryStorage::new());
1815 let key = Key::from_str("keep_me");
1816 let original = CipherBlob::new(vec![42]);
1817 storage.put(&key, &original).await.expect("Failed to put");
1818
1819 let service = AqlServiceImpl::new(storage.clone());
1820
1821 let query = Query::Update {
1822 collection: "c".to_string(),
1823 predicate: dummy_predicate(),
1824 updates: vec![], };
1826
1827 let result = service
1828 .execute_query_internal(query)
1829 .await
1830 .expect("Update with empty ops should succeed");
1831 match result.result {
1832 Some(query::query_result::Result::Success(s)) => {
1833 assert_eq!(s.affected_rows, 1);
1835 }
1836 other => panic!("Expected Success, got {:?}", other),
1837 }
1838
1839 let stored = storage
1841 .get(&key)
1842 .await
1843 .expect("Failed to get")
1844 .expect("Key missing");
1845 assert_eq!(stored, original);
1846 }
1847
1848 #[cfg(not(feature = "compute"))]
1849 #[tokio::test]
1850 async fn test_update_then_select_verifies_changes() {
1851 let storage = Arc::new(MemoryStorage::new());
1852
1853 for i in 0u8..3 {
1855 let key = Key::from_str(&format!("sel_{:02}", i));
1856 let value = CipherBlob::new(vec![i, i, i]);
1857 storage.put(&key, &value).await.expect("Failed to put");
1858 }
1859
1860 let service = AqlServiceImpl::new(storage.clone());
1861
1862 let update_query = Query::Update {
1864 collection: "c".to_string(),
1865 predicate: dummy_predicate(),
1866 updates: vec![amaters_core::Update::Add(
1867 amaters_core::ColumnRef::new("v"),
1868 CipherBlob::new(vec![1, 1, 1]),
1869 )],
1870 };
1871
1872 service
1873 .execute_query_internal(update_query)
1874 .await
1875 .expect("Update failed");
1876
1877 for i in 0u8..3 {
1879 let key = Key::from_str(&format!("sel_{:02}", i));
1880 let get_query = Query::Get {
1881 collection: "c".to_string(),
1882 key: key.clone(),
1883 };
1884
1885 let result = service
1886 .execute_query_internal(get_query)
1887 .await
1888 .expect("Get failed");
1889
1890 match result.result {
1891 Some(query::query_result::Result::Single(single)) => {
1892 let proto_val = single.value.expect("Expected value from get");
1893 let expected = vec![i + 1, i + 1, i + 1];
1895 assert_eq!(
1896 proto_val.data, expected,
1897 "Row sel_{:02} should have been updated",
1898 i
1899 );
1900 }
1901 other => panic!("Expected Single result, got {:?}", other),
1902 }
1903 }
1904 }
1905
1906 #[cfg(feature = "compute")]
1911 #[tokio::test]
1912 async fn test_update_with_compute_feature() {
1913 use amaters_core::{ColumnRef, Predicate};
1914
1915 let storage = Arc::new(MemoryStorage::new());
1916
1917 for i in 0u8..3 {
1918 let key = Key::from_str(&format!("row_{:02}", i));
1919 let value = CipherBlob::new(vec![i]);
1920 storage
1921 .put(&key, &value)
1922 .await
1923 .expect("Failed to insert test data");
1924 }
1925
1926 let service = AqlServiceImpl::new(storage);
1927
1928 let rhs_blob = CipherBlob::new(vec![1]);
1929 let predicate = Predicate::Eq(ColumnRef::new("value"), rhs_blob);
1930
1931 let update_query = Query::Update {
1932 collection: "test".to_string(),
1933 predicate,
1934 updates: vec![amaters_core::Update::Set(
1935 ColumnRef::new("v"),
1936 CipherBlob::new(vec![99]),
1937 )],
1938 };
1939
1940 let result = service.execute_query_internal(update_query).await;
1941
1942 match result {
1944 Ok(query_result) => {
1945 match query_result.result {
1946 Some(query::query_result::Result::Success(s)) => {
1947 assert!(s.affected_rows <= 3);
1949 }
1950 other => panic!("Expected Success result from update, got {:?}", other),
1951 }
1952 }
1953 Err(e) => {
1954 let msg = e.to_string();
1955 assert!(
1956 msg.contains("FHE")
1957 || msg.contains("fhe")
1958 || msg.contains("Predicate compilation")
1959 || msg.contains("compilation failed")
1960 || msg.contains("execution")
1961 || msg.contains("RHS"),
1962 "Unexpected error from update query: {}",
1963 msg
1964 );
1965 }
1966 }
1967 }
1968
1969 include!("server_rollback_tests.rs");
1971
1972 #[tokio::test]
1977 async fn test_compression_feature_gate_disabled() {
1978 let storage = Arc::new(MemoryStorage::new());
1979 let builder = AqlServerBuilder::new(storage);
1980 let _server = builder.build_grpc_service();
1982 }
1984}