1use reqwest::{
4 header::{HeaderMap, HeaderValue, AUTHORIZATION},
5 Client, StatusCode,
6};
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9use tracing::{debug, instrument};
10
11use serde::Deserialize;
12
13use crate::error::{ClientError, Result, ServerErrorCode};
14use crate::types::*;
15
16const DEFAULT_TIMEOUT_SECS: u64 = 30;
18
19#[derive(Debug, Clone)]
21pub struct DakeraClient {
22 pub(crate) client: Client,
24 pub(crate) base_url: String,
26 pub(crate) ode_url: Option<String>,
28 #[allow(dead_code)]
30 pub(crate) retry_config: RetryConfig,
31 pub(crate) last_rate_limit: Arc<Mutex<Option<RateLimitHeaders>>>,
33}
34
35impl DakeraClient {
36 pub fn new(base_url: impl Into<String>) -> Result<Self> {
46 DakeraClientBuilder::new(base_url).build()
47 }
48
49 pub fn builder(base_url: impl Into<String>) -> DakeraClientBuilder {
51 DakeraClientBuilder::new(base_url)
52 }
53
54 #[instrument(skip(self))]
60 pub async fn health(&self) -> Result<HealthResponse> {
61 let url = format!("{}/health", self.base_url);
62 let response = self.client.get(&url).send().await?;
63
64 if response.status().is_success() {
65 let json: serde_json::Value = response.json().await?;
66 let healthy = json
69 .get("healthy")
70 .and_then(|v| v.as_bool())
71 .unwrap_or_else(|| json.get("status").and_then(|v| v.as_str()) == Some("healthy"));
72 let version = json
73 .get("version")
74 .and_then(|v| v.as_str())
75 .map(String::from);
76 let uptime_seconds = json.get("uptime_seconds").and_then(|v| v.as_u64());
77 Ok(HealthResponse {
78 healthy,
79 version,
80 uptime_seconds,
81 build_sha: json
82 .get("build_sha")
83 .and_then(|v| v.as_str())
84 .map(String::from),
85 })
86 } else {
87 Ok(HealthResponse {
89 healthy: true,
90 version: None,
91 uptime_seconds: None,
92 build_sha: None,
93 })
94 }
95 }
96
97 #[instrument(skip(self))]
99 pub async fn ready(&self) -> Result<ReadinessResponse> {
100 let url = format!("{}/health/ready", self.base_url);
101 let response = self.client.get(&url).send().await?;
102
103 if response.status().is_success() {
104 Ok(response.json().await?)
105 } else {
106 Ok(ReadinessResponse {
107 ready: false,
108 components: None,
109 })
110 }
111 }
112
113 #[instrument(skip(self))]
115 pub async fn live(&self) -> Result<bool> {
116 let url = format!("{}/health/live", self.base_url);
117 let response = self.client.get(&url).send().await?;
118 Ok(response.status().is_success())
119 }
120
121 #[instrument(skip(self))]
127 pub async fn list_namespaces(&self) -> Result<Vec<String>> {
128 let url = format!("{}/v1/namespaces", self.base_url);
129 let response = self.client.get(&url).send().await?;
130 self.handle_response::<ListNamespacesResponse>(response)
131 .await
132 .map(|r| r.namespaces)
133 }
134
135 #[instrument(skip(self))]
137 pub async fn get_namespace(&self, namespace: &str) -> Result<NamespaceInfo> {
138 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
139 let response = self.client.get(&url).send().await?;
140 self.handle_response(response).await
141 }
142
143 #[instrument(skip(self, request))]
145 pub async fn create_namespace(
146 &self,
147 namespace: &str,
148 request: CreateNamespaceRequest,
149 ) -> Result<NamespaceInfo> {
150 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
151 let response = self.client.put(&url).json(&request).send().await?;
152 self.handle_response(response).await
153 }
154
155 #[instrument(skip(self, request), fields(namespace = %namespace))]
161 pub async fn configure_namespace(
162 &self,
163 namespace: &str,
164 request: ConfigureNamespaceRequest,
165 ) -> Result<ConfigureNamespaceResponse> {
166 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
167 let response = self.client.put(&url).json(&request).send().await?;
168 self.handle_response(response).await
169 }
170
171 #[instrument(skip(self))]
173 pub async fn delete_namespace(&self, namespace: &str) -> Result<()> {
174 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
175 let response = self.client.delete(&url).send().await?;
176 if response.status().is_success() {
177 Ok(())
178 } else {
179 let status = response.status().as_u16();
180 let text = response.text().await.unwrap_or_default();
181 Err(ClientError::Server {
182 status,
183 message: text,
184 code: None,
185 })
186 }
187 }
188
189 #[instrument(skip(self))]
191 pub async fn flush(&self, namespace: &str) -> Result<serde_json::Value> {
192 let url = format!("{}/v1/namespaces/{}/flush", self.base_url, namespace);
193 let response = self.client.post(&url).send().await?;
194 self.handle_response(response).await
195 }
196
197 #[instrument(skip(self))]
199 pub async fn get_namespace_stats(&self, namespace: &str) -> Result<serde_json::Value> {
200 let url = format!("{}/v1/namespaces/{}/stats", self.base_url, namespace);
201 let response = self.client.get(&url).send().await?;
202 self.handle_response(response).await
203 }
204
205 #[instrument(skip(self))]
207 pub async fn get_index_stats(&self, namespace: &str) -> Result<serde_json::Value> {
208 self.get_namespace_stats(namespace).await
209 }
210
211 #[instrument(skip(self, request), fields(vector_count = request.vectors.len()))]
217 pub async fn upsert(&self, namespace: &str, request: UpsertRequest) -> Result<UpsertResponse> {
218 let url = format!("{}/v1/namespaces/{}/vectors", self.base_url, namespace);
219 debug!(
220 "Upserting {} vectors to {}",
221 request.vectors.len(),
222 namespace
223 );
224
225 let response = self.client.post(&url).json(&request).send().await?;
226 self.handle_response(response).await
227 }
228
229 #[instrument(skip(self, vector))]
231 pub async fn upsert_one(&self, namespace: &str, vector: Vector) -> Result<UpsertResponse> {
232 self.upsert(namespace, UpsertRequest::single(vector)).await
233 }
234
235 #[instrument(skip(self, request), fields(namespace = %namespace, count = request.ids.len()))]
268 pub async fn upsert_columns(
269 &self,
270 namespace: &str,
271 request: ColumnUpsertRequest,
272 ) -> Result<UpsertResponse> {
273 let url = format!(
274 "{}/v1/namespaces/{}/upsert-columns",
275 self.base_url, namespace
276 );
277 debug!(
278 "Upserting {} vectors in column format to {}",
279 request.ids.len(),
280 namespace
281 );
282
283 let response = self.client.post(&url).json(&request).send().await?;
284 self.handle_response(response).await
285 }
286
287 #[instrument(skip(self, request), fields(top_k = request.top_k))]
289 pub async fn query(&self, namespace: &str, request: QueryRequest) -> Result<QueryResponse> {
290 let url = format!("{}/v1/namespaces/{}/query", self.base_url, namespace);
291 debug!(
292 "Querying namespace {} for top {} results",
293 namespace, request.top_k
294 );
295
296 let response = self.client.post(&url).json(&request).send().await?;
297 self.handle_response(response).await
298 }
299
300 #[instrument(skip(self, vector))]
302 pub async fn query_simple(
303 &self,
304 namespace: &str,
305 vector: Vec<f32>,
306 top_k: u32,
307 ) -> Result<QueryResponse> {
308 self.query(namespace, QueryRequest::new(vector, top_k))
309 .await
310 }
311
312 #[instrument(skip(self, request), fields(namespace = %namespace, query_count = request.queries.len()))]
336 pub async fn batch_query(
337 &self,
338 namespace: &str,
339 request: BatchQueryRequest,
340 ) -> Result<BatchQueryResponse> {
341 let url = format!("{}/v1/namespaces/{}/batch-query", self.base_url, namespace);
342 debug!(
343 "Batch querying namespace {} with {} queries",
344 namespace,
345 request.queries.len()
346 );
347
348 let response = self.client.post(&url).json(&request).send().await?;
349 self.handle_response(response).await
350 }
351
352 #[instrument(skip(self, request), fields(id_count = request.ids.len()))]
354 pub async fn delete(&self, namespace: &str, request: DeleteRequest) -> Result<DeleteResponse> {
355 let url = format!(
356 "{}/v1/namespaces/{}/vectors/delete",
357 self.base_url, namespace
358 );
359 debug!("Deleting {} vectors from {}", request.ids.len(), namespace);
360
361 let response = self.client.post(&url).json(&request).send().await?;
362 self.handle_response(response).await
363 }
364
365 #[instrument(skip(self))]
367 pub async fn delete_one(&self, namespace: &str, id: &str) -> Result<DeleteResponse> {
368 self.delete(namespace, DeleteRequest::single(id)).await
369 }
370
371 #[instrument(skip(self, request))]
373 pub async fn bulk_update_vectors(
374 &self,
375 namespace: &str,
376 request: BulkUpdateRequest,
377 ) -> Result<BulkUpdateResponse> {
378 let url = format!(
379 "{}/v1/namespaces/{}/vectors/bulk-update",
380 self.base_url, namespace
381 );
382 let response = self.client.post(&url).json(&request).send().await?;
383 self.handle_response(response).await
384 }
385
386 #[instrument(skip(self, request))]
388 pub async fn bulk_delete_vectors(
389 &self,
390 namespace: &str,
391 request: BulkDeleteRequest,
392 ) -> Result<BulkDeleteResponse> {
393 let url = format!(
394 "{}/v1/namespaces/{}/vectors/bulk-delete",
395 self.base_url, namespace
396 );
397 let response = self.client.post(&url).json(&request).send().await?;
398 self.handle_response(response).await
399 }
400
401 #[instrument(skip(self, request))]
403 pub async fn count_vectors(
404 &self,
405 namespace: &str,
406 request: CountVectorsRequest,
407 ) -> Result<CountVectorsResponse> {
408 let url = format!(
409 "{}/v1/namespaces/{}/vectors/count",
410 self.base_url, namespace
411 );
412 let response = self.client.post(&url).json(&request).send().await?;
413 self.handle_response(response).await
414 }
415
416 #[instrument(skip(self, request), fields(doc_count = request.documents.len()))]
422 pub async fn index_documents(
423 &self,
424 namespace: &str,
425 request: IndexDocumentsRequest,
426 ) -> Result<IndexDocumentsResponse> {
427 let url = format!(
428 "{}/v1/namespaces/{}/fulltext/index",
429 self.base_url, namespace
430 );
431 debug!(
432 "Indexing {} documents in {}",
433 request.documents.len(),
434 namespace
435 );
436
437 let response = self.client.post(&url).json(&request).send().await?;
438 self.handle_response(response).await
439 }
440
441 #[instrument(skip(self, document))]
443 pub async fn index_document(
444 &self,
445 namespace: &str,
446 document: Document,
447 ) -> Result<IndexDocumentsResponse> {
448 self.index_documents(
449 namespace,
450 IndexDocumentsRequest {
451 documents: vec![document],
452 },
453 )
454 .await
455 }
456
457 #[instrument(skip(self, request))]
459 pub async fn fulltext_search(
460 &self,
461 namespace: &str,
462 request: FullTextSearchRequest,
463 ) -> Result<FullTextSearchResponse> {
464 let url = format!(
465 "{}/v1/namespaces/{}/fulltext/search",
466 self.base_url, namespace
467 );
468 debug!("Full-text search in {} for: {}", namespace, request.query);
469
470 let response = self.client.post(&url).json(&request).send().await?;
471 self.handle_response(response).await
472 }
473
474 #[instrument(skip(self))]
476 pub async fn search_text(
477 &self,
478 namespace: &str,
479 query: &str,
480 top_k: u32,
481 ) -> Result<FullTextSearchResponse> {
482 self.fulltext_search(namespace, FullTextSearchRequest::new(query, top_k))
483 .await
484 }
485
486 #[instrument(skip(self))]
488 pub async fn fulltext_stats(&self, namespace: &str) -> Result<FullTextStats> {
489 let url = format!(
490 "{}/v1/namespaces/{}/fulltext/stats",
491 self.base_url, namespace
492 );
493 let response = self.client.get(&url).send().await?;
494 self.handle_response(response).await
495 }
496
497 #[instrument(skip(self, request))]
499 pub async fn fulltext_delete(
500 &self,
501 namespace: &str,
502 request: DeleteRequest,
503 ) -> Result<DeleteResponse> {
504 let url = format!(
505 "{}/v1/namespaces/{}/fulltext/delete",
506 self.base_url, namespace
507 );
508 let response = self.client.post(&url).json(&request).send().await?;
509 self.handle_response(response).await
510 }
511
512 #[instrument(skip(self, request), fields(top_k = request.top_k))]
518 pub async fn hybrid_search(
519 &self,
520 namespace: &str,
521 request: HybridSearchRequest,
522 ) -> Result<HybridSearchResponse> {
523 let url = format!("{}/v1/namespaces/{}/hybrid", self.base_url, namespace);
524 debug!(
525 "Hybrid search in {} with vector_weight={}",
526 namespace, request.vector_weight
527 );
528
529 let response = self.client.post(&url).json(&request).send().await?;
530 self.handle_response(response).await
531 }
532
533 #[instrument(skip(self, request), fields(namespace = %namespace))]
570 pub async fn multi_vector_search(
571 &self,
572 namespace: &str,
573 request: MultiVectorSearchRequest,
574 ) -> Result<MultiVectorSearchResponse> {
575 let url = format!("{}/v1/namespaces/{}/multi-vector", self.base_url, namespace);
576 debug!(
577 "Multi-vector search in {} with {} positive vectors",
578 namespace,
579 request.positive_vectors.len()
580 );
581
582 let response = self.client.post(&url).json(&request).send().await?;
583 self.handle_response(response).await
584 }
585
586 #[instrument(skip(self, request), fields(namespace = %namespace))]
620 pub async fn aggregate(
621 &self,
622 namespace: &str,
623 request: AggregationRequest,
624 ) -> Result<AggregationResponse> {
625 let url = format!("{}/v1/namespaces/{}/aggregate", self.base_url, namespace);
626 debug!(
627 "Aggregating in namespace {} with {} aggregations",
628 namespace,
629 request.aggregate_by.len()
630 );
631
632 let response = self.client.post(&url).json(&request).send().await?;
633 self.handle_response(response).await
634 }
635
636 #[instrument(skip(self, request), fields(namespace = %namespace))]
674 pub async fn unified_query(
675 &self,
676 namespace: &str,
677 request: UnifiedQueryRequest,
678 ) -> Result<UnifiedQueryResponse> {
679 let url = format!(
680 "{}/v1/namespaces/{}/unified-query",
681 self.base_url, namespace
682 );
683 debug!(
684 "Unified query in namespace {} with top_k={}",
685 namespace, request.top_k
686 );
687
688 let response = self.client.post(&url).json(&request).send().await?;
689 self.handle_response(response).await
690 }
691
692 #[instrument(skip(self, vector))]
696 pub async fn unified_vector_search(
697 &self,
698 namespace: &str,
699 vector: Vec<f32>,
700 top_k: usize,
701 ) -> Result<UnifiedQueryResponse> {
702 self.unified_query(namespace, UnifiedQueryRequest::vector_search(vector, top_k))
703 .await
704 }
705
706 #[instrument(skip(self))]
710 pub async fn unified_text_search(
711 &self,
712 namespace: &str,
713 field: &str,
714 query: &str,
715 top_k: usize,
716 ) -> Result<UnifiedQueryResponse> {
717 self.unified_query(
718 namespace,
719 UnifiedQueryRequest::fulltext_search(field, query, top_k),
720 )
721 .await
722 }
723
724 #[instrument(skip(self, request), fields(namespace = %namespace))]
761 pub async fn explain_query(
762 &self,
763 namespace: &str,
764 request: QueryExplainRequest,
765 ) -> Result<QueryExplainResponse> {
766 let url = format!("{}/v1/namespaces/{}/explain", self.base_url, namespace);
767 debug!(
768 "Explaining query in namespace {} (query_type={:?}, top_k={})",
769 namespace, request.query_type, request.top_k
770 );
771
772 let response = self.client.post(&url).json(&request).send().await?;
773 self.handle_response(response).await
774 }
775
776 #[instrument(skip(self, request), fields(namespace = %request.namespace, priority = ?request.priority))]
804 pub async fn warm_cache(&self, request: WarmCacheRequest) -> Result<WarmCacheResponse> {
805 let url = format!(
806 "{}/v1/namespaces/{}/cache/warm",
807 self.base_url, request.namespace
808 );
809 debug!(
810 "Warming cache for namespace {} with priority {:?}",
811 request.namespace, request.priority
812 );
813
814 let response = self.client.post(&url).json(&request).send().await?;
815 self.handle_response(response).await
816 }
817
818 #[instrument(skip(self, vector_ids))]
820 pub async fn warm_vectors(
821 &self,
822 namespace: &str,
823 vector_ids: Vec<String>,
824 ) -> Result<WarmCacheResponse> {
825 self.warm_cache(WarmCacheRequest::new(namespace).with_vector_ids(vector_ids))
826 .await
827 }
828
829 #[instrument(skip(self, request), fields(namespace = %namespace))]
862 pub async fn export(&self, namespace: &str, request: ExportRequest) -> Result<ExportResponse> {
863 let url = format!("{}/v1/namespaces/{}/export", self.base_url, namespace);
864 debug!(
865 "Exporting vectors from namespace {} (top_k={}, cursor={:?})",
866 namespace, request.top_k, request.cursor
867 );
868
869 let response = self.client.post(&url).json(&request).send().await?;
870 self.handle_response(response).await
871 }
872
873 #[instrument(skip(self))]
877 pub async fn export_all(&self, namespace: &str) -> Result<ExportResponse> {
878 self.export(namespace, ExportRequest::new()).await
879 }
880
881 #[instrument(skip(self, request), fields(namespace = %namespace))]
883 pub async fn export_vectors(
884 &self,
885 namespace: &str,
886 request: ExportRequest,
887 ) -> Result<ExportResponse> {
888 self.export(namespace, request).await
889 }
890
891 #[instrument(skip(self))]
897 pub async fn diagnostics(&self) -> Result<SystemDiagnostics> {
898 let url = format!("{}/ops/diagnostics", self.base_url);
899 let response = self.client.get(&url).send().await?;
900 self.handle_response(response).await
901 }
902
903 #[instrument(skip(self))]
905 pub async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
906 let url = format!("{}/ops/jobs", self.base_url);
907 let response = self.client.get(&url).send().await?;
908 self.handle_response(response).await
909 }
910
911 #[instrument(skip(self))]
913 pub async fn get_job(&self, job_id: &str) -> Result<Option<JobInfo>> {
914 let url = format!("{}/ops/jobs/{}", self.base_url, job_id);
915 let response = self.client.get(&url).send().await?;
916
917 if response.status() == StatusCode::NOT_FOUND {
918 return Ok(None);
919 }
920
921 self.handle_response(response).await.map(Some)
922 }
923
924 #[instrument(skip(self, request))]
926 pub async fn compact(&self, request: CompactionRequest) -> Result<CompactionResponse> {
927 let url = format!("{}/ops/compact", self.base_url);
928 let response = self.client.post(&url).json(&request).send().await?;
929 self.handle_response(response).await
930 }
931
932 #[instrument(skip(self))]
934 pub async fn shutdown(&self) -> Result<()> {
935 let url = format!("{}/ops/shutdown", self.base_url);
936 let response = self.client.post(&url).send().await?;
937
938 if response.status().is_success() {
939 Ok(())
940 } else {
941 let status = response.status().as_u16();
942 let text = response.text().await.unwrap_or_default();
943 Err(ClientError::Server {
944 status,
945 message: text,
946 code: None,
947 })
948 }
949 }
950
951 #[instrument(skip(self, request), fields(id_count = request.ids.len()))]
957 pub async fn fetch(&self, namespace: &str, request: FetchRequest) -> Result<FetchResponse> {
958 let url = format!("{}/v1/namespaces/{}/fetch", self.base_url, namespace);
959 debug!("Fetching {} vectors from {}", request.ids.len(), namespace);
960 let response = self.client.post(&url).json(&request).send().await?;
961 self.handle_response(response).await
962 }
963
964 #[instrument(skip(self))]
966 pub async fn fetch_by_ids(&self, namespace: &str, ids: &[&str]) -> Result<Vec<Vector>> {
967 let request = FetchRequest::new(ids.iter().map(|s| s.to_string()).collect());
968 self.fetch(namespace, request).await.map(|r| r.vectors)
969 }
970
971 #[instrument(skip(self, request), fields(doc_count = request.documents.len()))]
977 pub async fn upsert_text(
978 &self,
979 namespace: &str,
980 request: UpsertTextRequest,
981 ) -> Result<TextUpsertResponse> {
982 let url = format!("{}/v1/namespaces/{}/upsert-text", self.base_url, namespace);
983 debug!(
984 "Upserting {} text documents to {}",
985 request.documents.len(),
986 namespace
987 );
988 let response = self.client.post(&url).json(&request).send().await?;
989 self.handle_response(response).await
990 }
991
992 #[instrument(skip(self, request), fields(top_k = request.top_k))]
994 pub async fn query_text(
995 &self,
996 namespace: &str,
997 request: QueryTextRequest,
998 ) -> Result<TextQueryResponse> {
999 let url = format!("{}/v1/namespaces/{}/query-text", self.base_url, namespace);
1000 debug!("Text query in {} for: {}", namespace, request.text);
1001 let response = self.client.post(&url).json(&request).send().await?;
1002 self.handle_response(response).await
1003 }
1004
1005 #[instrument(skip(self))]
1007 pub async fn query_text_simple(
1008 &self,
1009 namespace: &str,
1010 text: &str,
1011 top_k: u32,
1012 ) -> Result<TextQueryResponse> {
1013 self.query_text(namespace, QueryTextRequest::new(text, top_k))
1014 .await
1015 }
1016
1017 #[instrument(skip(self, request), fields(query_count = request.queries.len()))]
1019 pub async fn batch_query_text(
1020 &self,
1021 namespace: &str,
1022 request: BatchQueryTextRequest,
1023 ) -> Result<BatchQueryTextResponse> {
1024 let url = format!(
1025 "{}/v1/namespaces/{}/batch-query-text",
1026 self.base_url, namespace
1027 );
1028 debug!(
1029 "Batch text query in {} with {} queries",
1030 namespace,
1031 request.queries.len()
1032 );
1033 let response = self.client.post(&url).json(&request).send().await?;
1034 self.handle_response(response).await
1035 }
1036
1037 #[instrument(skip(self))]
1043 pub async fn get_namespace_entity_config(
1044 &self,
1045 namespace: &str,
1046 ) -> Result<NamespaceEntityConfig> {
1047 let url = format!("{}/v1/namespaces/{}/config", self.base_url, namespace);
1048 let response = self.client.get(&url).send().await?;
1049 self.handle_response(response).await
1050 }
1051
1052 #[instrument(skip(self))]
1054 pub async fn get_namespace_extractor(
1055 &self,
1056 namespace: &str,
1057 ) -> Result<NamespaceExtractorConfig> {
1058 let url = format!("{}/v1/namespaces/{}/extractor", self.base_url, namespace);
1059 let response = self.client.get(&url).send().await?;
1060 self.handle_response(response).await
1061 }
1062
1063 #[instrument(skip(self, config))]
1068 pub async fn configure_namespace_ner(
1069 &self,
1070 namespace: &str,
1071 config: NamespaceNerConfig,
1072 ) -> Result<serde_json::Value> {
1073 let url = format!("{}/v1/namespaces/{}/config", self.base_url, namespace);
1074 let response = self.client.patch(&url).json(&config).send().await?;
1075 self.handle_response(response).await
1076 }
1077
1078 #[instrument(skip(self, text, entity_types))]
1083 pub async fn extract_entities(
1084 &self,
1085 text: &str,
1086 entity_types: Option<Vec<String>>,
1087 ) -> Result<EntityExtractionResponse> {
1088 let url = format!("{}/v1/memories/extract", self.base_url);
1089 let body = serde_json::json!({
1090 "content": text,
1091 "entity_types": entity_types,
1092 });
1093 let response = self.client.post(&url).json(&body).send().await?;
1094 self.handle_response(response).await
1095 }
1096
1097 #[instrument(skip(self))]
1101 pub async fn memory_entities(&self, memory_id: &str) -> Result<MemoryEntitiesResponse> {
1102 let url = format!("{}/v1/memory/entities/{}", self.base_url, memory_id);
1103 let response = self.client.get(&url).send().await?;
1104 self.handle_response(response).await
1105 }
1106
1107 pub fn last_rate_limit_headers(&self) -> Option<RateLimitHeaders> {
1115 self.last_rate_limit.lock().ok()?.clone()
1116 }
1117
1118 pub(crate) async fn handle_response<T: serde::de::DeserializeOwned>(
1120 &self,
1121 response: reqwest::Response,
1122 ) -> Result<T> {
1123 let status = response.status();
1124
1125 if let Ok(mut guard) = self.last_rate_limit.lock() {
1127 *guard = Some(RateLimitHeaders::from_response(&response));
1128 }
1129
1130 if status.is_success() {
1131 Ok(response.json().await?)
1132 } else {
1133 let status_code = status.as_u16();
1134 let retry_after = response
1136 .headers()
1137 .get("Retry-After")
1138 .and_then(|v| v.to_str().ok())
1139 .and_then(|s| s.parse::<u64>().ok());
1140 let text = response.text().await.unwrap_or_default();
1141
1142 if status_code == 429 {
1143 return Err(ClientError::RateLimitExceeded { retry_after });
1144 }
1145
1146 #[derive(Deserialize)]
1147 struct ErrorBody {
1148 error: Option<String>,
1149 code: Option<ServerErrorCode>,
1150 }
1151
1152 let (message, code) = if let Ok(body) = serde_json::from_str::<ErrorBody>(&text) {
1153 (body.error.unwrap_or_else(|| text.clone()), body.code)
1154 } else {
1155 (text, None)
1156 };
1157
1158 match status_code {
1159 401 => Err(ClientError::Server {
1160 status: 401,
1161 message,
1162 code,
1163 }),
1164 403 => Err(ClientError::Authorization {
1165 status: 403,
1166 message,
1167 code,
1168 }),
1169 404 => match &code {
1170 Some(ServerErrorCode::NamespaceNotFound) => {
1171 Err(ClientError::NamespaceNotFound(message))
1172 }
1173 Some(ServerErrorCode::VectorNotFound) => {
1174 Err(ClientError::VectorNotFound(message))
1175 }
1176 _ => Err(ClientError::Server {
1177 status: 404,
1178 message,
1179 code,
1180 }),
1181 },
1182 _ => Err(ClientError::Server {
1183 status: status_code,
1184 message,
1185 code,
1186 }),
1187 }
1188 }
1189 }
1190
1191 pub(crate) async fn handle_text_response(&self, response: reqwest::Response) -> Result<String> {
1193 let status = response.status();
1194
1195 if let Ok(mut guard) = self.last_rate_limit.lock() {
1197 *guard = Some(RateLimitHeaders::from_response(&response));
1198 }
1199
1200 let retry_after = response
1201 .headers()
1202 .get("Retry-After")
1203 .and_then(|v| v.to_str().ok())
1204 .and_then(|s| s.parse::<u64>().ok());
1205 let text = response.text().await.unwrap_or_default();
1206
1207 if status.is_success() {
1208 return Ok(text);
1209 }
1210
1211 let status_code = status.as_u16();
1212
1213 if status_code == 429 {
1214 return Err(ClientError::RateLimitExceeded { retry_after });
1215 }
1216
1217 #[derive(Deserialize)]
1218 struct ErrorBody {
1219 error: Option<String>,
1220 code: Option<ServerErrorCode>,
1221 }
1222
1223 let (message, code) = if let Ok(body) = serde_json::from_str::<ErrorBody>(&text) {
1224 (body.error.unwrap_or_else(|| text.clone()), body.code)
1225 } else {
1226 (text, None)
1227 };
1228
1229 match status_code {
1230 401 => Err(ClientError::Server {
1231 status: 401,
1232 message,
1233 code,
1234 }),
1235 403 => Err(ClientError::Authorization {
1236 status: 403,
1237 message,
1238 code,
1239 }),
1240 _ => Err(ClientError::Server {
1241 status: status_code,
1242 message,
1243 code,
1244 }),
1245 }
1246 }
1247
1248 #[allow(dead_code)]
1256 pub(crate) async fn execute_with_retry<F, Fut, T>(&self, f: F) -> Result<T>
1257 where
1258 F: Fn() -> Fut,
1259 Fut: std::future::Future<Output = Result<T>>,
1260 {
1261 let rc = &self.retry_config;
1262
1263 for attempt in 0..rc.max_retries {
1264 match f().await {
1265 Ok(v) => return Ok(v),
1266 Err(e) => {
1267 let is_last = attempt == rc.max_retries - 1;
1268 if is_last || !e.is_retryable() {
1269 return Err(e);
1270 }
1271
1272 let wait = match &e {
1273 ClientError::RateLimitExceeded {
1274 retry_after: Some(secs),
1275 } => Duration::from_secs(*secs),
1276 _ => {
1277 let base_ms = rc.base_delay.as_millis() as f64;
1278 let backoff_ms = base_ms * 2f64.powi(attempt as i32);
1279 let capped_ms = backoff_ms.min(rc.max_delay.as_millis() as f64);
1280 let final_ms = if rc.jitter {
1281 let seed = (attempt as u64).wrapping_mul(6364136223846793005);
1283 let factor = 0.5 + (seed % 1000) as f64 / 1000.0;
1284 capped_ms * factor
1285 } else {
1286 capped_ms
1287 };
1288 Duration::from_millis(final_ms as u64)
1289 }
1290 };
1291
1292 tokio::time::sleep(wait).await;
1293 }
1294 }
1295 }
1296
1297 Err(ClientError::Config("retry loop exhausted".to_string()))
1299 }
1300}
1301
1302impl DakeraClient {
1307 pub async fn ode_extract_entities(
1319 &self,
1320 req: ExtractEntitiesRequest,
1321 ) -> Result<ExtractEntitiesResponse> {
1322 let ode_url = self.ode_url.as_deref().ok_or_else(|| {
1323 ClientError::Config(
1324 "ode_url must be configured to use extract_entities(). \
1325 Call .ode_url(\"http://localhost:8080\") on the builder."
1326 .to_string(),
1327 )
1328 })?;
1329 let url = format!("{}/ode/extract", ode_url);
1330 let response = self.client.post(&url).json(&req).send().await?;
1331 if response.status().is_success() {
1332 Ok(response.json::<ExtractEntitiesResponse>().await?)
1333 } else {
1334 let status = response.status().as_u16();
1335 let body = response.text().await.unwrap_or_default();
1336 Err(ClientError::Server {
1337 status,
1338 message: format!("ODE sidecar error: {}", body),
1339 code: None,
1340 })
1341 }
1342 }
1343
1344 #[instrument(skip(self))]
1356 pub async fn get_memory_policy(&self, namespace: &str) -> Result<MemoryPolicy> {
1357 let url = format!(
1358 "{}/v1/namespaces/{}/memory_policy",
1359 self.base_url,
1360 urlencoding::encode(namespace)
1361 );
1362 let response = self.client.get(&url).send().await?;
1363 self.handle_response(response).await
1364 }
1365
1366 #[instrument(skip(self, policy))]
1373 pub async fn set_memory_policy(
1374 &self,
1375 namespace: &str,
1376 policy: MemoryPolicy,
1377 ) -> Result<MemoryPolicy> {
1378 let url = format!(
1379 "{}/v1/namespaces/{}/memory_policy",
1380 self.base_url,
1381 urlencoding::encode(namespace)
1382 );
1383 let response = self.client.put(&url).json(&policy).send().await?;
1384 self.handle_response(response).await
1385 }
1386}
1387
1388#[derive(Debug)]
1390pub struct DakeraClientBuilder {
1391 base_url: String,
1392 api_key: Option<String>,
1393 ode_url: Option<String>,
1394 timeout: Duration,
1395 connect_timeout: Option<Duration>,
1396 retry_config: RetryConfig,
1397 user_agent: Option<String>,
1398 extra_headers: Vec<(String, String)>,
1399}
1400
1401impl DakeraClientBuilder {
1402 pub fn new(base_url: impl Into<String>) -> Self {
1404 Self {
1405 base_url: base_url.into(),
1406 api_key: None,
1407 ode_url: None,
1408 timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
1409 connect_timeout: None,
1410 retry_config: RetryConfig::default(),
1411 user_agent: None,
1412 extra_headers: Vec::new(),
1413 }
1414 }
1415
1416 pub fn api_key(mut self, key: impl Into<String>) -> Self {
1421 self.api_key = Some(key.into());
1422 self
1423 }
1424
1425 pub fn ode_url(mut self, ode_url: impl Into<String>) -> Self {
1429 self.ode_url = Some(ode_url.into().trim_end_matches('/').to_string());
1430 self
1431 }
1432
1433 pub fn timeout(mut self, timeout: Duration) -> Self {
1435 self.timeout = timeout;
1436 self
1437 }
1438
1439 pub fn timeout_secs(mut self, secs: u64) -> Self {
1441 self.timeout = Duration::from_secs(secs);
1442 self
1443 }
1444
1445 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
1447 self.connect_timeout = Some(timeout);
1448 self
1449 }
1450
1451 pub fn retry_config(mut self, config: RetryConfig) -> Self {
1453 self.retry_config = config;
1454 self
1455 }
1456
1457 pub fn max_retries(mut self, max_retries: u32) -> Self {
1459 self.retry_config.max_retries = max_retries;
1460 self
1461 }
1462
1463 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
1465 self.user_agent = Some(user_agent.into());
1466 self
1467 }
1468
1469 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
1475 self.extra_headers.push((name.into(), value.into()));
1476 self
1477 }
1478
1479 pub fn build(self) -> Result<DakeraClient> {
1481 let base_url = self.base_url.trim_end_matches('/').to_string();
1483
1484 if !base_url.starts_with("http://") && !base_url.starts_with("https://") {
1486 return Err(ClientError::InvalidUrl(
1487 "URL must start with http:// or https://".to_string(),
1488 ));
1489 }
1490
1491 let user_agent = self
1492 .user_agent
1493 .unwrap_or_else(|| format!("dakera-client/{}", env!("CARGO_PKG_VERSION")));
1494
1495 let connect_timeout = self.connect_timeout.unwrap_or(self.timeout);
1496
1497 let api_key = self
1499 .api_key
1500 .or_else(|| std::env::var("DAKERA_API_KEY").ok());
1501
1502 let mut default_headers = HeaderMap::new();
1503 if let Some(key) = &api_key {
1504 let bearer = format!("Bearer {key}");
1505 let mut value = HeaderValue::from_str(&bearer)
1506 .map_err(|_| ClientError::Config("invalid API key".into()))?;
1507 value.set_sensitive(true);
1508 default_headers.insert(AUTHORIZATION, value);
1509 }
1510 for (name, value) in &self.extra_headers {
1511 let header_name = reqwest::header::HeaderName::from_bytes(name.as_bytes())
1512 .map_err(|_| ClientError::Config(format!("invalid header name: {name}")))?;
1513 let header_value = HeaderValue::from_str(value)
1514 .map_err(|_| ClientError::Config(format!("invalid header value for {name}")))?;
1515 default_headers.insert(header_name, header_value);
1516 }
1517
1518 let client = Client::builder()
1519 .timeout(self.timeout)
1520 .connect_timeout(connect_timeout)
1521 .user_agent(user_agent)
1522 .default_headers(default_headers)
1523 .build()
1524 .map_err(|e| ClientError::Config(e.to_string()))?;
1525
1526 Ok(DakeraClient {
1527 client,
1528 base_url,
1529 ode_url: self.ode_url,
1530 retry_config: self.retry_config,
1531 last_rate_limit: Arc::new(Mutex::new(None)),
1532 })
1533 }
1534}
1535
1536impl DakeraClient {
1541 pub async fn stream_namespace_events(
1566 &self,
1567 namespace: &str,
1568 ) -> Result<tokio::sync::mpsc::Receiver<Result<crate::events::DakeraEvent>>> {
1569 let url = format!(
1570 "{}/v1/namespaces/{}/events",
1571 self.base_url,
1572 urlencoding::encode(namespace)
1573 );
1574 self.stream_sse(url).await
1575 }
1576
1577 pub async fn stream_global_events(
1584 &self,
1585 ) -> Result<tokio::sync::mpsc::Receiver<Result<crate::events::DakeraEvent>>> {
1586 let url = format!("{}/ops/events", self.base_url);
1587 self.stream_sse(url).await
1588 }
1589
1590 pub async fn stream_memory_events(
1599 &self,
1600 ) -> Result<tokio::sync::mpsc::Receiver<Result<crate::events::MemoryEvent>>> {
1601 let url = format!("{}/v1/events/stream", self.base_url);
1602 self.stream_sse(url).await
1603 }
1604
1605 pub(crate) async fn stream_sse<T>(
1607 &self,
1608 url: String,
1609 ) -> Result<tokio::sync::mpsc::Receiver<Result<T>>>
1610 where
1611 T: serde::de::DeserializeOwned + Send + 'static,
1612 {
1613 use futures_util::StreamExt;
1614
1615 let response = self
1616 .client
1617 .get(&url)
1618 .header("Accept", "text/event-stream")
1619 .header("Cache-Control", "no-cache")
1620 .send()
1621 .await?;
1622
1623 if !response.status().is_success() {
1624 let status = response.status().as_u16();
1625 let body = response.text().await.unwrap_or_default();
1626 return Err(ClientError::Server {
1627 status,
1628 message: body,
1629 code: None,
1630 });
1631 }
1632
1633 let (tx, rx) = tokio::sync::mpsc::channel(64);
1634
1635 tokio::spawn(async move {
1636 let mut byte_stream = response.bytes_stream();
1637 let mut remaining = String::new();
1638 let mut data_lines: Vec<String> = Vec::new();
1639
1640 while let Some(chunk) = byte_stream.next().await {
1641 match chunk {
1642 Ok(bytes) => {
1643 remaining.push_str(&String::from_utf8_lossy(&bytes));
1644 while let Some(pos) = remaining.find('\n') {
1645 let raw = &remaining[..pos];
1646 let line = raw.trim_end_matches('\r').to_string();
1647 remaining = remaining[pos + 1..].to_string();
1648
1649 if line.starts_with(':') {
1650 } else if let Some(data) = line.strip_prefix("data:") {
1652 data_lines.push(data.trim_start().to_string());
1653 } else if line.is_empty() {
1654 if !data_lines.is_empty() {
1655 let payload = data_lines.join("\n");
1656 data_lines.clear();
1657 let result = serde_json::from_str::<T>(&payload)
1658 .map_err(ClientError::Json);
1659 if tx.send(result).await.is_err() {
1660 return; }
1662 }
1663 } else {
1664 }
1666 }
1667 }
1668 Err(e) => {
1669 let _ = tx.send(Err(ClientError::Http(e))).await;
1670 return;
1671 }
1672 }
1673 }
1674 });
1675
1676 Ok(rx)
1677 }
1678
1679 #[instrument(skip(self, request))]
1685 pub async fn route_query(&self, request: RouteRequest) -> Result<RouteResponse> {
1686 let url = format!("{}/v1/route", self.base_url);
1687 let response = self.client.post(&url).json(&request).send().await?;
1688 self.handle_response(response).await
1689 }
1690
1691 #[instrument(skip(self))]
1697 pub async fn import_job_status(&self, job_id: &str) -> Result<ImportJobStatus> {
1698 let url = format!("{}/v1/import/{}/status", self.base_url, job_id);
1699 let response = self.client.get(&url).send().await?;
1700 self.handle_response(response).await
1701 }
1702}
1703
1704#[cfg(test)]
1705mod tests {
1706 use super::*;
1707
1708 #[test]
1709 fn test_client_builder() {
1710 let client = DakeraClient::new("http://localhost:3000");
1711 assert!(client.is_ok());
1712 }
1713
1714 #[test]
1715 fn test_client_builder_with_options() {
1716 let client = DakeraClient::builder("http://localhost:3000")
1717 .timeout_secs(60)
1718 .user_agent("test-client/1.0")
1719 .build();
1720 assert!(client.is_ok());
1721 }
1722
1723 #[test]
1724 fn test_client_builder_invalid_url() {
1725 let client = DakeraClient::new("invalid-url");
1726 assert!(client.is_err());
1727 }
1728
1729 #[test]
1730 fn test_client_builder_trailing_slash() {
1731 let client = DakeraClient::new("http://localhost:3000/").unwrap();
1732 assert!(!client.base_url.ends_with('/'));
1733 }
1734
1735 #[test]
1736 fn test_client_builder_with_extra_header() {
1737 let client = DakeraClient::builder("http://localhost:3000")
1738 .header("X-Playground-Session", "pg_abc123def456")
1739 .build();
1740 assert!(client.is_ok());
1741 }
1742
1743 #[test]
1744 fn test_client_builder_invalid_header_name() {
1745 let client = DakeraClient::builder("http://localhost:3000")
1746 .header("invalid header name!", "value")
1747 .build();
1748 assert!(client.is_err());
1749 }
1750
1751 #[test]
1752 fn test_vector_creation() {
1753 let v = Vector::new("test", vec![0.1, 0.2, 0.3]);
1754 assert_eq!(v.id, "test");
1755 assert_eq!(v.values.len(), 3);
1756 assert!(v.metadata.is_none());
1757 }
1758
1759 #[test]
1760 fn test_query_request_builder() {
1761 let req = QueryRequest::new(vec![0.1, 0.2], 10)
1762 .with_filter(serde_json::json!({"category": "test"}))
1763 .include_metadata(false);
1764
1765 assert_eq!(req.top_k, 10);
1766 assert!(req.filter.is_some());
1767 assert!(!req.include_metadata);
1768 }
1769
1770 #[test]
1771 fn test_hybrid_search_request() {
1772 let req = HybridSearchRequest::new(vec![0.1], "test query", 5).with_vector_weight(0.7);
1773
1774 assert_eq!(req.vector_weight, 0.7);
1775 assert_eq!(req.text, "test query");
1776 assert!(req.vector.is_some());
1777 }
1778
1779 #[test]
1780 fn test_hybrid_search_weight_clamping() {
1781 let req = HybridSearchRequest::new(vec![0.1], "test", 5).with_vector_weight(1.5); assert_eq!(req.vector_weight, 1.0);
1784 }
1785
1786 #[test]
1787 fn test_hybrid_search_text_only() {
1788 let req = HybridSearchRequest::text_only("bm25 query", 10);
1789
1790 assert!(req.vector.is_none());
1791 assert_eq!(req.text, "bm25 query");
1792 assert_eq!(req.top_k, 10);
1793 let json = serde_json::to_value(&req).unwrap();
1795 assert!(json.get("vector").is_none());
1796 }
1797
1798 #[test]
1799 fn test_text_document_builder() {
1800 let doc = TextDocument::new("doc1", "Hello world").with_ttl(3600);
1801
1802 assert_eq!(doc.id, "doc1");
1803 assert_eq!(doc.text, "Hello world");
1804 assert_eq!(doc.ttl_seconds, Some(3600));
1805 assert!(doc.metadata.is_none());
1806 }
1807
1808 #[test]
1809 fn test_upsert_text_request_builder() {
1810 let docs = vec![
1811 TextDocument::new("doc1", "Hello"),
1812 TextDocument::new("doc2", "World"),
1813 ];
1814 let req = UpsertTextRequest::new(docs).with_model(EmbeddingModel::BgeSmall);
1815
1816 assert_eq!(req.documents.len(), 2);
1817 assert_eq!(req.model, Some(EmbeddingModel::BgeSmall));
1818 }
1819
1820 #[test]
1821 fn test_query_text_request_builder() {
1822 let req = QueryTextRequest::new("semantic search query", 5)
1823 .with_filter(serde_json::json!({"category": "docs"}))
1824 .include_vectors(true)
1825 .with_model(EmbeddingModel::E5Small);
1826
1827 assert_eq!(req.text, "semantic search query");
1828 assert_eq!(req.top_k, 5);
1829 assert!(req.filter.is_some());
1830 assert!(req.include_vectors);
1831 assert_eq!(req.model, Some(EmbeddingModel::E5Small));
1832 }
1833
1834 #[test]
1835 fn test_fetch_request_builder() {
1836 let req = FetchRequest::new(vec!["id1".to_string(), "id2".to_string()]);
1837
1838 assert_eq!(req.ids.len(), 2);
1839 assert!(req.include_values);
1840 assert!(req.include_metadata);
1841 }
1842
1843 #[test]
1844 fn test_create_namespace_request_builder() {
1845 let req = CreateNamespaceRequest::new()
1846 .with_dimensions(384)
1847 .with_index_type("hnsw");
1848
1849 assert_eq!(req.dimensions, Some(384));
1850 assert_eq!(req.index_type.as_deref(), Some("hnsw"));
1851 }
1852
1853 #[test]
1854 fn test_batch_query_text_request() {
1855 let req =
1856 BatchQueryTextRequest::new(vec!["query one".to_string(), "query two".to_string()], 10);
1857
1858 assert_eq!(req.queries.len(), 2);
1859 assert_eq!(req.top_k, 10);
1860 assert!(!req.include_vectors);
1861 assert!(req.model.is_none());
1862 }
1863
1864 #[test]
1869 fn test_retry_config_defaults() {
1870 let rc = RetryConfig::default();
1871 assert_eq!(rc.max_retries, 3);
1872 assert_eq!(rc.base_delay, Duration::from_millis(100));
1873 assert_eq!(rc.max_delay, Duration::from_secs(60));
1874 assert!(rc.jitter);
1875 }
1876
1877 #[test]
1878 fn test_builder_connect_timeout() {
1879 let client = DakeraClient::builder("http://localhost:3000")
1880 .connect_timeout(Duration::from_secs(5))
1881 .timeout_secs(30)
1882 .build()
1883 .unwrap();
1884 assert!(client.base_url.starts_with("http"));
1886 }
1887
1888 #[test]
1889 fn test_builder_max_retries() {
1890 let client = DakeraClient::builder("http://localhost:3000")
1891 .max_retries(5)
1892 .build()
1893 .unwrap();
1894 assert_eq!(client.retry_config.max_retries, 5);
1895 }
1896
1897 #[test]
1898 fn test_builder_retry_config() {
1899 let rc = RetryConfig {
1900 max_retries: 7,
1901 base_delay: Duration::from_millis(200),
1902 max_delay: Duration::from_secs(30),
1903 jitter: false,
1904 };
1905 let client = DakeraClient::builder("http://localhost:3000")
1906 .retry_config(rc)
1907 .build()
1908 .unwrap();
1909 assert_eq!(client.retry_config.max_retries, 7);
1910 assert!(!client.retry_config.jitter);
1911 }
1912
1913 #[test]
1914 fn test_rate_limit_error_retryable() {
1915 let e = ClientError::RateLimitExceeded { retry_after: None };
1916 assert!(e.is_retryable());
1917 }
1918
1919 #[test]
1920 fn test_server_408_retryable() {
1921 let e = ClientError::Server {
1922 status: 408,
1923 message: String::new(),
1924 code: None,
1925 };
1926 assert!(e.is_retryable());
1927 }
1928
1929 #[test]
1930 fn test_server_400_not_retryable() {
1931 let e = ClientError::Server {
1932 status: 400,
1933 message: String::new(),
1934 code: None,
1935 };
1936 assert!(!e.is_retryable());
1937 }
1938
1939 #[test]
1940 fn test_rate_limit_error_with_retry_after_zero() {
1941 let e = ClientError::RateLimitExceeded {
1943 retry_after: Some(0),
1944 };
1945 assert!(e.is_retryable());
1946 if let ClientError::RateLimitExceeded {
1947 retry_after: Some(secs),
1948 } = &e
1949 {
1950 assert_eq!(*secs, 0u64);
1951 } else {
1952 panic!("unexpected variant");
1953 }
1954 }
1955
1956 #[tokio::test]
1957 async fn test_execute_with_retry_succeeds_immediately() {
1958 let client = DakeraClient::builder("http://localhost:3000")
1959 .max_retries(3)
1960 .build()
1961 .unwrap();
1962
1963 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1964 let cc = call_count.clone();
1965 let result = client
1966 .execute_with_retry(|| {
1967 let cc = cc.clone();
1968 async move {
1969 cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1970 Ok::<u32, ClientError>(42)
1971 }
1972 })
1973 .await;
1974 assert_eq!(result.unwrap(), 42);
1975 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
1976 }
1977
1978 #[tokio::test]
1979 async fn test_execute_with_retry_no_retry_on_4xx() {
1980 let client = DakeraClient::builder("http://localhost:3000")
1981 .max_retries(3)
1982 .build()
1983 .unwrap();
1984
1985 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1986 let cc = call_count.clone();
1987 let result = client
1988 .execute_with_retry(|| {
1989 let cc = cc.clone();
1990 async move {
1991 cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1992 Err::<u32, ClientError>(ClientError::Server {
1993 status: 400,
1994 message: "bad request".to_string(),
1995 code: None,
1996 })
1997 }
1998 })
1999 .await;
2000 assert!(result.is_err());
2001 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
2003 }
2004
2005 #[tokio::test]
2006 async fn test_execute_with_retry_retries_on_5xx() {
2007 let client = DakeraClient::builder("http://localhost:3000")
2008 .retry_config(RetryConfig {
2009 max_retries: 3,
2010 base_delay: Duration::from_millis(0),
2011 max_delay: Duration::from_millis(0),
2012 jitter: false,
2013 })
2014 .build()
2015 .unwrap();
2016
2017 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
2018 let cc = call_count.clone();
2019 let result = client
2020 .execute_with_retry(|| {
2021 let cc = cc.clone();
2022 async move {
2023 let n = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2024 if n < 2 {
2025 Err::<u32, ClientError>(ClientError::Server {
2026 status: 503,
2027 message: "unavailable".to_string(),
2028 code: None,
2029 })
2030 } else {
2031 Ok(99)
2032 }
2033 }
2034 })
2035 .await;
2036 assert_eq!(result.unwrap(), 99);
2037 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
2038 }
2039
2040 #[test]
2045 fn test_batch_recall_request_new() {
2046 use crate::memory::BatchRecallRequest;
2047 let req = BatchRecallRequest::new("agent-1");
2048 assert_eq!(req.agent_id, "agent-1");
2049 assert_eq!(req.limit, 100);
2050 }
2051
2052 #[test]
2053 fn test_batch_recall_request_builder() {
2054 use crate::memory::{BatchMemoryFilter, BatchRecallRequest};
2055 let filter = BatchMemoryFilter::default()
2056 .with_tags(vec!["qa".to_string()])
2057 .with_min_importance(0.7);
2058 let req = BatchRecallRequest::new("agent-1")
2059 .with_filter(filter)
2060 .with_limit(50);
2061 assert_eq!(req.agent_id, "agent-1");
2062 assert_eq!(req.limit, 50);
2063 assert_eq!(
2064 req.filter.tags.as_deref(),
2065 Some(["qa".to_string()].as_slice())
2066 );
2067 assert_eq!(req.filter.min_importance, Some(0.7));
2068 }
2069
2070 #[test]
2071 fn test_batch_recall_request_serialization() {
2072 use crate::memory::{BatchMemoryFilter, BatchRecallRequest};
2073 let filter = BatchMemoryFilter::default().with_min_importance(0.5);
2074 let req = BatchRecallRequest::new("agent-1")
2075 .with_filter(filter)
2076 .with_limit(25);
2077 let json = serde_json::to_value(&req).unwrap();
2078 assert_eq!(json["agent_id"], "agent-1");
2079 assert_eq!(json["limit"], 25);
2080 assert_eq!(json["filter"]["min_importance"], 0.5);
2081 }
2082
2083 #[test]
2084 fn test_batch_forget_request_new() {
2085 use crate::memory::{BatchForgetRequest, BatchMemoryFilter};
2086 let filter = BatchMemoryFilter::default().with_min_importance(0.1);
2087 let req = BatchForgetRequest::new("agent-1", filter);
2088 assert_eq!(req.agent_id, "agent-1");
2089 assert_eq!(req.filter.min_importance, Some(0.1));
2090 }
2091
2092 #[test]
2093 fn test_batch_forget_request_serialization() {
2094 use crate::memory::{BatchForgetRequest, BatchMemoryFilter};
2095 let filter = BatchMemoryFilter {
2096 created_before: Some(1_700_000_000),
2097 ..Default::default()
2098 };
2099 let req = BatchForgetRequest::new("agent-1", filter);
2100 let json = serde_json::to_value(&req).unwrap();
2101 assert_eq!(json["agent_id"], "agent-1");
2102 assert_eq!(json["filter"]["created_before"], 1_700_000_000u64);
2103 }
2104
2105 #[test]
2106 fn test_batch_recall_response_deserialization() {
2107 use crate::memory::BatchRecallResponse;
2108 let json = serde_json::json!({
2109 "memories": [],
2110 "total": 42,
2111 "filtered": 7
2112 });
2113 let resp: BatchRecallResponse = serde_json::from_value(json).unwrap();
2114 assert_eq!(resp.total, 42);
2115 assert_eq!(resp.filtered, 7);
2116 assert!(resp.memories.is_empty());
2117 }
2118
2119 #[test]
2120 fn test_batch_forget_response_deserialization() {
2121 use crate::memory::BatchForgetResponse;
2122 let json = serde_json::json!({ "deleted_count": 13 });
2123 let resp: BatchForgetResponse = serde_json::from_value(json).unwrap();
2124 assert_eq!(resp.deleted_count, 13);
2125 }
2126
2127 #[test]
2132 fn test_rate_limit_headers_default_all_none() {
2133 use crate::types::RateLimitHeaders;
2134 let rl = RateLimitHeaders {
2135 limit: None,
2136 remaining: None,
2137 reset: None,
2138 quota_used: None,
2139 quota_limit: None,
2140 };
2141 assert!(rl.limit.is_none());
2142 assert!(rl.remaining.is_none());
2143 assert!(rl.reset.is_none());
2144 assert!(rl.quota_used.is_none());
2145 assert!(rl.quota_limit.is_none());
2146 }
2147
2148 #[test]
2149 fn test_rate_limit_headers_populated() {
2150 use crate::types::RateLimitHeaders;
2151 let rl = RateLimitHeaders {
2152 limit: Some(1000),
2153 remaining: Some(750),
2154 reset: Some(1_700_000_060),
2155 quota_used: Some(500),
2156 quota_limit: Some(10_000),
2157 };
2158 assert_eq!(rl.limit, Some(1000));
2159 assert_eq!(rl.remaining, Some(750));
2160 assert_eq!(rl.reset, Some(1_700_000_060));
2161 assert_eq!(rl.quota_used, Some(500));
2162 assert_eq!(rl.quota_limit, Some(10_000));
2163 }
2164
2165 #[test]
2166 fn test_last_rate_limit_headers_initially_none() {
2167 let client = DakeraClient::new("http://localhost:3000").unwrap();
2168 assert!(client.last_rate_limit_headers().is_none());
2169 }
2170
2171 #[test]
2176 fn test_namespace_ner_config_default() {
2177 use crate::types::NamespaceNerConfig;
2178 let cfg = NamespaceNerConfig::default();
2179 assert!(!cfg.extract_entities);
2180 assert!(cfg.entity_types.is_none());
2181 }
2182
2183 #[test]
2184 fn test_namespace_ner_config_serialization_skip_none() {
2185 use crate::types::NamespaceNerConfig;
2186 let cfg = NamespaceNerConfig {
2187 extract_entities: true,
2188 entity_types: None,
2189 };
2190 let json = serde_json::to_value(&cfg).unwrap();
2191 assert_eq!(json["extract_entities"], true);
2192 assert!(json.get("entity_types").is_none());
2194 }
2195
2196 #[test]
2197 fn test_namespace_ner_config_serialization_with_types() {
2198 use crate::types::NamespaceNerConfig;
2199 let cfg = NamespaceNerConfig {
2200 extract_entities: true,
2201 entity_types: Some(vec!["PERSON".to_string(), "ORG".to_string()]),
2202 };
2203 let json = serde_json::to_value(&cfg).unwrap();
2204 assert_eq!(json["extract_entities"], true);
2205 assert_eq!(json["entity_types"][0], "PERSON");
2206 assert_eq!(json["entity_types"][1], "ORG");
2207 }
2208
2209 #[test]
2210 fn test_extracted_entity_deserialization() {
2211 use crate::types::ExtractedEntity;
2212 let json = serde_json::json!({
2213 "entity_type": "PERSON",
2214 "value": "Alice",
2215 "score": 0.95
2216 });
2217 let entity: ExtractedEntity = serde_json::from_value(json).unwrap();
2218 assert_eq!(entity.entity_type, "PERSON");
2219 assert_eq!(entity.value, "Alice");
2220 assert!((entity.score - 0.95).abs() < f64::EPSILON);
2221 }
2222
2223 #[test]
2224 fn test_entity_extraction_response_deserialization() {
2225 use crate::types::EntityExtractionResponse;
2226 let json = serde_json::json!({
2227 "entities": [
2228 { "entity_type": "PERSON", "value": "Bob", "score": 0.9 },
2229 { "entity_type": "ORG", "value": "Acme", "score": 0.87 }
2230 ]
2231 });
2232 let resp: EntityExtractionResponse = serde_json::from_value(json).unwrap();
2233 assert_eq!(resp.entities.len(), 2);
2234 assert_eq!(resp.entities[0].entity_type, "PERSON");
2235 assert_eq!(resp.entities[1].value, "Acme");
2236 }
2237
2238 #[test]
2239 fn test_memory_entities_response_deserialization() {
2240 use crate::types::MemoryEntitiesResponse;
2241 let json = serde_json::json!({
2242 "memory_id": "mem-abc-123",
2243 "entities": [
2244 { "entity_type": "LOC", "value": "London", "score": 0.88 }
2245 ]
2246 });
2247 let resp: MemoryEntitiesResponse = serde_json::from_value(json).unwrap();
2248 assert_eq!(resp.memory_id, "mem-abc-123");
2249 assert_eq!(resp.entities.len(), 1);
2250 assert_eq!(resp.entities[0].entity_type, "LOC");
2251 assert_eq!(resp.entities[0].value, "London");
2252 }
2253
2254 #[test]
2255 fn test_configure_namespace_ner_url_pattern() {
2256 let client = DakeraClient::new("http://localhost:3000").unwrap();
2258 let expected = "http://localhost:3000/v1/namespaces/my-ns/config";
2259 let actual = format!("{}/v1/namespaces/{}/config", client.base_url, "my-ns");
2260 assert_eq!(actual, expected);
2261 }
2262
2263 #[test]
2264 fn test_extract_entities_url_pattern() {
2265 let client = DakeraClient::new("http://localhost:3000").unwrap();
2266 let expected = "http://localhost:3000/v1/memories/extract";
2267 let actual = format!("{}/v1/memories/extract", client.base_url);
2268 assert_eq!(actual, expected);
2269 }
2270
2271 #[test]
2272 fn test_memory_entities_url_pattern() {
2273 let client = DakeraClient::new("http://localhost:3000").unwrap();
2274 let memory_id = "mem-xyz-789";
2275 let expected = "http://localhost:3000/v1/memory/entities/mem-xyz-789";
2276 let actual = format!("{}/v1/memory/entities/{}", client.base_url, memory_id);
2277 assert_eq!(actual, expected);
2278 }
2279
2280 #[test]
2285 fn test_feedback_signal_serialization() {
2286 use crate::types::FeedbackSignal;
2287 let upvote = serde_json::to_value(FeedbackSignal::Upvote).unwrap();
2288 assert_eq!(upvote, serde_json::json!("upvote"));
2289 let downvote = serde_json::to_value(FeedbackSignal::Downvote).unwrap();
2290 assert_eq!(downvote, serde_json::json!("downvote"));
2291 let flag = serde_json::to_value(FeedbackSignal::Flag).unwrap();
2292 assert_eq!(flag, serde_json::json!("flag"));
2293 }
2294
2295 #[test]
2296 fn test_feedback_signal_deserialization() {
2297 use crate::types::FeedbackSignal;
2298 let signal: FeedbackSignal = serde_json::from_str("\"upvote\"").unwrap();
2299 assert_eq!(signal, FeedbackSignal::Upvote);
2300 let signal: FeedbackSignal = serde_json::from_str("\"positive\"").unwrap();
2301 assert_eq!(signal, FeedbackSignal::Positive);
2302 }
2303
2304 #[test]
2305 fn test_feedback_response_deserialization() {
2306 use crate::types::{FeedbackResponse, FeedbackSignal};
2307 let json = serde_json::json!({
2308 "memory_id": "mem-abc",
2309 "new_importance": 0.92,
2310 "signal": "upvote"
2311 });
2312 let resp: FeedbackResponse = serde_json::from_value(json).unwrap();
2313 assert_eq!(resp.memory_id, "mem-abc");
2314 assert!((resp.new_importance - 0.92).abs() < f32::EPSILON);
2315 assert_eq!(resp.signal, FeedbackSignal::Upvote);
2316 }
2317
2318 #[test]
2319 fn test_feedback_history_response_deserialization() {
2320 use crate::types::{FeedbackHistoryResponse, FeedbackSignal};
2321 let json = serde_json::json!({
2322 "memory_id": "mem-abc",
2323 "entries": [
2324 {"signal": "upvote", "timestamp": 1774000000_u64, "old_importance": 0.5, "new_importance": 0.575},
2325 {"signal": "downvote", "timestamp": 1774001000_u64, "old_importance": 0.575, "new_importance": 0.489}
2326 ]
2327 });
2328 let resp: FeedbackHistoryResponse = serde_json::from_value(json).unwrap();
2329 assert_eq!(resp.memory_id, "mem-abc");
2330 assert_eq!(resp.entries.len(), 2);
2331 assert_eq!(resp.entries[0].signal, FeedbackSignal::Upvote);
2332 assert_eq!(resp.entries[1].signal, FeedbackSignal::Downvote);
2333 }
2334
2335 #[test]
2336 fn test_agent_feedback_summary_deserialization() {
2337 use crate::types::AgentFeedbackSummary;
2338 let json = serde_json::json!({
2339 "agent_id": "agent-1",
2340 "upvotes": 42_u64,
2341 "downvotes": 7_u64,
2342 "flags": 2_u64,
2343 "total_feedback": 51_u64,
2344 "health_score": 0.78
2345 });
2346 let summary: AgentFeedbackSummary = serde_json::from_value(json).unwrap();
2347 assert_eq!(summary.agent_id, "agent-1");
2348 assert_eq!(summary.upvotes, 42);
2349 assert_eq!(summary.total_feedback, 51);
2350 assert!((summary.health_score - 0.78).abs() < f32::EPSILON);
2351 }
2352
2353 #[test]
2354 fn test_feedback_health_response_deserialization() {
2355 use crate::types::FeedbackHealthResponse;
2356 let json = serde_json::json!({
2357 "agent_id": "agent-1",
2358 "health_score": 0.78,
2359 "memory_count": 120_usize,
2360 "avg_importance": 0.72
2361 });
2362 let health: FeedbackHealthResponse = serde_json::from_value(json).unwrap();
2363 assert_eq!(health.agent_id, "agent-1");
2364 assert!((health.health_score - 0.78).abs() < f32::EPSILON);
2365 assert_eq!(health.memory_count, 120);
2366 }
2367
2368 #[test]
2369 fn test_memory_feedback_body_serialization() {
2370 use crate::types::{FeedbackSignal, MemoryFeedbackBody};
2371 let body = MemoryFeedbackBody {
2372 agent_id: "agent-1".to_string(),
2373 signal: FeedbackSignal::Flag,
2374 };
2375 let json = serde_json::to_value(body).unwrap();
2376 assert_eq!(json["agent_id"], "agent-1");
2377 assert_eq!(json["signal"], "flag");
2378 }
2379
2380 #[test]
2381 fn test_feedback_memory_url_pattern() {
2382 let client = DakeraClient::new("http://localhost:3000").unwrap();
2383 let memory_id = "mem-abc";
2384 let expected_post = "http://localhost:3000/v1/memories/mem-abc/feedback";
2385 let actual_post = format!("{}/v1/memories/{}/feedback", client.base_url, memory_id);
2386 assert_eq!(actual_post, expected_post);
2387
2388 let expected_patch = "http://localhost:3000/v1/memories/mem-abc/importance";
2389 let actual_patch = format!("{}/v1/memories/{}/importance", client.base_url, memory_id);
2390 assert_eq!(actual_patch, expected_patch);
2391 }
2392
2393 #[test]
2394 fn test_feedback_health_url_pattern() {
2395 let client = DakeraClient::new("http://localhost:3000").unwrap();
2396 let agent_id = "agent-1";
2397 let expected = "http://localhost:3000/v1/feedback/health?agent_id=agent-1";
2398 let actual = format!(
2399 "{}/v1/feedback/health?agent_id={}",
2400 client.base_url, agent_id
2401 );
2402 assert_eq!(actual, expected);
2403 }
2404
2405 #[test]
2407 fn test_ode_extract_entities_requires_ode_url() {
2408 let client = DakeraClient::new("http://localhost:3000").unwrap();
2410 let rt = tokio::runtime::Runtime::new().unwrap();
2411 let result = rt.block_on(client.ode_extract_entities(ExtractEntitiesRequest {
2412 content: "Alice lives in Paris.".to_string(),
2413 agent_id: "agent-1".to_string(),
2414 memory_id: None,
2415 entity_types: None,
2416 }));
2417 assert!(result.is_err());
2418 let err = result.unwrap_err();
2419 assert!(matches!(err, ClientError::Config(_)));
2420 }
2421
2422 #[test]
2423 fn test_ode_extract_entities_url_built_from_ode_url() {
2424 let client = DakeraClient::builder("http://localhost:3000")
2426 .ode_url("http://localhost:8080")
2427 .build()
2428 .unwrap();
2429 assert_eq!(client.ode_url.as_deref(), Some("http://localhost:8080"));
2430 let expected = "http://localhost:8080/ode/extract";
2431 let actual = format!("{}/ode/extract", client.ode_url.as_deref().unwrap());
2432 assert_eq!(actual, expected);
2433 }
2434
2435 #[test]
2436 fn test_extract_entities_request_serialization() {
2437 let req = ExtractEntitiesRequest {
2438 content: "Alice in Wonderland".to_string(),
2439 agent_id: "agent-42".to_string(),
2440 memory_id: Some("mem-001".to_string()),
2441 entity_types: Some(vec!["person".to_string(), "location".to_string()]),
2442 };
2443 let json = serde_json::to_string(&req).unwrap();
2444 assert!(json.contains("\"content\":\"Alice in Wonderland\""));
2445 assert!(json.contains("\"agent_id\":\"agent-42\""));
2446 assert!(json.contains("\"memory_id\":\"mem-001\""));
2447 assert!(json.contains("\"person\""));
2448 }
2449
2450 #[test]
2451 fn test_extract_entities_request_omits_none_fields() {
2452 let req = ExtractEntitiesRequest {
2453 content: "hello".to_string(),
2454 agent_id: "a".to_string(),
2455 memory_id: None,
2456 entity_types: None,
2457 };
2458 let json = serde_json::to_string(&req).unwrap();
2459 assert!(!json.contains("memory_id"));
2460 assert!(!json.contains("entity_types"));
2461 }
2462
2463 #[test]
2464 fn test_ode_entity_deserialization() {
2465 let json = r#"{"text":"Alice","label":"person","start":0,"end":5,"score":0.97}"#;
2466 let entity: OdeEntity = serde_json::from_str(json).unwrap();
2467 assert_eq!(entity.text, "Alice");
2468 assert_eq!(entity.label, "person");
2469 assert_eq!(entity.start, 0);
2470 assert_eq!(entity.end, 5);
2471 assert!((entity.score - 0.97).abs() < 1e-4);
2472 }
2473
2474 #[test]
2475 fn test_extract_entities_response_deserialization() {
2476 let json = r#"{
2477 "entities": [
2478 {"text":"Alice","label":"person","start":0,"end":5,"score":0.97},
2479 {"text":"Paris","label":"location","start":16,"end":21,"score":0.92}
2480 ],
2481 "model": "gliner-multi-v2.1",
2482 "processing_time_ms": 34
2483 }"#;
2484 let resp: ExtractEntitiesResponse = serde_json::from_str(json).unwrap();
2485 assert_eq!(resp.entities.len(), 2);
2486 assert_eq!(resp.entities[0].text, "Alice");
2487 assert_eq!(resp.model, "gliner-multi-v2.1");
2488 assert_eq!(resp.processing_time_ms, 34);
2489 }
2490}