1use reqwest::{
4 header::{HeaderMap, HeaderValue, AUTHORIZATION},
5 Client, StatusCode,
6};
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9use tracing::{debug, instrument};
10
11use serde::Deserialize;
12
13use crate::error::{ClientError, Result, ServerErrorCode};
14use crate::types::*;
15
16const DEFAULT_TIMEOUT_SECS: u64 = 30;
18
19#[derive(Debug, Clone)]
21pub struct DakeraClient {
22 pub(crate) client: Client,
24 pub(crate) base_url: String,
26 pub(crate) ode_url: Option<String>,
28 #[allow(dead_code)]
30 pub(crate) retry_config: RetryConfig,
31 pub(crate) last_rate_limit: Arc<Mutex<Option<RateLimitHeaders>>>,
33}
34
35impl DakeraClient {
36 pub fn new(base_url: impl Into<String>) -> Result<Self> {
46 DakeraClientBuilder::new(base_url).build()
47 }
48
49 pub fn builder(base_url: impl Into<String>) -> DakeraClientBuilder {
51 DakeraClientBuilder::new(base_url)
52 }
53
54 #[instrument(skip(self))]
60 pub async fn health(&self) -> Result<HealthResponse> {
61 let url = format!("{}/health", self.base_url);
62 let response = self.client.get(&url).send().await?;
63
64 if response.status().is_success() {
65 let json: serde_json::Value = response.json().await?;
66 let healthy = json
69 .get("healthy")
70 .and_then(|v| v.as_bool())
71 .unwrap_or_else(|| json.get("status").and_then(|v| v.as_str()) == Some("healthy"));
72 let version = json
73 .get("version")
74 .and_then(|v| v.as_str())
75 .map(String::from);
76 let uptime_seconds = json.get("uptime_seconds").and_then(|v| v.as_u64());
77 Ok(HealthResponse {
78 healthy,
79 version,
80 uptime_seconds,
81 })
82 } else {
83 Ok(HealthResponse {
85 healthy: true,
86 version: None,
87 uptime_seconds: None,
88 })
89 }
90 }
91
92 #[instrument(skip(self))]
94 pub async fn ready(&self) -> Result<ReadinessResponse> {
95 let url = format!("{}/health/ready", self.base_url);
96 let response = self.client.get(&url).send().await?;
97
98 if response.status().is_success() {
99 Ok(response.json().await?)
100 } else {
101 Ok(ReadinessResponse {
102 ready: false,
103 components: None,
104 })
105 }
106 }
107
108 #[instrument(skip(self))]
110 pub async fn live(&self) -> Result<bool> {
111 let url = format!("{}/health/live", self.base_url);
112 let response = self.client.get(&url).send().await?;
113 Ok(response.status().is_success())
114 }
115
116 #[instrument(skip(self))]
122 pub async fn list_namespaces(&self) -> Result<Vec<String>> {
123 let url = format!("{}/v1/namespaces", self.base_url);
124 let response = self.client.get(&url).send().await?;
125 self.handle_response::<ListNamespacesResponse>(response)
126 .await
127 .map(|r| r.namespaces)
128 }
129
130 #[instrument(skip(self))]
132 pub async fn get_namespace(&self, namespace: &str) -> Result<NamespaceInfo> {
133 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
134 let response = self.client.get(&url).send().await?;
135 self.handle_response(response).await
136 }
137
138 #[instrument(skip(self, request))]
140 pub async fn create_namespace(
141 &self,
142 namespace: &str,
143 request: CreateNamespaceRequest,
144 ) -> Result<NamespaceInfo> {
145 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
146 let response = self.client.put(&url).json(&request).send().await?;
147 self.handle_response(response).await
148 }
149
150 #[instrument(skip(self, request), fields(namespace = %namespace))]
156 pub async fn configure_namespace(
157 &self,
158 namespace: &str,
159 request: ConfigureNamespaceRequest,
160 ) -> Result<ConfigureNamespaceResponse> {
161 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
162 let response = self.client.put(&url).json(&request).send().await?;
163 self.handle_response(response).await
164 }
165
166 #[instrument(skip(self))]
168 pub async fn delete_namespace(&self, namespace: &str) -> Result<()> {
169 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
170 let response = self.client.delete(&url).send().await?;
171 if response.status().is_success() {
172 Ok(())
173 } else {
174 let status = response.status().as_u16();
175 let text = response.text().await.unwrap_or_default();
176 Err(ClientError::Server {
177 status,
178 message: text,
179 code: None,
180 })
181 }
182 }
183
184 #[instrument(skip(self))]
186 pub async fn flush(&self, namespace: &str) -> Result<serde_json::Value> {
187 let url = format!("{}/v1/namespaces/{}/flush", self.base_url, namespace);
188 let response = self.client.post(&url).send().await?;
189 self.handle_response(response).await
190 }
191
192 #[instrument(skip(self))]
194 pub async fn get_namespace_stats(&self, namespace: &str) -> Result<serde_json::Value> {
195 let url = format!("{}/v1/namespaces/{}/stats", self.base_url, namespace);
196 let response = self.client.get(&url).send().await?;
197 self.handle_response(response).await
198 }
199
200 #[instrument(skip(self))]
202 pub async fn get_index_stats(&self, namespace: &str) -> Result<serde_json::Value> {
203 self.get_namespace_stats(namespace).await
204 }
205
206 #[instrument(skip(self, request), fields(vector_count = request.vectors.len()))]
212 pub async fn upsert(&self, namespace: &str, request: UpsertRequest) -> Result<UpsertResponse> {
213 let url = format!("{}/v1/namespaces/{}/vectors", self.base_url, namespace);
214 debug!(
215 "Upserting {} vectors to {}",
216 request.vectors.len(),
217 namespace
218 );
219
220 let response = self.client.post(&url).json(&request).send().await?;
221 self.handle_response(response).await
222 }
223
224 #[instrument(skip(self, vector))]
226 pub async fn upsert_one(&self, namespace: &str, vector: Vector) -> Result<UpsertResponse> {
227 self.upsert(namespace, UpsertRequest::single(vector)).await
228 }
229
230 #[instrument(skip(self, request), fields(namespace = %namespace, count = request.ids.len()))]
263 pub async fn upsert_columns(
264 &self,
265 namespace: &str,
266 request: ColumnUpsertRequest,
267 ) -> Result<UpsertResponse> {
268 let url = format!(
269 "{}/v1/namespaces/{}/upsert-columns",
270 self.base_url, namespace
271 );
272 debug!(
273 "Upserting {} vectors in column format to {}",
274 request.ids.len(),
275 namespace
276 );
277
278 let response = self.client.post(&url).json(&request).send().await?;
279 self.handle_response(response).await
280 }
281
282 #[instrument(skip(self, request), fields(top_k = request.top_k))]
284 pub async fn query(&self, namespace: &str, request: QueryRequest) -> Result<QueryResponse> {
285 let url = format!("{}/v1/namespaces/{}/query", self.base_url, namespace);
286 debug!(
287 "Querying namespace {} for top {} results",
288 namespace, request.top_k
289 );
290
291 let response = self.client.post(&url).json(&request).send().await?;
292 self.handle_response(response).await
293 }
294
295 #[instrument(skip(self, vector))]
297 pub async fn query_simple(
298 &self,
299 namespace: &str,
300 vector: Vec<f32>,
301 top_k: u32,
302 ) -> Result<QueryResponse> {
303 self.query(namespace, QueryRequest::new(vector, top_k))
304 .await
305 }
306
307 #[instrument(skip(self, request), fields(namespace = %namespace, query_count = request.queries.len()))]
331 pub async fn batch_query(
332 &self,
333 namespace: &str,
334 request: BatchQueryRequest,
335 ) -> Result<BatchQueryResponse> {
336 let url = format!("{}/v1/namespaces/{}/batch-query", self.base_url, namespace);
337 debug!(
338 "Batch querying namespace {} with {} queries",
339 namespace,
340 request.queries.len()
341 );
342
343 let response = self.client.post(&url).json(&request).send().await?;
344 self.handle_response(response).await
345 }
346
347 #[instrument(skip(self, request), fields(id_count = request.ids.len()))]
349 pub async fn delete(&self, namespace: &str, request: DeleteRequest) -> Result<DeleteResponse> {
350 let url = format!(
351 "{}/v1/namespaces/{}/vectors/delete",
352 self.base_url, namespace
353 );
354 debug!("Deleting {} vectors from {}", request.ids.len(), namespace);
355
356 let response = self.client.post(&url).json(&request).send().await?;
357 self.handle_response(response).await
358 }
359
360 #[instrument(skip(self))]
362 pub async fn delete_one(&self, namespace: &str, id: &str) -> Result<DeleteResponse> {
363 self.delete(namespace, DeleteRequest::single(id)).await
364 }
365
366 #[instrument(skip(self, request))]
368 pub async fn bulk_update_vectors(
369 &self,
370 namespace: &str,
371 request: BulkUpdateRequest,
372 ) -> Result<BulkUpdateResponse> {
373 let url = format!(
374 "{}/v1/namespaces/{}/vectors/bulk-update",
375 self.base_url, namespace
376 );
377 let response = self.client.post(&url).json(&request).send().await?;
378 self.handle_response(response).await
379 }
380
381 #[instrument(skip(self, request))]
383 pub async fn bulk_delete_vectors(
384 &self,
385 namespace: &str,
386 request: BulkDeleteRequest,
387 ) -> Result<BulkDeleteResponse> {
388 let url = format!(
389 "{}/v1/namespaces/{}/vectors/bulk-delete",
390 self.base_url, namespace
391 );
392 let response = self.client.post(&url).json(&request).send().await?;
393 self.handle_response(response).await
394 }
395
396 #[instrument(skip(self, request))]
398 pub async fn count_vectors(
399 &self,
400 namespace: &str,
401 request: CountVectorsRequest,
402 ) -> Result<CountVectorsResponse> {
403 let url = format!(
404 "{}/v1/namespaces/{}/vectors/count",
405 self.base_url, namespace
406 );
407 let response = self.client.post(&url).json(&request).send().await?;
408 self.handle_response(response).await
409 }
410
411 #[instrument(skip(self, request), fields(doc_count = request.documents.len()))]
417 pub async fn index_documents(
418 &self,
419 namespace: &str,
420 request: IndexDocumentsRequest,
421 ) -> Result<IndexDocumentsResponse> {
422 let url = format!(
423 "{}/v1/namespaces/{}/fulltext/index",
424 self.base_url, namespace
425 );
426 debug!(
427 "Indexing {} documents in {}",
428 request.documents.len(),
429 namespace
430 );
431
432 let response = self.client.post(&url).json(&request).send().await?;
433 self.handle_response(response).await
434 }
435
436 #[instrument(skip(self, document))]
438 pub async fn index_document(
439 &self,
440 namespace: &str,
441 document: Document,
442 ) -> Result<IndexDocumentsResponse> {
443 self.index_documents(
444 namespace,
445 IndexDocumentsRequest {
446 documents: vec![document],
447 },
448 )
449 .await
450 }
451
452 #[instrument(skip(self, request))]
454 pub async fn fulltext_search(
455 &self,
456 namespace: &str,
457 request: FullTextSearchRequest,
458 ) -> Result<FullTextSearchResponse> {
459 let url = format!(
460 "{}/v1/namespaces/{}/fulltext/search",
461 self.base_url, namespace
462 );
463 debug!("Full-text search in {} for: {}", namespace, request.query);
464
465 let response = self.client.post(&url).json(&request).send().await?;
466 self.handle_response(response).await
467 }
468
469 #[instrument(skip(self))]
471 pub async fn search_text(
472 &self,
473 namespace: &str,
474 query: &str,
475 top_k: u32,
476 ) -> Result<FullTextSearchResponse> {
477 self.fulltext_search(namespace, FullTextSearchRequest::new(query, top_k))
478 .await
479 }
480
481 #[instrument(skip(self))]
483 pub async fn fulltext_stats(&self, namespace: &str) -> Result<FullTextStats> {
484 let url = format!(
485 "{}/v1/namespaces/{}/fulltext/stats",
486 self.base_url, namespace
487 );
488 let response = self.client.get(&url).send().await?;
489 self.handle_response(response).await
490 }
491
492 #[instrument(skip(self, request))]
494 pub async fn fulltext_delete(
495 &self,
496 namespace: &str,
497 request: DeleteRequest,
498 ) -> Result<DeleteResponse> {
499 let url = format!(
500 "{}/v1/namespaces/{}/fulltext/delete",
501 self.base_url, namespace
502 );
503 let response = self.client.post(&url).json(&request).send().await?;
504 self.handle_response(response).await
505 }
506
507 #[instrument(skip(self, request), fields(top_k = request.top_k))]
513 pub async fn hybrid_search(
514 &self,
515 namespace: &str,
516 request: HybridSearchRequest,
517 ) -> Result<HybridSearchResponse> {
518 let url = format!("{}/v1/namespaces/{}/hybrid", self.base_url, namespace);
519 debug!(
520 "Hybrid search in {} with vector_weight={}",
521 namespace, request.vector_weight
522 );
523
524 let response = self.client.post(&url).json(&request).send().await?;
525 self.handle_response(response).await
526 }
527
528 #[instrument(skip(self, request), fields(namespace = %namespace))]
565 pub async fn multi_vector_search(
566 &self,
567 namespace: &str,
568 request: MultiVectorSearchRequest,
569 ) -> Result<MultiVectorSearchResponse> {
570 let url = format!("{}/v1/namespaces/{}/multi-vector", self.base_url, namespace);
571 debug!(
572 "Multi-vector search in {} with {} positive vectors",
573 namespace,
574 request.positive_vectors.len()
575 );
576
577 let response = self.client.post(&url).json(&request).send().await?;
578 self.handle_response(response).await
579 }
580
581 #[instrument(skip(self, request), fields(namespace = %namespace))]
615 pub async fn aggregate(
616 &self,
617 namespace: &str,
618 request: AggregationRequest,
619 ) -> Result<AggregationResponse> {
620 let url = format!("{}/v1/namespaces/{}/aggregate", self.base_url, namespace);
621 debug!(
622 "Aggregating in namespace {} with {} aggregations",
623 namespace,
624 request.aggregate_by.len()
625 );
626
627 let response = self.client.post(&url).json(&request).send().await?;
628 self.handle_response(response).await
629 }
630
631 #[instrument(skip(self, request), fields(namespace = %namespace))]
669 pub async fn unified_query(
670 &self,
671 namespace: &str,
672 request: UnifiedQueryRequest,
673 ) -> Result<UnifiedQueryResponse> {
674 let url = format!(
675 "{}/v1/namespaces/{}/unified-query",
676 self.base_url, namespace
677 );
678 debug!(
679 "Unified query in namespace {} with top_k={}",
680 namespace, request.top_k
681 );
682
683 let response = self.client.post(&url).json(&request).send().await?;
684 self.handle_response(response).await
685 }
686
687 #[instrument(skip(self, vector))]
691 pub async fn unified_vector_search(
692 &self,
693 namespace: &str,
694 vector: Vec<f32>,
695 top_k: usize,
696 ) -> Result<UnifiedQueryResponse> {
697 self.unified_query(namespace, UnifiedQueryRequest::vector_search(vector, top_k))
698 .await
699 }
700
701 #[instrument(skip(self))]
705 pub async fn unified_text_search(
706 &self,
707 namespace: &str,
708 field: &str,
709 query: &str,
710 top_k: usize,
711 ) -> Result<UnifiedQueryResponse> {
712 self.unified_query(
713 namespace,
714 UnifiedQueryRequest::fulltext_search(field, query, top_k),
715 )
716 .await
717 }
718
719 #[instrument(skip(self, request), fields(namespace = %namespace))]
756 pub async fn explain_query(
757 &self,
758 namespace: &str,
759 request: QueryExplainRequest,
760 ) -> Result<QueryExplainResponse> {
761 let url = format!("{}/v1/namespaces/{}/explain", self.base_url, namespace);
762 debug!(
763 "Explaining query in namespace {} (query_type={:?}, top_k={})",
764 namespace, request.query_type, request.top_k
765 );
766
767 let response = self.client.post(&url).json(&request).send().await?;
768 self.handle_response(response).await
769 }
770
771 #[instrument(skip(self, request), fields(namespace = %request.namespace, priority = ?request.priority))]
799 pub async fn warm_cache(&self, request: WarmCacheRequest) -> Result<WarmCacheResponse> {
800 let url = format!(
801 "{}/v1/namespaces/{}/cache/warm",
802 self.base_url, request.namespace
803 );
804 debug!(
805 "Warming cache for namespace {} with priority {:?}",
806 request.namespace, request.priority
807 );
808
809 let response = self.client.post(&url).json(&request).send().await?;
810 self.handle_response(response).await
811 }
812
813 #[instrument(skip(self, vector_ids))]
815 pub async fn warm_vectors(
816 &self,
817 namespace: &str,
818 vector_ids: Vec<String>,
819 ) -> Result<WarmCacheResponse> {
820 self.warm_cache(WarmCacheRequest::new(namespace).with_vector_ids(vector_ids))
821 .await
822 }
823
824 #[instrument(skip(self, request), fields(namespace = %namespace))]
857 pub async fn export(&self, namespace: &str, request: ExportRequest) -> Result<ExportResponse> {
858 let url = format!("{}/v1/namespaces/{}/export", self.base_url, namespace);
859 debug!(
860 "Exporting vectors from namespace {} (top_k={}, cursor={:?})",
861 namespace, request.top_k, request.cursor
862 );
863
864 let response = self.client.post(&url).json(&request).send().await?;
865 self.handle_response(response).await
866 }
867
868 #[instrument(skip(self))]
872 pub async fn export_all(&self, namespace: &str) -> Result<ExportResponse> {
873 self.export(namespace, ExportRequest::new()).await
874 }
875
876 #[instrument(skip(self, request), fields(namespace = %namespace))]
878 pub async fn export_vectors(
879 &self,
880 namespace: &str,
881 request: ExportRequest,
882 ) -> Result<ExportResponse> {
883 self.export(namespace, request).await
884 }
885
886 #[instrument(skip(self))]
892 pub async fn diagnostics(&self) -> Result<SystemDiagnostics> {
893 let url = format!("{}/ops/diagnostics", self.base_url);
894 let response = self.client.get(&url).send().await?;
895 self.handle_response(response).await
896 }
897
898 #[instrument(skip(self))]
900 pub async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
901 let url = format!("{}/ops/jobs", self.base_url);
902 let response = self.client.get(&url).send().await?;
903 self.handle_response(response).await
904 }
905
906 #[instrument(skip(self))]
908 pub async fn get_job(&self, job_id: &str) -> Result<Option<JobInfo>> {
909 let url = format!("{}/ops/jobs/{}", self.base_url, job_id);
910 let response = self.client.get(&url).send().await?;
911
912 if response.status() == StatusCode::NOT_FOUND {
913 return Ok(None);
914 }
915
916 self.handle_response(response).await.map(Some)
917 }
918
919 #[instrument(skip(self, request))]
921 pub async fn compact(&self, request: CompactionRequest) -> Result<CompactionResponse> {
922 let url = format!("{}/ops/compact", self.base_url);
923 let response = self.client.post(&url).json(&request).send().await?;
924 self.handle_response(response).await
925 }
926
927 #[instrument(skip(self))]
929 pub async fn shutdown(&self) -> Result<()> {
930 let url = format!("{}/ops/shutdown", self.base_url);
931 let response = self.client.post(&url).send().await?;
932
933 if response.status().is_success() {
934 Ok(())
935 } else {
936 let status = response.status().as_u16();
937 let text = response.text().await.unwrap_or_default();
938 Err(ClientError::Server {
939 status,
940 message: text,
941 code: None,
942 })
943 }
944 }
945
946 #[instrument(skip(self, request), fields(id_count = request.ids.len()))]
952 pub async fn fetch(&self, namespace: &str, request: FetchRequest) -> Result<FetchResponse> {
953 let url = format!("{}/v1/namespaces/{}/fetch", self.base_url, namespace);
954 debug!("Fetching {} vectors from {}", request.ids.len(), namespace);
955 let response = self.client.post(&url).json(&request).send().await?;
956 self.handle_response(response).await
957 }
958
959 #[instrument(skip(self))]
961 pub async fn fetch_by_ids(&self, namespace: &str, ids: &[&str]) -> Result<Vec<Vector>> {
962 let request = FetchRequest::new(ids.iter().map(|s| s.to_string()).collect());
963 self.fetch(namespace, request).await.map(|r| r.vectors)
964 }
965
966 #[instrument(skip(self, request), fields(doc_count = request.documents.len()))]
972 pub async fn upsert_text(
973 &self,
974 namespace: &str,
975 request: UpsertTextRequest,
976 ) -> Result<TextUpsertResponse> {
977 let url = format!("{}/v1/namespaces/{}/upsert-text", self.base_url, namespace);
978 debug!(
979 "Upserting {} text documents to {}",
980 request.documents.len(),
981 namespace
982 );
983 let response = self.client.post(&url).json(&request).send().await?;
984 self.handle_response(response).await
985 }
986
987 #[instrument(skip(self, request), fields(top_k = request.top_k))]
989 pub async fn query_text(
990 &self,
991 namespace: &str,
992 request: QueryTextRequest,
993 ) -> Result<TextQueryResponse> {
994 let url = format!("{}/v1/namespaces/{}/query-text", self.base_url, namespace);
995 debug!("Text query in {} for: {}", namespace, request.text);
996 let response = self.client.post(&url).json(&request).send().await?;
997 self.handle_response(response).await
998 }
999
1000 #[instrument(skip(self))]
1002 pub async fn query_text_simple(
1003 &self,
1004 namespace: &str,
1005 text: &str,
1006 top_k: u32,
1007 ) -> Result<TextQueryResponse> {
1008 self.query_text(namespace, QueryTextRequest::new(text, top_k))
1009 .await
1010 }
1011
1012 #[instrument(skip(self, request), fields(query_count = request.queries.len()))]
1014 pub async fn batch_query_text(
1015 &self,
1016 namespace: &str,
1017 request: BatchQueryTextRequest,
1018 ) -> Result<BatchQueryTextResponse> {
1019 let url = format!(
1020 "{}/v1/namespaces/{}/batch-query-text",
1021 self.base_url, namespace
1022 );
1023 debug!(
1024 "Batch text query in {} with {} queries",
1025 namespace,
1026 request.queries.len()
1027 );
1028 let response = self.client.post(&url).json(&request).send().await?;
1029 self.handle_response(response).await
1030 }
1031
1032 #[instrument(skip(self))]
1038 pub async fn get_namespace_entity_config(
1039 &self,
1040 namespace: &str,
1041 ) -> Result<NamespaceEntityConfig> {
1042 let url = format!("{}/v1/namespaces/{}/config", self.base_url, namespace);
1043 let response = self.client.get(&url).send().await?;
1044 self.handle_response(response).await
1045 }
1046
1047 #[instrument(skip(self))]
1049 pub async fn get_namespace_extractor(
1050 &self,
1051 namespace: &str,
1052 ) -> Result<NamespaceExtractorConfig> {
1053 let url = format!("{}/v1/namespaces/{}/extractor", self.base_url, namespace);
1054 let response = self.client.get(&url).send().await?;
1055 self.handle_response(response).await
1056 }
1057
1058 #[instrument(skip(self, config))]
1063 pub async fn configure_namespace_ner(
1064 &self,
1065 namespace: &str,
1066 config: NamespaceNerConfig,
1067 ) -> Result<serde_json::Value> {
1068 let url = format!("{}/v1/namespaces/{}/config", self.base_url, namespace);
1069 let response = self.client.patch(&url).json(&config).send().await?;
1070 self.handle_response(response).await
1071 }
1072
1073 #[instrument(skip(self, text, entity_types))]
1078 pub async fn extract_entities(
1079 &self,
1080 text: &str,
1081 entity_types: Option<Vec<String>>,
1082 ) -> Result<EntityExtractionResponse> {
1083 let url = format!("{}/v1/memories/extract", self.base_url);
1084 let body = serde_json::json!({
1085 "content": text,
1086 "entity_types": entity_types,
1087 });
1088 let response = self.client.post(&url).json(&body).send().await?;
1089 self.handle_response(response).await
1090 }
1091
1092 #[instrument(skip(self))]
1096 pub async fn memory_entities(&self, memory_id: &str) -> Result<MemoryEntitiesResponse> {
1097 let url = format!("{}/v1/memory/entities/{}", self.base_url, memory_id);
1098 let response = self.client.get(&url).send().await?;
1099 self.handle_response(response).await
1100 }
1101
1102 pub fn last_rate_limit_headers(&self) -> Option<RateLimitHeaders> {
1110 self.last_rate_limit.lock().ok()?.clone()
1111 }
1112
1113 pub(crate) async fn handle_response<T: serde::de::DeserializeOwned>(
1115 &self,
1116 response: reqwest::Response,
1117 ) -> Result<T> {
1118 let status = response.status();
1119
1120 if let Ok(mut guard) = self.last_rate_limit.lock() {
1122 *guard = Some(RateLimitHeaders::from_response(&response));
1123 }
1124
1125 if status.is_success() {
1126 Ok(response.json().await?)
1127 } else {
1128 let status_code = status.as_u16();
1129 let retry_after = response
1131 .headers()
1132 .get("Retry-After")
1133 .and_then(|v| v.to_str().ok())
1134 .and_then(|s| s.parse::<u64>().ok());
1135 let text = response.text().await.unwrap_or_default();
1136
1137 if status_code == 429 {
1138 return Err(ClientError::RateLimitExceeded { retry_after });
1139 }
1140
1141 #[derive(Deserialize)]
1142 struct ErrorBody {
1143 error: Option<String>,
1144 code: Option<ServerErrorCode>,
1145 }
1146
1147 let (message, code) = if let Ok(body) = serde_json::from_str::<ErrorBody>(&text) {
1148 (body.error.unwrap_or_else(|| text.clone()), body.code)
1149 } else {
1150 (text, None)
1151 };
1152
1153 match status_code {
1154 401 => Err(ClientError::Server {
1155 status: 401,
1156 message,
1157 code,
1158 }),
1159 403 => Err(ClientError::Authorization {
1160 status: 403,
1161 message,
1162 code,
1163 }),
1164 404 => match &code {
1165 Some(ServerErrorCode::NamespaceNotFound) => {
1166 Err(ClientError::NamespaceNotFound(message))
1167 }
1168 Some(ServerErrorCode::VectorNotFound) => {
1169 Err(ClientError::VectorNotFound(message))
1170 }
1171 _ => Err(ClientError::Server {
1172 status: 404,
1173 message,
1174 code,
1175 }),
1176 },
1177 _ => Err(ClientError::Server {
1178 status: status_code,
1179 message,
1180 code,
1181 }),
1182 }
1183 }
1184 }
1185
1186 pub(crate) async fn handle_text_response(&self, response: reqwest::Response) -> Result<String> {
1188 let status = response.status();
1189
1190 if let Ok(mut guard) = self.last_rate_limit.lock() {
1192 *guard = Some(RateLimitHeaders::from_response(&response));
1193 }
1194
1195 let retry_after = response
1196 .headers()
1197 .get("Retry-After")
1198 .and_then(|v| v.to_str().ok())
1199 .and_then(|s| s.parse::<u64>().ok());
1200 let text = response.text().await.unwrap_or_default();
1201
1202 if status.is_success() {
1203 return Ok(text);
1204 }
1205
1206 let status_code = status.as_u16();
1207
1208 if status_code == 429 {
1209 return Err(ClientError::RateLimitExceeded { retry_after });
1210 }
1211
1212 #[derive(Deserialize)]
1213 struct ErrorBody {
1214 error: Option<String>,
1215 code: Option<ServerErrorCode>,
1216 }
1217
1218 let (message, code) = if let Ok(body) = serde_json::from_str::<ErrorBody>(&text) {
1219 (body.error.unwrap_or_else(|| text.clone()), body.code)
1220 } else {
1221 (text, None)
1222 };
1223
1224 match status_code {
1225 401 => Err(ClientError::Server {
1226 status: 401,
1227 message,
1228 code,
1229 }),
1230 403 => Err(ClientError::Authorization {
1231 status: 403,
1232 message,
1233 code,
1234 }),
1235 _ => Err(ClientError::Server {
1236 status: status_code,
1237 message,
1238 code,
1239 }),
1240 }
1241 }
1242
1243 #[allow(dead_code)]
1251 pub(crate) async fn execute_with_retry<F, Fut, T>(&self, f: F) -> Result<T>
1252 where
1253 F: Fn() -> Fut,
1254 Fut: std::future::Future<Output = Result<T>>,
1255 {
1256 let rc = &self.retry_config;
1257
1258 for attempt in 0..rc.max_retries {
1259 match f().await {
1260 Ok(v) => return Ok(v),
1261 Err(e) => {
1262 let is_last = attempt == rc.max_retries - 1;
1263 if is_last || !e.is_retryable() {
1264 return Err(e);
1265 }
1266
1267 let wait = match &e {
1268 ClientError::RateLimitExceeded {
1269 retry_after: Some(secs),
1270 } => Duration::from_secs(*secs),
1271 _ => {
1272 let base_ms = rc.base_delay.as_millis() as f64;
1273 let backoff_ms = base_ms * 2f64.powi(attempt as i32);
1274 let capped_ms = backoff_ms.min(rc.max_delay.as_millis() as f64);
1275 let final_ms = if rc.jitter {
1276 let seed = (attempt as u64).wrapping_mul(6364136223846793005);
1278 let factor = 0.5 + (seed % 1000) as f64 / 1000.0;
1279 capped_ms * factor
1280 } else {
1281 capped_ms
1282 };
1283 Duration::from_millis(final_ms as u64)
1284 }
1285 };
1286
1287 tokio::time::sleep(wait).await;
1288 }
1289 }
1290 }
1291
1292 Err(ClientError::Config("retry loop exhausted".to_string()))
1294 }
1295}
1296
1297impl DakeraClient {
1302 pub async fn ode_extract_entities(
1314 &self,
1315 req: ExtractEntitiesRequest,
1316 ) -> Result<ExtractEntitiesResponse> {
1317 let ode_url = self.ode_url.as_deref().ok_or_else(|| {
1318 ClientError::Config(
1319 "ode_url must be configured to use extract_entities(). \
1320 Call .ode_url(\"http://localhost:8080\") on the builder."
1321 .to_string(),
1322 )
1323 })?;
1324 let url = format!("{}/ode/extract", ode_url);
1325 let response = self.client.post(&url).json(&req).send().await?;
1326 if response.status().is_success() {
1327 Ok(response.json::<ExtractEntitiesResponse>().await?)
1328 } else {
1329 let status = response.status().as_u16();
1330 let body = response.text().await.unwrap_or_default();
1331 Err(ClientError::Server {
1332 status,
1333 message: format!("ODE sidecar error: {}", body),
1334 code: None,
1335 })
1336 }
1337 }
1338
1339 #[instrument(skip(self))]
1351 pub async fn get_memory_policy(&self, namespace: &str) -> Result<MemoryPolicy> {
1352 let url = format!(
1353 "{}/v1/namespaces/{}/memory_policy",
1354 self.base_url,
1355 urlencoding::encode(namespace)
1356 );
1357 let response = self.client.get(&url).send().await?;
1358 self.handle_response(response).await
1359 }
1360
1361 #[instrument(skip(self, policy))]
1368 pub async fn set_memory_policy(
1369 &self,
1370 namespace: &str,
1371 policy: MemoryPolicy,
1372 ) -> Result<MemoryPolicy> {
1373 let url = format!(
1374 "{}/v1/namespaces/{}/memory_policy",
1375 self.base_url,
1376 urlencoding::encode(namespace)
1377 );
1378 let response = self.client.put(&url).json(&policy).send().await?;
1379 self.handle_response(response).await
1380 }
1381}
1382
1383#[derive(Debug)]
1385pub struct DakeraClientBuilder {
1386 base_url: String,
1387 api_key: Option<String>,
1388 ode_url: Option<String>,
1389 timeout: Duration,
1390 connect_timeout: Option<Duration>,
1391 retry_config: RetryConfig,
1392 user_agent: Option<String>,
1393}
1394
1395impl DakeraClientBuilder {
1396 pub fn new(base_url: impl Into<String>) -> Self {
1398 Self {
1399 base_url: base_url.into(),
1400 api_key: None,
1401 ode_url: None,
1402 timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
1403 connect_timeout: None,
1404 retry_config: RetryConfig::default(),
1405 user_agent: None,
1406 }
1407 }
1408
1409 pub fn api_key(mut self, key: impl Into<String>) -> Self {
1414 self.api_key = Some(key.into());
1415 self
1416 }
1417
1418 pub fn ode_url(mut self, ode_url: impl Into<String>) -> Self {
1422 self.ode_url = Some(ode_url.into().trim_end_matches('/').to_string());
1423 self
1424 }
1425
1426 pub fn timeout(mut self, timeout: Duration) -> Self {
1428 self.timeout = timeout;
1429 self
1430 }
1431
1432 pub fn timeout_secs(mut self, secs: u64) -> Self {
1434 self.timeout = Duration::from_secs(secs);
1435 self
1436 }
1437
1438 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
1440 self.connect_timeout = Some(timeout);
1441 self
1442 }
1443
1444 pub fn retry_config(mut self, config: RetryConfig) -> Self {
1446 self.retry_config = config;
1447 self
1448 }
1449
1450 pub fn max_retries(mut self, max_retries: u32) -> Self {
1452 self.retry_config.max_retries = max_retries;
1453 self
1454 }
1455
1456 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
1458 self.user_agent = Some(user_agent.into());
1459 self
1460 }
1461
1462 pub fn build(self) -> Result<DakeraClient> {
1464 let base_url = self.base_url.trim_end_matches('/').to_string();
1466
1467 if !base_url.starts_with("http://") && !base_url.starts_with("https://") {
1469 return Err(ClientError::InvalidUrl(
1470 "URL must start with http:// or https://".to_string(),
1471 ));
1472 }
1473
1474 let user_agent = self
1475 .user_agent
1476 .unwrap_or_else(|| format!("dakera-client/{}", env!("CARGO_PKG_VERSION")));
1477
1478 let connect_timeout = self.connect_timeout.unwrap_or(self.timeout);
1479
1480 let api_key = self
1482 .api_key
1483 .or_else(|| std::env::var("DAKERA_API_KEY").ok());
1484
1485 let mut default_headers = HeaderMap::new();
1486 if let Some(key) = &api_key {
1487 let bearer = format!("Bearer {key}");
1488 let mut value = HeaderValue::from_str(&bearer)
1489 .map_err(|_| ClientError::Config("invalid API key".into()))?;
1490 value.set_sensitive(true);
1491 default_headers.insert(AUTHORIZATION, value);
1492 }
1493
1494 let client = Client::builder()
1495 .timeout(self.timeout)
1496 .connect_timeout(connect_timeout)
1497 .user_agent(user_agent)
1498 .default_headers(default_headers)
1499 .build()
1500 .map_err(|e| ClientError::Config(e.to_string()))?;
1501
1502 Ok(DakeraClient {
1503 client,
1504 base_url,
1505 ode_url: self.ode_url,
1506 retry_config: self.retry_config,
1507 last_rate_limit: Arc::new(Mutex::new(None)),
1508 })
1509 }
1510}
1511
1512impl DakeraClient {
1517 pub async fn stream_namespace_events(
1542 &self,
1543 namespace: &str,
1544 ) -> Result<tokio::sync::mpsc::Receiver<Result<crate::events::DakeraEvent>>> {
1545 let url = format!(
1546 "{}/v1/namespaces/{}/events",
1547 self.base_url,
1548 urlencoding::encode(namespace)
1549 );
1550 self.stream_sse(url).await
1551 }
1552
1553 pub async fn stream_global_events(
1560 &self,
1561 ) -> Result<tokio::sync::mpsc::Receiver<Result<crate::events::DakeraEvent>>> {
1562 let url = format!("{}/ops/events", self.base_url);
1563 self.stream_sse(url).await
1564 }
1565
1566 pub async fn stream_memory_events(
1575 &self,
1576 ) -> Result<tokio::sync::mpsc::Receiver<Result<crate::events::MemoryEvent>>> {
1577 let url = format!("{}/v1/events/stream", self.base_url);
1578 self.stream_sse(url).await
1579 }
1580
1581 pub(crate) async fn stream_sse<T>(
1583 &self,
1584 url: String,
1585 ) -> Result<tokio::sync::mpsc::Receiver<Result<T>>>
1586 where
1587 T: serde::de::DeserializeOwned + Send + 'static,
1588 {
1589 use futures_util::StreamExt;
1590
1591 let response = self
1592 .client
1593 .get(&url)
1594 .header("Accept", "text/event-stream")
1595 .header("Cache-Control", "no-cache")
1596 .send()
1597 .await?;
1598
1599 if !response.status().is_success() {
1600 let status = response.status().as_u16();
1601 let body = response.text().await.unwrap_or_default();
1602 return Err(ClientError::Server {
1603 status,
1604 message: body,
1605 code: None,
1606 });
1607 }
1608
1609 let (tx, rx) = tokio::sync::mpsc::channel(64);
1610
1611 tokio::spawn(async move {
1612 let mut byte_stream = response.bytes_stream();
1613 let mut remaining = String::new();
1614 let mut data_lines: Vec<String> = Vec::new();
1615
1616 while let Some(chunk) = byte_stream.next().await {
1617 match chunk {
1618 Ok(bytes) => {
1619 remaining.push_str(&String::from_utf8_lossy(&bytes));
1620 while let Some(pos) = remaining.find('\n') {
1621 let raw = &remaining[..pos];
1622 let line = raw.trim_end_matches('\r').to_string();
1623 remaining = remaining[pos + 1..].to_string();
1624
1625 if line.starts_with(':') {
1626 } else if let Some(data) = line.strip_prefix("data:") {
1628 data_lines.push(data.trim_start().to_string());
1629 } else if line.is_empty() {
1630 if !data_lines.is_empty() {
1631 let payload = data_lines.join("\n");
1632 data_lines.clear();
1633 let result = serde_json::from_str::<T>(&payload)
1634 .map_err(ClientError::Json);
1635 if tx.send(result).await.is_err() {
1636 return; }
1638 }
1639 } else {
1640 }
1642 }
1643 }
1644 Err(e) => {
1645 let _ = tx.send(Err(ClientError::Http(e))).await;
1646 return;
1647 }
1648 }
1649 }
1650 });
1651
1652 Ok(rx)
1653 }
1654
1655 #[instrument(skip(self, request))]
1661 pub async fn route_query(&self, request: RouteRequest) -> Result<RouteResponse> {
1662 let url = format!("{}/v1/route", self.base_url);
1663 let response = self.client.post(&url).json(&request).send().await?;
1664 self.handle_response(response).await
1665 }
1666
1667 #[instrument(skip(self))]
1673 pub async fn import_job_status(&self, job_id: &str) -> Result<ImportJobStatus> {
1674 let url = format!("{}/v1/import/{}/status", self.base_url, job_id);
1675 let response = self.client.get(&url).send().await?;
1676 self.handle_response(response).await
1677 }
1678}
1679
1680#[cfg(test)]
1681mod tests {
1682 use super::*;
1683
1684 #[test]
1685 fn test_client_builder() {
1686 let client = DakeraClient::new("http://localhost:3000");
1687 assert!(client.is_ok());
1688 }
1689
1690 #[test]
1691 fn test_client_builder_with_options() {
1692 let client = DakeraClient::builder("http://localhost:3000")
1693 .timeout_secs(60)
1694 .user_agent("test-client/1.0")
1695 .build();
1696 assert!(client.is_ok());
1697 }
1698
1699 #[test]
1700 fn test_client_builder_invalid_url() {
1701 let client = DakeraClient::new("invalid-url");
1702 assert!(client.is_err());
1703 }
1704
1705 #[test]
1706 fn test_client_builder_trailing_slash() {
1707 let client = DakeraClient::new("http://localhost:3000/").unwrap();
1708 assert!(!client.base_url.ends_with('/'));
1709 }
1710
1711 #[test]
1712 fn test_vector_creation() {
1713 let v = Vector::new("test", vec![0.1, 0.2, 0.3]);
1714 assert_eq!(v.id, "test");
1715 assert_eq!(v.values.len(), 3);
1716 assert!(v.metadata.is_none());
1717 }
1718
1719 #[test]
1720 fn test_query_request_builder() {
1721 let req = QueryRequest::new(vec![0.1, 0.2], 10)
1722 .with_filter(serde_json::json!({"category": "test"}))
1723 .include_metadata(false);
1724
1725 assert_eq!(req.top_k, 10);
1726 assert!(req.filter.is_some());
1727 assert!(!req.include_metadata);
1728 }
1729
1730 #[test]
1731 fn test_hybrid_search_request() {
1732 let req = HybridSearchRequest::new(vec![0.1], "test query", 5).with_vector_weight(0.7);
1733
1734 assert_eq!(req.vector_weight, 0.7);
1735 assert_eq!(req.text, "test query");
1736 assert!(req.vector.is_some());
1737 }
1738
1739 #[test]
1740 fn test_hybrid_search_weight_clamping() {
1741 let req = HybridSearchRequest::new(vec![0.1], "test", 5).with_vector_weight(1.5); assert_eq!(req.vector_weight, 1.0);
1744 }
1745
1746 #[test]
1747 fn test_hybrid_search_text_only() {
1748 let req = HybridSearchRequest::text_only("bm25 query", 10);
1749
1750 assert!(req.vector.is_none());
1751 assert_eq!(req.text, "bm25 query");
1752 assert_eq!(req.top_k, 10);
1753 let json = serde_json::to_value(&req).unwrap();
1755 assert!(json.get("vector").is_none());
1756 }
1757
1758 #[test]
1759 fn test_text_document_builder() {
1760 let doc = TextDocument::new("doc1", "Hello world").with_ttl(3600);
1761
1762 assert_eq!(doc.id, "doc1");
1763 assert_eq!(doc.text, "Hello world");
1764 assert_eq!(doc.ttl_seconds, Some(3600));
1765 assert!(doc.metadata.is_none());
1766 }
1767
1768 #[test]
1769 fn test_upsert_text_request_builder() {
1770 let docs = vec![
1771 TextDocument::new("doc1", "Hello"),
1772 TextDocument::new("doc2", "World"),
1773 ];
1774 let req = UpsertTextRequest::new(docs).with_model(EmbeddingModel::BgeSmall);
1775
1776 assert_eq!(req.documents.len(), 2);
1777 assert_eq!(req.model, Some(EmbeddingModel::BgeSmall));
1778 }
1779
1780 #[test]
1781 fn test_query_text_request_builder() {
1782 let req = QueryTextRequest::new("semantic search query", 5)
1783 .with_filter(serde_json::json!({"category": "docs"}))
1784 .include_vectors(true)
1785 .with_model(EmbeddingModel::E5Small);
1786
1787 assert_eq!(req.text, "semantic search query");
1788 assert_eq!(req.top_k, 5);
1789 assert!(req.filter.is_some());
1790 assert!(req.include_vectors);
1791 assert_eq!(req.model, Some(EmbeddingModel::E5Small));
1792 }
1793
1794 #[test]
1795 fn test_fetch_request_builder() {
1796 let req = FetchRequest::new(vec!["id1".to_string(), "id2".to_string()]);
1797
1798 assert_eq!(req.ids.len(), 2);
1799 assert!(req.include_values);
1800 assert!(req.include_metadata);
1801 }
1802
1803 #[test]
1804 fn test_create_namespace_request_builder() {
1805 let req = CreateNamespaceRequest::new()
1806 .with_dimensions(384)
1807 .with_index_type("hnsw");
1808
1809 assert_eq!(req.dimensions, Some(384));
1810 assert_eq!(req.index_type.as_deref(), Some("hnsw"));
1811 }
1812
1813 #[test]
1814 fn test_batch_query_text_request() {
1815 let req =
1816 BatchQueryTextRequest::new(vec!["query one".to_string(), "query two".to_string()], 10);
1817
1818 assert_eq!(req.queries.len(), 2);
1819 assert_eq!(req.top_k, 10);
1820 assert!(!req.include_vectors);
1821 assert!(req.model.is_none());
1822 }
1823
1824 #[test]
1829 fn test_retry_config_defaults() {
1830 let rc = RetryConfig::default();
1831 assert_eq!(rc.max_retries, 3);
1832 assert_eq!(rc.base_delay, Duration::from_millis(100));
1833 assert_eq!(rc.max_delay, Duration::from_secs(60));
1834 assert!(rc.jitter);
1835 }
1836
1837 #[test]
1838 fn test_builder_connect_timeout() {
1839 let client = DakeraClient::builder("http://localhost:3000")
1840 .connect_timeout(Duration::from_secs(5))
1841 .timeout_secs(30)
1842 .build()
1843 .unwrap();
1844 assert!(client.base_url.starts_with("http"));
1846 }
1847
1848 #[test]
1849 fn test_builder_max_retries() {
1850 let client = DakeraClient::builder("http://localhost:3000")
1851 .max_retries(5)
1852 .build()
1853 .unwrap();
1854 assert_eq!(client.retry_config.max_retries, 5);
1855 }
1856
1857 #[test]
1858 fn test_builder_retry_config() {
1859 let rc = RetryConfig {
1860 max_retries: 7,
1861 base_delay: Duration::from_millis(200),
1862 max_delay: Duration::from_secs(30),
1863 jitter: false,
1864 };
1865 let client = DakeraClient::builder("http://localhost:3000")
1866 .retry_config(rc)
1867 .build()
1868 .unwrap();
1869 assert_eq!(client.retry_config.max_retries, 7);
1870 assert!(!client.retry_config.jitter);
1871 }
1872
1873 #[test]
1874 fn test_rate_limit_error_retryable() {
1875 let e = ClientError::RateLimitExceeded { retry_after: None };
1876 assert!(e.is_retryable());
1877 }
1878
1879 #[test]
1880 fn test_server_408_retryable() {
1881 let e = ClientError::Server {
1882 status: 408,
1883 message: String::new(),
1884 code: None,
1885 };
1886 assert!(e.is_retryable());
1887 }
1888
1889 #[test]
1890 fn test_server_400_not_retryable() {
1891 let e = ClientError::Server {
1892 status: 400,
1893 message: String::new(),
1894 code: None,
1895 };
1896 assert!(!e.is_retryable());
1897 }
1898
1899 #[test]
1900 fn test_rate_limit_error_with_retry_after_zero() {
1901 let e = ClientError::RateLimitExceeded {
1903 retry_after: Some(0),
1904 };
1905 assert!(e.is_retryable());
1906 if let ClientError::RateLimitExceeded {
1907 retry_after: Some(secs),
1908 } = &e
1909 {
1910 assert_eq!(*secs, 0u64);
1911 } else {
1912 panic!("unexpected variant");
1913 }
1914 }
1915
1916 #[tokio::test]
1917 async fn test_execute_with_retry_succeeds_immediately() {
1918 let client = DakeraClient::builder("http://localhost:3000")
1919 .max_retries(3)
1920 .build()
1921 .unwrap();
1922
1923 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1924 let cc = call_count.clone();
1925 let result = client
1926 .execute_with_retry(|| {
1927 let cc = cc.clone();
1928 async move {
1929 cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1930 Ok::<u32, ClientError>(42)
1931 }
1932 })
1933 .await;
1934 assert_eq!(result.unwrap(), 42);
1935 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
1936 }
1937
1938 #[tokio::test]
1939 async fn test_execute_with_retry_no_retry_on_4xx() {
1940 let client = DakeraClient::builder("http://localhost:3000")
1941 .max_retries(3)
1942 .build()
1943 .unwrap();
1944
1945 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1946 let cc = call_count.clone();
1947 let result = client
1948 .execute_with_retry(|| {
1949 let cc = cc.clone();
1950 async move {
1951 cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1952 Err::<u32, ClientError>(ClientError::Server {
1953 status: 400,
1954 message: "bad request".to_string(),
1955 code: None,
1956 })
1957 }
1958 })
1959 .await;
1960 assert!(result.is_err());
1961 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
1963 }
1964
1965 #[tokio::test]
1966 async fn test_execute_with_retry_retries_on_5xx() {
1967 let client = DakeraClient::builder("http://localhost:3000")
1968 .retry_config(RetryConfig {
1969 max_retries: 3,
1970 base_delay: Duration::from_millis(0),
1971 max_delay: Duration::from_millis(0),
1972 jitter: false,
1973 })
1974 .build()
1975 .unwrap();
1976
1977 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1978 let cc = call_count.clone();
1979 let result = client
1980 .execute_with_retry(|| {
1981 let cc = cc.clone();
1982 async move {
1983 let n = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1984 if n < 2 {
1985 Err::<u32, ClientError>(ClientError::Server {
1986 status: 503,
1987 message: "unavailable".to_string(),
1988 code: None,
1989 })
1990 } else {
1991 Ok(99)
1992 }
1993 }
1994 })
1995 .await;
1996 assert_eq!(result.unwrap(), 99);
1997 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
1998 }
1999
2000 #[test]
2005 fn test_batch_recall_request_new() {
2006 use crate::memory::BatchRecallRequest;
2007 let req = BatchRecallRequest::new("agent-1");
2008 assert_eq!(req.agent_id, "agent-1");
2009 assert_eq!(req.limit, 100);
2010 }
2011
2012 #[test]
2013 fn test_batch_recall_request_builder() {
2014 use crate::memory::{BatchMemoryFilter, BatchRecallRequest};
2015 let filter = BatchMemoryFilter::default()
2016 .with_tags(vec!["qa".to_string()])
2017 .with_min_importance(0.7);
2018 let req = BatchRecallRequest::new("agent-1")
2019 .with_filter(filter)
2020 .with_limit(50);
2021 assert_eq!(req.agent_id, "agent-1");
2022 assert_eq!(req.limit, 50);
2023 assert_eq!(
2024 req.filter.tags.as_deref(),
2025 Some(["qa".to_string()].as_slice())
2026 );
2027 assert_eq!(req.filter.min_importance, Some(0.7));
2028 }
2029
2030 #[test]
2031 fn test_batch_recall_request_serialization() {
2032 use crate::memory::{BatchMemoryFilter, BatchRecallRequest};
2033 let filter = BatchMemoryFilter::default().with_min_importance(0.5);
2034 let req = BatchRecallRequest::new("agent-1")
2035 .with_filter(filter)
2036 .with_limit(25);
2037 let json = serde_json::to_value(&req).unwrap();
2038 assert_eq!(json["agent_id"], "agent-1");
2039 assert_eq!(json["limit"], 25);
2040 assert_eq!(json["filter"]["min_importance"], 0.5);
2041 }
2042
2043 #[test]
2044 fn test_batch_forget_request_new() {
2045 use crate::memory::{BatchForgetRequest, BatchMemoryFilter};
2046 let filter = BatchMemoryFilter::default().with_min_importance(0.1);
2047 let req = BatchForgetRequest::new("agent-1", filter);
2048 assert_eq!(req.agent_id, "agent-1");
2049 assert_eq!(req.filter.min_importance, Some(0.1));
2050 }
2051
2052 #[test]
2053 fn test_batch_forget_request_serialization() {
2054 use crate::memory::{BatchForgetRequest, BatchMemoryFilter};
2055 let filter = BatchMemoryFilter {
2056 created_before: Some(1_700_000_000),
2057 ..Default::default()
2058 };
2059 let req = BatchForgetRequest::new("agent-1", filter);
2060 let json = serde_json::to_value(&req).unwrap();
2061 assert_eq!(json["agent_id"], "agent-1");
2062 assert_eq!(json["filter"]["created_before"], 1_700_000_000u64);
2063 }
2064
2065 #[test]
2066 fn test_batch_recall_response_deserialization() {
2067 use crate::memory::BatchRecallResponse;
2068 let json = serde_json::json!({
2069 "memories": [],
2070 "total": 42,
2071 "filtered": 7
2072 });
2073 let resp: BatchRecallResponse = serde_json::from_value(json).unwrap();
2074 assert_eq!(resp.total, 42);
2075 assert_eq!(resp.filtered, 7);
2076 assert!(resp.memories.is_empty());
2077 }
2078
2079 #[test]
2080 fn test_batch_forget_response_deserialization() {
2081 use crate::memory::BatchForgetResponse;
2082 let json = serde_json::json!({ "deleted_count": 13 });
2083 let resp: BatchForgetResponse = serde_json::from_value(json).unwrap();
2084 assert_eq!(resp.deleted_count, 13);
2085 }
2086
2087 #[test]
2092 fn test_rate_limit_headers_default_all_none() {
2093 use crate::types::RateLimitHeaders;
2094 let rl = RateLimitHeaders {
2095 limit: None,
2096 remaining: None,
2097 reset: None,
2098 quota_used: None,
2099 quota_limit: None,
2100 };
2101 assert!(rl.limit.is_none());
2102 assert!(rl.remaining.is_none());
2103 assert!(rl.reset.is_none());
2104 assert!(rl.quota_used.is_none());
2105 assert!(rl.quota_limit.is_none());
2106 }
2107
2108 #[test]
2109 fn test_rate_limit_headers_populated() {
2110 use crate::types::RateLimitHeaders;
2111 let rl = RateLimitHeaders {
2112 limit: Some(1000),
2113 remaining: Some(750),
2114 reset: Some(1_700_000_060),
2115 quota_used: Some(500),
2116 quota_limit: Some(10_000),
2117 };
2118 assert_eq!(rl.limit, Some(1000));
2119 assert_eq!(rl.remaining, Some(750));
2120 assert_eq!(rl.reset, Some(1_700_000_060));
2121 assert_eq!(rl.quota_used, Some(500));
2122 assert_eq!(rl.quota_limit, Some(10_000));
2123 }
2124
2125 #[test]
2126 fn test_last_rate_limit_headers_initially_none() {
2127 let client = DakeraClient::new("http://localhost:3000").unwrap();
2128 assert!(client.last_rate_limit_headers().is_none());
2129 }
2130
2131 #[test]
2136 fn test_namespace_ner_config_default() {
2137 use crate::types::NamespaceNerConfig;
2138 let cfg = NamespaceNerConfig::default();
2139 assert!(!cfg.extract_entities);
2140 assert!(cfg.entity_types.is_none());
2141 }
2142
2143 #[test]
2144 fn test_namespace_ner_config_serialization_skip_none() {
2145 use crate::types::NamespaceNerConfig;
2146 let cfg = NamespaceNerConfig {
2147 extract_entities: true,
2148 entity_types: None,
2149 };
2150 let json = serde_json::to_value(&cfg).unwrap();
2151 assert_eq!(json["extract_entities"], true);
2152 assert!(json.get("entity_types").is_none());
2154 }
2155
2156 #[test]
2157 fn test_namespace_ner_config_serialization_with_types() {
2158 use crate::types::NamespaceNerConfig;
2159 let cfg = NamespaceNerConfig {
2160 extract_entities: true,
2161 entity_types: Some(vec!["PERSON".to_string(), "ORG".to_string()]),
2162 };
2163 let json = serde_json::to_value(&cfg).unwrap();
2164 assert_eq!(json["extract_entities"], true);
2165 assert_eq!(json["entity_types"][0], "PERSON");
2166 assert_eq!(json["entity_types"][1], "ORG");
2167 }
2168
2169 #[test]
2170 fn test_extracted_entity_deserialization() {
2171 use crate::types::ExtractedEntity;
2172 let json = serde_json::json!({
2173 "entity_type": "PERSON",
2174 "value": "Alice",
2175 "score": 0.95
2176 });
2177 let entity: ExtractedEntity = serde_json::from_value(json).unwrap();
2178 assert_eq!(entity.entity_type, "PERSON");
2179 assert_eq!(entity.value, "Alice");
2180 assert!((entity.score - 0.95).abs() < f64::EPSILON);
2181 }
2182
2183 #[test]
2184 fn test_entity_extraction_response_deserialization() {
2185 use crate::types::EntityExtractionResponse;
2186 let json = serde_json::json!({
2187 "entities": [
2188 { "entity_type": "PERSON", "value": "Bob", "score": 0.9 },
2189 { "entity_type": "ORG", "value": "Acme", "score": 0.87 }
2190 ]
2191 });
2192 let resp: EntityExtractionResponse = serde_json::from_value(json).unwrap();
2193 assert_eq!(resp.entities.len(), 2);
2194 assert_eq!(resp.entities[0].entity_type, "PERSON");
2195 assert_eq!(resp.entities[1].value, "Acme");
2196 }
2197
2198 #[test]
2199 fn test_memory_entities_response_deserialization() {
2200 use crate::types::MemoryEntitiesResponse;
2201 let json = serde_json::json!({
2202 "memory_id": "mem-abc-123",
2203 "entities": [
2204 { "entity_type": "LOC", "value": "London", "score": 0.88 }
2205 ]
2206 });
2207 let resp: MemoryEntitiesResponse = serde_json::from_value(json).unwrap();
2208 assert_eq!(resp.memory_id, "mem-abc-123");
2209 assert_eq!(resp.entities.len(), 1);
2210 assert_eq!(resp.entities[0].entity_type, "LOC");
2211 assert_eq!(resp.entities[0].value, "London");
2212 }
2213
2214 #[test]
2215 fn test_configure_namespace_ner_url_pattern() {
2216 let client = DakeraClient::new("http://localhost:3000").unwrap();
2218 let expected = "http://localhost:3000/v1/namespaces/my-ns/config";
2219 let actual = format!("{}/v1/namespaces/{}/config", client.base_url, "my-ns");
2220 assert_eq!(actual, expected);
2221 }
2222
2223 #[test]
2224 fn test_extract_entities_url_pattern() {
2225 let client = DakeraClient::new("http://localhost:3000").unwrap();
2226 let expected = "http://localhost:3000/v1/memories/extract";
2227 let actual = format!("{}/v1/memories/extract", client.base_url);
2228 assert_eq!(actual, expected);
2229 }
2230
2231 #[test]
2232 fn test_memory_entities_url_pattern() {
2233 let client = DakeraClient::new("http://localhost:3000").unwrap();
2234 let memory_id = "mem-xyz-789";
2235 let expected = "http://localhost:3000/v1/memory/entities/mem-xyz-789";
2236 let actual = format!("{}/v1/memory/entities/{}", client.base_url, memory_id);
2237 assert_eq!(actual, expected);
2238 }
2239
2240 #[test]
2245 fn test_feedback_signal_serialization() {
2246 use crate::types::FeedbackSignal;
2247 let upvote = serde_json::to_value(FeedbackSignal::Upvote).unwrap();
2248 assert_eq!(upvote, serde_json::json!("upvote"));
2249 let downvote = serde_json::to_value(FeedbackSignal::Downvote).unwrap();
2250 assert_eq!(downvote, serde_json::json!("downvote"));
2251 let flag = serde_json::to_value(FeedbackSignal::Flag).unwrap();
2252 assert_eq!(flag, serde_json::json!("flag"));
2253 }
2254
2255 #[test]
2256 fn test_feedback_signal_deserialization() {
2257 use crate::types::FeedbackSignal;
2258 let signal: FeedbackSignal = serde_json::from_str("\"upvote\"").unwrap();
2259 assert_eq!(signal, FeedbackSignal::Upvote);
2260 let signal: FeedbackSignal = serde_json::from_str("\"positive\"").unwrap();
2261 assert_eq!(signal, FeedbackSignal::Positive);
2262 }
2263
2264 #[test]
2265 fn test_feedback_response_deserialization() {
2266 use crate::types::{FeedbackResponse, FeedbackSignal};
2267 let json = serde_json::json!({
2268 "memory_id": "mem-abc",
2269 "new_importance": 0.92,
2270 "signal": "upvote"
2271 });
2272 let resp: FeedbackResponse = serde_json::from_value(json).unwrap();
2273 assert_eq!(resp.memory_id, "mem-abc");
2274 assert!((resp.new_importance - 0.92).abs() < f32::EPSILON);
2275 assert_eq!(resp.signal, FeedbackSignal::Upvote);
2276 }
2277
2278 #[test]
2279 fn test_feedback_history_response_deserialization() {
2280 use crate::types::{FeedbackHistoryResponse, FeedbackSignal};
2281 let json = serde_json::json!({
2282 "memory_id": "mem-abc",
2283 "entries": [
2284 {"signal": "upvote", "timestamp": 1774000000_u64, "old_importance": 0.5, "new_importance": 0.575},
2285 {"signal": "downvote", "timestamp": 1774001000_u64, "old_importance": 0.575, "new_importance": 0.489}
2286 ]
2287 });
2288 let resp: FeedbackHistoryResponse = serde_json::from_value(json).unwrap();
2289 assert_eq!(resp.memory_id, "mem-abc");
2290 assert_eq!(resp.entries.len(), 2);
2291 assert_eq!(resp.entries[0].signal, FeedbackSignal::Upvote);
2292 assert_eq!(resp.entries[1].signal, FeedbackSignal::Downvote);
2293 }
2294
2295 #[test]
2296 fn test_agent_feedback_summary_deserialization() {
2297 use crate::types::AgentFeedbackSummary;
2298 let json = serde_json::json!({
2299 "agent_id": "agent-1",
2300 "upvotes": 42_u64,
2301 "downvotes": 7_u64,
2302 "flags": 2_u64,
2303 "total_feedback": 51_u64,
2304 "health_score": 0.78
2305 });
2306 let summary: AgentFeedbackSummary = serde_json::from_value(json).unwrap();
2307 assert_eq!(summary.agent_id, "agent-1");
2308 assert_eq!(summary.upvotes, 42);
2309 assert_eq!(summary.total_feedback, 51);
2310 assert!((summary.health_score - 0.78).abs() < f32::EPSILON);
2311 }
2312
2313 #[test]
2314 fn test_feedback_health_response_deserialization() {
2315 use crate::types::FeedbackHealthResponse;
2316 let json = serde_json::json!({
2317 "agent_id": "agent-1",
2318 "health_score": 0.78,
2319 "memory_count": 120_usize,
2320 "avg_importance": 0.72
2321 });
2322 let health: FeedbackHealthResponse = serde_json::from_value(json).unwrap();
2323 assert_eq!(health.agent_id, "agent-1");
2324 assert!((health.health_score - 0.78).abs() < f32::EPSILON);
2325 assert_eq!(health.memory_count, 120);
2326 }
2327
2328 #[test]
2329 fn test_memory_feedback_body_serialization() {
2330 use crate::types::{FeedbackSignal, MemoryFeedbackBody};
2331 let body = MemoryFeedbackBody {
2332 agent_id: "agent-1".to_string(),
2333 signal: FeedbackSignal::Flag,
2334 };
2335 let json = serde_json::to_value(body).unwrap();
2336 assert_eq!(json["agent_id"], "agent-1");
2337 assert_eq!(json["signal"], "flag");
2338 }
2339
2340 #[test]
2341 fn test_feedback_memory_url_pattern() {
2342 let client = DakeraClient::new("http://localhost:3000").unwrap();
2343 let memory_id = "mem-abc";
2344 let expected_post = "http://localhost:3000/v1/memories/mem-abc/feedback";
2345 let actual_post = format!("{}/v1/memories/{}/feedback", client.base_url, memory_id);
2346 assert_eq!(actual_post, expected_post);
2347
2348 let expected_patch = "http://localhost:3000/v1/memories/mem-abc/importance";
2349 let actual_patch = format!("{}/v1/memories/{}/importance", client.base_url, memory_id);
2350 assert_eq!(actual_patch, expected_patch);
2351 }
2352
2353 #[test]
2354 fn test_feedback_health_url_pattern() {
2355 let client = DakeraClient::new("http://localhost:3000").unwrap();
2356 let agent_id = "agent-1";
2357 let expected = "http://localhost:3000/v1/feedback/health?agent_id=agent-1";
2358 let actual = format!(
2359 "{}/v1/feedback/health?agent_id={}",
2360 client.base_url, agent_id
2361 );
2362 assert_eq!(actual, expected);
2363 }
2364
2365 #[test]
2367 fn test_ode_extract_entities_requires_ode_url() {
2368 let client = DakeraClient::new("http://localhost:3000").unwrap();
2370 let rt = tokio::runtime::Runtime::new().unwrap();
2371 let result = rt.block_on(client.ode_extract_entities(ExtractEntitiesRequest {
2372 content: "Alice lives in Paris.".to_string(),
2373 agent_id: "agent-1".to_string(),
2374 memory_id: None,
2375 entity_types: None,
2376 }));
2377 assert!(result.is_err());
2378 let err = result.unwrap_err();
2379 assert!(matches!(err, ClientError::Config(_)));
2380 }
2381
2382 #[test]
2383 fn test_ode_extract_entities_url_built_from_ode_url() {
2384 let client = DakeraClient::builder("http://localhost:3000")
2386 .ode_url("http://localhost:8080")
2387 .build()
2388 .unwrap();
2389 assert_eq!(client.ode_url.as_deref(), Some("http://localhost:8080"));
2390 let expected = "http://localhost:8080/ode/extract";
2391 let actual = format!("{}/ode/extract", client.ode_url.as_deref().unwrap());
2392 assert_eq!(actual, expected);
2393 }
2394
2395 #[test]
2396 fn test_extract_entities_request_serialization() {
2397 let req = ExtractEntitiesRequest {
2398 content: "Alice in Wonderland".to_string(),
2399 agent_id: "agent-42".to_string(),
2400 memory_id: Some("mem-001".to_string()),
2401 entity_types: Some(vec!["person".to_string(), "location".to_string()]),
2402 };
2403 let json = serde_json::to_string(&req).unwrap();
2404 assert!(json.contains("\"content\":\"Alice in Wonderland\""));
2405 assert!(json.contains("\"agent_id\":\"agent-42\""));
2406 assert!(json.contains("\"memory_id\":\"mem-001\""));
2407 assert!(json.contains("\"person\""));
2408 }
2409
2410 #[test]
2411 fn test_extract_entities_request_omits_none_fields() {
2412 let req = ExtractEntitiesRequest {
2413 content: "hello".to_string(),
2414 agent_id: "a".to_string(),
2415 memory_id: None,
2416 entity_types: None,
2417 };
2418 let json = serde_json::to_string(&req).unwrap();
2419 assert!(!json.contains("memory_id"));
2420 assert!(!json.contains("entity_types"));
2421 }
2422
2423 #[test]
2424 fn test_ode_entity_deserialization() {
2425 let json = r#"{"text":"Alice","label":"person","start":0,"end":5,"score":0.97}"#;
2426 let entity: OdeEntity = serde_json::from_str(json).unwrap();
2427 assert_eq!(entity.text, "Alice");
2428 assert_eq!(entity.label, "person");
2429 assert_eq!(entity.start, 0);
2430 assert_eq!(entity.end, 5);
2431 assert!((entity.score - 0.97).abs() < 1e-4);
2432 }
2433
2434 #[test]
2435 fn test_extract_entities_response_deserialization() {
2436 let json = r#"{
2437 "entities": [
2438 {"text":"Alice","label":"person","start":0,"end":5,"score":0.97},
2439 {"text":"Paris","label":"location","start":16,"end":21,"score":0.92}
2440 ],
2441 "model": "gliner-multi-v2.1",
2442 "processing_time_ms": 34
2443 }"#;
2444 let resp: ExtractEntitiesResponse = serde_json::from_str(json).unwrap();
2445 assert_eq!(resp.entities.len(), 2);
2446 assert_eq!(resp.entities[0].text, "Alice");
2447 assert_eq!(resp.model, "gliner-multi-v2.1");
2448 assert_eq!(resp.processing_time_ms, 34);
2449 }
2450}