Skip to main content

reddb_server/runtime/ai/
batch_client.rs

1//! Batch embedding client — issue #275.
2//!
3//! Accepts a `Vec<String>`, splits into sub-batches up to `max_batch_size`,
4//! sends each via `AiTransport` with retry, and reassembles results in the
5//! original order. Empty texts are skipped; their positions get an empty
6//! `Vec<f32>` in the output.
7//!
8//! Issue #277 adds optional dedup cache and text chunking.
9
10use std::future::Future;
11use std::sync::Arc;
12use std::time::Instant;
13
14use crate::ai::AiProvider;
15use crate::json::{Map, Value as JsonValue};
16use crate::runtime::ai::dedup_cache::{
17    EmbeddingDedupCache, DEFAULT_DEDUP_LRU_SIZE, DEFAULT_DEDUP_TTL_MS,
18};
19use crate::runtime::ai::text_chunker::{ChunkMode, DEFAULT_MAX_TOKENS};
20use crate::runtime::ai::transport::{AiHttpRequest, AiTransport, AiTransportError};
21use crate::runtime::audit_log::AuditLogger;
22
23pub const CONFIG_MAX_BATCH_SIZE: &str = "runtime.ai.embedding_max_batch_size";
24pub const DEFAULT_OPENAI_MAX_BATCH: usize = 2048;
25pub const DEFAULT_OTHER_MAX_BATCH: usize = 256;
26
27/// One sub-batch worth of work.
28pub struct SubBatchRequest {
29    pub provider: String,
30    pub api_key: String,
31    pub api_base: String,
32    pub model: String,
33    pub inputs: Vec<String>,
34}
35
36pub struct SubBatchResponse {
37    pub embeddings: Vec<Vec<f32>>,
38    pub model: String,
39    pub prompt_tokens: Option<u64>,
40    pub total_tokens: Option<u64>,
41    pub attempt_count: u32,
42    pub total_wait_ms: u64,
43}
44
45/// Backend abstraction. Production uses `AiTransportSender`; tests use mocks.
46pub trait SubBatchSender: Send + Sync {
47    fn send(
48        &self,
49        request: SubBatchRequest,
50    ) -> impl Future<Output = Result<SubBatchResponse, AiTransportError>> + Send + '_;
51}
52
53/// Production backend: routes sub-batches through `AiTransport`.
54pub struct AiTransportSender {
55    pub transport: AiTransport,
56}
57
58impl SubBatchSender for AiTransportSender {
59    #[allow(clippy::manual_async_fn)]
60    fn send(
61        &self,
62        request: SubBatchRequest,
63    ) -> impl Future<Output = Result<SubBatchResponse, AiTransportError>> + Send + '_ {
64        async move {
65            let payload = crate::ai::build_embedding_payload(&request.model, &request.inputs);
66            let url = format!("{}/embeddings", request.api_base.trim_end_matches('/'));
67            let http_req = AiHttpRequest::post_json(request.provider.as_str(), url, payload)
68                .model(request.model.clone())
69                .header("authorization", format!("Bearer {}", request.api_key));
70
71            let response = self.transport.request(http_req).await?;
72
73            let parsed = crate::ai::parse_embedding_response(&response.body).map_err(|msg| {
74                AiTransportError {
75                    provider: request.provider.clone(),
76                    status_code: None,
77                    attempt_count: 1,
78                    total_wait_ms: 0,
79                    message: msg,
80                }
81            })?;
82
83            Ok(SubBatchResponse {
84                embeddings: parsed.embeddings,
85                model: parsed.model,
86                prompt_tokens: parsed.prompt_tokens,
87                total_tokens: parsed.total_tokens,
88                attempt_count: response.attempt_count,
89                total_wait_ms: response.total_wait_ms,
90            })
91        }
92    }
93}
94
95/// Batch embedding client.
96///
97/// Generic over the backend so tests can inject mocks without HTTP.
98/// The default type parameter is `AiTransportSender` (production).
99pub struct AiBatchClient<S = AiTransportSender> {
100    sender: S,
101    max_batch_size_override: Option<usize>,
102    /// Optional dedup cache. None = dedup disabled (default).
103    dedup_cache: Option<Arc<EmbeddingDedupCache>>,
104    /// Chunk mode applied before sending. Default = Single.
105    chunk_mode: ChunkMode,
106    /// Max tokens per chunk (approximate: 1 token ≈ 4 bytes).
107    max_tokens: usize,
108    audit_log: Option<Arc<AuditLogger>>,
109}
110
111impl AiBatchClient<AiTransportSender> {
112    pub fn new(transport: AiTransport) -> Self {
113        Self {
114            sender: AiTransportSender { transport },
115            max_batch_size_override: None,
116            dedup_cache: None,
117            chunk_mode: ChunkMode::Single,
118            max_tokens: DEFAULT_MAX_TOKENS,
119            audit_log: None,
120        }
121    }
122
123    pub fn from_runtime(runtime: &crate::runtime::RedDBRuntime) -> Self {
124        use crate::runtime::ai::dedup_cache::{
125            CONFIG_DEDUP_ENABLED, CONFIG_DEDUP_LRU_SIZE, CONFIG_DEDUP_TTL_MS,
126        };
127        use crate::runtime::ai::text_chunker::{CONFIG_CHUNK_MODE, CONFIG_MAX_TOKENS};
128        use std::time::Duration;
129
130        let transport = AiTransport::from_runtime(runtime);
131        let dedup_enabled = runtime.config_bool(CONFIG_DEDUP_ENABLED, false);
132        let dedup_cache = if dedup_enabled {
133            let lru_size =
134                runtime.config_u64(CONFIG_DEDUP_LRU_SIZE, DEFAULT_DEDUP_LRU_SIZE as u64) as usize;
135            let ttl_ms = runtime.config_u64(CONFIG_DEDUP_TTL_MS, DEFAULT_DEDUP_TTL_MS);
136            Some(Arc::new(EmbeddingDedupCache::new(
137                lru_size,
138                Duration::from_millis(ttl_ms),
139            )))
140        } else {
141            None
142        };
143        let chunk_mode = ChunkMode::from_str(&runtime.config_string(CONFIG_CHUNK_MODE, "single"));
144        let max_tokens = runtime.config_u64(CONFIG_MAX_TOKENS, DEFAULT_MAX_TOKENS as u64) as usize;
145
146        Self {
147            sender: AiTransportSender { transport },
148            max_batch_size_override: None,
149            dedup_cache,
150            chunk_mode,
151            max_tokens,
152            audit_log: Some(runtime.audit_log_arc()),
153        }
154    }
155}
156
157impl<S: SubBatchSender> AiBatchClient<S> {
158    /// Create with a custom backend (useful in tests).
159    pub fn with_sender(sender: S) -> Self {
160        Self {
161            sender,
162            max_batch_size_override: None,
163            dedup_cache: None,
164            chunk_mode: ChunkMode::Single,
165            max_tokens: DEFAULT_MAX_TOKENS,
166            audit_log: None,
167        }
168    }
169
170    /// Override the max sub-batch size (defaults per provider if not set).
171    pub fn with_max_batch_size(mut self, size: usize) -> Self {
172        self.max_batch_size_override = Some(size.max(1));
173        self
174    }
175
176    /// Enable dedup cache.
177    pub fn with_dedup_cache(mut self, cache: Arc<EmbeddingDedupCache>) -> Self {
178        self.dedup_cache = Some(cache);
179        self
180    }
181
182    /// Set chunk mode (Single or Multi).
183    pub fn with_chunk_mode(mut self, mode: ChunkMode) -> Self {
184        self.chunk_mode = mode;
185        self
186    }
187
188    /// Set max tokens per chunk.
189    pub fn with_max_tokens(mut self, max: usize) -> Self {
190        self.max_tokens = max.max(1);
191        self
192    }
193
194    pub fn with_audit_log(mut self, audit_log: Arc<AuditLogger>) -> Self {
195        self.audit_log = Some(audit_log);
196        self
197    }
198
199    /// Embed `texts` in batch. Returns one `Vec<f32>` per input in order.
200    /// Empty/whitespace-only inputs yield an empty `Vec<f32>` at their position
201    /// without consuming a provider request slot.
202    ///
203    /// When dedup is enabled, previously-seen texts are served from cache and
204    /// only unseen texts are sent to the provider. Duplicate texts within a
205    /// single call are also deduplicated — the provider receives each unique
206    /// text only once.
207    ///
208    /// When chunking is enabled, texts exceeding `max_tokens` are chunked;
209    /// in Single mode the first chunk is sent to the provider.
210    pub async fn embed_batch(
211        &self,
212        provider: &AiProvider,
213        model: &str,
214        api_key: &str,
215        texts: Vec<String>,
216    ) -> Result<Vec<Vec<f32>>, AiTransportError> {
217        if texts.is_empty() {
218            return Ok(vec![]);
219        }
220
221        let max_batch = self
222            .max_batch_size_override
223            .unwrap_or_else(|| default_max_batch_size(provider));
224        let api_base = provider.resolve_api_base();
225        let started = Instant::now();
226        let mut local_dedup_hits = 0u64;
227        let mut any_chunked = false;
228        let mut retries_total = 0u64;
229        let mut total_wait_ms = 0u64;
230        let mut prompt_tokens_total = 0u64;
231        let mut total_tokens_total = 0u64;
232
233        // Step 1: apply chunking — each text → representative text to embed.
234        // In Single mode this is the first (or only) chunk.
235        let mut chunked_texts: Vec<String> = Vec::with_capacity(texts.len());
236        for t in &texts {
237            let chunks = crate::runtime::ai::text_chunker::chunk(t, self.max_tokens);
238            if chunks.len() > 1 {
239                any_chunked = true;
240            }
241            let chosen = crate::runtime::ai::text_chunker::apply_mode(chunks, self.chunk_mode);
242            chunked_texts.push(chosen.into_iter().next().unwrap_or_default());
243        }
244
245        // Step 2: check dedup cache and collect unique provider misses.
246        // result[i] = Some(embedding) when resolved, None when pending.
247        let mut result: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
248
249        // unique_texts_to_embed: insertion-ordered unique texts that need a
250        // provider call. text → index in unique_texts_to_embed.
251        let mut unique_text_index: std::collections::HashMap<String, usize> =
252            std::collections::HashMap::new();
253        let mut unique_texts_to_embed: Vec<String> = Vec::new();
254
255        // For each input, map position → unique_texts index (for fan-out later).
256        let mut pos_to_unique: Vec<Option<usize>> = vec![None; texts.len()];
257
258        for (i, text) in chunked_texts.iter().enumerate() {
259            if text.trim().is_empty() {
260                result[i] = Some(vec![]);
261                continue;
262            }
263            // Cache lookup (covers both warm cache and intra-batch duplicates
264            // that were already cached in a prior iteration of this loop after
265            // the provider returned).
266            if let Some(cache) = &self.dedup_cache {
267                if let Some(cached) = cache.get(text) {
268                    local_dedup_hits = local_dedup_hits.saturating_add(1);
269                    result[i] = Some(cached);
270                    continue;
271                }
272            }
273            // Dedup within this batch: if text already queued, reuse its slot.
274            let unique_idx = if let Some(&existing) = unique_text_index.get(text.as_str()) {
275                existing
276            } else {
277                let idx = unique_texts_to_embed.len();
278                unique_text_index.insert(text.clone(), idx);
279                unique_texts_to_embed.push(text.clone());
280                idx
281            };
282            pos_to_unique[i] = Some(unique_idx);
283        }
284
285        // Step 3: send unique_texts_to_embed in sub-batches.
286        let mut unique_embeddings: Vec<Vec<f32>> = vec![vec![]; unique_texts_to_embed.len()];
287
288        for chunk in unique_texts_to_embed.chunks(max_batch) {
289            crate::runtime::ai::metrics::record_batch_size(provider.token(), chunk.len());
290            // Determine the start index of this chunk within unique_texts_to_embed.
291            // (We need this to write results back into unique_embeddings.)
292            let chunk_start = {
293                // chunk is a subslice of unique_texts_to_embed; compute offset.
294                let base = unique_texts_to_embed.as_ptr();
295                let ptr = chunk.as_ptr();
296                (ptr as usize - base as usize) / std::mem::size_of::<String>()
297            };
298
299            let request = SubBatchRequest {
300                provider: provider.token().to_string(),
301                api_key: api_key.to_string(),
302                api_base: api_base.clone(),
303                model: model.to_string(),
304                inputs: chunk.to_vec(),
305            };
306
307            let response = match self.sender.send(request).await {
308                Ok(response) => response,
309                Err(err) => {
310                    self.record_error_audit(provider.token(), &err);
311                    return Err(err);
312                }
313            };
314            retries_total =
315                retries_total.saturating_add(u64::from(response.attempt_count.saturating_sub(1)));
316            total_wait_ms = total_wait_ms.saturating_add(response.total_wait_ms);
317            if let Some(tokens) = response.prompt_tokens {
318                prompt_tokens_total = prompt_tokens_total.saturating_add(tokens);
319            }
320            if let Some(tokens) = response.total_tokens {
321                total_tokens_total = total_tokens_total.saturating_add(tokens);
322            }
323            let token_metric = response
324                .prompt_tokens
325                .unwrap_or(0)
326                .saturating_add(response.total_tokens.unwrap_or(0));
327            crate::runtime::ai::metrics::record_tokens(
328                provider.token(),
329                &response.model,
330                token_metric,
331            );
332            let embeddings = response.embeddings;
333
334            if embeddings.len() != chunk.len() {
335                let err = AiTransportError {
336                    provider: provider.token().to_string(),
337                    status_code: None,
338                    attempt_count: 0,
339                    total_wait_ms: 0,
340                    message: format!(
341                        "provider returned {} embeddings for {} inputs",
342                        embeddings.len(),
343                        chunk.len()
344                    ),
345                };
346                self.record_error_audit(provider.token(), &err);
347                return Err(err);
348            }
349
350            for (j, embedding) in embeddings.into_iter().enumerate() {
351                let unique_idx = chunk_start + j;
352                // Insert into dedup cache
353                if let Some(cache) = &self.dedup_cache {
354                    cache.insert(&unique_texts_to_embed[unique_idx], embedding.clone());
355                }
356                unique_embeddings[unique_idx] = embedding;
357            }
358        }
359
360        // Step 4: fan-out unique_embeddings back to result positions.
361        for (i, unique_idx_opt) in pos_to_unique.into_iter().enumerate() {
362            if let Some(unique_idx) = unique_idx_opt {
363                result[i] = Some(unique_embeddings[unique_idx].clone());
364            }
365        }
366
367        self.record_batch_audit(BatchAudit {
368            provider: provider.token(),
369            model,
370            batch_size: texts.len(),
371            total_tokens: total_tokens_total,
372            duration_ms: millis_u64(started.elapsed()),
373            retries: retries_total,
374            dedup_hits: local_dedup_hits,
375            chunked: any_chunked,
376            total_wait_ms,
377            prompt_tokens: prompt_tokens_total,
378        });
379
380        Ok(result.into_iter().map(|v| v.unwrap_or_default()).collect())
381    }
382
383    fn record_batch_audit(&self, audit: BatchAudit<'_>) {
384        tracing::info!(
385            target: "reddb::developer",
386            provider = audit.provider,
387            model = audit.model,
388            batch_size = audit.batch_size,
389            total_tokens = audit.total_tokens,
390            duration_ms = audit.duration_ms,
391            retries = audit.retries,
392            dedup_hits = audit.dedup_hits,
393            chunked = audit.chunked,
394            "ai embedding batch completed"
395        );
396
397        let Some(audit_log) = &self.audit_log else {
398            return;
399        };
400        let mut details = Map::new();
401        details.insert(
402            "provider".to_string(),
403            JsonValue::String(audit.provider.to_string()),
404        );
405        details.insert(
406            "model".to_string(),
407            JsonValue::String(audit.model.to_string()),
408        );
409        details.insert(
410            "batch_size".to_string(),
411            JsonValue::Number(audit.batch_size as f64),
412        );
413        details.insert(
414            "total_tokens".to_string(),
415            JsonValue::Number(audit.total_tokens as f64),
416        );
417        details.insert(
418            "duration_ms".to_string(),
419            JsonValue::Number(audit.duration_ms as f64),
420        );
421        details.insert(
422            "retries".to_string(),
423            JsonValue::Number(audit.retries as f64),
424        );
425        details.insert(
426            "dedup_hits".to_string(),
427            JsonValue::Number(audit.dedup_hits as f64),
428        );
429        details.insert("chunked".to_string(), JsonValue::Bool(audit.chunked));
430        details.insert(
431            "total_wait_ms".to_string(),
432            JsonValue::Number(audit.total_wait_ms as f64),
433        );
434        details.insert(
435            "prompt_tokens".to_string(),
436            JsonValue::Number(audit.prompt_tokens as f64),
437        );
438        audit_log.record(
439            "ai/embedding_batch",
440            "system",
441            audit.provider,
442            "ok",
443            JsonValue::Object(details),
444        );
445    }
446
447    fn record_error_audit(&self, provider: &str, err: &AiTransportError) {
448        tracing::warn!(
449            target: "reddb::developer",
450            provider = provider,
451            status_code = err.status_code.unwrap_or(0),
452            attempt_count = err.attempt_count,
453            total_wait_ms = err.total_wait_ms,
454            "ai embedding provider error"
455        );
456
457        let Some(audit_log) = &self.audit_log else {
458            return;
459        };
460        let mut details = Map::new();
461        details.insert(
462            "provider".to_string(),
463            JsonValue::String(provider.to_string()),
464        );
465        details.insert(
466            "status_code".to_string(),
467            err.status_code
468                .map(|status| JsonValue::Number(status as f64))
469                .unwrap_or(JsonValue::Null),
470        );
471        details.insert(
472            "attempt_count".to_string(),
473            JsonValue::Number(err.attempt_count as f64),
474        );
475        details.insert(
476            "total_wait_ms".to_string(),
477            JsonValue::Number(err.total_wait_ms as f64),
478        );
479        audit_log.record(
480            "ai/embedding_error",
481            "system",
482            provider,
483            "error",
484            JsonValue::Object(details),
485        );
486    }
487}
488
489struct BatchAudit<'a> {
490    provider: &'a str,
491    model: &'a str,
492    batch_size: usize,
493    total_tokens: u64,
494    duration_ms: u64,
495    retries: u64,
496    dedup_hits: u64,
497    chunked: bool,
498    total_wait_ms: u64,
499    prompt_tokens: u64,
500}
501
502fn millis_u64(duration: std::time::Duration) -> u64 {
503    duration.as_millis().min(u128::from(u64::MAX)) as u64
504}
505
506fn default_max_batch_size(provider: &AiProvider) -> usize {
507    match provider {
508        AiProvider::OpenAi
509        | AiProvider::OpenRouter
510        | AiProvider::Together
511        | AiProvider::Venice
512        | AiProvider::Groq
513        | AiProvider::DeepSeek
514        | AiProvider::Custom(_) => DEFAULT_OPENAI_MAX_BATCH,
515        _ => DEFAULT_OTHER_MAX_BATCH,
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522    use std::sync::atomic::{AtomicUsize, Ordering};
523    use std::sync::Arc;
524    use std::time::Duration;
525
526    struct MockSender {
527        call_count: Arc<AtomicUsize>,
528        dims: usize,
529    }
530
531    impl SubBatchSender for MockSender {
532        fn send(
533            &self,
534            request: SubBatchRequest,
535        ) -> impl Future<Output = Result<SubBatchResponse, AiTransportError>> + Send + '_ {
536            let n = request.inputs.len();
537            let dims = self.dims;
538            self.call_count.fetch_add(1, Ordering::SeqCst);
539            async move {
540                Ok(SubBatchResponse {
541                    embeddings: (0..n).map(|_| vec![0.1f32; dims]).collect(),
542                    model: request.model,
543                    prompt_tokens: Some(n as u64),
544                    total_tokens: Some(n as u64),
545                    attempt_count: 1,
546                    total_wait_ms: 0,
547                })
548            }
549        }
550    }
551
552    fn mock_client(dims: usize) -> (AiBatchClient<MockSender>, Arc<AtomicUsize>) {
553        let counter = Arc::new(AtomicUsize::new(0));
554        let client = AiBatchClient::with_sender(MockSender {
555            call_count: Arc::clone(&counter),
556            dims,
557        });
558        (client, counter)
559    }
560
561    #[tokio::test]
562    async fn embed_three_texts_returns_three_vectors() {
563        let (client, _) = mock_client(3);
564        let result = client
565            .embed_batch(
566                &AiProvider::OpenAi,
567                "model",
568                "key",
569                vec!["a".into(), "b".into(), "c".into()],
570            )
571            .await
572            .unwrap();
573        assert_eq!(result.len(), 3);
574        assert!(result.iter().all(|v| v.len() == 3));
575    }
576
577    #[tokio::test]
578    async fn embed_empty_input_zero_requests() {
579        let (client, counter) = mock_client(3);
580        let result = client
581            .embed_batch(&AiProvider::OpenAi, "model", "key", vec![])
582            .await
583            .unwrap();
584        assert!(result.is_empty());
585        assert_eq!(counter.load(Ordering::SeqCst), 0);
586    }
587
588    #[tokio::test]
589    async fn embed_1000_inputs_single_request_openai() {
590        let (client, counter) = mock_client(4);
591        let texts: Vec<String> = (0..1000).map(|i| format!("text {i}")).collect();
592        let result = client
593            .embed_batch(&AiProvider::OpenAi, "model", "key", texts)
594            .await
595            .unwrap();
596        assert_eq!(result.len(), 1000);
597        // 1000 < DEFAULT_OPENAI_MAX_BATCH (2048) → exactly 1 request
598        assert_eq!(counter.load(Ordering::SeqCst), 1);
599    }
600
601    #[tokio::test]
602    async fn embed_splits_when_over_max_batch() {
603        let (client, counter) = mock_client(2);
604        let client = client.with_max_batch_size(3);
605        let texts: Vec<String> = (0..7).map(|i| format!("t{i}")).collect();
606        let result = client
607            .embed_batch(&AiProvider::OpenAi, "model", "key", texts)
608            .await
609            .unwrap();
610        assert_eq!(result.len(), 7);
611        // ceil(7/3) = 3 batches
612        assert_eq!(counter.load(Ordering::SeqCst), 3);
613    }
614
615    #[tokio::test]
616    async fn embed_records_batch_size_and_token_metrics() {
617        let (client, _) = mock_client(2);
618        let provider = AiProvider::Custom("test_batch_metrics_provider".to_string());
619        let _ = client
620            .with_max_batch_size(2)
621            .embed_batch(
622                &provider,
623                "test-batch-metrics-model",
624                "key",
625                vec!["a".into(), "b".into(), "c".into()],
626            )
627            .await
628            .unwrap();
629
630        let mut body = String::new();
631        crate::runtime::ai::metrics::render_ai_metrics(&mut body);
632        assert!(
633            body.contains(
634                "reddb_ai_embedding_batch_size_count{provider=\"test_batch_metrics_provider\"} 2"
635            ),
636            "{body}"
637        );
638        assert!(
639            body.contains(
640                "reddb_ai_text_tokens_total{provider=\"test_batch_metrics_provider\",model=\"test-batch-metrics-model\"} 6"
641            ),
642            "{body}"
643        );
644    }
645
646    #[tokio::test]
647    async fn embed_empty_strings_skipped_positions_preserved() {
648        let (client, counter) = mock_client(2);
649        let texts = vec![
650            "".to_string(),
651            "hello".to_string(),
652            "  ".to_string(),
653            "world".to_string(),
654        ];
655        let result = client
656            .embed_batch(&AiProvider::OpenAi, "model", "key", texts)
657            .await
658            .unwrap();
659        assert_eq!(result.len(), 4);
660        assert!(result[0].is_empty(), "empty string → empty vec");
661        assert_eq!(result[1].len(), 2, "hello → embedding");
662        assert!(result[2].is_empty(), "whitespace-only → empty vec");
663        assert_eq!(result[3].len(), 2, "world → embedding");
664        // Only 2 non-empty texts → 1 request
665        assert_eq!(counter.load(Ordering::SeqCst), 1);
666    }
667
668    #[tokio::test]
669    async fn embed_error_propagated() {
670        struct ErrorSender;
671
672        impl SubBatchSender for ErrorSender {
673            fn send(
674                &self,
675                request: SubBatchRequest,
676            ) -> impl Future<Output = Result<SubBatchResponse, AiTransportError>> + Send + '_
677            {
678                async move {
679                    Err(AiTransportError {
680                        provider: request.provider,
681                        status_code: Some(500),
682                        attempt_count: 3,
683                        total_wait_ms: 2000,
684                        message: "server error".to_string(),
685                    })
686                }
687            }
688        }
689
690        let client = AiBatchClient::with_sender(ErrorSender);
691        let err = client
692            .embed_batch(
693                &AiProvider::OpenAi,
694                "model",
695                "key",
696                vec!["text".to_string()],
697            )
698            .await
699            .unwrap_err();
700        assert_eq!(err.status_code, Some(500));
701        assert_eq!(err.attempt_count, 3);
702    }
703
704    #[tokio::test]
705    async fn embed_writes_structured_audit_line_when_logger_attached() {
706        let (client, _) = mock_client(2);
707        let dir = tempfile::tempdir().unwrap();
708        let audit_path = dir.path().join(".audit.log");
709        let audit_log = Arc::new(AuditLogger::with_max_bytes(audit_path, 1024 * 1024));
710        let provider = AiProvider::Custom("test_audit_provider".to_string());
711
712        let _ = client
713            .with_audit_log(Arc::clone(&audit_log))
714            .embed_batch(
715                &provider,
716                "test-audit-model",
717                "key",
718                vec!["alpha".into(), "beta".into()],
719            )
720            .await
721            .unwrap();
722
723        assert!(audit_log.wait_idle(Duration::from_secs(2)));
724        let body = std::fs::read_to_string(audit_log.path()).unwrap();
725        assert!(body.contains("\"action\":\"ai/embedding_batch\""), "{body}");
726        assert!(
727            body.contains("\"provider\":\"test_audit_provider\""),
728            "{body}"
729        );
730        assert!(body.contains("\"model\":\"test-audit-model\""), "{body}");
731        assert!(body.contains("\"batch_size\":2"), "{body}");
732        assert!(body.contains("\"total_tokens\":2"), "{body}");
733        assert!(body.contains("\"duration_ms\""), "{body}");
734        assert!(body.contains("\"retries\":0"), "{body}");
735        assert!(body.contains("\"dedup_hits\":0"), "{body}");
736        assert!(body.contains("\"chunked\":false"), "{body}");
737    }
738
739    #[tokio::test]
740    async fn embed_order_preserved_across_batches() {
741        struct BatchNumberSender {
742            call_count: Arc<AtomicUsize>,
743        }
744
745        impl SubBatchSender for BatchNumberSender {
746            fn send(
747                &self,
748                request: SubBatchRequest,
749            ) -> impl Future<Output = Result<SubBatchResponse, AiTransportError>> + Send + '_
750            {
751                let call = self.call_count.fetch_add(1, Ordering::SeqCst);
752                let n = request.inputs.len();
753                async move {
754                    // encode batch number as first float for order verification
755                    Ok(SubBatchResponse {
756                        embeddings: (0..n).map(|_| vec![call as f32]).collect(),
757                        model: request.model,
758                        prompt_tokens: Some(n as u64),
759                        total_tokens: Some(n as u64),
760                        attempt_count: 1,
761                        total_wait_ms: 0,
762                    })
763                }
764            }
765        }
766
767        let counter = Arc::new(AtomicUsize::new(0));
768        let client = AiBatchClient::with_sender(BatchNumberSender {
769            call_count: Arc::clone(&counter),
770        })
771        .with_max_batch_size(3);
772
773        // 5 texts → 2 batches (3 + 2)
774        let texts: Vec<String> = (0..5).map(|i| format!("t{i}")).collect();
775        let result = client
776            .embed_batch(&AiProvider::OpenAi, "model", "key", texts)
777            .await
778            .unwrap();
779
780        assert_eq!(result.len(), 5);
781        assert_eq!(counter.load(Ordering::SeqCst), 2);
782        // First 3 from batch 0
783        assert_eq!(result[0], vec![0.0]);
784        assert_eq!(result[1], vec![0.0]);
785        assert_eq!(result[2], vec![0.0]);
786        // Last 2 from batch 1
787        assert_eq!(result[3], vec![1.0]);
788        assert_eq!(result[4], vec![1.0]);
789    }
790
791    #[tokio::test]
792    async fn default_max_batch_size_openai_is_2048() {
793        assert_eq!(default_max_batch_size(&AiProvider::OpenAi), 2048);
794    }
795
796    #[tokio::test]
797    async fn default_max_batch_size_ollama_is_256() {
798        assert_eq!(default_max_batch_size(&AiProvider::Ollama), 256);
799    }
800
801    // ── Issue #277: dedup cache tests ──────────────────────────────────────
802
803    #[tokio::test]
804    async fn dedup_on_1000_inputs_10_unique_sends_10_to_provider() {
805        let (base_client, counter) = mock_client(4);
806        let cache = Arc::new(EmbeddingDedupCache::new(1024, Duration::from_secs(60)));
807        let client = base_client.with_dedup_cache(Arc::clone(&cache));
808
809        let unique: Vec<String> = (0..10).map(|i| format!("unique text {i}")).collect();
810        let texts: Vec<String> = (0..1000).map(|i| unique[i % 10].clone()).collect();
811
812        let result = client
813            .embed_batch(&AiProvider::OpenAi, "model", "key", texts.clone())
814            .await
815            .unwrap();
816
817        assert_eq!(result.len(), 1000);
818        // Intra-batch dedup: provider receives 10 unique texts in 1 sub-batch.
819        assert_eq!(counter.load(Ordering::SeqCst), 1, "1 sub-batch request");
820        // First call: all 1000 texts check the cache → 1000 misses (cache empty).
821        // Intra-batch duplicates are deduplicated via HashMap, but still count
822        // as cache misses since each text is checked individually.
823        assert_eq!(cache.misses(), 1000);
824        assert_eq!(cache.hits(), 0);
825
826        // Second call with same texts: cache now has 10 entries → all 1000
827        // input texts hit cache (10 unique + 990 duplicates each hit).
828        let result2 = client
829            .embed_batch(&AiProvider::OpenAi, "model", "key", texts)
830            .await
831            .unwrap();
832        assert_eq!(result2.len(), 1000);
833        // No new provider call — all served from cache
834        assert_eq!(
835            counter.load(Ordering::SeqCst),
836            1,
837            "still 1 provider request total"
838        );
839        assert_eq!(cache.hits(), 1000, "all 1000 hit cache on second call");
840    }
841
842    #[tokio::test]
843    async fn dedup_off_by_default_all_texts_sent() {
844        let (client, counter) = mock_client(4);
845        // no dedup cache attached
846        let texts: Vec<String> = (0..10).map(|i| format!("text {i}")).collect();
847        let result = client
848            .embed_batch(&AiProvider::OpenAi, "model", "key", texts.clone())
849            .await
850            .unwrap();
851        assert_eq!(result.len(), 10);
852        // Second call with same texts — still sends all (no cache)
853        let _ = client
854            .embed_batch(&AiProvider::OpenAi, "model", "key", texts)
855            .await
856            .unwrap();
857        // 2 calls, each 1 sub-batch
858        assert_eq!(counter.load(Ordering::SeqCst), 2);
859    }
860
861    #[tokio::test]
862    async fn chunker_long_text_truncated_to_first_chunk_single_mode() {
863        // 1 token ≈ 4 bytes; max_tokens=10 → max 40 bytes
864        let (base_client, counter) = mock_client(2);
865        let client = base_client.with_max_tokens(10); // 40 byte chunks
866
867        let long_text = "a".repeat(200); // 200 bytes >> 40 byte limit
868        let result = client
869            .embed_batch(&AiProvider::OpenAi, "model", "key", vec![long_text])
870            .await
871            .unwrap();
872
873        assert_eq!(result.len(), 1);
874        assert_eq!(counter.load(Ordering::SeqCst), 1);
875        // provider received 1 item (first chunk only in Single mode)
876    }
877}