1use reqwest::{Client, StatusCode};
4use std::sync::{Arc, Mutex};
5use std::time::Duration;
6use tracing::{debug, instrument};
7
8use serde::Deserialize;
9
10use crate::error::{ClientError, Result, ServerErrorCode};
11use crate::types::*;
12
13const DEFAULT_TIMEOUT_SECS: u64 = 30;
15
16#[derive(Debug, Clone)]
18pub struct DakeraClient {
19 pub(crate) client: Client,
21 pub(crate) base_url: String,
23 pub(crate) ode_url: Option<String>,
25 #[allow(dead_code)]
27 pub(crate) retry_config: RetryConfig,
28 pub(crate) last_rate_limit: Arc<Mutex<Option<RateLimitHeaders>>>,
30}
31
32impl DakeraClient {
33 pub fn new(base_url: impl Into<String>) -> Result<Self> {
43 DakeraClientBuilder::new(base_url).build()
44 }
45
46 pub fn builder(base_url: impl Into<String>) -> DakeraClientBuilder {
48 DakeraClientBuilder::new(base_url)
49 }
50
51 #[instrument(skip(self))]
57 pub async fn health(&self) -> Result<HealthResponse> {
58 let url = format!("{}/health", self.base_url);
59 let response = self.client.get(&url).send().await?;
60
61 if response.status().is_success() {
62 Ok(response.json().await?)
63 } else {
64 Ok(HealthResponse {
66 healthy: true,
67 version: None,
68 uptime_seconds: None,
69 })
70 }
71 }
72
73 #[instrument(skip(self))]
75 pub async fn ready(&self) -> Result<ReadinessResponse> {
76 let url = format!("{}/health/ready", self.base_url);
77 let response = self.client.get(&url).send().await?;
78
79 if response.status().is_success() {
80 Ok(response.json().await?)
81 } else {
82 Ok(ReadinessResponse {
83 ready: false,
84 components: None,
85 })
86 }
87 }
88
89 #[instrument(skip(self))]
91 pub async fn live(&self) -> Result<bool> {
92 let url = format!("{}/health/live", self.base_url);
93 let response = self.client.get(&url).send().await?;
94 Ok(response.status().is_success())
95 }
96
97 #[instrument(skip(self))]
103 pub async fn list_namespaces(&self) -> Result<Vec<String>> {
104 let url = format!("{}/v1/namespaces", self.base_url);
105 let response = self.client.get(&url).send().await?;
106 self.handle_response::<ListNamespacesResponse>(response)
107 .await
108 .map(|r| r.namespaces)
109 }
110
111 #[instrument(skip(self))]
113 pub async fn get_namespace(&self, namespace: &str) -> Result<NamespaceInfo> {
114 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
115 let response = self.client.get(&url).send().await?;
116 self.handle_response(response).await
117 }
118
119 #[instrument(skip(self, request))]
121 pub async fn create_namespace(
122 &self,
123 namespace: &str,
124 request: CreateNamespaceRequest,
125 ) -> Result<NamespaceInfo> {
126 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
127 let response = self.client.post(&url).json(&request).send().await?;
128 self.handle_response(response).await
129 }
130
131 #[instrument(skip(self, request), fields(namespace = %namespace))]
137 pub async fn configure_namespace(
138 &self,
139 namespace: &str,
140 request: ConfigureNamespaceRequest,
141 ) -> Result<ConfigureNamespaceResponse> {
142 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
143 let response = self.client.put(&url).json(&request).send().await?;
144 self.handle_response(response).await
145 }
146
147 #[instrument(skip(self, request), fields(vector_count = request.vectors.len()))]
153 pub async fn upsert(&self, namespace: &str, request: UpsertRequest) -> Result<UpsertResponse> {
154 let url = format!("{}/v1/namespaces/{}/vectors", self.base_url, namespace);
155 debug!(
156 "Upserting {} vectors to {}",
157 request.vectors.len(),
158 namespace
159 );
160
161 let response = self.client.post(&url).json(&request).send().await?;
162 self.handle_response(response).await
163 }
164
165 #[instrument(skip(self, vector))]
167 pub async fn upsert_one(&self, namespace: &str, vector: Vector) -> Result<UpsertResponse> {
168 self.upsert(namespace, UpsertRequest::single(vector)).await
169 }
170
171 #[instrument(skip(self, request), fields(namespace = %namespace, count = request.ids.len()))]
204 pub async fn upsert_columns(
205 &self,
206 namespace: &str,
207 request: ColumnUpsertRequest,
208 ) -> Result<UpsertResponse> {
209 let url = format!(
210 "{}/v1/namespaces/{}/upsert-columns",
211 self.base_url, namespace
212 );
213 debug!(
214 "Upserting {} vectors in column format to {}",
215 request.ids.len(),
216 namespace
217 );
218
219 let response = self.client.post(&url).json(&request).send().await?;
220 self.handle_response(response).await
221 }
222
223 #[instrument(skip(self, request), fields(top_k = request.top_k))]
225 pub async fn query(&self, namespace: &str, request: QueryRequest) -> Result<QueryResponse> {
226 let url = format!("{}/v1/namespaces/{}/query", self.base_url, namespace);
227 debug!(
228 "Querying namespace {} for top {} results",
229 namespace, request.top_k
230 );
231
232 let response = self.client.post(&url).json(&request).send().await?;
233 self.handle_response(response).await
234 }
235
236 #[instrument(skip(self, vector))]
238 pub async fn query_simple(
239 &self,
240 namespace: &str,
241 vector: Vec<f32>,
242 top_k: u32,
243 ) -> Result<QueryResponse> {
244 self.query(namespace, QueryRequest::new(vector, top_k))
245 .await
246 }
247
248 #[instrument(skip(self, request), fields(namespace = %namespace, query_count = request.queries.len()))]
272 pub async fn batch_query(
273 &self,
274 namespace: &str,
275 request: BatchQueryRequest,
276 ) -> Result<BatchQueryResponse> {
277 let url = format!("{}/v1/namespaces/{}/batch-query", self.base_url, namespace);
278 debug!(
279 "Batch querying namespace {} with {} queries",
280 namespace,
281 request.queries.len()
282 );
283
284 let response = self.client.post(&url).json(&request).send().await?;
285 self.handle_response(response).await
286 }
287
288 #[instrument(skip(self, request), fields(id_count = request.ids.len()))]
290 pub async fn delete(&self, namespace: &str, request: DeleteRequest) -> Result<DeleteResponse> {
291 let url = format!(
292 "{}/v1/namespaces/{}/vectors/delete",
293 self.base_url, namespace
294 );
295 debug!("Deleting {} vectors from {}", request.ids.len(), namespace);
296
297 let response = self.client.post(&url).json(&request).send().await?;
298 self.handle_response(response).await
299 }
300
301 #[instrument(skip(self))]
303 pub async fn delete_one(&self, namespace: &str, id: &str) -> Result<DeleteResponse> {
304 self.delete(namespace, DeleteRequest::single(id)).await
305 }
306
307 #[instrument(skip(self, request), fields(doc_count = request.documents.len()))]
313 pub async fn index_documents(
314 &self,
315 namespace: &str,
316 request: IndexDocumentsRequest,
317 ) -> Result<IndexDocumentsResponse> {
318 let url = format!(
319 "{}/v1/namespaces/{}/fulltext/index",
320 self.base_url, namespace
321 );
322 debug!(
323 "Indexing {} documents in {}",
324 request.documents.len(),
325 namespace
326 );
327
328 let response = self.client.post(&url).json(&request).send().await?;
329 self.handle_response(response).await
330 }
331
332 #[instrument(skip(self, document))]
334 pub async fn index_document(
335 &self,
336 namespace: &str,
337 document: Document,
338 ) -> Result<IndexDocumentsResponse> {
339 self.index_documents(
340 namespace,
341 IndexDocumentsRequest {
342 documents: vec![document],
343 },
344 )
345 .await
346 }
347
348 #[instrument(skip(self, request))]
350 pub async fn fulltext_search(
351 &self,
352 namespace: &str,
353 request: FullTextSearchRequest,
354 ) -> Result<FullTextSearchResponse> {
355 let url = format!(
356 "{}/v1/namespaces/{}/fulltext/search",
357 self.base_url, namespace
358 );
359 debug!("Full-text search in {} for: {}", namespace, request.query);
360
361 let response = self.client.post(&url).json(&request).send().await?;
362 self.handle_response(response).await
363 }
364
365 #[instrument(skip(self))]
367 pub async fn search_text(
368 &self,
369 namespace: &str,
370 query: &str,
371 top_k: u32,
372 ) -> Result<FullTextSearchResponse> {
373 self.fulltext_search(namespace, FullTextSearchRequest::new(query, top_k))
374 .await
375 }
376
377 #[instrument(skip(self))]
379 pub async fn fulltext_stats(&self, namespace: &str) -> Result<FullTextStats> {
380 let url = format!(
381 "{}/v1/namespaces/{}/fulltext/stats",
382 self.base_url, namespace
383 );
384 let response = self.client.get(&url).send().await?;
385 self.handle_response(response).await
386 }
387
388 #[instrument(skip(self, request))]
390 pub async fn fulltext_delete(
391 &self,
392 namespace: &str,
393 request: DeleteRequest,
394 ) -> Result<DeleteResponse> {
395 let url = format!(
396 "{}/v1/namespaces/{}/fulltext/delete",
397 self.base_url, namespace
398 );
399 let response = self.client.post(&url).json(&request).send().await?;
400 self.handle_response(response).await
401 }
402
403 #[instrument(skip(self, request), fields(top_k = request.top_k))]
409 pub async fn hybrid_search(
410 &self,
411 namespace: &str,
412 request: HybridSearchRequest,
413 ) -> Result<HybridSearchResponse> {
414 let url = format!("{}/v1/namespaces/{}/hybrid", self.base_url, namespace);
415 debug!(
416 "Hybrid search in {} with vector_weight={}",
417 namespace, request.vector_weight
418 );
419
420 let response = self.client.post(&url).json(&request).send().await?;
421 self.handle_response(response).await
422 }
423
424 #[instrument(skip(self, request), fields(namespace = %namespace))]
461 pub async fn multi_vector_search(
462 &self,
463 namespace: &str,
464 request: MultiVectorSearchRequest,
465 ) -> Result<MultiVectorSearchResponse> {
466 let url = format!("{}/v1/namespaces/{}/multi-vector", self.base_url, namespace);
467 debug!(
468 "Multi-vector search in {} with {} positive vectors",
469 namespace,
470 request.positive_vectors.len()
471 );
472
473 let response = self.client.post(&url).json(&request).send().await?;
474 self.handle_response(response).await
475 }
476
477 #[instrument(skip(self, request), fields(namespace = %namespace))]
511 pub async fn aggregate(
512 &self,
513 namespace: &str,
514 request: AggregationRequest,
515 ) -> Result<AggregationResponse> {
516 let url = format!("{}/v1/namespaces/{}/aggregate", self.base_url, namespace);
517 debug!(
518 "Aggregating in namespace {} with {} aggregations",
519 namespace,
520 request.aggregate_by.len()
521 );
522
523 let response = self.client.post(&url).json(&request).send().await?;
524 self.handle_response(response).await
525 }
526
527 #[instrument(skip(self, request), fields(namespace = %namespace))]
565 pub async fn unified_query(
566 &self,
567 namespace: &str,
568 request: UnifiedQueryRequest,
569 ) -> Result<UnifiedQueryResponse> {
570 let url = format!(
571 "{}/v1/namespaces/{}/unified-query",
572 self.base_url, namespace
573 );
574 debug!(
575 "Unified query in namespace {} with top_k={}",
576 namespace, request.top_k
577 );
578
579 let response = self.client.post(&url).json(&request).send().await?;
580 self.handle_response(response).await
581 }
582
583 #[instrument(skip(self, vector))]
587 pub async fn unified_vector_search(
588 &self,
589 namespace: &str,
590 vector: Vec<f32>,
591 top_k: usize,
592 ) -> Result<UnifiedQueryResponse> {
593 self.unified_query(namespace, UnifiedQueryRequest::vector_search(vector, top_k))
594 .await
595 }
596
597 #[instrument(skip(self))]
601 pub async fn unified_text_search(
602 &self,
603 namespace: &str,
604 field: &str,
605 query: &str,
606 top_k: usize,
607 ) -> Result<UnifiedQueryResponse> {
608 self.unified_query(
609 namespace,
610 UnifiedQueryRequest::fulltext_search(field, query, top_k),
611 )
612 .await
613 }
614
615 #[instrument(skip(self, request), fields(namespace = %namespace))]
652 pub async fn explain_query(
653 &self,
654 namespace: &str,
655 request: QueryExplainRequest,
656 ) -> Result<QueryExplainResponse> {
657 let url = format!("{}/v1/namespaces/{}/explain", self.base_url, namespace);
658 debug!(
659 "Explaining query in namespace {} (query_type={:?}, top_k={})",
660 namespace, request.query_type, request.top_k
661 );
662
663 let response = self.client.post(&url).json(&request).send().await?;
664 self.handle_response(response).await
665 }
666
667 #[instrument(skip(self, request), fields(namespace = %request.namespace, priority = ?request.priority))]
695 pub async fn warm_cache(&self, request: WarmCacheRequest) -> Result<WarmCacheResponse> {
696 let url = format!(
697 "{}/v1/namespaces/{}/cache/warm",
698 self.base_url, request.namespace
699 );
700 debug!(
701 "Warming cache for namespace {} with priority {:?}",
702 request.namespace, request.priority
703 );
704
705 let response = self.client.post(&url).json(&request).send().await?;
706 self.handle_response(response).await
707 }
708
709 #[instrument(skip(self, vector_ids))]
711 pub async fn warm_vectors(
712 &self,
713 namespace: &str,
714 vector_ids: Vec<String>,
715 ) -> Result<WarmCacheResponse> {
716 self.warm_cache(WarmCacheRequest::new(namespace).with_vector_ids(vector_ids))
717 .await
718 }
719
720 #[instrument(skip(self, request), fields(namespace = %namespace))]
753 pub async fn export(&self, namespace: &str, request: ExportRequest) -> Result<ExportResponse> {
754 let url = format!("{}/v1/namespaces/{}/export", self.base_url, namespace);
755 debug!(
756 "Exporting vectors from namespace {} (top_k={}, cursor={:?})",
757 namespace, request.top_k, request.cursor
758 );
759
760 let response = self.client.post(&url).json(&request).send().await?;
761 self.handle_response(response).await
762 }
763
764 #[instrument(skip(self))]
768 pub async fn export_all(&self, namespace: &str) -> Result<ExportResponse> {
769 self.export(namespace, ExportRequest::new()).await
770 }
771
772 #[instrument(skip(self))]
778 pub async fn diagnostics(&self) -> Result<SystemDiagnostics> {
779 let url = format!("{}/ops/diagnostics", self.base_url);
780 let response = self.client.get(&url).send().await?;
781 self.handle_response(response).await
782 }
783
784 #[instrument(skip(self))]
786 pub async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
787 let url = format!("{}/ops/jobs", self.base_url);
788 let response = self.client.get(&url).send().await?;
789 self.handle_response(response).await
790 }
791
792 #[instrument(skip(self))]
794 pub async fn get_job(&self, job_id: &str) -> Result<Option<JobInfo>> {
795 let url = format!("{}/ops/jobs/{}", self.base_url, job_id);
796 let response = self.client.get(&url).send().await?;
797
798 if response.status() == StatusCode::NOT_FOUND {
799 return Ok(None);
800 }
801
802 self.handle_response(response).await.map(Some)
803 }
804
805 #[instrument(skip(self, request))]
807 pub async fn compact(&self, request: CompactionRequest) -> Result<CompactionResponse> {
808 let url = format!("{}/ops/compact", self.base_url);
809 let response = self.client.post(&url).json(&request).send().await?;
810 self.handle_response(response).await
811 }
812
813 #[instrument(skip(self))]
815 pub async fn shutdown(&self) -> Result<()> {
816 let url = format!("{}/ops/shutdown", self.base_url);
817 let response = self.client.post(&url).send().await?;
818
819 if response.status().is_success() {
820 Ok(())
821 } else {
822 let status = response.status().as_u16();
823 let text = response.text().await.unwrap_or_default();
824 Err(ClientError::Server {
825 status,
826 message: text,
827 code: None,
828 })
829 }
830 }
831
832 #[instrument(skip(self, request), fields(id_count = request.ids.len()))]
838 pub async fn fetch(&self, namespace: &str, request: FetchRequest) -> Result<FetchResponse> {
839 let url = format!("{}/v1/namespaces/{}/fetch", self.base_url, namespace);
840 debug!("Fetching {} vectors from {}", request.ids.len(), namespace);
841 let response = self.client.post(&url).json(&request).send().await?;
842 self.handle_response(response).await
843 }
844
845 #[instrument(skip(self))]
847 pub async fn fetch_by_ids(&self, namespace: &str, ids: &[&str]) -> Result<Vec<Vector>> {
848 let request = FetchRequest::new(ids.iter().map(|s| s.to_string()).collect());
849 self.fetch(namespace, request).await.map(|r| r.vectors)
850 }
851
852 #[instrument(skip(self, request), fields(doc_count = request.documents.len()))]
858 pub async fn upsert_text(
859 &self,
860 namespace: &str,
861 request: UpsertTextRequest,
862 ) -> Result<TextUpsertResponse> {
863 let url = format!("{}/v1/namespaces/{}/upsert-text", self.base_url, namespace);
864 debug!(
865 "Upserting {} text documents to {}",
866 request.documents.len(),
867 namespace
868 );
869 let response = self.client.post(&url).json(&request).send().await?;
870 self.handle_response(response).await
871 }
872
873 #[instrument(skip(self, request), fields(top_k = request.top_k))]
875 pub async fn query_text(
876 &self,
877 namespace: &str,
878 request: QueryTextRequest,
879 ) -> Result<TextQueryResponse> {
880 let url = format!("{}/v1/namespaces/{}/query-text", self.base_url, namespace);
881 debug!("Text query in {} for: {}", namespace, request.text);
882 let response = self.client.post(&url).json(&request).send().await?;
883 self.handle_response(response).await
884 }
885
886 #[instrument(skip(self))]
888 pub async fn query_text_simple(
889 &self,
890 namespace: &str,
891 text: &str,
892 top_k: u32,
893 ) -> Result<TextQueryResponse> {
894 self.query_text(namespace, QueryTextRequest::new(text, top_k))
895 .await
896 }
897
898 #[instrument(skip(self, request), fields(query_count = request.queries.len()))]
900 pub async fn batch_query_text(
901 &self,
902 namespace: &str,
903 request: BatchQueryTextRequest,
904 ) -> Result<BatchQueryTextResponse> {
905 let url = format!(
906 "{}/v1/namespaces/{}/batch-query-text",
907 self.base_url, namespace
908 );
909 debug!(
910 "Batch text query in {} with {} queries",
911 namespace,
912 request.queries.len()
913 );
914 let response = self.client.post(&url).json(&request).send().await?;
915 self.handle_response(response).await
916 }
917
918 #[instrument(skip(self, config))]
927 pub async fn configure_namespace_ner(
928 &self,
929 namespace: &str,
930 config: NamespaceNerConfig,
931 ) -> Result<serde_json::Value> {
932 let url = format!("{}/v1/namespaces/{}/config", self.base_url, namespace);
933 let response = self.client.patch(&url).json(&config).send().await?;
934 self.handle_response(response).await
935 }
936
937 #[instrument(skip(self, text, entity_types))]
942 pub async fn extract_entities(
943 &self,
944 text: &str,
945 entity_types: Option<Vec<String>>,
946 ) -> Result<EntityExtractionResponse> {
947 let url = format!("{}/v1/memories/extract", self.base_url);
948 let body = serde_json::json!({
949 "text": text,
950 "entity_types": entity_types,
951 });
952 let response = self.client.post(&url).json(&body).send().await?;
953 self.handle_response(response).await
954 }
955
956 #[instrument(skip(self))]
960 pub async fn memory_entities(&self, memory_id: &str) -> Result<MemoryEntitiesResponse> {
961 let url = format!("{}/v1/memory/entities/{}", self.base_url, memory_id);
962 let response = self.client.get(&url).send().await?;
963 self.handle_response(response).await
964 }
965
966 pub fn last_rate_limit_headers(&self) -> Option<RateLimitHeaders> {
974 self.last_rate_limit.lock().ok()?.clone()
975 }
976
977 pub(crate) async fn handle_response<T: serde::de::DeserializeOwned>(
979 &self,
980 response: reqwest::Response,
981 ) -> Result<T> {
982 let status = response.status();
983
984 if let Ok(mut guard) = self.last_rate_limit.lock() {
986 *guard = Some(RateLimitHeaders::from_response(&response));
987 }
988
989 if status.is_success() {
990 Ok(response.json().await?)
991 } else {
992 let status_code = status.as_u16();
993 let retry_after = response
995 .headers()
996 .get("Retry-After")
997 .and_then(|v| v.to_str().ok())
998 .and_then(|s| s.parse::<u64>().ok());
999 let text = response.text().await.unwrap_or_default();
1000
1001 if status_code == 429 {
1002 return Err(ClientError::RateLimitExceeded { retry_after });
1003 }
1004
1005 #[derive(Deserialize)]
1006 struct ErrorBody {
1007 error: Option<String>,
1008 code: Option<ServerErrorCode>,
1009 }
1010
1011 let (message, code) = if let Ok(body) = serde_json::from_str::<ErrorBody>(&text) {
1012 (body.error.unwrap_or_else(|| text.clone()), body.code)
1013 } else {
1014 (text, None)
1015 };
1016
1017 match status_code {
1018 401 => Err(ClientError::Server {
1019 status: 401,
1020 message,
1021 code,
1022 }),
1023 403 => Err(ClientError::Authorization {
1024 status: 403,
1025 message,
1026 code,
1027 }),
1028 404 => match &code {
1029 Some(ServerErrorCode::NamespaceNotFound) => {
1030 Err(ClientError::NamespaceNotFound(message))
1031 }
1032 Some(ServerErrorCode::VectorNotFound) => {
1033 Err(ClientError::VectorNotFound(message))
1034 }
1035 _ => Err(ClientError::Server {
1036 status: 404,
1037 message,
1038 code,
1039 }),
1040 },
1041 _ => Err(ClientError::Server {
1042 status: status_code,
1043 message,
1044 code,
1045 }),
1046 }
1047 }
1048 }
1049
1050 pub(crate) async fn handle_text_response(&self, response: reqwest::Response) -> Result<String> {
1052 let status = response.status();
1053
1054 if let Ok(mut guard) = self.last_rate_limit.lock() {
1056 *guard = Some(RateLimitHeaders::from_response(&response));
1057 }
1058
1059 let retry_after = response
1060 .headers()
1061 .get("Retry-After")
1062 .and_then(|v| v.to_str().ok())
1063 .and_then(|s| s.parse::<u64>().ok());
1064 let text = response.text().await.unwrap_or_default();
1065
1066 if status.is_success() {
1067 return Ok(text);
1068 }
1069
1070 let status_code = status.as_u16();
1071
1072 if status_code == 429 {
1073 return Err(ClientError::RateLimitExceeded { retry_after });
1074 }
1075
1076 #[derive(Deserialize)]
1077 struct ErrorBody {
1078 error: Option<String>,
1079 code: Option<ServerErrorCode>,
1080 }
1081
1082 let (message, code) = if let Ok(body) = serde_json::from_str::<ErrorBody>(&text) {
1083 (body.error.unwrap_or_else(|| text.clone()), body.code)
1084 } else {
1085 (text, None)
1086 };
1087
1088 match status_code {
1089 401 => Err(ClientError::Server {
1090 status: 401,
1091 message,
1092 code,
1093 }),
1094 403 => Err(ClientError::Authorization {
1095 status: 403,
1096 message,
1097 code,
1098 }),
1099 _ => Err(ClientError::Server {
1100 status: status_code,
1101 message,
1102 code,
1103 }),
1104 }
1105 }
1106
1107 #[allow(dead_code)]
1115 pub(crate) async fn execute_with_retry<F, Fut, T>(&self, f: F) -> Result<T>
1116 where
1117 F: Fn() -> Fut,
1118 Fut: std::future::Future<Output = Result<T>>,
1119 {
1120 let rc = &self.retry_config;
1121
1122 for attempt in 0..rc.max_retries {
1123 match f().await {
1124 Ok(v) => return Ok(v),
1125 Err(e) => {
1126 let is_last = attempt == rc.max_retries - 1;
1127 if is_last || !e.is_retryable() {
1128 return Err(e);
1129 }
1130
1131 let wait = match &e {
1132 ClientError::RateLimitExceeded {
1133 retry_after: Some(secs),
1134 } => Duration::from_secs(*secs),
1135 _ => {
1136 let base_ms = rc.base_delay.as_millis() as f64;
1137 let backoff_ms = base_ms * 2f64.powi(attempt as i32);
1138 let capped_ms = backoff_ms.min(rc.max_delay.as_millis() as f64);
1139 let final_ms = if rc.jitter {
1140 let seed = (attempt as u64).wrapping_mul(6364136223846793005);
1142 let factor = 0.5 + (seed % 1000) as f64 / 1000.0;
1143 capped_ms * factor
1144 } else {
1145 capped_ms
1146 };
1147 Duration::from_millis(final_ms as u64)
1148 }
1149 };
1150
1151 tokio::time::sleep(wait).await;
1152 }
1153 }
1154 }
1155
1156 Err(ClientError::Config("retry loop exhausted".to_string()))
1158 }
1159}
1160
1161impl DakeraClient {
1166 pub async fn ode_extract_entities(
1178 &self,
1179 req: ExtractEntitiesRequest,
1180 ) -> Result<ExtractEntitiesResponse> {
1181 let ode_url = self.ode_url.as_deref().ok_or_else(|| {
1182 ClientError::Config(
1183 "ode_url must be configured to use extract_entities(). \
1184 Call .ode_url(\"http://localhost:8080\") on the builder."
1185 .to_string(),
1186 )
1187 })?;
1188 let url = format!("{}/ode/extract", ode_url);
1189 let response = self.client.post(&url).json(&req).send().await?;
1190 if response.status().is_success() {
1191 Ok(response.json::<ExtractEntitiesResponse>().await?)
1192 } else {
1193 let status = response.status().as_u16();
1194 let body = response.text().await.unwrap_or_default();
1195 Err(ClientError::Server {
1196 status,
1197 message: format!("ODE sidecar error: {}", body),
1198 code: None,
1199 })
1200 }
1201 }
1202
1203 #[instrument(skip(self))]
1215 pub async fn get_memory_policy(&self, namespace: &str) -> Result<MemoryPolicy> {
1216 let url = format!(
1217 "{}/v1/namespaces/{}/memory_policy",
1218 self.base_url,
1219 urlencoding::encode(namespace)
1220 );
1221 let response = self.client.get(&url).send().await?;
1222 self.handle_response(response).await
1223 }
1224
1225 #[instrument(skip(self, policy))]
1232 pub async fn set_memory_policy(
1233 &self,
1234 namespace: &str,
1235 policy: MemoryPolicy,
1236 ) -> Result<MemoryPolicy> {
1237 let url = format!(
1238 "{}/v1/namespaces/{}/memory_policy",
1239 self.base_url,
1240 urlencoding::encode(namespace)
1241 );
1242 let response = self.client.put(&url).json(&policy).send().await?;
1243 self.handle_response(response).await
1244 }
1245}
1246
1247#[derive(Debug)]
1249pub struct DakeraClientBuilder {
1250 base_url: String,
1251 ode_url: Option<String>,
1252 timeout: Duration,
1253 connect_timeout: Option<Duration>,
1254 retry_config: RetryConfig,
1255 user_agent: Option<String>,
1256}
1257
1258impl DakeraClientBuilder {
1259 pub fn new(base_url: impl Into<String>) -> Self {
1261 Self {
1262 base_url: base_url.into(),
1263 ode_url: None,
1264 timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
1265 connect_timeout: None,
1266 retry_config: RetryConfig::default(),
1267 user_agent: None,
1268 }
1269 }
1270
1271 pub fn ode_url(mut self, ode_url: impl Into<String>) -> Self {
1275 self.ode_url = Some(ode_url.into().trim_end_matches('/').to_string());
1276 self
1277 }
1278
1279 pub fn timeout(mut self, timeout: Duration) -> Self {
1281 self.timeout = timeout;
1282 self
1283 }
1284
1285 pub fn timeout_secs(mut self, secs: u64) -> Self {
1287 self.timeout = Duration::from_secs(secs);
1288 self
1289 }
1290
1291 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
1293 self.connect_timeout = Some(timeout);
1294 self
1295 }
1296
1297 pub fn retry_config(mut self, config: RetryConfig) -> Self {
1299 self.retry_config = config;
1300 self
1301 }
1302
1303 pub fn max_retries(mut self, max_retries: u32) -> Self {
1305 self.retry_config.max_retries = max_retries;
1306 self
1307 }
1308
1309 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
1311 self.user_agent = Some(user_agent.into());
1312 self
1313 }
1314
1315 pub fn build(self) -> Result<DakeraClient> {
1317 let base_url = self.base_url.trim_end_matches('/').to_string();
1319
1320 if !base_url.starts_with("http://") && !base_url.starts_with("https://") {
1322 return Err(ClientError::InvalidUrl(
1323 "URL must start with http:// or https://".to_string(),
1324 ));
1325 }
1326
1327 let user_agent = self
1328 .user_agent
1329 .unwrap_or_else(|| format!("dakera-client/{}", env!("CARGO_PKG_VERSION")));
1330
1331 let connect_timeout = self.connect_timeout.unwrap_or(self.timeout);
1332
1333 let client = Client::builder()
1334 .timeout(self.timeout)
1335 .connect_timeout(connect_timeout)
1336 .user_agent(user_agent)
1337 .build()
1338 .map_err(|e| ClientError::Config(e.to_string()))?;
1339
1340 Ok(DakeraClient {
1341 client,
1342 base_url,
1343 ode_url: self.ode_url,
1344 retry_config: self.retry_config,
1345 last_rate_limit: Arc::new(Mutex::new(None)),
1346 })
1347 }
1348}
1349
1350impl DakeraClient {
1355 pub async fn stream_namespace_events(
1380 &self,
1381 namespace: &str,
1382 ) -> Result<tokio::sync::mpsc::Receiver<Result<crate::events::DakeraEvent>>> {
1383 let url = format!(
1384 "{}/v1/namespaces/{}/events",
1385 self.base_url,
1386 urlencoding::encode(namespace)
1387 );
1388 self.stream_sse(url).await
1389 }
1390
1391 pub async fn stream_global_events(
1398 &self,
1399 ) -> Result<tokio::sync::mpsc::Receiver<Result<crate::events::DakeraEvent>>> {
1400 let url = format!("{}/ops/events", self.base_url);
1401 self.stream_sse(url).await
1402 }
1403
1404 pub async fn stream_memory_events(
1413 &self,
1414 ) -> Result<tokio::sync::mpsc::Receiver<Result<crate::events::MemoryEvent>>> {
1415 let url = format!("{}/v1/events/stream", self.base_url);
1416 self.stream_sse(url).await
1417 }
1418
1419 pub(crate) async fn stream_sse<T>(
1421 &self,
1422 url: String,
1423 ) -> Result<tokio::sync::mpsc::Receiver<Result<T>>>
1424 where
1425 T: serde::de::DeserializeOwned + Send + 'static,
1426 {
1427 use futures_util::StreamExt;
1428
1429 let response = self
1430 .client
1431 .get(&url)
1432 .header("Accept", "text/event-stream")
1433 .header("Cache-Control", "no-cache")
1434 .send()
1435 .await?;
1436
1437 if !response.status().is_success() {
1438 let status = response.status().as_u16();
1439 let body = response.text().await.unwrap_or_default();
1440 return Err(ClientError::Server {
1441 status,
1442 message: body,
1443 code: None,
1444 });
1445 }
1446
1447 let (tx, rx) = tokio::sync::mpsc::channel(64);
1448
1449 tokio::spawn(async move {
1450 let mut byte_stream = response.bytes_stream();
1451 let mut remaining = String::new();
1452 let mut data_lines: Vec<String> = Vec::new();
1453
1454 while let Some(chunk) = byte_stream.next().await {
1455 match chunk {
1456 Ok(bytes) => {
1457 remaining.push_str(&String::from_utf8_lossy(&bytes));
1458 while let Some(pos) = remaining.find('\n') {
1459 let raw = &remaining[..pos];
1460 let line = raw.trim_end_matches('\r').to_string();
1461 remaining = remaining[pos + 1..].to_string();
1462
1463 if line.starts_with(':') {
1464 } else if let Some(data) = line.strip_prefix("data:") {
1466 data_lines.push(data.trim_start().to_string());
1467 } else if line.is_empty() {
1468 if !data_lines.is_empty() {
1469 let payload = data_lines.join("\n");
1470 data_lines.clear();
1471 let result = serde_json::from_str::<T>(&payload)
1472 .map_err(ClientError::Json);
1473 if tx.send(result).await.is_err() {
1474 return; }
1476 }
1477 } else {
1478 }
1480 }
1481 }
1482 Err(e) => {
1483 let _ = tx.send(Err(ClientError::Http(e))).await;
1484 return;
1485 }
1486 }
1487 }
1488 });
1489
1490 Ok(rx)
1491 }
1492}
1493
1494#[cfg(test)]
1495mod tests {
1496 use super::*;
1497
1498 #[test]
1499 fn test_client_builder() {
1500 let client = DakeraClient::new("http://localhost:3000");
1501 assert!(client.is_ok());
1502 }
1503
1504 #[test]
1505 fn test_client_builder_with_options() {
1506 let client = DakeraClient::builder("http://localhost:3000")
1507 .timeout_secs(60)
1508 .user_agent("test-client/1.0")
1509 .build();
1510 assert!(client.is_ok());
1511 }
1512
1513 #[test]
1514 fn test_client_builder_invalid_url() {
1515 let client = DakeraClient::new("invalid-url");
1516 assert!(client.is_err());
1517 }
1518
1519 #[test]
1520 fn test_client_builder_trailing_slash() {
1521 let client = DakeraClient::new("http://localhost:3000/").unwrap();
1522 assert!(!client.base_url.ends_with('/'));
1523 }
1524
1525 #[test]
1526 fn test_vector_creation() {
1527 let v = Vector::new("test", vec![0.1, 0.2, 0.3]);
1528 assert_eq!(v.id, "test");
1529 assert_eq!(v.values.len(), 3);
1530 assert!(v.metadata.is_none());
1531 }
1532
1533 #[test]
1534 fn test_query_request_builder() {
1535 let req = QueryRequest::new(vec![0.1, 0.2], 10)
1536 .with_filter(serde_json::json!({"category": "test"}))
1537 .include_metadata(false);
1538
1539 assert_eq!(req.top_k, 10);
1540 assert!(req.filter.is_some());
1541 assert!(!req.include_metadata);
1542 }
1543
1544 #[test]
1545 fn test_hybrid_search_request() {
1546 let req = HybridSearchRequest::new(vec![0.1], "test query", 5).with_vector_weight(0.7);
1547
1548 assert_eq!(req.vector_weight, 0.7);
1549 assert_eq!(req.text, "test query");
1550 assert!(req.vector.is_some());
1551 }
1552
1553 #[test]
1554 fn test_hybrid_search_weight_clamping() {
1555 let req = HybridSearchRequest::new(vec![0.1], "test", 5).with_vector_weight(1.5); assert_eq!(req.vector_weight, 1.0);
1558 }
1559
1560 #[test]
1561 fn test_hybrid_search_text_only() {
1562 let req = HybridSearchRequest::text_only("bm25 query", 10);
1563
1564 assert!(req.vector.is_none());
1565 assert_eq!(req.text, "bm25 query");
1566 assert_eq!(req.top_k, 10);
1567 let json = serde_json::to_value(&req).unwrap();
1569 assert!(json.get("vector").is_none());
1570 }
1571
1572 #[test]
1573 fn test_text_document_builder() {
1574 let doc = TextDocument::new("doc1", "Hello world").with_ttl(3600);
1575
1576 assert_eq!(doc.id, "doc1");
1577 assert_eq!(doc.text, "Hello world");
1578 assert_eq!(doc.ttl_seconds, Some(3600));
1579 assert!(doc.metadata.is_none());
1580 }
1581
1582 #[test]
1583 fn test_upsert_text_request_builder() {
1584 let docs = vec![
1585 TextDocument::new("doc1", "Hello"),
1586 TextDocument::new("doc2", "World"),
1587 ];
1588 let req = UpsertTextRequest::new(docs).with_model(EmbeddingModel::BgeSmall);
1589
1590 assert_eq!(req.documents.len(), 2);
1591 assert_eq!(req.model, Some(EmbeddingModel::BgeSmall));
1592 }
1593
1594 #[test]
1595 fn test_query_text_request_builder() {
1596 let req = QueryTextRequest::new("semantic search query", 5)
1597 .with_filter(serde_json::json!({"category": "docs"}))
1598 .include_vectors(true)
1599 .with_model(EmbeddingModel::E5Small);
1600
1601 assert_eq!(req.text, "semantic search query");
1602 assert_eq!(req.top_k, 5);
1603 assert!(req.filter.is_some());
1604 assert!(req.include_vectors);
1605 assert_eq!(req.model, Some(EmbeddingModel::E5Small));
1606 }
1607
1608 #[test]
1609 fn test_fetch_request_builder() {
1610 let req = FetchRequest::new(vec!["id1".to_string(), "id2".to_string()]);
1611
1612 assert_eq!(req.ids.len(), 2);
1613 assert!(req.include_values);
1614 assert!(req.include_metadata);
1615 }
1616
1617 #[test]
1618 fn test_create_namespace_request_builder() {
1619 let req = CreateNamespaceRequest::new()
1620 .with_dimensions(384)
1621 .with_index_type("hnsw");
1622
1623 assert_eq!(req.dimensions, Some(384));
1624 assert_eq!(req.index_type.as_deref(), Some("hnsw"));
1625 }
1626
1627 #[test]
1628 fn test_batch_query_text_request() {
1629 let req =
1630 BatchQueryTextRequest::new(vec!["query one".to_string(), "query two".to_string()], 10);
1631
1632 assert_eq!(req.queries.len(), 2);
1633 assert_eq!(req.top_k, 10);
1634 assert!(!req.include_vectors);
1635 assert!(req.model.is_none());
1636 }
1637
1638 #[test]
1643 fn test_retry_config_defaults() {
1644 let rc = RetryConfig::default();
1645 assert_eq!(rc.max_retries, 3);
1646 assert_eq!(rc.base_delay, Duration::from_millis(100));
1647 assert_eq!(rc.max_delay, Duration::from_secs(60));
1648 assert!(rc.jitter);
1649 }
1650
1651 #[test]
1652 fn test_builder_connect_timeout() {
1653 let client = DakeraClient::builder("http://localhost:3000")
1654 .connect_timeout(Duration::from_secs(5))
1655 .timeout_secs(30)
1656 .build()
1657 .unwrap();
1658 assert!(client.base_url.starts_with("http"));
1660 }
1661
1662 #[test]
1663 fn test_builder_max_retries() {
1664 let client = DakeraClient::builder("http://localhost:3000")
1665 .max_retries(5)
1666 .build()
1667 .unwrap();
1668 assert_eq!(client.retry_config.max_retries, 5);
1669 }
1670
1671 #[test]
1672 fn test_builder_retry_config() {
1673 let rc = RetryConfig {
1674 max_retries: 7,
1675 base_delay: Duration::from_millis(200),
1676 max_delay: Duration::from_secs(30),
1677 jitter: false,
1678 };
1679 let client = DakeraClient::builder("http://localhost:3000")
1680 .retry_config(rc)
1681 .build()
1682 .unwrap();
1683 assert_eq!(client.retry_config.max_retries, 7);
1684 assert!(!client.retry_config.jitter);
1685 }
1686
1687 #[test]
1688 fn test_rate_limit_error_retryable() {
1689 let e = ClientError::RateLimitExceeded { retry_after: None };
1690 assert!(e.is_retryable());
1691 }
1692
1693 #[test]
1694 fn test_rate_limit_error_with_retry_after_zero() {
1695 let e = ClientError::RateLimitExceeded {
1697 retry_after: Some(0),
1698 };
1699 assert!(e.is_retryable());
1700 if let ClientError::RateLimitExceeded {
1701 retry_after: Some(secs),
1702 } = &e
1703 {
1704 assert_eq!(*secs, 0u64);
1705 } else {
1706 panic!("unexpected variant");
1707 }
1708 }
1709
1710 #[tokio::test]
1711 async fn test_execute_with_retry_succeeds_immediately() {
1712 let client = DakeraClient::builder("http://localhost:3000")
1713 .max_retries(3)
1714 .build()
1715 .unwrap();
1716
1717 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1718 let cc = call_count.clone();
1719 let result = client
1720 .execute_with_retry(|| {
1721 let cc = cc.clone();
1722 async move {
1723 cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1724 Ok::<u32, ClientError>(42)
1725 }
1726 })
1727 .await;
1728 assert_eq!(result.unwrap(), 42);
1729 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
1730 }
1731
1732 #[tokio::test]
1733 async fn test_execute_with_retry_no_retry_on_4xx() {
1734 let client = DakeraClient::builder("http://localhost:3000")
1735 .max_retries(3)
1736 .build()
1737 .unwrap();
1738
1739 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1740 let cc = call_count.clone();
1741 let result = client
1742 .execute_with_retry(|| {
1743 let cc = cc.clone();
1744 async move {
1745 cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1746 Err::<u32, ClientError>(ClientError::Server {
1747 status: 400,
1748 message: "bad request".to_string(),
1749 code: None,
1750 })
1751 }
1752 })
1753 .await;
1754 assert!(result.is_err());
1755 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
1757 }
1758
1759 #[tokio::test]
1760 async fn test_execute_with_retry_retries_on_5xx() {
1761 let client = DakeraClient::builder("http://localhost:3000")
1762 .retry_config(RetryConfig {
1763 max_retries: 3,
1764 base_delay: Duration::from_millis(0),
1765 max_delay: Duration::from_millis(0),
1766 jitter: false,
1767 })
1768 .build()
1769 .unwrap();
1770
1771 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1772 let cc = call_count.clone();
1773 let result = client
1774 .execute_with_retry(|| {
1775 let cc = cc.clone();
1776 async move {
1777 let n = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1778 if n < 2 {
1779 Err::<u32, ClientError>(ClientError::Server {
1780 status: 503,
1781 message: "unavailable".to_string(),
1782 code: None,
1783 })
1784 } else {
1785 Ok(99)
1786 }
1787 }
1788 })
1789 .await;
1790 assert_eq!(result.unwrap(), 99);
1791 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
1792 }
1793
1794 #[test]
1799 fn test_batch_recall_request_new() {
1800 use crate::memory::BatchRecallRequest;
1801 let req = BatchRecallRequest::new("agent-1");
1802 assert_eq!(req.agent_id, "agent-1");
1803 assert_eq!(req.limit, 100);
1804 }
1805
1806 #[test]
1807 fn test_batch_recall_request_builder() {
1808 use crate::memory::{BatchMemoryFilter, BatchRecallRequest};
1809 let filter = BatchMemoryFilter::default()
1810 .with_tags(vec!["qa".to_string()])
1811 .with_min_importance(0.7);
1812 let req = BatchRecallRequest::new("agent-1")
1813 .with_filter(filter)
1814 .with_limit(50);
1815 assert_eq!(req.agent_id, "agent-1");
1816 assert_eq!(req.limit, 50);
1817 assert_eq!(
1818 req.filter.tags.as_deref(),
1819 Some(["qa".to_string()].as_slice())
1820 );
1821 assert_eq!(req.filter.min_importance, Some(0.7));
1822 }
1823
1824 #[test]
1825 fn test_batch_recall_request_serialization() {
1826 use crate::memory::{BatchMemoryFilter, BatchRecallRequest};
1827 let filter = BatchMemoryFilter::default().with_min_importance(0.5);
1828 let req = BatchRecallRequest::new("agent-1")
1829 .with_filter(filter)
1830 .with_limit(25);
1831 let json = serde_json::to_value(&req).unwrap();
1832 assert_eq!(json["agent_id"], "agent-1");
1833 assert_eq!(json["limit"], 25);
1834 assert_eq!(json["filter"]["min_importance"], 0.5);
1835 }
1836
1837 #[test]
1838 fn test_batch_forget_request_new() {
1839 use crate::memory::{BatchForgetRequest, BatchMemoryFilter};
1840 let filter = BatchMemoryFilter::default().with_min_importance(0.1);
1841 let req = BatchForgetRequest::new("agent-1", filter);
1842 assert_eq!(req.agent_id, "agent-1");
1843 assert_eq!(req.filter.min_importance, Some(0.1));
1844 }
1845
1846 #[test]
1847 fn test_batch_forget_request_serialization() {
1848 use crate::memory::{BatchForgetRequest, BatchMemoryFilter};
1849 let filter = BatchMemoryFilter {
1850 created_before: Some(1_700_000_000),
1851 ..Default::default()
1852 };
1853 let req = BatchForgetRequest::new("agent-1", filter);
1854 let json = serde_json::to_value(&req).unwrap();
1855 assert_eq!(json["agent_id"], "agent-1");
1856 assert_eq!(json["filter"]["created_before"], 1_700_000_000u64);
1857 }
1858
1859 #[test]
1860 fn test_batch_recall_response_deserialization() {
1861 use crate::memory::BatchRecallResponse;
1862 let json = serde_json::json!({
1863 "memories": [],
1864 "total": 42,
1865 "filtered": 7
1866 });
1867 let resp: BatchRecallResponse = serde_json::from_value(json).unwrap();
1868 assert_eq!(resp.total, 42);
1869 assert_eq!(resp.filtered, 7);
1870 assert!(resp.memories.is_empty());
1871 }
1872
1873 #[test]
1874 fn test_batch_forget_response_deserialization() {
1875 use crate::memory::BatchForgetResponse;
1876 let json = serde_json::json!({ "deleted_count": 13 });
1877 let resp: BatchForgetResponse = serde_json::from_value(json).unwrap();
1878 assert_eq!(resp.deleted_count, 13);
1879 }
1880
1881 #[test]
1886 fn test_rate_limit_headers_default_all_none() {
1887 use crate::types::RateLimitHeaders;
1888 let rl = RateLimitHeaders {
1889 limit: None,
1890 remaining: None,
1891 reset: None,
1892 quota_used: None,
1893 quota_limit: None,
1894 };
1895 assert!(rl.limit.is_none());
1896 assert!(rl.remaining.is_none());
1897 assert!(rl.reset.is_none());
1898 assert!(rl.quota_used.is_none());
1899 assert!(rl.quota_limit.is_none());
1900 }
1901
1902 #[test]
1903 fn test_rate_limit_headers_populated() {
1904 use crate::types::RateLimitHeaders;
1905 let rl = RateLimitHeaders {
1906 limit: Some(1000),
1907 remaining: Some(750),
1908 reset: Some(1_700_000_060),
1909 quota_used: Some(500),
1910 quota_limit: Some(10_000),
1911 };
1912 assert_eq!(rl.limit, Some(1000));
1913 assert_eq!(rl.remaining, Some(750));
1914 assert_eq!(rl.reset, Some(1_700_000_060));
1915 assert_eq!(rl.quota_used, Some(500));
1916 assert_eq!(rl.quota_limit, Some(10_000));
1917 }
1918
1919 #[test]
1920 fn test_last_rate_limit_headers_initially_none() {
1921 let client = DakeraClient::new("http://localhost:3000").unwrap();
1922 assert!(client.last_rate_limit_headers().is_none());
1923 }
1924
1925 #[test]
1930 fn test_namespace_ner_config_default() {
1931 use crate::types::NamespaceNerConfig;
1932 let cfg = NamespaceNerConfig::default();
1933 assert!(!cfg.extract_entities);
1934 assert!(cfg.entity_types.is_none());
1935 }
1936
1937 #[test]
1938 fn test_namespace_ner_config_serialization_skip_none() {
1939 use crate::types::NamespaceNerConfig;
1940 let cfg = NamespaceNerConfig {
1941 extract_entities: true,
1942 entity_types: None,
1943 };
1944 let json = serde_json::to_value(&cfg).unwrap();
1945 assert_eq!(json["extract_entities"], true);
1946 assert!(json.get("entity_types").is_none());
1948 }
1949
1950 #[test]
1951 fn test_namespace_ner_config_serialization_with_types() {
1952 use crate::types::NamespaceNerConfig;
1953 let cfg = NamespaceNerConfig {
1954 extract_entities: true,
1955 entity_types: Some(vec!["PERSON".to_string(), "ORG".to_string()]),
1956 };
1957 let json = serde_json::to_value(&cfg).unwrap();
1958 assert_eq!(json["extract_entities"], true);
1959 assert_eq!(json["entity_types"][0], "PERSON");
1960 assert_eq!(json["entity_types"][1], "ORG");
1961 }
1962
1963 #[test]
1964 fn test_extracted_entity_deserialization() {
1965 use crate::types::ExtractedEntity;
1966 let json = serde_json::json!({
1967 "entity_type": "PERSON",
1968 "value": "Alice",
1969 "score": 0.95
1970 });
1971 let entity: ExtractedEntity = serde_json::from_value(json).unwrap();
1972 assert_eq!(entity.entity_type, "PERSON");
1973 assert_eq!(entity.value, "Alice");
1974 assert!((entity.score - 0.95).abs() < f64::EPSILON);
1975 }
1976
1977 #[test]
1978 fn test_entity_extraction_response_deserialization() {
1979 use crate::types::EntityExtractionResponse;
1980 let json = serde_json::json!({
1981 "entities": [
1982 { "entity_type": "PERSON", "value": "Bob", "score": 0.9 },
1983 { "entity_type": "ORG", "value": "Acme", "score": 0.87 }
1984 ]
1985 });
1986 let resp: EntityExtractionResponse = serde_json::from_value(json).unwrap();
1987 assert_eq!(resp.entities.len(), 2);
1988 assert_eq!(resp.entities[0].entity_type, "PERSON");
1989 assert_eq!(resp.entities[1].value, "Acme");
1990 }
1991
1992 #[test]
1993 fn test_memory_entities_response_deserialization() {
1994 use crate::types::MemoryEntitiesResponse;
1995 let json = serde_json::json!({
1996 "memory_id": "mem-abc-123",
1997 "entities": [
1998 { "entity_type": "LOC", "value": "London", "score": 0.88 }
1999 ]
2000 });
2001 let resp: MemoryEntitiesResponse = serde_json::from_value(json).unwrap();
2002 assert_eq!(resp.memory_id, "mem-abc-123");
2003 assert_eq!(resp.entities.len(), 1);
2004 assert_eq!(resp.entities[0].entity_type, "LOC");
2005 assert_eq!(resp.entities[0].value, "London");
2006 }
2007
2008 #[test]
2009 fn test_configure_namespace_ner_url_pattern() {
2010 let client = DakeraClient::new("http://localhost:3000").unwrap();
2012 let expected = "http://localhost:3000/v1/namespaces/my-ns/config";
2013 let actual = format!("{}/v1/namespaces/{}/config", client.base_url, "my-ns");
2014 assert_eq!(actual, expected);
2015 }
2016
2017 #[test]
2018 fn test_extract_entities_url_pattern() {
2019 let client = DakeraClient::new("http://localhost:3000").unwrap();
2020 let expected = "http://localhost:3000/v1/memories/extract";
2021 let actual = format!("{}/v1/memories/extract", client.base_url);
2022 assert_eq!(actual, expected);
2023 }
2024
2025 #[test]
2026 fn test_memory_entities_url_pattern() {
2027 let client = DakeraClient::new("http://localhost:3000").unwrap();
2028 let memory_id = "mem-xyz-789";
2029 let expected = "http://localhost:3000/v1/memory/entities/mem-xyz-789";
2030 let actual = format!("{}/v1/memory/entities/{}", client.base_url, memory_id);
2031 assert_eq!(actual, expected);
2032 }
2033
2034 #[test]
2039 fn test_feedback_signal_serialization() {
2040 use crate::types::FeedbackSignal;
2041 let upvote = serde_json::to_value(FeedbackSignal::Upvote).unwrap();
2042 assert_eq!(upvote, serde_json::json!("upvote"));
2043 let downvote = serde_json::to_value(FeedbackSignal::Downvote).unwrap();
2044 assert_eq!(downvote, serde_json::json!("downvote"));
2045 let flag = serde_json::to_value(FeedbackSignal::Flag).unwrap();
2046 assert_eq!(flag, serde_json::json!("flag"));
2047 }
2048
2049 #[test]
2050 fn test_feedback_signal_deserialization() {
2051 use crate::types::FeedbackSignal;
2052 let signal: FeedbackSignal = serde_json::from_str("\"upvote\"").unwrap();
2053 assert_eq!(signal, FeedbackSignal::Upvote);
2054 let signal: FeedbackSignal = serde_json::from_str("\"positive\"").unwrap();
2055 assert_eq!(signal, FeedbackSignal::Positive);
2056 }
2057
2058 #[test]
2059 fn test_feedback_response_deserialization() {
2060 use crate::types::{FeedbackResponse, FeedbackSignal};
2061 let json = serde_json::json!({
2062 "memory_id": "mem-abc",
2063 "new_importance": 0.92,
2064 "signal": "upvote"
2065 });
2066 let resp: FeedbackResponse = serde_json::from_value(json).unwrap();
2067 assert_eq!(resp.memory_id, "mem-abc");
2068 assert!((resp.new_importance - 0.92).abs() < f32::EPSILON);
2069 assert_eq!(resp.signal, FeedbackSignal::Upvote);
2070 }
2071
2072 #[test]
2073 fn test_feedback_history_response_deserialization() {
2074 use crate::types::{FeedbackHistoryResponse, FeedbackSignal};
2075 let json = serde_json::json!({
2076 "memory_id": "mem-abc",
2077 "entries": [
2078 {"signal": "upvote", "timestamp": 1774000000_u64, "old_importance": 0.5, "new_importance": 0.575},
2079 {"signal": "downvote", "timestamp": 1774001000_u64, "old_importance": 0.575, "new_importance": 0.489}
2080 ]
2081 });
2082 let resp: FeedbackHistoryResponse = serde_json::from_value(json).unwrap();
2083 assert_eq!(resp.memory_id, "mem-abc");
2084 assert_eq!(resp.entries.len(), 2);
2085 assert_eq!(resp.entries[0].signal, FeedbackSignal::Upvote);
2086 assert_eq!(resp.entries[1].signal, FeedbackSignal::Downvote);
2087 }
2088
2089 #[test]
2090 fn test_agent_feedback_summary_deserialization() {
2091 use crate::types::AgentFeedbackSummary;
2092 let json = serde_json::json!({
2093 "agent_id": "agent-1",
2094 "upvotes": 42_u64,
2095 "downvotes": 7_u64,
2096 "flags": 2_u64,
2097 "total_feedback": 51_u64,
2098 "health_score": 0.78
2099 });
2100 let summary: AgentFeedbackSummary = serde_json::from_value(json).unwrap();
2101 assert_eq!(summary.agent_id, "agent-1");
2102 assert_eq!(summary.upvotes, 42);
2103 assert_eq!(summary.total_feedback, 51);
2104 assert!((summary.health_score - 0.78).abs() < f32::EPSILON);
2105 }
2106
2107 #[test]
2108 fn test_feedback_health_response_deserialization() {
2109 use crate::types::FeedbackHealthResponse;
2110 let json = serde_json::json!({
2111 "agent_id": "agent-1",
2112 "health_score": 0.78,
2113 "memory_count": 120_usize,
2114 "avg_importance": 0.72
2115 });
2116 let health: FeedbackHealthResponse = serde_json::from_value(json).unwrap();
2117 assert_eq!(health.agent_id, "agent-1");
2118 assert!((health.health_score - 0.78).abs() < f32::EPSILON);
2119 assert_eq!(health.memory_count, 120);
2120 }
2121
2122 #[test]
2123 fn test_memory_feedback_body_serialization() {
2124 use crate::types::{FeedbackSignal, MemoryFeedbackBody};
2125 let body = MemoryFeedbackBody {
2126 agent_id: "agent-1".to_string(),
2127 signal: FeedbackSignal::Flag,
2128 };
2129 let json = serde_json::to_value(body).unwrap();
2130 assert_eq!(json["agent_id"], "agent-1");
2131 assert_eq!(json["signal"], "flag");
2132 }
2133
2134 #[test]
2135 fn test_feedback_memory_url_pattern() {
2136 let client = DakeraClient::new("http://localhost:3000").unwrap();
2137 let memory_id = "mem-abc";
2138 let expected_post = "http://localhost:3000/v1/memories/mem-abc/feedback";
2139 let actual_post = format!("{}/v1/memories/{}/feedback", client.base_url, memory_id);
2140 assert_eq!(actual_post, expected_post);
2141
2142 let expected_patch = "http://localhost:3000/v1/memories/mem-abc/importance";
2143 let actual_patch = format!("{}/v1/memories/{}/importance", client.base_url, memory_id);
2144 assert_eq!(actual_patch, expected_patch);
2145 }
2146
2147 #[test]
2148 fn test_feedback_health_url_pattern() {
2149 let client = DakeraClient::new("http://localhost:3000").unwrap();
2150 let agent_id = "agent-1";
2151 let expected = "http://localhost:3000/v1/feedback/health?agent_id=agent-1";
2152 let actual = format!(
2153 "{}/v1/feedback/health?agent_id={}",
2154 client.base_url, agent_id
2155 );
2156 assert_eq!(actual, expected);
2157 }
2158
2159 #[test]
2161 fn test_ode_extract_entities_requires_ode_url() {
2162 let client = DakeraClient::new("http://localhost:3000").unwrap();
2164 let rt = tokio::runtime::Runtime::new().unwrap();
2165 let result = rt.block_on(client.ode_extract_entities(ExtractEntitiesRequest {
2166 content: "Alice lives in Paris.".to_string(),
2167 agent_id: "agent-1".to_string(),
2168 memory_id: None,
2169 entity_types: None,
2170 }));
2171 assert!(result.is_err());
2172 let err = result.unwrap_err();
2173 assert!(matches!(err, ClientError::Config(_)));
2174 }
2175
2176 #[test]
2177 fn test_ode_extract_entities_url_built_from_ode_url() {
2178 let client = DakeraClient::builder("http://localhost:3000")
2180 .ode_url("http://localhost:8080")
2181 .build()
2182 .unwrap();
2183 assert_eq!(client.ode_url.as_deref(), Some("http://localhost:8080"));
2184 let expected = "http://localhost:8080/ode/extract";
2185 let actual = format!("{}/ode/extract", client.ode_url.as_deref().unwrap());
2186 assert_eq!(actual, expected);
2187 }
2188
2189 #[test]
2190 fn test_extract_entities_request_serialization() {
2191 let req = ExtractEntitiesRequest {
2192 content: "Alice in Wonderland".to_string(),
2193 agent_id: "agent-42".to_string(),
2194 memory_id: Some("mem-001".to_string()),
2195 entity_types: Some(vec!["person".to_string(), "location".to_string()]),
2196 };
2197 let json = serde_json::to_string(&req).unwrap();
2198 assert!(json.contains("\"content\":\"Alice in Wonderland\""));
2199 assert!(json.contains("\"agent_id\":\"agent-42\""));
2200 assert!(json.contains("\"memory_id\":\"mem-001\""));
2201 assert!(json.contains("\"person\""));
2202 }
2203
2204 #[test]
2205 fn test_extract_entities_request_omits_none_fields() {
2206 let req = ExtractEntitiesRequest {
2207 content: "hello".to_string(),
2208 agent_id: "a".to_string(),
2209 memory_id: None,
2210 entity_types: None,
2211 };
2212 let json = serde_json::to_string(&req).unwrap();
2213 assert!(!json.contains("memory_id"));
2214 assert!(!json.contains("entity_types"));
2215 }
2216
2217 #[test]
2218 fn test_ode_entity_deserialization() {
2219 let json = r#"{"text":"Alice","label":"person","start":0,"end":5,"score":0.97}"#;
2220 let entity: OdeEntity = serde_json::from_str(json).unwrap();
2221 assert_eq!(entity.text, "Alice");
2222 assert_eq!(entity.label, "person");
2223 assert_eq!(entity.start, 0);
2224 assert_eq!(entity.end, 5);
2225 assert!((entity.score - 0.97).abs() < 1e-4);
2226 }
2227
2228 #[test]
2229 fn test_extract_entities_response_deserialization() {
2230 let json = r#"{
2231 "entities": [
2232 {"text":"Alice","label":"person","start":0,"end":5,"score":0.97},
2233 {"text":"Paris","label":"location","start":16,"end":21,"score":0.92}
2234 ],
2235 "model": "gliner-multi-v2.1",
2236 "processing_time_ms": 34
2237 }"#;
2238 let resp: ExtractEntitiesResponse = serde_json::from_str(json).unwrap();
2239 assert_eq!(resp.entities.len(), 2);
2240 assert_eq!(resp.entities[0].text, "Alice");
2241 assert_eq!(resp.model, "gliner-multi-v2.1");
2242 assert_eq!(resp.processing_time_ms, 34);
2243 }
2244}