1use std::future::Future;
11use std::sync::Arc;
12use std::time::Instant;
13
14use crate::ai::AiProvider;
15use crate::json::{Map, Value as JsonValue};
16use crate::runtime::ai::dedup_cache::{
17 EmbeddingDedupCache, DEFAULT_DEDUP_LRU_SIZE, DEFAULT_DEDUP_TTL_MS,
18};
19use crate::runtime::ai::text_chunker::{ChunkMode, DEFAULT_MAX_TOKENS};
20use crate::runtime::ai::transport::{AiHttpRequest, AiTransport, AiTransportError};
21use crate::runtime::audit_log::AuditLogger;
22
23pub const CONFIG_MAX_BATCH_SIZE: &str = "runtime.ai.embedding_max_batch_size";
24pub const DEFAULT_OPENAI_MAX_BATCH: usize = 2048;
25pub const DEFAULT_OTHER_MAX_BATCH: usize = 256;
26
27pub struct SubBatchRequest {
29 pub provider: String,
30 pub api_key: String,
31 pub api_base: String,
32 pub model: String,
33 pub inputs: Vec<String>,
34}
35
36pub struct SubBatchResponse {
37 pub embeddings: Vec<Vec<f32>>,
38 pub model: String,
39 pub prompt_tokens: Option<u64>,
40 pub total_tokens: Option<u64>,
41 pub attempt_count: u32,
42 pub total_wait_ms: u64,
43}
44
45pub trait SubBatchSender: Send + Sync {
47 fn send(
48 &self,
49 request: SubBatchRequest,
50 ) -> impl Future<Output = Result<SubBatchResponse, AiTransportError>> + Send + '_;
51}
52
53pub struct AiTransportSender {
55 pub transport: AiTransport,
56}
57
58impl SubBatchSender for AiTransportSender {
59 #[allow(clippy::manual_async_fn)]
60 fn send(
61 &self,
62 request: SubBatchRequest,
63 ) -> impl Future<Output = Result<SubBatchResponse, AiTransportError>> + Send + '_ {
64 async move {
65 let payload = crate::ai::build_embedding_payload(&request.model, &request.inputs);
66 let url = format!("{}/embeddings", request.api_base.trim_end_matches('/'));
67 let http_req = AiHttpRequest::post_json(request.provider.as_str(), url, payload)
68 .model(request.model.clone())
69 .header("authorization", format!("Bearer {}", request.api_key));
70
71 let response = self.transport.request(http_req).await?;
72
73 let parsed = crate::ai::parse_embedding_response(&response.body).map_err(|msg| {
74 AiTransportError {
75 provider: request.provider.clone(),
76 status_code: None,
77 attempt_count: 1,
78 total_wait_ms: 0,
79 message: msg,
80 }
81 })?;
82
83 Ok(SubBatchResponse {
84 embeddings: parsed.embeddings,
85 model: parsed.model,
86 prompt_tokens: parsed.prompt_tokens,
87 total_tokens: parsed.total_tokens,
88 attempt_count: response.attempt_count,
89 total_wait_ms: response.total_wait_ms,
90 })
91 }
92 }
93}
94
95pub struct AiBatchClient<S = AiTransportSender> {
100 sender: S,
101 max_batch_size_override: Option<usize>,
102 dedup_cache: Option<Arc<EmbeddingDedupCache>>,
104 chunk_mode: ChunkMode,
106 max_tokens: usize,
108 audit_log: Option<Arc<AuditLogger>>,
109}
110
111impl AiBatchClient<AiTransportSender> {
112 pub fn new(transport: AiTransport) -> Self {
113 Self {
114 sender: AiTransportSender { transport },
115 max_batch_size_override: None,
116 dedup_cache: None,
117 chunk_mode: ChunkMode::Single,
118 max_tokens: DEFAULT_MAX_TOKENS,
119 audit_log: None,
120 }
121 }
122
123 pub fn from_runtime(runtime: &crate::runtime::RedDBRuntime) -> Self {
124 use crate::runtime::ai::dedup_cache::{
125 CONFIG_DEDUP_ENABLED, CONFIG_DEDUP_LRU_SIZE, CONFIG_DEDUP_TTL_MS,
126 };
127 use crate::runtime::ai::text_chunker::{CONFIG_CHUNK_MODE, CONFIG_MAX_TOKENS};
128 use std::time::Duration;
129
130 let transport = AiTransport::from_runtime(runtime);
131 let dedup_enabled = runtime.config_bool(CONFIG_DEDUP_ENABLED, false);
132 let dedup_cache = if dedup_enabled {
133 let lru_size =
134 runtime.config_u64(CONFIG_DEDUP_LRU_SIZE, DEFAULT_DEDUP_LRU_SIZE as u64) as usize;
135 let ttl_ms = runtime.config_u64(CONFIG_DEDUP_TTL_MS, DEFAULT_DEDUP_TTL_MS);
136 Some(Arc::new(EmbeddingDedupCache::new(
137 lru_size,
138 Duration::from_millis(ttl_ms),
139 )))
140 } else {
141 None
142 };
143 let chunk_mode = ChunkMode::from_str(&runtime.config_string(CONFIG_CHUNK_MODE, "single"));
144 let max_tokens = runtime.config_u64(CONFIG_MAX_TOKENS, DEFAULT_MAX_TOKENS as u64) as usize;
145
146 Self {
147 sender: AiTransportSender { transport },
148 max_batch_size_override: None,
149 dedup_cache,
150 chunk_mode,
151 max_tokens,
152 audit_log: Some(runtime.audit_log_arc()),
153 }
154 }
155}
156
157impl<S: SubBatchSender> AiBatchClient<S> {
158 pub fn with_sender(sender: S) -> Self {
160 Self {
161 sender,
162 max_batch_size_override: None,
163 dedup_cache: None,
164 chunk_mode: ChunkMode::Single,
165 max_tokens: DEFAULT_MAX_TOKENS,
166 audit_log: None,
167 }
168 }
169
170 pub fn with_max_batch_size(mut self, size: usize) -> Self {
172 self.max_batch_size_override = Some(size.max(1));
173 self
174 }
175
176 pub fn with_dedup_cache(mut self, cache: Arc<EmbeddingDedupCache>) -> Self {
178 self.dedup_cache = Some(cache);
179 self
180 }
181
182 pub fn with_chunk_mode(mut self, mode: ChunkMode) -> Self {
184 self.chunk_mode = mode;
185 self
186 }
187
188 pub fn with_max_tokens(mut self, max: usize) -> Self {
190 self.max_tokens = max.max(1);
191 self
192 }
193
194 pub fn with_audit_log(mut self, audit_log: Arc<AuditLogger>) -> Self {
195 self.audit_log = Some(audit_log);
196 self
197 }
198
199 pub async fn embed_batch(
211 &self,
212 provider: &AiProvider,
213 model: &str,
214 api_key: &str,
215 texts: Vec<String>,
216 ) -> Result<Vec<Vec<f32>>, AiTransportError> {
217 if texts.is_empty() {
218 return Ok(vec![]);
219 }
220
221 let max_batch = self
222 .max_batch_size_override
223 .unwrap_or_else(|| default_max_batch_size(provider));
224 let api_base = provider.resolve_api_base();
225 let started = Instant::now();
226 let mut local_dedup_hits = 0u64;
227 let mut any_chunked = false;
228 let mut retries_total = 0u64;
229 let mut total_wait_ms = 0u64;
230 let mut prompt_tokens_total = 0u64;
231 let mut total_tokens_total = 0u64;
232
233 let mut chunked_texts: Vec<String> = Vec::with_capacity(texts.len());
236 for t in &texts {
237 let chunks = crate::runtime::ai::text_chunker::chunk(t, self.max_tokens);
238 if chunks.len() > 1 {
239 any_chunked = true;
240 }
241 let chosen = crate::runtime::ai::text_chunker::apply_mode(chunks, self.chunk_mode);
242 chunked_texts.push(chosen.into_iter().next().unwrap_or_default());
243 }
244
245 let mut result: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
248
249 let mut unique_text_index: std::collections::HashMap<String, usize> =
252 std::collections::HashMap::new();
253 let mut unique_texts_to_embed: Vec<String> = Vec::new();
254
255 let mut pos_to_unique: Vec<Option<usize>> = vec![None; texts.len()];
257
258 for (i, text) in chunked_texts.iter().enumerate() {
259 if text.trim().is_empty() {
260 result[i] = Some(vec![]);
261 continue;
262 }
263 if let Some(cache) = &self.dedup_cache {
267 if let Some(cached) = cache.get(text) {
268 local_dedup_hits = local_dedup_hits.saturating_add(1);
269 result[i] = Some(cached);
270 continue;
271 }
272 }
273 let unique_idx = if let Some(&existing) = unique_text_index.get(text.as_str()) {
275 existing
276 } else {
277 let idx = unique_texts_to_embed.len();
278 unique_text_index.insert(text.clone(), idx);
279 unique_texts_to_embed.push(text.clone());
280 idx
281 };
282 pos_to_unique[i] = Some(unique_idx);
283 }
284
285 let mut unique_embeddings: Vec<Vec<f32>> = vec![vec![]; unique_texts_to_embed.len()];
287
288 for chunk in unique_texts_to_embed.chunks(max_batch) {
289 crate::runtime::ai::metrics::record_batch_size(provider.token(), chunk.len());
290 let chunk_start = {
293 let base = unique_texts_to_embed.as_ptr();
295 let ptr = chunk.as_ptr();
296 (ptr as usize - base as usize) / std::mem::size_of::<String>()
297 };
298
299 let request = SubBatchRequest {
300 provider: provider.token().to_string(),
301 api_key: api_key.to_string(),
302 api_base: api_base.clone(),
303 model: model.to_string(),
304 inputs: chunk.to_vec(),
305 };
306
307 let response = match self.sender.send(request).await {
308 Ok(response) => response,
309 Err(err) => {
310 self.record_error_audit(provider.token(), &err);
311 return Err(err);
312 }
313 };
314 retries_total =
315 retries_total.saturating_add(u64::from(response.attempt_count.saturating_sub(1)));
316 total_wait_ms = total_wait_ms.saturating_add(response.total_wait_ms);
317 if let Some(tokens) = response.prompt_tokens {
318 prompt_tokens_total = prompt_tokens_total.saturating_add(tokens);
319 }
320 if let Some(tokens) = response.total_tokens {
321 total_tokens_total = total_tokens_total.saturating_add(tokens);
322 }
323 let token_metric = response
324 .prompt_tokens
325 .unwrap_or(0)
326 .saturating_add(response.total_tokens.unwrap_or(0));
327 crate::runtime::ai::metrics::record_tokens(
328 provider.token(),
329 &response.model,
330 token_metric,
331 );
332 let embeddings = response.embeddings;
333
334 if embeddings.len() != chunk.len() {
335 let err = AiTransportError {
336 provider: provider.token().to_string(),
337 status_code: None,
338 attempt_count: 0,
339 total_wait_ms: 0,
340 message: format!(
341 "provider returned {} embeddings for {} inputs",
342 embeddings.len(),
343 chunk.len()
344 ),
345 };
346 self.record_error_audit(provider.token(), &err);
347 return Err(err);
348 }
349
350 for (j, embedding) in embeddings.into_iter().enumerate() {
351 let unique_idx = chunk_start + j;
352 if let Some(cache) = &self.dedup_cache {
354 cache.insert(&unique_texts_to_embed[unique_idx], embedding.clone());
355 }
356 unique_embeddings[unique_idx] = embedding;
357 }
358 }
359
360 for (i, unique_idx_opt) in pos_to_unique.into_iter().enumerate() {
362 if let Some(unique_idx) = unique_idx_opt {
363 result[i] = Some(unique_embeddings[unique_idx].clone());
364 }
365 }
366
367 self.record_batch_audit(BatchAudit {
368 provider: provider.token(),
369 model,
370 batch_size: texts.len(),
371 total_tokens: total_tokens_total,
372 duration_ms: millis_u64(started.elapsed()),
373 retries: retries_total,
374 dedup_hits: local_dedup_hits,
375 chunked: any_chunked,
376 total_wait_ms,
377 prompt_tokens: prompt_tokens_total,
378 });
379
380 Ok(result.into_iter().map(|v| v.unwrap_or_default()).collect())
381 }
382
383 fn record_batch_audit(&self, audit: BatchAudit<'_>) {
384 tracing::info!(
385 target: "reddb::developer",
386 provider = audit.provider,
387 model = audit.model,
388 batch_size = audit.batch_size,
389 total_tokens = audit.total_tokens,
390 duration_ms = audit.duration_ms,
391 retries = audit.retries,
392 dedup_hits = audit.dedup_hits,
393 chunked = audit.chunked,
394 "ai embedding batch completed"
395 );
396
397 let Some(audit_log) = &self.audit_log else {
398 return;
399 };
400 let mut details = Map::new();
401 details.insert(
402 "provider".to_string(),
403 JsonValue::String(audit.provider.to_string()),
404 );
405 details.insert(
406 "model".to_string(),
407 JsonValue::String(audit.model.to_string()),
408 );
409 details.insert(
410 "batch_size".to_string(),
411 JsonValue::Number(audit.batch_size as f64),
412 );
413 details.insert(
414 "total_tokens".to_string(),
415 JsonValue::Number(audit.total_tokens as f64),
416 );
417 details.insert(
418 "duration_ms".to_string(),
419 JsonValue::Number(audit.duration_ms as f64),
420 );
421 details.insert(
422 "retries".to_string(),
423 JsonValue::Number(audit.retries as f64),
424 );
425 details.insert(
426 "dedup_hits".to_string(),
427 JsonValue::Number(audit.dedup_hits as f64),
428 );
429 details.insert("chunked".to_string(), JsonValue::Bool(audit.chunked));
430 details.insert(
431 "total_wait_ms".to_string(),
432 JsonValue::Number(audit.total_wait_ms as f64),
433 );
434 details.insert(
435 "prompt_tokens".to_string(),
436 JsonValue::Number(audit.prompt_tokens as f64),
437 );
438 audit_log.record(
439 "ai/embedding_batch",
440 "system",
441 audit.provider,
442 "ok",
443 JsonValue::Object(details),
444 );
445 }
446
447 fn record_error_audit(&self, provider: &str, err: &AiTransportError) {
448 tracing::warn!(
449 target: "reddb::developer",
450 provider = provider,
451 status_code = err.status_code.unwrap_or(0),
452 attempt_count = err.attempt_count,
453 total_wait_ms = err.total_wait_ms,
454 "ai embedding provider error"
455 );
456
457 let Some(audit_log) = &self.audit_log else {
458 return;
459 };
460 let mut details = Map::new();
461 details.insert(
462 "provider".to_string(),
463 JsonValue::String(provider.to_string()),
464 );
465 details.insert(
466 "status_code".to_string(),
467 err.status_code
468 .map(|status| JsonValue::Number(status as f64))
469 .unwrap_or(JsonValue::Null),
470 );
471 details.insert(
472 "attempt_count".to_string(),
473 JsonValue::Number(err.attempt_count as f64),
474 );
475 details.insert(
476 "total_wait_ms".to_string(),
477 JsonValue::Number(err.total_wait_ms as f64),
478 );
479 audit_log.record(
480 "ai/embedding_error",
481 "system",
482 provider,
483 "error",
484 JsonValue::Object(details),
485 );
486 }
487}
488
489struct BatchAudit<'a> {
490 provider: &'a str,
491 model: &'a str,
492 batch_size: usize,
493 total_tokens: u64,
494 duration_ms: u64,
495 retries: u64,
496 dedup_hits: u64,
497 chunked: bool,
498 total_wait_ms: u64,
499 prompt_tokens: u64,
500}
501
502fn millis_u64(duration: std::time::Duration) -> u64 {
503 duration.as_millis().min(u128::from(u64::MAX)) as u64
504}
505
506fn default_max_batch_size(provider: &AiProvider) -> usize {
507 match provider {
508 AiProvider::OpenAi
509 | AiProvider::OpenRouter
510 | AiProvider::Together
511 | AiProvider::Venice
512 | AiProvider::Groq
513 | AiProvider::DeepSeek
514 | AiProvider::Custom(_) => DEFAULT_OPENAI_MAX_BATCH,
515 _ => DEFAULT_OTHER_MAX_BATCH,
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522 use std::sync::atomic::{AtomicUsize, Ordering};
523 use std::sync::Arc;
524 use std::time::Duration;
525
526 struct MockSender {
527 call_count: Arc<AtomicUsize>,
528 dims: usize,
529 }
530
531 impl SubBatchSender for MockSender {
532 fn send(
533 &self,
534 request: SubBatchRequest,
535 ) -> impl Future<Output = Result<SubBatchResponse, AiTransportError>> + Send + '_ {
536 let n = request.inputs.len();
537 let dims = self.dims;
538 self.call_count.fetch_add(1, Ordering::SeqCst);
539 async move {
540 Ok(SubBatchResponse {
541 embeddings: (0..n).map(|_| vec![0.1f32; dims]).collect(),
542 model: request.model,
543 prompt_tokens: Some(n as u64),
544 total_tokens: Some(n as u64),
545 attempt_count: 1,
546 total_wait_ms: 0,
547 })
548 }
549 }
550 }
551
552 fn mock_client(dims: usize) -> (AiBatchClient<MockSender>, Arc<AtomicUsize>) {
553 let counter = Arc::new(AtomicUsize::new(0));
554 let client = AiBatchClient::with_sender(MockSender {
555 call_count: Arc::clone(&counter),
556 dims,
557 });
558 (client, counter)
559 }
560
561 #[tokio::test]
562 async fn embed_three_texts_returns_three_vectors() {
563 let (client, _) = mock_client(3);
564 let result = client
565 .embed_batch(
566 &AiProvider::OpenAi,
567 "model",
568 "key",
569 vec!["a".into(), "b".into(), "c".into()],
570 )
571 .await
572 .unwrap();
573 assert_eq!(result.len(), 3);
574 assert!(result.iter().all(|v| v.len() == 3));
575 }
576
577 #[tokio::test]
578 async fn embed_empty_input_zero_requests() {
579 let (client, counter) = mock_client(3);
580 let result = client
581 .embed_batch(&AiProvider::OpenAi, "model", "key", vec![])
582 .await
583 .unwrap();
584 assert!(result.is_empty());
585 assert_eq!(counter.load(Ordering::SeqCst), 0);
586 }
587
588 #[tokio::test]
589 async fn embed_1000_inputs_single_request_openai() {
590 let (client, counter) = mock_client(4);
591 let texts: Vec<String> = (0..1000).map(|i| format!("text {i}")).collect();
592 let result = client
593 .embed_batch(&AiProvider::OpenAi, "model", "key", texts)
594 .await
595 .unwrap();
596 assert_eq!(result.len(), 1000);
597 assert_eq!(counter.load(Ordering::SeqCst), 1);
599 }
600
601 #[tokio::test]
602 async fn embed_splits_when_over_max_batch() {
603 let (client, counter) = mock_client(2);
604 let client = client.with_max_batch_size(3);
605 let texts: Vec<String> = (0..7).map(|i| format!("t{i}")).collect();
606 let result = client
607 .embed_batch(&AiProvider::OpenAi, "model", "key", texts)
608 .await
609 .unwrap();
610 assert_eq!(result.len(), 7);
611 assert_eq!(counter.load(Ordering::SeqCst), 3);
613 }
614
615 #[tokio::test]
616 async fn embed_records_batch_size_and_token_metrics() {
617 let (client, _) = mock_client(2);
618 let provider = AiProvider::Custom("test_batch_metrics_provider".to_string());
619 let _ = client
620 .with_max_batch_size(2)
621 .embed_batch(
622 &provider,
623 "test-batch-metrics-model",
624 "key",
625 vec!["a".into(), "b".into(), "c".into()],
626 )
627 .await
628 .unwrap();
629
630 let mut body = String::new();
631 crate::runtime::ai::metrics::render_ai_metrics(&mut body);
632 assert!(
633 body.contains(
634 "reddb_ai_embedding_batch_size_count{provider=\"test_batch_metrics_provider\"} 2"
635 ),
636 "{body}"
637 );
638 assert!(
639 body.contains(
640 "reddb_ai_text_tokens_total{provider=\"test_batch_metrics_provider\",model=\"test-batch-metrics-model\"} 6"
641 ),
642 "{body}"
643 );
644 }
645
646 #[tokio::test]
647 async fn embed_empty_strings_skipped_positions_preserved() {
648 let (client, counter) = mock_client(2);
649 let texts = vec![
650 "".to_string(),
651 "hello".to_string(),
652 " ".to_string(),
653 "world".to_string(),
654 ];
655 let result = client
656 .embed_batch(&AiProvider::OpenAi, "model", "key", texts)
657 .await
658 .unwrap();
659 assert_eq!(result.len(), 4);
660 assert!(result[0].is_empty(), "empty string → empty vec");
661 assert_eq!(result[1].len(), 2, "hello → embedding");
662 assert!(result[2].is_empty(), "whitespace-only → empty vec");
663 assert_eq!(result[3].len(), 2, "world → embedding");
664 assert_eq!(counter.load(Ordering::SeqCst), 1);
666 }
667
668 #[tokio::test]
669 async fn embed_error_propagated() {
670 struct ErrorSender;
671
672 impl SubBatchSender for ErrorSender {
673 fn send(
674 &self,
675 request: SubBatchRequest,
676 ) -> impl Future<Output = Result<SubBatchResponse, AiTransportError>> + Send + '_
677 {
678 async move {
679 Err(AiTransportError {
680 provider: request.provider,
681 status_code: Some(500),
682 attempt_count: 3,
683 total_wait_ms: 2000,
684 message: "server error".to_string(),
685 })
686 }
687 }
688 }
689
690 let client = AiBatchClient::with_sender(ErrorSender);
691 let err = client
692 .embed_batch(
693 &AiProvider::OpenAi,
694 "model",
695 "key",
696 vec!["text".to_string()],
697 )
698 .await
699 .unwrap_err();
700 assert_eq!(err.status_code, Some(500));
701 assert_eq!(err.attempt_count, 3);
702 }
703
704 #[tokio::test]
705 async fn embed_writes_structured_audit_line_when_logger_attached() {
706 let (client, _) = mock_client(2);
707 let dir = tempfile::tempdir().unwrap();
708 let audit_path = dir.path().join(".audit.log");
709 let audit_log = Arc::new(AuditLogger::with_max_bytes(audit_path, 1024 * 1024));
710 let provider = AiProvider::Custom("test_audit_provider".to_string());
711
712 let _ = client
713 .with_audit_log(Arc::clone(&audit_log))
714 .embed_batch(
715 &provider,
716 "test-audit-model",
717 "key",
718 vec!["alpha".into(), "beta".into()],
719 )
720 .await
721 .unwrap();
722
723 assert!(audit_log.wait_idle(Duration::from_secs(2)));
724 let body = std::fs::read_to_string(audit_log.path()).unwrap();
725 assert!(body.contains("\"action\":\"ai/embedding_batch\""), "{body}");
726 assert!(
727 body.contains("\"provider\":\"test_audit_provider\""),
728 "{body}"
729 );
730 assert!(body.contains("\"model\":\"test-audit-model\""), "{body}");
731 assert!(body.contains("\"batch_size\":2"), "{body}");
732 assert!(body.contains("\"total_tokens\":2"), "{body}");
733 assert!(body.contains("\"duration_ms\""), "{body}");
734 assert!(body.contains("\"retries\":0"), "{body}");
735 assert!(body.contains("\"dedup_hits\":0"), "{body}");
736 assert!(body.contains("\"chunked\":false"), "{body}");
737 }
738
739 #[tokio::test]
740 async fn embed_order_preserved_across_batches() {
741 struct BatchNumberSender {
742 call_count: Arc<AtomicUsize>,
743 }
744
745 impl SubBatchSender for BatchNumberSender {
746 fn send(
747 &self,
748 request: SubBatchRequest,
749 ) -> impl Future<Output = Result<SubBatchResponse, AiTransportError>> + Send + '_
750 {
751 let call = self.call_count.fetch_add(1, Ordering::SeqCst);
752 let n = request.inputs.len();
753 async move {
754 Ok(SubBatchResponse {
756 embeddings: (0..n).map(|_| vec![call as f32]).collect(),
757 model: request.model,
758 prompt_tokens: Some(n as u64),
759 total_tokens: Some(n as u64),
760 attempt_count: 1,
761 total_wait_ms: 0,
762 })
763 }
764 }
765 }
766
767 let counter = Arc::new(AtomicUsize::new(0));
768 let client = AiBatchClient::with_sender(BatchNumberSender {
769 call_count: Arc::clone(&counter),
770 })
771 .with_max_batch_size(3);
772
773 let texts: Vec<String> = (0..5).map(|i| format!("t{i}")).collect();
775 let result = client
776 .embed_batch(&AiProvider::OpenAi, "model", "key", texts)
777 .await
778 .unwrap();
779
780 assert_eq!(result.len(), 5);
781 assert_eq!(counter.load(Ordering::SeqCst), 2);
782 assert_eq!(result[0], vec![0.0]);
784 assert_eq!(result[1], vec![0.0]);
785 assert_eq!(result[2], vec![0.0]);
786 assert_eq!(result[3], vec![1.0]);
788 assert_eq!(result[4], vec![1.0]);
789 }
790
791 #[tokio::test]
792 async fn default_max_batch_size_openai_is_2048() {
793 assert_eq!(default_max_batch_size(&AiProvider::OpenAi), 2048);
794 }
795
796 #[tokio::test]
797 async fn default_max_batch_size_ollama_is_256() {
798 assert_eq!(default_max_batch_size(&AiProvider::Ollama), 256);
799 }
800
801 #[tokio::test]
804 async fn dedup_on_1000_inputs_10_unique_sends_10_to_provider() {
805 let (base_client, counter) = mock_client(4);
806 let cache = Arc::new(EmbeddingDedupCache::new(1024, Duration::from_secs(60)));
807 let client = base_client.with_dedup_cache(Arc::clone(&cache));
808
809 let unique: Vec<String> = (0..10).map(|i| format!("unique text {i}")).collect();
810 let texts: Vec<String> = (0..1000).map(|i| unique[i % 10].clone()).collect();
811
812 let result = client
813 .embed_batch(&AiProvider::OpenAi, "model", "key", texts.clone())
814 .await
815 .unwrap();
816
817 assert_eq!(result.len(), 1000);
818 assert_eq!(counter.load(Ordering::SeqCst), 1, "1 sub-batch request");
820 assert_eq!(cache.misses(), 1000);
824 assert_eq!(cache.hits(), 0);
825
826 let result2 = client
829 .embed_batch(&AiProvider::OpenAi, "model", "key", texts)
830 .await
831 .unwrap();
832 assert_eq!(result2.len(), 1000);
833 assert_eq!(
835 counter.load(Ordering::SeqCst),
836 1,
837 "still 1 provider request total"
838 );
839 assert_eq!(cache.hits(), 1000, "all 1000 hit cache on second call");
840 }
841
842 #[tokio::test]
843 async fn dedup_off_by_default_all_texts_sent() {
844 let (client, counter) = mock_client(4);
845 let texts: Vec<String> = (0..10).map(|i| format!("text {i}")).collect();
847 let result = client
848 .embed_batch(&AiProvider::OpenAi, "model", "key", texts.clone())
849 .await
850 .unwrap();
851 assert_eq!(result.len(), 10);
852 let _ = client
854 .embed_batch(&AiProvider::OpenAi, "model", "key", texts)
855 .await
856 .unwrap();
857 assert_eq!(counter.load(Ordering::SeqCst), 2);
859 }
860
861 #[tokio::test]
862 async fn chunker_long_text_truncated_to_first_chunk_single_mode() {
863 let (base_client, counter) = mock_client(2);
865 let client = base_client.with_max_tokens(10); let long_text = "a".repeat(200); let result = client
869 .embed_batch(&AiProvider::OpenAi, "model", "key", vec![long_text])
870 .await
871 .unwrap();
872
873 assert_eq!(result.len(), 1);
874 assert_eq!(counter.load(Ordering::SeqCst), 1);
875 }
877}