1use reqwest::{
4 header::{HeaderMap, HeaderValue, AUTHORIZATION},
5 Client, StatusCode,
6};
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9use tracing::{debug, instrument};
10
11use serde::Deserialize;
12
13use crate::error::{ClientError, Result, ServerErrorCode};
14use crate::types::*;
15
16const DEFAULT_TIMEOUT_SECS: u64 = 30;
18
19#[derive(Debug, Clone)]
21pub struct DakeraClient {
22 pub(crate) client: Client,
24 pub(crate) base_url: String,
26 pub(crate) ode_url: Option<String>,
28 #[allow(dead_code)]
30 pub(crate) retry_config: RetryConfig,
31 pub(crate) last_rate_limit: Arc<Mutex<Option<RateLimitHeaders>>>,
33}
34
35impl DakeraClient {
36 pub fn new(base_url: impl Into<String>) -> Result<Self> {
46 DakeraClientBuilder::new(base_url).build()
47 }
48
49 pub fn builder(base_url: impl Into<String>) -> DakeraClientBuilder {
51 DakeraClientBuilder::new(base_url)
52 }
53
54 #[instrument(skip(self))]
60 pub async fn health(&self) -> Result<HealthResponse> {
61 let url = format!("{}/health", self.base_url);
62 let response = self.client.get(&url).send().await?;
63
64 if response.status().is_success() {
65 let json: serde_json::Value = response.json().await?;
66 let healthy = json
69 .get("healthy")
70 .and_then(|v| v.as_bool())
71 .unwrap_or_else(|| json.get("status").and_then(|v| v.as_str()) == Some("healthy"));
72 let version = json
73 .get("version")
74 .and_then(|v| v.as_str())
75 .map(String::from);
76 let uptime_seconds = json.get("uptime_seconds").and_then(|v| v.as_u64());
77 Ok(HealthResponse {
78 healthy,
79 version,
80 uptime_seconds,
81 })
82 } else {
83 Ok(HealthResponse {
85 healthy: true,
86 version: None,
87 uptime_seconds: None,
88 })
89 }
90 }
91
92 #[instrument(skip(self))]
94 pub async fn ready(&self) -> Result<ReadinessResponse> {
95 let url = format!("{}/health/ready", self.base_url);
96 let response = self.client.get(&url).send().await?;
97
98 if response.status().is_success() {
99 Ok(response.json().await?)
100 } else {
101 Ok(ReadinessResponse {
102 ready: false,
103 components: None,
104 })
105 }
106 }
107
108 #[instrument(skip(self))]
110 pub async fn live(&self) -> Result<bool> {
111 let url = format!("{}/health/live", self.base_url);
112 let response = self.client.get(&url).send().await?;
113 Ok(response.status().is_success())
114 }
115
116 #[instrument(skip(self))]
122 pub async fn list_namespaces(&self) -> Result<Vec<String>> {
123 let url = format!("{}/v1/namespaces", self.base_url);
124 let response = self.client.get(&url).send().await?;
125 self.handle_response::<ListNamespacesResponse>(response)
126 .await
127 .map(|r| r.namespaces)
128 }
129
130 #[instrument(skip(self))]
132 pub async fn get_namespace(&self, namespace: &str) -> Result<NamespaceInfo> {
133 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
134 let response = self.client.get(&url).send().await?;
135 self.handle_response(response).await
136 }
137
138 #[instrument(skip(self, request))]
140 pub async fn create_namespace(
141 &self,
142 namespace: &str,
143 request: CreateNamespaceRequest,
144 ) -> Result<NamespaceInfo> {
145 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
146 let response = self.client.post(&url).json(&request).send().await?;
147 self.handle_response(response).await
148 }
149
150 #[instrument(skip(self, request), fields(namespace = %namespace))]
156 pub async fn configure_namespace(
157 &self,
158 namespace: &str,
159 request: ConfigureNamespaceRequest,
160 ) -> Result<ConfigureNamespaceResponse> {
161 let url = format!("{}/v1/namespaces/{}", self.base_url, namespace);
162 let response = self.client.put(&url).json(&request).send().await?;
163 self.handle_response(response).await
164 }
165
166 #[instrument(skip(self, request), fields(vector_count = request.vectors.len()))]
172 pub async fn upsert(&self, namespace: &str, request: UpsertRequest) -> Result<UpsertResponse> {
173 let url = format!("{}/v1/namespaces/{}/vectors", self.base_url, namespace);
174 debug!(
175 "Upserting {} vectors to {}",
176 request.vectors.len(),
177 namespace
178 );
179
180 let response = self.client.post(&url).json(&request).send().await?;
181 self.handle_response(response).await
182 }
183
184 #[instrument(skip(self, vector))]
186 pub async fn upsert_one(&self, namespace: &str, vector: Vector) -> Result<UpsertResponse> {
187 self.upsert(namespace, UpsertRequest::single(vector)).await
188 }
189
190 #[instrument(skip(self, request), fields(namespace = %namespace, count = request.ids.len()))]
223 pub async fn upsert_columns(
224 &self,
225 namespace: &str,
226 request: ColumnUpsertRequest,
227 ) -> Result<UpsertResponse> {
228 let url = format!(
229 "{}/v1/namespaces/{}/upsert-columns",
230 self.base_url, namespace
231 );
232 debug!(
233 "Upserting {} vectors in column format to {}",
234 request.ids.len(),
235 namespace
236 );
237
238 let response = self.client.post(&url).json(&request).send().await?;
239 self.handle_response(response).await
240 }
241
242 #[instrument(skip(self, request), fields(top_k = request.top_k))]
244 pub async fn query(&self, namespace: &str, request: QueryRequest) -> Result<QueryResponse> {
245 let url = format!("{}/v1/namespaces/{}/query", self.base_url, namespace);
246 debug!(
247 "Querying namespace {} for top {} results",
248 namespace, request.top_k
249 );
250
251 let response = self.client.post(&url).json(&request).send().await?;
252 self.handle_response(response).await
253 }
254
255 #[instrument(skip(self, vector))]
257 pub async fn query_simple(
258 &self,
259 namespace: &str,
260 vector: Vec<f32>,
261 top_k: u32,
262 ) -> Result<QueryResponse> {
263 self.query(namespace, QueryRequest::new(vector, top_k))
264 .await
265 }
266
267 #[instrument(skip(self, request), fields(namespace = %namespace, query_count = request.queries.len()))]
291 pub async fn batch_query(
292 &self,
293 namespace: &str,
294 request: BatchQueryRequest,
295 ) -> Result<BatchQueryResponse> {
296 let url = format!("{}/v1/namespaces/{}/batch-query", self.base_url, namespace);
297 debug!(
298 "Batch querying namespace {} with {} queries",
299 namespace,
300 request.queries.len()
301 );
302
303 let response = self.client.post(&url).json(&request).send().await?;
304 self.handle_response(response).await
305 }
306
307 #[instrument(skip(self, request), fields(id_count = request.ids.len()))]
309 pub async fn delete(&self, namespace: &str, request: DeleteRequest) -> Result<DeleteResponse> {
310 let url = format!(
311 "{}/v1/namespaces/{}/vectors/delete",
312 self.base_url, namespace
313 );
314 debug!("Deleting {} vectors from {}", request.ids.len(), namespace);
315
316 let response = self.client.post(&url).json(&request).send().await?;
317 self.handle_response(response).await
318 }
319
320 #[instrument(skip(self))]
322 pub async fn delete_one(&self, namespace: &str, id: &str) -> Result<DeleteResponse> {
323 self.delete(namespace, DeleteRequest::single(id)).await
324 }
325
326 #[instrument(skip(self, request), fields(doc_count = request.documents.len()))]
332 pub async fn index_documents(
333 &self,
334 namespace: &str,
335 request: IndexDocumentsRequest,
336 ) -> Result<IndexDocumentsResponse> {
337 let url = format!(
338 "{}/v1/namespaces/{}/fulltext/index",
339 self.base_url, namespace
340 );
341 debug!(
342 "Indexing {} documents in {}",
343 request.documents.len(),
344 namespace
345 );
346
347 let response = self.client.post(&url).json(&request).send().await?;
348 self.handle_response(response).await
349 }
350
351 #[instrument(skip(self, document))]
353 pub async fn index_document(
354 &self,
355 namespace: &str,
356 document: Document,
357 ) -> Result<IndexDocumentsResponse> {
358 self.index_documents(
359 namespace,
360 IndexDocumentsRequest {
361 documents: vec![document],
362 },
363 )
364 .await
365 }
366
367 #[instrument(skip(self, request))]
369 pub async fn fulltext_search(
370 &self,
371 namespace: &str,
372 request: FullTextSearchRequest,
373 ) -> Result<FullTextSearchResponse> {
374 let url = format!(
375 "{}/v1/namespaces/{}/fulltext/search",
376 self.base_url, namespace
377 );
378 debug!("Full-text search in {} for: {}", namespace, request.query);
379
380 let response = self.client.post(&url).json(&request).send().await?;
381 self.handle_response(response).await
382 }
383
384 #[instrument(skip(self))]
386 pub async fn search_text(
387 &self,
388 namespace: &str,
389 query: &str,
390 top_k: u32,
391 ) -> Result<FullTextSearchResponse> {
392 self.fulltext_search(namespace, FullTextSearchRequest::new(query, top_k))
393 .await
394 }
395
396 #[instrument(skip(self))]
398 pub async fn fulltext_stats(&self, namespace: &str) -> Result<FullTextStats> {
399 let url = format!(
400 "{}/v1/namespaces/{}/fulltext/stats",
401 self.base_url, namespace
402 );
403 let response = self.client.get(&url).send().await?;
404 self.handle_response(response).await
405 }
406
407 #[instrument(skip(self, request))]
409 pub async fn fulltext_delete(
410 &self,
411 namespace: &str,
412 request: DeleteRequest,
413 ) -> Result<DeleteResponse> {
414 let url = format!(
415 "{}/v1/namespaces/{}/fulltext/delete",
416 self.base_url, namespace
417 );
418 let response = self.client.post(&url).json(&request).send().await?;
419 self.handle_response(response).await
420 }
421
422 #[instrument(skip(self, request), fields(top_k = request.top_k))]
428 pub async fn hybrid_search(
429 &self,
430 namespace: &str,
431 request: HybridSearchRequest,
432 ) -> Result<HybridSearchResponse> {
433 let url = format!("{}/v1/namespaces/{}/hybrid", self.base_url, namespace);
434 debug!(
435 "Hybrid search in {} with vector_weight={}",
436 namespace, request.vector_weight
437 );
438
439 let response = self.client.post(&url).json(&request).send().await?;
440 self.handle_response(response).await
441 }
442
443 #[instrument(skip(self, request), fields(namespace = %namespace))]
480 pub async fn multi_vector_search(
481 &self,
482 namespace: &str,
483 request: MultiVectorSearchRequest,
484 ) -> Result<MultiVectorSearchResponse> {
485 let url = format!("{}/v1/namespaces/{}/multi-vector", self.base_url, namespace);
486 debug!(
487 "Multi-vector search in {} with {} positive vectors",
488 namespace,
489 request.positive_vectors.len()
490 );
491
492 let response = self.client.post(&url).json(&request).send().await?;
493 self.handle_response(response).await
494 }
495
496 #[instrument(skip(self, request), fields(namespace = %namespace))]
530 pub async fn aggregate(
531 &self,
532 namespace: &str,
533 request: AggregationRequest,
534 ) -> Result<AggregationResponse> {
535 let url = format!("{}/v1/namespaces/{}/aggregate", self.base_url, namespace);
536 debug!(
537 "Aggregating in namespace {} with {} aggregations",
538 namespace,
539 request.aggregate_by.len()
540 );
541
542 let response = self.client.post(&url).json(&request).send().await?;
543 self.handle_response(response).await
544 }
545
546 #[instrument(skip(self, request), fields(namespace = %namespace))]
584 pub async fn unified_query(
585 &self,
586 namespace: &str,
587 request: UnifiedQueryRequest,
588 ) -> Result<UnifiedQueryResponse> {
589 let url = format!(
590 "{}/v1/namespaces/{}/unified-query",
591 self.base_url, namespace
592 );
593 debug!(
594 "Unified query in namespace {} with top_k={}",
595 namespace, request.top_k
596 );
597
598 let response = self.client.post(&url).json(&request).send().await?;
599 self.handle_response(response).await
600 }
601
602 #[instrument(skip(self, vector))]
606 pub async fn unified_vector_search(
607 &self,
608 namespace: &str,
609 vector: Vec<f32>,
610 top_k: usize,
611 ) -> Result<UnifiedQueryResponse> {
612 self.unified_query(namespace, UnifiedQueryRequest::vector_search(vector, top_k))
613 .await
614 }
615
616 #[instrument(skip(self))]
620 pub async fn unified_text_search(
621 &self,
622 namespace: &str,
623 field: &str,
624 query: &str,
625 top_k: usize,
626 ) -> Result<UnifiedQueryResponse> {
627 self.unified_query(
628 namespace,
629 UnifiedQueryRequest::fulltext_search(field, query, top_k),
630 )
631 .await
632 }
633
634 #[instrument(skip(self, request), fields(namespace = %namespace))]
671 pub async fn explain_query(
672 &self,
673 namespace: &str,
674 request: QueryExplainRequest,
675 ) -> Result<QueryExplainResponse> {
676 let url = format!("{}/v1/namespaces/{}/explain", self.base_url, namespace);
677 debug!(
678 "Explaining query in namespace {} (query_type={:?}, top_k={})",
679 namespace, request.query_type, request.top_k
680 );
681
682 let response = self.client.post(&url).json(&request).send().await?;
683 self.handle_response(response).await
684 }
685
686 #[instrument(skip(self, request), fields(namespace = %request.namespace, priority = ?request.priority))]
714 pub async fn warm_cache(&self, request: WarmCacheRequest) -> Result<WarmCacheResponse> {
715 let url = format!(
716 "{}/v1/namespaces/{}/cache/warm",
717 self.base_url, request.namespace
718 );
719 debug!(
720 "Warming cache for namespace {} with priority {:?}",
721 request.namespace, request.priority
722 );
723
724 let response = self.client.post(&url).json(&request).send().await?;
725 self.handle_response(response).await
726 }
727
728 #[instrument(skip(self, vector_ids))]
730 pub async fn warm_vectors(
731 &self,
732 namespace: &str,
733 vector_ids: Vec<String>,
734 ) -> Result<WarmCacheResponse> {
735 self.warm_cache(WarmCacheRequest::new(namespace).with_vector_ids(vector_ids))
736 .await
737 }
738
739 #[instrument(skip(self, request), fields(namespace = %namespace))]
772 pub async fn export(&self, namespace: &str, request: ExportRequest) -> Result<ExportResponse> {
773 let url = format!("{}/v1/namespaces/{}/export", self.base_url, namespace);
774 debug!(
775 "Exporting vectors from namespace {} (top_k={}, cursor={:?})",
776 namespace, request.top_k, request.cursor
777 );
778
779 let response = self.client.post(&url).json(&request).send().await?;
780 self.handle_response(response).await
781 }
782
783 #[instrument(skip(self))]
787 pub async fn export_all(&self, namespace: &str) -> Result<ExportResponse> {
788 self.export(namespace, ExportRequest::new()).await
789 }
790
791 #[instrument(skip(self))]
797 pub async fn diagnostics(&self) -> Result<SystemDiagnostics> {
798 let url = format!("{}/ops/diagnostics", self.base_url);
799 let response = self.client.get(&url).send().await?;
800 self.handle_response(response).await
801 }
802
803 #[instrument(skip(self))]
805 pub async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
806 let url = format!("{}/ops/jobs", self.base_url);
807 let response = self.client.get(&url).send().await?;
808 self.handle_response(response).await
809 }
810
811 #[instrument(skip(self))]
813 pub async fn get_job(&self, job_id: &str) -> Result<Option<JobInfo>> {
814 let url = format!("{}/ops/jobs/{}", self.base_url, job_id);
815 let response = self.client.get(&url).send().await?;
816
817 if response.status() == StatusCode::NOT_FOUND {
818 return Ok(None);
819 }
820
821 self.handle_response(response).await.map(Some)
822 }
823
824 #[instrument(skip(self, request))]
826 pub async fn compact(&self, request: CompactionRequest) -> Result<CompactionResponse> {
827 let url = format!("{}/ops/compact", self.base_url);
828 let response = self.client.post(&url).json(&request).send().await?;
829 self.handle_response(response).await
830 }
831
832 #[instrument(skip(self))]
834 pub async fn shutdown(&self) -> Result<()> {
835 let url = format!("{}/ops/shutdown", self.base_url);
836 let response = self.client.post(&url).send().await?;
837
838 if response.status().is_success() {
839 Ok(())
840 } else {
841 let status = response.status().as_u16();
842 let text = response.text().await.unwrap_or_default();
843 Err(ClientError::Server {
844 status,
845 message: text,
846 code: None,
847 })
848 }
849 }
850
851 #[instrument(skip(self, request), fields(id_count = request.ids.len()))]
857 pub async fn fetch(&self, namespace: &str, request: FetchRequest) -> Result<FetchResponse> {
858 let url = format!("{}/v1/namespaces/{}/fetch", self.base_url, namespace);
859 debug!("Fetching {} vectors from {}", request.ids.len(), namespace);
860 let response = self.client.post(&url).json(&request).send().await?;
861 self.handle_response(response).await
862 }
863
864 #[instrument(skip(self))]
866 pub async fn fetch_by_ids(&self, namespace: &str, ids: &[&str]) -> Result<Vec<Vector>> {
867 let request = FetchRequest::new(ids.iter().map(|s| s.to_string()).collect());
868 self.fetch(namespace, request).await.map(|r| r.vectors)
869 }
870
871 #[instrument(skip(self, request), fields(doc_count = request.documents.len()))]
877 pub async fn upsert_text(
878 &self,
879 namespace: &str,
880 request: UpsertTextRequest,
881 ) -> Result<TextUpsertResponse> {
882 let url = format!("{}/v1/namespaces/{}/upsert-text", self.base_url, namespace);
883 debug!(
884 "Upserting {} text documents to {}",
885 request.documents.len(),
886 namespace
887 );
888 let response = self.client.post(&url).json(&request).send().await?;
889 self.handle_response(response).await
890 }
891
892 #[instrument(skip(self, request), fields(top_k = request.top_k))]
894 pub async fn query_text(
895 &self,
896 namespace: &str,
897 request: QueryTextRequest,
898 ) -> Result<TextQueryResponse> {
899 let url = format!("{}/v1/namespaces/{}/query-text", self.base_url, namespace);
900 debug!("Text query in {} for: {}", namespace, request.text);
901 let response = self.client.post(&url).json(&request).send().await?;
902 self.handle_response(response).await
903 }
904
905 #[instrument(skip(self))]
907 pub async fn query_text_simple(
908 &self,
909 namespace: &str,
910 text: &str,
911 top_k: u32,
912 ) -> Result<TextQueryResponse> {
913 self.query_text(namespace, QueryTextRequest::new(text, top_k))
914 .await
915 }
916
917 #[instrument(skip(self, request), fields(query_count = request.queries.len()))]
919 pub async fn batch_query_text(
920 &self,
921 namespace: &str,
922 request: BatchQueryTextRequest,
923 ) -> Result<BatchQueryTextResponse> {
924 let url = format!(
925 "{}/v1/namespaces/{}/batch-query-text",
926 self.base_url, namespace
927 );
928 debug!(
929 "Batch text query in {} with {} queries",
930 namespace,
931 request.queries.len()
932 );
933 let response = self.client.post(&url).json(&request).send().await?;
934 self.handle_response(response).await
935 }
936
937 #[instrument(skip(self, config))]
946 pub async fn configure_namespace_ner(
947 &self,
948 namespace: &str,
949 config: NamespaceNerConfig,
950 ) -> Result<serde_json::Value> {
951 let url = format!("{}/v1/namespaces/{}/config", self.base_url, namespace);
952 let response = self.client.patch(&url).json(&config).send().await?;
953 self.handle_response(response).await
954 }
955
956 #[instrument(skip(self, text, entity_types))]
961 pub async fn extract_entities(
962 &self,
963 text: &str,
964 entity_types: Option<Vec<String>>,
965 ) -> Result<EntityExtractionResponse> {
966 let url = format!("{}/v1/memories/extract", self.base_url);
967 let body = serde_json::json!({
968 "text": text,
969 "entity_types": entity_types,
970 });
971 let response = self.client.post(&url).json(&body).send().await?;
972 self.handle_response(response).await
973 }
974
975 #[instrument(skip(self))]
979 pub async fn memory_entities(&self, memory_id: &str) -> Result<MemoryEntitiesResponse> {
980 let url = format!("{}/v1/memory/entities/{}", self.base_url, memory_id);
981 let response = self.client.get(&url).send().await?;
982 self.handle_response(response).await
983 }
984
985 pub fn last_rate_limit_headers(&self) -> Option<RateLimitHeaders> {
993 self.last_rate_limit.lock().ok()?.clone()
994 }
995
996 pub(crate) async fn handle_response<T: serde::de::DeserializeOwned>(
998 &self,
999 response: reqwest::Response,
1000 ) -> Result<T> {
1001 let status = response.status();
1002
1003 if let Ok(mut guard) = self.last_rate_limit.lock() {
1005 *guard = Some(RateLimitHeaders::from_response(&response));
1006 }
1007
1008 if status.is_success() {
1009 Ok(response.json().await?)
1010 } else {
1011 let status_code = status.as_u16();
1012 let retry_after = response
1014 .headers()
1015 .get("Retry-After")
1016 .and_then(|v| v.to_str().ok())
1017 .and_then(|s| s.parse::<u64>().ok());
1018 let text = response.text().await.unwrap_or_default();
1019
1020 if status_code == 429 {
1021 return Err(ClientError::RateLimitExceeded { retry_after });
1022 }
1023
1024 #[derive(Deserialize)]
1025 struct ErrorBody {
1026 error: Option<String>,
1027 code: Option<ServerErrorCode>,
1028 }
1029
1030 let (message, code) = if let Ok(body) = serde_json::from_str::<ErrorBody>(&text) {
1031 (body.error.unwrap_or_else(|| text.clone()), body.code)
1032 } else {
1033 (text, None)
1034 };
1035
1036 match status_code {
1037 401 => Err(ClientError::Server {
1038 status: 401,
1039 message,
1040 code,
1041 }),
1042 403 => Err(ClientError::Authorization {
1043 status: 403,
1044 message,
1045 code,
1046 }),
1047 404 => match &code {
1048 Some(ServerErrorCode::NamespaceNotFound) => {
1049 Err(ClientError::NamespaceNotFound(message))
1050 }
1051 Some(ServerErrorCode::VectorNotFound) => {
1052 Err(ClientError::VectorNotFound(message))
1053 }
1054 _ => Err(ClientError::Server {
1055 status: 404,
1056 message,
1057 code,
1058 }),
1059 },
1060 _ => Err(ClientError::Server {
1061 status: status_code,
1062 message,
1063 code,
1064 }),
1065 }
1066 }
1067 }
1068
1069 pub(crate) async fn handle_text_response(&self, response: reqwest::Response) -> Result<String> {
1071 let status = response.status();
1072
1073 if let Ok(mut guard) = self.last_rate_limit.lock() {
1075 *guard = Some(RateLimitHeaders::from_response(&response));
1076 }
1077
1078 let retry_after = response
1079 .headers()
1080 .get("Retry-After")
1081 .and_then(|v| v.to_str().ok())
1082 .and_then(|s| s.parse::<u64>().ok());
1083 let text = response.text().await.unwrap_or_default();
1084
1085 if status.is_success() {
1086 return Ok(text);
1087 }
1088
1089 let status_code = status.as_u16();
1090
1091 if status_code == 429 {
1092 return Err(ClientError::RateLimitExceeded { retry_after });
1093 }
1094
1095 #[derive(Deserialize)]
1096 struct ErrorBody {
1097 error: Option<String>,
1098 code: Option<ServerErrorCode>,
1099 }
1100
1101 let (message, code) = if let Ok(body) = serde_json::from_str::<ErrorBody>(&text) {
1102 (body.error.unwrap_or_else(|| text.clone()), body.code)
1103 } else {
1104 (text, None)
1105 };
1106
1107 match status_code {
1108 401 => Err(ClientError::Server {
1109 status: 401,
1110 message,
1111 code,
1112 }),
1113 403 => Err(ClientError::Authorization {
1114 status: 403,
1115 message,
1116 code,
1117 }),
1118 _ => Err(ClientError::Server {
1119 status: status_code,
1120 message,
1121 code,
1122 }),
1123 }
1124 }
1125
1126 #[allow(dead_code)]
1134 pub(crate) async fn execute_with_retry<F, Fut, T>(&self, f: F) -> Result<T>
1135 where
1136 F: Fn() -> Fut,
1137 Fut: std::future::Future<Output = Result<T>>,
1138 {
1139 let rc = &self.retry_config;
1140
1141 for attempt in 0..rc.max_retries {
1142 match f().await {
1143 Ok(v) => return Ok(v),
1144 Err(e) => {
1145 let is_last = attempt == rc.max_retries - 1;
1146 if is_last || !e.is_retryable() {
1147 return Err(e);
1148 }
1149
1150 let wait = match &e {
1151 ClientError::RateLimitExceeded {
1152 retry_after: Some(secs),
1153 } => Duration::from_secs(*secs),
1154 _ => {
1155 let base_ms = rc.base_delay.as_millis() as f64;
1156 let backoff_ms = base_ms * 2f64.powi(attempt as i32);
1157 let capped_ms = backoff_ms.min(rc.max_delay.as_millis() as f64);
1158 let final_ms = if rc.jitter {
1159 let seed = (attempt as u64).wrapping_mul(6364136223846793005);
1161 let factor = 0.5 + (seed % 1000) as f64 / 1000.0;
1162 capped_ms * factor
1163 } else {
1164 capped_ms
1165 };
1166 Duration::from_millis(final_ms as u64)
1167 }
1168 };
1169
1170 tokio::time::sleep(wait).await;
1171 }
1172 }
1173 }
1174
1175 Err(ClientError::Config("retry loop exhausted".to_string()))
1177 }
1178}
1179
1180impl DakeraClient {
1185 pub async fn ode_extract_entities(
1197 &self,
1198 req: ExtractEntitiesRequest,
1199 ) -> Result<ExtractEntitiesResponse> {
1200 let ode_url = self.ode_url.as_deref().ok_or_else(|| {
1201 ClientError::Config(
1202 "ode_url must be configured to use extract_entities(). \
1203 Call .ode_url(\"http://localhost:8080\") on the builder."
1204 .to_string(),
1205 )
1206 })?;
1207 let url = format!("{}/ode/extract", ode_url);
1208 let response = self.client.post(&url).json(&req).send().await?;
1209 if response.status().is_success() {
1210 Ok(response.json::<ExtractEntitiesResponse>().await?)
1211 } else {
1212 let status = response.status().as_u16();
1213 let body = response.text().await.unwrap_or_default();
1214 Err(ClientError::Server {
1215 status,
1216 message: format!("ODE sidecar error: {}", body),
1217 code: None,
1218 })
1219 }
1220 }
1221
1222 #[instrument(skip(self))]
1234 pub async fn get_memory_policy(&self, namespace: &str) -> Result<MemoryPolicy> {
1235 let url = format!(
1236 "{}/v1/namespaces/{}/memory_policy",
1237 self.base_url,
1238 urlencoding::encode(namespace)
1239 );
1240 let response = self.client.get(&url).send().await?;
1241 self.handle_response(response).await
1242 }
1243
1244 #[instrument(skip(self, policy))]
1251 pub async fn set_memory_policy(
1252 &self,
1253 namespace: &str,
1254 policy: MemoryPolicy,
1255 ) -> Result<MemoryPolicy> {
1256 let url = format!(
1257 "{}/v1/namespaces/{}/memory_policy",
1258 self.base_url,
1259 urlencoding::encode(namespace)
1260 );
1261 let response = self.client.put(&url).json(&policy).send().await?;
1262 self.handle_response(response).await
1263 }
1264}
1265
1266#[derive(Debug)]
1268pub struct DakeraClientBuilder {
1269 base_url: String,
1270 api_key: Option<String>,
1271 ode_url: Option<String>,
1272 timeout: Duration,
1273 connect_timeout: Option<Duration>,
1274 retry_config: RetryConfig,
1275 user_agent: Option<String>,
1276}
1277
1278impl DakeraClientBuilder {
1279 pub fn new(base_url: impl Into<String>) -> Self {
1281 Self {
1282 base_url: base_url.into(),
1283 api_key: None,
1284 ode_url: None,
1285 timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
1286 connect_timeout: None,
1287 retry_config: RetryConfig::default(),
1288 user_agent: None,
1289 }
1290 }
1291
1292 pub fn api_key(mut self, key: impl Into<String>) -> Self {
1297 self.api_key = Some(key.into());
1298 self
1299 }
1300
1301 pub fn ode_url(mut self, ode_url: impl Into<String>) -> Self {
1305 self.ode_url = Some(ode_url.into().trim_end_matches('/').to_string());
1306 self
1307 }
1308
1309 pub fn timeout(mut self, timeout: Duration) -> Self {
1311 self.timeout = timeout;
1312 self
1313 }
1314
1315 pub fn timeout_secs(mut self, secs: u64) -> Self {
1317 self.timeout = Duration::from_secs(secs);
1318 self
1319 }
1320
1321 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
1323 self.connect_timeout = Some(timeout);
1324 self
1325 }
1326
1327 pub fn retry_config(mut self, config: RetryConfig) -> Self {
1329 self.retry_config = config;
1330 self
1331 }
1332
1333 pub fn max_retries(mut self, max_retries: u32) -> Self {
1335 self.retry_config.max_retries = max_retries;
1336 self
1337 }
1338
1339 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
1341 self.user_agent = Some(user_agent.into());
1342 self
1343 }
1344
1345 pub fn build(self) -> Result<DakeraClient> {
1347 let base_url = self.base_url.trim_end_matches('/').to_string();
1349
1350 if !base_url.starts_with("http://") && !base_url.starts_with("https://") {
1352 return Err(ClientError::InvalidUrl(
1353 "URL must start with http:// or https://".to_string(),
1354 ));
1355 }
1356
1357 let user_agent = self
1358 .user_agent
1359 .unwrap_or_else(|| format!("dakera-client/{}", env!("CARGO_PKG_VERSION")));
1360
1361 let connect_timeout = self.connect_timeout.unwrap_or(self.timeout);
1362
1363 let api_key = self
1365 .api_key
1366 .or_else(|| std::env::var("DAKERA_API_KEY").ok());
1367
1368 let mut default_headers = HeaderMap::new();
1369 if let Some(key) = &api_key {
1370 let bearer = format!("Bearer {key}");
1371 let mut value = HeaderValue::from_str(&bearer)
1372 .map_err(|_| ClientError::Config("invalid API key".into()))?;
1373 value.set_sensitive(true);
1374 default_headers.insert(AUTHORIZATION, value);
1375 }
1376
1377 let client = Client::builder()
1378 .timeout(self.timeout)
1379 .connect_timeout(connect_timeout)
1380 .user_agent(user_agent)
1381 .default_headers(default_headers)
1382 .build()
1383 .map_err(|e| ClientError::Config(e.to_string()))?;
1384
1385 Ok(DakeraClient {
1386 client,
1387 base_url,
1388 ode_url: self.ode_url,
1389 retry_config: self.retry_config,
1390 last_rate_limit: Arc::new(Mutex::new(None)),
1391 })
1392 }
1393}
1394
1395impl DakeraClient {
1400 pub async fn stream_namespace_events(
1425 &self,
1426 namespace: &str,
1427 ) -> Result<tokio::sync::mpsc::Receiver<Result<crate::events::DakeraEvent>>> {
1428 let url = format!(
1429 "{}/v1/namespaces/{}/events",
1430 self.base_url,
1431 urlencoding::encode(namespace)
1432 );
1433 self.stream_sse(url).await
1434 }
1435
1436 pub async fn stream_global_events(
1443 &self,
1444 ) -> Result<tokio::sync::mpsc::Receiver<Result<crate::events::DakeraEvent>>> {
1445 let url = format!("{}/ops/events", self.base_url);
1446 self.stream_sse(url).await
1447 }
1448
1449 pub async fn stream_memory_events(
1458 &self,
1459 ) -> Result<tokio::sync::mpsc::Receiver<Result<crate::events::MemoryEvent>>> {
1460 let url = format!("{}/v1/events/stream", self.base_url);
1461 self.stream_sse(url).await
1462 }
1463
1464 pub(crate) async fn stream_sse<T>(
1466 &self,
1467 url: String,
1468 ) -> Result<tokio::sync::mpsc::Receiver<Result<T>>>
1469 where
1470 T: serde::de::DeserializeOwned + Send + 'static,
1471 {
1472 use futures_util::StreamExt;
1473
1474 let response = self
1475 .client
1476 .get(&url)
1477 .header("Accept", "text/event-stream")
1478 .header("Cache-Control", "no-cache")
1479 .send()
1480 .await?;
1481
1482 if !response.status().is_success() {
1483 let status = response.status().as_u16();
1484 let body = response.text().await.unwrap_or_default();
1485 return Err(ClientError::Server {
1486 status,
1487 message: body,
1488 code: None,
1489 });
1490 }
1491
1492 let (tx, rx) = tokio::sync::mpsc::channel(64);
1493
1494 tokio::spawn(async move {
1495 let mut byte_stream = response.bytes_stream();
1496 let mut remaining = String::new();
1497 let mut data_lines: Vec<String> = Vec::new();
1498
1499 while let Some(chunk) = byte_stream.next().await {
1500 match chunk {
1501 Ok(bytes) => {
1502 remaining.push_str(&String::from_utf8_lossy(&bytes));
1503 while let Some(pos) = remaining.find('\n') {
1504 let raw = &remaining[..pos];
1505 let line = raw.trim_end_matches('\r').to_string();
1506 remaining = remaining[pos + 1..].to_string();
1507
1508 if line.starts_with(':') {
1509 } else if let Some(data) = line.strip_prefix("data:") {
1511 data_lines.push(data.trim_start().to_string());
1512 } else if line.is_empty() {
1513 if !data_lines.is_empty() {
1514 let payload = data_lines.join("\n");
1515 data_lines.clear();
1516 let result = serde_json::from_str::<T>(&payload)
1517 .map_err(ClientError::Json);
1518 if tx.send(result).await.is_err() {
1519 return; }
1521 }
1522 } else {
1523 }
1525 }
1526 }
1527 Err(e) => {
1528 let _ = tx.send(Err(ClientError::Http(e))).await;
1529 return;
1530 }
1531 }
1532 }
1533 });
1534
1535 Ok(rx)
1536 }
1537}
1538
1539#[cfg(test)]
1540mod tests {
1541 use super::*;
1542
1543 #[test]
1544 fn test_client_builder() {
1545 let client = DakeraClient::new("http://localhost:3000");
1546 assert!(client.is_ok());
1547 }
1548
1549 #[test]
1550 fn test_client_builder_with_options() {
1551 let client = DakeraClient::builder("http://localhost:3000")
1552 .timeout_secs(60)
1553 .user_agent("test-client/1.0")
1554 .build();
1555 assert!(client.is_ok());
1556 }
1557
1558 #[test]
1559 fn test_client_builder_invalid_url() {
1560 let client = DakeraClient::new("invalid-url");
1561 assert!(client.is_err());
1562 }
1563
1564 #[test]
1565 fn test_client_builder_trailing_slash() {
1566 let client = DakeraClient::new("http://localhost:3000/").unwrap();
1567 assert!(!client.base_url.ends_with('/'));
1568 }
1569
1570 #[test]
1571 fn test_vector_creation() {
1572 let v = Vector::new("test", vec![0.1, 0.2, 0.3]);
1573 assert_eq!(v.id, "test");
1574 assert_eq!(v.values.len(), 3);
1575 assert!(v.metadata.is_none());
1576 }
1577
1578 #[test]
1579 fn test_query_request_builder() {
1580 let req = QueryRequest::new(vec![0.1, 0.2], 10)
1581 .with_filter(serde_json::json!({"category": "test"}))
1582 .include_metadata(false);
1583
1584 assert_eq!(req.top_k, 10);
1585 assert!(req.filter.is_some());
1586 assert!(!req.include_metadata);
1587 }
1588
1589 #[test]
1590 fn test_hybrid_search_request() {
1591 let req = HybridSearchRequest::new(vec![0.1], "test query", 5).with_vector_weight(0.7);
1592
1593 assert_eq!(req.vector_weight, 0.7);
1594 assert_eq!(req.text, "test query");
1595 assert!(req.vector.is_some());
1596 }
1597
1598 #[test]
1599 fn test_hybrid_search_weight_clamping() {
1600 let req = HybridSearchRequest::new(vec![0.1], "test", 5).with_vector_weight(1.5); assert_eq!(req.vector_weight, 1.0);
1603 }
1604
1605 #[test]
1606 fn test_hybrid_search_text_only() {
1607 let req = HybridSearchRequest::text_only("bm25 query", 10);
1608
1609 assert!(req.vector.is_none());
1610 assert_eq!(req.text, "bm25 query");
1611 assert_eq!(req.top_k, 10);
1612 let json = serde_json::to_value(&req).unwrap();
1614 assert!(json.get("vector").is_none());
1615 }
1616
1617 #[test]
1618 fn test_text_document_builder() {
1619 let doc = TextDocument::new("doc1", "Hello world").with_ttl(3600);
1620
1621 assert_eq!(doc.id, "doc1");
1622 assert_eq!(doc.text, "Hello world");
1623 assert_eq!(doc.ttl_seconds, Some(3600));
1624 assert!(doc.metadata.is_none());
1625 }
1626
1627 #[test]
1628 fn test_upsert_text_request_builder() {
1629 let docs = vec![
1630 TextDocument::new("doc1", "Hello"),
1631 TextDocument::new("doc2", "World"),
1632 ];
1633 let req = UpsertTextRequest::new(docs).with_model(EmbeddingModel::BgeSmall);
1634
1635 assert_eq!(req.documents.len(), 2);
1636 assert_eq!(req.model, Some(EmbeddingModel::BgeSmall));
1637 }
1638
1639 #[test]
1640 fn test_query_text_request_builder() {
1641 let req = QueryTextRequest::new("semantic search query", 5)
1642 .with_filter(serde_json::json!({"category": "docs"}))
1643 .include_vectors(true)
1644 .with_model(EmbeddingModel::E5Small);
1645
1646 assert_eq!(req.text, "semantic search query");
1647 assert_eq!(req.top_k, 5);
1648 assert!(req.filter.is_some());
1649 assert!(req.include_vectors);
1650 assert_eq!(req.model, Some(EmbeddingModel::E5Small));
1651 }
1652
1653 #[test]
1654 fn test_fetch_request_builder() {
1655 let req = FetchRequest::new(vec!["id1".to_string(), "id2".to_string()]);
1656
1657 assert_eq!(req.ids.len(), 2);
1658 assert!(req.include_values);
1659 assert!(req.include_metadata);
1660 }
1661
1662 #[test]
1663 fn test_create_namespace_request_builder() {
1664 let req = CreateNamespaceRequest::new()
1665 .with_dimensions(384)
1666 .with_index_type("hnsw");
1667
1668 assert_eq!(req.dimensions, Some(384));
1669 assert_eq!(req.index_type.as_deref(), Some("hnsw"));
1670 }
1671
1672 #[test]
1673 fn test_batch_query_text_request() {
1674 let req =
1675 BatchQueryTextRequest::new(vec!["query one".to_string(), "query two".to_string()], 10);
1676
1677 assert_eq!(req.queries.len(), 2);
1678 assert_eq!(req.top_k, 10);
1679 assert!(!req.include_vectors);
1680 assert!(req.model.is_none());
1681 }
1682
1683 #[test]
1688 fn test_retry_config_defaults() {
1689 let rc = RetryConfig::default();
1690 assert_eq!(rc.max_retries, 3);
1691 assert_eq!(rc.base_delay, Duration::from_millis(100));
1692 assert_eq!(rc.max_delay, Duration::from_secs(60));
1693 assert!(rc.jitter);
1694 }
1695
1696 #[test]
1697 fn test_builder_connect_timeout() {
1698 let client = DakeraClient::builder("http://localhost:3000")
1699 .connect_timeout(Duration::from_secs(5))
1700 .timeout_secs(30)
1701 .build()
1702 .unwrap();
1703 assert!(client.base_url.starts_with("http"));
1705 }
1706
1707 #[test]
1708 fn test_builder_max_retries() {
1709 let client = DakeraClient::builder("http://localhost:3000")
1710 .max_retries(5)
1711 .build()
1712 .unwrap();
1713 assert_eq!(client.retry_config.max_retries, 5);
1714 }
1715
1716 #[test]
1717 fn test_builder_retry_config() {
1718 let rc = RetryConfig {
1719 max_retries: 7,
1720 base_delay: Duration::from_millis(200),
1721 max_delay: Duration::from_secs(30),
1722 jitter: false,
1723 };
1724 let client = DakeraClient::builder("http://localhost:3000")
1725 .retry_config(rc)
1726 .build()
1727 .unwrap();
1728 assert_eq!(client.retry_config.max_retries, 7);
1729 assert!(!client.retry_config.jitter);
1730 }
1731
1732 #[test]
1733 fn test_rate_limit_error_retryable() {
1734 let e = ClientError::RateLimitExceeded { retry_after: None };
1735 assert!(e.is_retryable());
1736 }
1737
1738 #[test]
1739 fn test_server_408_retryable() {
1740 let e = ClientError::Server {
1741 status: 408,
1742 message: String::new(),
1743 code: None,
1744 };
1745 assert!(e.is_retryable());
1746 }
1747
1748 #[test]
1749 fn test_server_400_not_retryable() {
1750 let e = ClientError::Server {
1751 status: 400,
1752 message: String::new(),
1753 code: None,
1754 };
1755 assert!(!e.is_retryable());
1756 }
1757
1758 #[test]
1759 fn test_rate_limit_error_with_retry_after_zero() {
1760 let e = ClientError::RateLimitExceeded {
1762 retry_after: Some(0),
1763 };
1764 assert!(e.is_retryable());
1765 if let ClientError::RateLimitExceeded {
1766 retry_after: Some(secs),
1767 } = &e
1768 {
1769 assert_eq!(*secs, 0u64);
1770 } else {
1771 panic!("unexpected variant");
1772 }
1773 }
1774
1775 #[tokio::test]
1776 async fn test_execute_with_retry_succeeds_immediately() {
1777 let client = DakeraClient::builder("http://localhost:3000")
1778 .max_retries(3)
1779 .build()
1780 .unwrap();
1781
1782 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1783 let cc = call_count.clone();
1784 let result = client
1785 .execute_with_retry(|| {
1786 let cc = cc.clone();
1787 async move {
1788 cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1789 Ok::<u32, ClientError>(42)
1790 }
1791 })
1792 .await;
1793 assert_eq!(result.unwrap(), 42);
1794 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
1795 }
1796
1797 #[tokio::test]
1798 async fn test_execute_with_retry_no_retry_on_4xx() {
1799 let client = DakeraClient::builder("http://localhost:3000")
1800 .max_retries(3)
1801 .build()
1802 .unwrap();
1803
1804 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1805 let cc = call_count.clone();
1806 let result = client
1807 .execute_with_retry(|| {
1808 let cc = cc.clone();
1809 async move {
1810 cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1811 Err::<u32, ClientError>(ClientError::Server {
1812 status: 400,
1813 message: "bad request".to_string(),
1814 code: None,
1815 })
1816 }
1817 })
1818 .await;
1819 assert!(result.is_err());
1820 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
1822 }
1823
1824 #[tokio::test]
1825 async fn test_execute_with_retry_retries_on_5xx() {
1826 let client = DakeraClient::builder("http://localhost:3000")
1827 .retry_config(RetryConfig {
1828 max_retries: 3,
1829 base_delay: Duration::from_millis(0),
1830 max_delay: Duration::from_millis(0),
1831 jitter: false,
1832 })
1833 .build()
1834 .unwrap();
1835
1836 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1837 let cc = call_count.clone();
1838 let result = client
1839 .execute_with_retry(|| {
1840 let cc = cc.clone();
1841 async move {
1842 let n = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1843 if n < 2 {
1844 Err::<u32, ClientError>(ClientError::Server {
1845 status: 503,
1846 message: "unavailable".to_string(),
1847 code: None,
1848 })
1849 } else {
1850 Ok(99)
1851 }
1852 }
1853 })
1854 .await;
1855 assert_eq!(result.unwrap(), 99);
1856 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
1857 }
1858
1859 #[test]
1864 fn test_batch_recall_request_new() {
1865 use crate::memory::BatchRecallRequest;
1866 let req = BatchRecallRequest::new("agent-1");
1867 assert_eq!(req.agent_id, "agent-1");
1868 assert_eq!(req.limit, 100);
1869 }
1870
1871 #[test]
1872 fn test_batch_recall_request_builder() {
1873 use crate::memory::{BatchMemoryFilter, BatchRecallRequest};
1874 let filter = BatchMemoryFilter::default()
1875 .with_tags(vec!["qa".to_string()])
1876 .with_min_importance(0.7);
1877 let req = BatchRecallRequest::new("agent-1")
1878 .with_filter(filter)
1879 .with_limit(50);
1880 assert_eq!(req.agent_id, "agent-1");
1881 assert_eq!(req.limit, 50);
1882 assert_eq!(
1883 req.filter.tags.as_deref(),
1884 Some(["qa".to_string()].as_slice())
1885 );
1886 assert_eq!(req.filter.min_importance, Some(0.7));
1887 }
1888
1889 #[test]
1890 fn test_batch_recall_request_serialization() {
1891 use crate::memory::{BatchMemoryFilter, BatchRecallRequest};
1892 let filter = BatchMemoryFilter::default().with_min_importance(0.5);
1893 let req = BatchRecallRequest::new("agent-1")
1894 .with_filter(filter)
1895 .with_limit(25);
1896 let json = serde_json::to_value(&req).unwrap();
1897 assert_eq!(json["agent_id"], "agent-1");
1898 assert_eq!(json["limit"], 25);
1899 assert_eq!(json["filter"]["min_importance"], 0.5);
1900 }
1901
1902 #[test]
1903 fn test_batch_forget_request_new() {
1904 use crate::memory::{BatchForgetRequest, BatchMemoryFilter};
1905 let filter = BatchMemoryFilter::default().with_min_importance(0.1);
1906 let req = BatchForgetRequest::new("agent-1", filter);
1907 assert_eq!(req.agent_id, "agent-1");
1908 assert_eq!(req.filter.min_importance, Some(0.1));
1909 }
1910
1911 #[test]
1912 fn test_batch_forget_request_serialization() {
1913 use crate::memory::{BatchForgetRequest, BatchMemoryFilter};
1914 let filter = BatchMemoryFilter {
1915 created_before: Some(1_700_000_000),
1916 ..Default::default()
1917 };
1918 let req = BatchForgetRequest::new("agent-1", filter);
1919 let json = serde_json::to_value(&req).unwrap();
1920 assert_eq!(json["agent_id"], "agent-1");
1921 assert_eq!(json["filter"]["created_before"], 1_700_000_000u64);
1922 }
1923
1924 #[test]
1925 fn test_batch_recall_response_deserialization() {
1926 use crate::memory::BatchRecallResponse;
1927 let json = serde_json::json!({
1928 "memories": [],
1929 "total": 42,
1930 "filtered": 7
1931 });
1932 let resp: BatchRecallResponse = serde_json::from_value(json).unwrap();
1933 assert_eq!(resp.total, 42);
1934 assert_eq!(resp.filtered, 7);
1935 assert!(resp.memories.is_empty());
1936 }
1937
1938 #[test]
1939 fn test_batch_forget_response_deserialization() {
1940 use crate::memory::BatchForgetResponse;
1941 let json = serde_json::json!({ "deleted_count": 13 });
1942 let resp: BatchForgetResponse = serde_json::from_value(json).unwrap();
1943 assert_eq!(resp.deleted_count, 13);
1944 }
1945
1946 #[test]
1951 fn test_rate_limit_headers_default_all_none() {
1952 use crate::types::RateLimitHeaders;
1953 let rl = RateLimitHeaders {
1954 limit: None,
1955 remaining: None,
1956 reset: None,
1957 quota_used: None,
1958 quota_limit: None,
1959 };
1960 assert!(rl.limit.is_none());
1961 assert!(rl.remaining.is_none());
1962 assert!(rl.reset.is_none());
1963 assert!(rl.quota_used.is_none());
1964 assert!(rl.quota_limit.is_none());
1965 }
1966
1967 #[test]
1968 fn test_rate_limit_headers_populated() {
1969 use crate::types::RateLimitHeaders;
1970 let rl = RateLimitHeaders {
1971 limit: Some(1000),
1972 remaining: Some(750),
1973 reset: Some(1_700_000_060),
1974 quota_used: Some(500),
1975 quota_limit: Some(10_000),
1976 };
1977 assert_eq!(rl.limit, Some(1000));
1978 assert_eq!(rl.remaining, Some(750));
1979 assert_eq!(rl.reset, Some(1_700_000_060));
1980 assert_eq!(rl.quota_used, Some(500));
1981 assert_eq!(rl.quota_limit, Some(10_000));
1982 }
1983
1984 #[test]
1985 fn test_last_rate_limit_headers_initially_none() {
1986 let client = DakeraClient::new("http://localhost:3000").unwrap();
1987 assert!(client.last_rate_limit_headers().is_none());
1988 }
1989
1990 #[test]
1995 fn test_namespace_ner_config_default() {
1996 use crate::types::NamespaceNerConfig;
1997 let cfg = NamespaceNerConfig::default();
1998 assert!(!cfg.extract_entities);
1999 assert!(cfg.entity_types.is_none());
2000 }
2001
2002 #[test]
2003 fn test_namespace_ner_config_serialization_skip_none() {
2004 use crate::types::NamespaceNerConfig;
2005 let cfg = NamespaceNerConfig {
2006 extract_entities: true,
2007 entity_types: None,
2008 };
2009 let json = serde_json::to_value(&cfg).unwrap();
2010 assert_eq!(json["extract_entities"], true);
2011 assert!(json.get("entity_types").is_none());
2013 }
2014
2015 #[test]
2016 fn test_namespace_ner_config_serialization_with_types() {
2017 use crate::types::NamespaceNerConfig;
2018 let cfg = NamespaceNerConfig {
2019 extract_entities: true,
2020 entity_types: Some(vec!["PERSON".to_string(), "ORG".to_string()]),
2021 };
2022 let json = serde_json::to_value(&cfg).unwrap();
2023 assert_eq!(json["extract_entities"], true);
2024 assert_eq!(json["entity_types"][0], "PERSON");
2025 assert_eq!(json["entity_types"][1], "ORG");
2026 }
2027
2028 #[test]
2029 fn test_extracted_entity_deserialization() {
2030 use crate::types::ExtractedEntity;
2031 let json = serde_json::json!({
2032 "entity_type": "PERSON",
2033 "value": "Alice",
2034 "score": 0.95
2035 });
2036 let entity: ExtractedEntity = serde_json::from_value(json).unwrap();
2037 assert_eq!(entity.entity_type, "PERSON");
2038 assert_eq!(entity.value, "Alice");
2039 assert!((entity.score - 0.95).abs() < f64::EPSILON);
2040 }
2041
2042 #[test]
2043 fn test_entity_extraction_response_deserialization() {
2044 use crate::types::EntityExtractionResponse;
2045 let json = serde_json::json!({
2046 "entities": [
2047 { "entity_type": "PERSON", "value": "Bob", "score": 0.9 },
2048 { "entity_type": "ORG", "value": "Acme", "score": 0.87 }
2049 ]
2050 });
2051 let resp: EntityExtractionResponse = serde_json::from_value(json).unwrap();
2052 assert_eq!(resp.entities.len(), 2);
2053 assert_eq!(resp.entities[0].entity_type, "PERSON");
2054 assert_eq!(resp.entities[1].value, "Acme");
2055 }
2056
2057 #[test]
2058 fn test_memory_entities_response_deserialization() {
2059 use crate::types::MemoryEntitiesResponse;
2060 let json = serde_json::json!({
2061 "memory_id": "mem-abc-123",
2062 "entities": [
2063 { "entity_type": "LOC", "value": "London", "score": 0.88 }
2064 ]
2065 });
2066 let resp: MemoryEntitiesResponse = serde_json::from_value(json).unwrap();
2067 assert_eq!(resp.memory_id, "mem-abc-123");
2068 assert_eq!(resp.entities.len(), 1);
2069 assert_eq!(resp.entities[0].entity_type, "LOC");
2070 assert_eq!(resp.entities[0].value, "London");
2071 }
2072
2073 #[test]
2074 fn test_configure_namespace_ner_url_pattern() {
2075 let client = DakeraClient::new("http://localhost:3000").unwrap();
2077 let expected = "http://localhost:3000/v1/namespaces/my-ns/config";
2078 let actual = format!("{}/v1/namespaces/{}/config", client.base_url, "my-ns");
2079 assert_eq!(actual, expected);
2080 }
2081
2082 #[test]
2083 fn test_extract_entities_url_pattern() {
2084 let client = DakeraClient::new("http://localhost:3000").unwrap();
2085 let expected = "http://localhost:3000/v1/memories/extract";
2086 let actual = format!("{}/v1/memories/extract", client.base_url);
2087 assert_eq!(actual, expected);
2088 }
2089
2090 #[test]
2091 fn test_memory_entities_url_pattern() {
2092 let client = DakeraClient::new("http://localhost:3000").unwrap();
2093 let memory_id = "mem-xyz-789";
2094 let expected = "http://localhost:3000/v1/memory/entities/mem-xyz-789";
2095 let actual = format!("{}/v1/memory/entities/{}", client.base_url, memory_id);
2096 assert_eq!(actual, expected);
2097 }
2098
2099 #[test]
2104 fn test_feedback_signal_serialization() {
2105 use crate::types::FeedbackSignal;
2106 let upvote = serde_json::to_value(FeedbackSignal::Upvote).unwrap();
2107 assert_eq!(upvote, serde_json::json!("upvote"));
2108 let downvote = serde_json::to_value(FeedbackSignal::Downvote).unwrap();
2109 assert_eq!(downvote, serde_json::json!("downvote"));
2110 let flag = serde_json::to_value(FeedbackSignal::Flag).unwrap();
2111 assert_eq!(flag, serde_json::json!("flag"));
2112 }
2113
2114 #[test]
2115 fn test_feedback_signal_deserialization() {
2116 use crate::types::FeedbackSignal;
2117 let signal: FeedbackSignal = serde_json::from_str("\"upvote\"").unwrap();
2118 assert_eq!(signal, FeedbackSignal::Upvote);
2119 let signal: FeedbackSignal = serde_json::from_str("\"positive\"").unwrap();
2120 assert_eq!(signal, FeedbackSignal::Positive);
2121 }
2122
2123 #[test]
2124 fn test_feedback_response_deserialization() {
2125 use crate::types::{FeedbackResponse, FeedbackSignal};
2126 let json = serde_json::json!({
2127 "memory_id": "mem-abc",
2128 "new_importance": 0.92,
2129 "signal": "upvote"
2130 });
2131 let resp: FeedbackResponse = serde_json::from_value(json).unwrap();
2132 assert_eq!(resp.memory_id, "mem-abc");
2133 assert!((resp.new_importance - 0.92).abs() < f32::EPSILON);
2134 assert_eq!(resp.signal, FeedbackSignal::Upvote);
2135 }
2136
2137 #[test]
2138 fn test_feedback_history_response_deserialization() {
2139 use crate::types::{FeedbackHistoryResponse, FeedbackSignal};
2140 let json = serde_json::json!({
2141 "memory_id": "mem-abc",
2142 "entries": [
2143 {"signal": "upvote", "timestamp": 1774000000_u64, "old_importance": 0.5, "new_importance": 0.575},
2144 {"signal": "downvote", "timestamp": 1774001000_u64, "old_importance": 0.575, "new_importance": 0.489}
2145 ]
2146 });
2147 let resp: FeedbackHistoryResponse = serde_json::from_value(json).unwrap();
2148 assert_eq!(resp.memory_id, "mem-abc");
2149 assert_eq!(resp.entries.len(), 2);
2150 assert_eq!(resp.entries[0].signal, FeedbackSignal::Upvote);
2151 assert_eq!(resp.entries[1].signal, FeedbackSignal::Downvote);
2152 }
2153
2154 #[test]
2155 fn test_agent_feedback_summary_deserialization() {
2156 use crate::types::AgentFeedbackSummary;
2157 let json = serde_json::json!({
2158 "agent_id": "agent-1",
2159 "upvotes": 42_u64,
2160 "downvotes": 7_u64,
2161 "flags": 2_u64,
2162 "total_feedback": 51_u64,
2163 "health_score": 0.78
2164 });
2165 let summary: AgentFeedbackSummary = serde_json::from_value(json).unwrap();
2166 assert_eq!(summary.agent_id, "agent-1");
2167 assert_eq!(summary.upvotes, 42);
2168 assert_eq!(summary.total_feedback, 51);
2169 assert!((summary.health_score - 0.78).abs() < f32::EPSILON);
2170 }
2171
2172 #[test]
2173 fn test_feedback_health_response_deserialization() {
2174 use crate::types::FeedbackHealthResponse;
2175 let json = serde_json::json!({
2176 "agent_id": "agent-1",
2177 "health_score": 0.78,
2178 "memory_count": 120_usize,
2179 "avg_importance": 0.72
2180 });
2181 let health: FeedbackHealthResponse = serde_json::from_value(json).unwrap();
2182 assert_eq!(health.agent_id, "agent-1");
2183 assert!((health.health_score - 0.78).abs() < f32::EPSILON);
2184 assert_eq!(health.memory_count, 120);
2185 }
2186
2187 #[test]
2188 fn test_memory_feedback_body_serialization() {
2189 use crate::types::{FeedbackSignal, MemoryFeedbackBody};
2190 let body = MemoryFeedbackBody {
2191 agent_id: "agent-1".to_string(),
2192 signal: FeedbackSignal::Flag,
2193 };
2194 let json = serde_json::to_value(body).unwrap();
2195 assert_eq!(json["agent_id"], "agent-1");
2196 assert_eq!(json["signal"], "flag");
2197 }
2198
2199 #[test]
2200 fn test_feedback_memory_url_pattern() {
2201 let client = DakeraClient::new("http://localhost:3000").unwrap();
2202 let memory_id = "mem-abc";
2203 let expected_post = "http://localhost:3000/v1/memories/mem-abc/feedback";
2204 let actual_post = format!("{}/v1/memories/{}/feedback", client.base_url, memory_id);
2205 assert_eq!(actual_post, expected_post);
2206
2207 let expected_patch = "http://localhost:3000/v1/memories/mem-abc/importance";
2208 let actual_patch = format!("{}/v1/memories/{}/importance", client.base_url, memory_id);
2209 assert_eq!(actual_patch, expected_patch);
2210 }
2211
2212 #[test]
2213 fn test_feedback_health_url_pattern() {
2214 let client = DakeraClient::new("http://localhost:3000").unwrap();
2215 let agent_id = "agent-1";
2216 let expected = "http://localhost:3000/v1/feedback/health?agent_id=agent-1";
2217 let actual = format!(
2218 "{}/v1/feedback/health?agent_id={}",
2219 client.base_url, agent_id
2220 );
2221 assert_eq!(actual, expected);
2222 }
2223
2224 #[test]
2226 fn test_ode_extract_entities_requires_ode_url() {
2227 let client = DakeraClient::new("http://localhost:3000").unwrap();
2229 let rt = tokio::runtime::Runtime::new().unwrap();
2230 let result = rt.block_on(client.ode_extract_entities(ExtractEntitiesRequest {
2231 content: "Alice lives in Paris.".to_string(),
2232 agent_id: "agent-1".to_string(),
2233 memory_id: None,
2234 entity_types: None,
2235 }));
2236 assert!(result.is_err());
2237 let err = result.unwrap_err();
2238 assert!(matches!(err, ClientError::Config(_)));
2239 }
2240
2241 #[test]
2242 fn test_ode_extract_entities_url_built_from_ode_url() {
2243 let client = DakeraClient::builder("http://localhost:3000")
2245 .ode_url("http://localhost:8080")
2246 .build()
2247 .unwrap();
2248 assert_eq!(client.ode_url.as_deref(), Some("http://localhost:8080"));
2249 let expected = "http://localhost:8080/ode/extract";
2250 let actual = format!("{}/ode/extract", client.ode_url.as_deref().unwrap());
2251 assert_eq!(actual, expected);
2252 }
2253
2254 #[test]
2255 fn test_extract_entities_request_serialization() {
2256 let req = ExtractEntitiesRequest {
2257 content: "Alice in Wonderland".to_string(),
2258 agent_id: "agent-42".to_string(),
2259 memory_id: Some("mem-001".to_string()),
2260 entity_types: Some(vec!["person".to_string(), "location".to_string()]),
2261 };
2262 let json = serde_json::to_string(&req).unwrap();
2263 assert!(json.contains("\"content\":\"Alice in Wonderland\""));
2264 assert!(json.contains("\"agent_id\":\"agent-42\""));
2265 assert!(json.contains("\"memory_id\":\"mem-001\""));
2266 assert!(json.contains("\"person\""));
2267 }
2268
2269 #[test]
2270 fn test_extract_entities_request_omits_none_fields() {
2271 let req = ExtractEntitiesRequest {
2272 content: "hello".to_string(),
2273 agent_id: "a".to_string(),
2274 memory_id: None,
2275 entity_types: None,
2276 };
2277 let json = serde_json::to_string(&req).unwrap();
2278 assert!(!json.contains("memory_id"));
2279 assert!(!json.contains("entity_types"));
2280 }
2281
2282 #[test]
2283 fn test_ode_entity_deserialization() {
2284 let json = r#"{"text":"Alice","label":"person","start":0,"end":5,"score":0.97}"#;
2285 let entity: OdeEntity = serde_json::from_str(json).unwrap();
2286 assert_eq!(entity.text, "Alice");
2287 assert_eq!(entity.label, "person");
2288 assert_eq!(entity.start, 0);
2289 assert_eq!(entity.end, 5);
2290 assert!((entity.score - 0.97).abs() < 1e-4);
2291 }
2292
2293 #[test]
2294 fn test_extract_entities_response_deserialization() {
2295 let json = r#"{
2296 "entities": [
2297 {"text":"Alice","label":"person","start":0,"end":5,"score":0.97},
2298 {"text":"Paris","label":"location","start":16,"end":21,"score":0.92}
2299 ],
2300 "model": "gliner-multi-v2.1",
2301 "processing_time_ms": 34
2302 }"#;
2303 let resp: ExtractEntitiesResponse = serde_json::from_str(json).unwrap();
2304 assert_eq!(resp.entities.len(), 2);
2305 assert_eq!(resp.entities[0].text, "Alice");
2306 assert_eq!(resp.model, "gliner-multi-v2.1");
2307 assert_eq!(resp.processing_time_ms, 34);
2308 }
2309}