Skip to main content

ironclad_llm/
lib.rs

1//! # ironclad-llm
2//!
3//! LLM client pipeline for the Ironclad agent runtime. Requests flow through a
4//! multi-stage pipeline: cache check, routing (heuristic or ML), circuit
5//! breaker, dedup, format translation, prompt compression, tier adaptation,
6//! and HTTP forwarding.
7//!
8//! ## Key Types
9//!
10//! - [`LlmService`] -- Top-level facade composing all pipeline stages
11//! - [`SemanticCache`] -- 3-level cache (exact hash, tool TTL, semantic cosine)
12//! - [`ModelRouter`] -- Runtime model selection and override control
13//! - [`LlmClient`] -- HTTP/2 client pool with streaming support
14//! - [`EmbeddingClient`] -- Multi-provider embedding client with n-gram fallback
15//! - [`SseChunkStream`] -- SSE byte stream to parsed `StreamChunk` adapter
16//!
17//! ## Modules
18//!
19//! - `cache` -- Semantic cache with HashMap + SQLite persistence
20//! - `router` -- Heuristic model router (feature extraction, complexity scoring)
21//! - `ml_router` -- Logistic regression backend + preference learning
22//! - `tiered` -- Tiered inference with confidence evaluation and escalation
23//! - `cascade` -- Cascade optimizer (cheapest-first, fallback chain)
24//! - `circuit` -- Per-provider circuit breaker with exponential backoff
25//! - `dedup` -- In-flight duplicate request detection
26//! - `format` -- API format translation (OpenAI, Ollama, Google, Anthropic)
27//! - `compression` -- Prompt compression and token estimation
28//! - `tier` -- Tier-based prompt adaptation (T1 strip, T2 preamble, T3/T4 pass)
29//! - `client` -- HTTP client pool, request forwarding, cost tracking
30//! - `provider` -- Provider definitions and registry
31//! - `embedding` -- Multi-provider embedding client
32//! - `capacity` -- TPM/RPM sliding-window capacity tracking
33//! - `accuracy` -- Per-model quality tracking
34//! - `oauth` -- OAuth2 token management and refresh
35//! - `transform` -- Request/response transform pipeline
36
37pub mod accuracy;
38pub mod cache;
39pub mod capacity;
40pub mod cascade;
41pub mod circuit;
42pub mod client;
43pub mod compression;
44pub mod dedup;
45pub mod embedding;
46/// Offline routing evaluation harness for replaying historical decisions.
47pub mod eval_harness;
48pub mod format;
49pub mod ml_router;
50pub mod oauth;
51pub mod profile;
52pub mod provider;
53pub mod router;
54pub mod tier;
55pub mod tiered;
56
57pub use accuracy::QualityTracker;
58pub use cache::{CachedResponse, ExportedCacheEntry, SemanticCache};
59pub use capacity::CapacityTracker;
60pub use cascade::{CascadeOptimizer, CascadeOutcome, CascadeStrategy};
61pub use circuit::{CircuitBreakerRegistry, CircuitState};
62pub use client::LlmClient;
63pub use compression::{CompressionEstimate, PromptCompressor};
64pub use dedup::DedupTracker;
65pub use embedding::{EmbeddingClient, EmbeddingConfig};
66pub use ml_router::{LogisticBackend, PreferenceCollector, PreferenceRecord};
67pub use oauth::OAuthManager;
68pub use profile::{MetascoreBreakdown, ModelProfile, build_model_profiles, select_by_metascore};
69pub use provider::{Provider, ProviderRegistry};
70pub use router::{ModelRouter, classify_complexity, extract_features};
71pub use tiered::{ConfidenceEvaluator, EscalationTracker, InferenceTier};
72
73pub use format::StreamChunk;
74
75use std::collections::HashMap;
76use std::pin::Pin;
77use std::task::{Context, Poll};
78
79use bytes::Bytes;
80use futures::Stream;
81use std::sync::Arc;
82
83use ironclad_core::{ApiFormat, IroncladConfig, PaymentHandler, Result};
84use router::HeuristicBackend;
85
86pub struct LlmService {
87    pub cache: SemanticCache,
88    pub breakers: CircuitBreakerRegistry,
89    pub dedup: DedupTracker,
90    pub router: ModelRouter,
91    pub client: LlmClient,
92    pub providers: ProviderRegistry,
93    pub capacity: CapacityTracker,
94    pub quality: QualityTracker,
95    pub confidence: ConfidenceEvaluator,
96    pub escalation: EscalationTracker,
97    pub embedding: EmbeddingClient,
98}
99
100impl LlmService {
101    pub fn new(config: &IroncladConfig) -> Result<Self> {
102        let cache = SemanticCache::with_threshold(
103            config.cache.enabled,
104            config.cache.exact_match_ttl_seconds,
105            config.cache.max_entries,
106            config.cache.semantic_threshold as f32,
107        );
108
109        let breakers = CircuitBreakerRegistry::new(&config.circuit_breaker);
110
111        let dedup = DedupTracker::default();
112
113        let routing_config = config.models.routing.clone();
114
115        let router = ModelRouter::new(
116            config.models.primary.clone(),
117            config.models.fallbacks.clone(),
118            routing_config,
119            Box::new(HeuristicBackend),
120        );
121
122        let client = LlmClient::new()?;
123
124        let providers = ProviderRegistry::from_config(&config.providers);
125
126        let capacity = CapacityTracker::new(60);
127        for provider in providers.list() {
128            capacity.register(&provider.name, provider.tpm_limit, provider.rpm_limit);
129        }
130
131        let quality = QualityTracker::new(100);
132        let confidence = ConfidenceEvaluator::new(config.models.tiered_inference.confidence_floor);
133        let escalation = EscalationTracker::default();
134
135        let embedding_config = Self::resolve_embedding_config(&config.memory, &providers);
136        let embedding = EmbeddingClient::new(embedding_config)?;
137
138        Ok(Self {
139            cache,
140            breakers,
141            dedup,
142            router,
143            client,
144            providers,
145            capacity,
146            quality,
147            confidence,
148            escalation,
149            embedding,
150        })
151    }
152
153    /// Inject an x402 payment handler so the LLM client can autonomously pay
154    /// for 402-gated resources. Call this after construction when the wallet
155    /// is available.
156    pub fn set_payment_handler(&mut self, handler: Arc<dyn PaymentHandler>) {
157        self.client = self.client.clone().with_payment_handler(handler);
158    }
159
160    /// Stream a request to the given provider, returning parsed `StreamChunk`s.
161    ///
162    /// The caller is responsible for provider selection and key resolution.
163    /// `body` should already be translated via `format::translate_request`.
164    /// This method injects `"stream": true` into the body before sending.
165    pub async fn stream_to_provider(
166        &self,
167        url: String,
168        api_key: String,
169        mut body: serde_json::Value,
170        auth_header: String,
171        extra_headers: HashMap<String, String>,
172        api_format: ApiFormat,
173    ) -> Result<SseChunkStream> {
174        body["stream"] = serde_json::json!(true);
175
176        let raw_stream = self
177            .client
178            .forward_stream(&url, &api_key, body, &auth_header, &extra_headers)
179            .await?;
180
181        Ok(SseChunkStream::new(raw_stream, api_format))
182    }
183
184    fn resolve_embedding_config(
185        memory: &ironclad_core::config::MemoryConfig,
186        providers: &ProviderRegistry,
187    ) -> Option<EmbeddingConfig> {
188        let provider_name = memory.embedding_provider.as_deref()?;
189        let provider = providers.get(provider_name)?;
190        let embedding_path = provider.embedding_path.as_deref()?;
191
192        let model = memory
193            .embedding_model
194            .clone()
195            .or_else(|| provider.embedding_model.clone())?;
196
197        let dimensions = provider.embedding_dimensions.unwrap_or(768);
198
199        Some(EmbeddingConfig {
200            base_url: provider.url.clone(),
201            embedding_path: embedding_path.to_string(),
202            model,
203            dimensions,
204            format: provider.format,
205            api_key_env: provider.api_key_env.clone(),
206            auth_header: provider.auth_header.clone(),
207            extra_headers: provider.extra_headers.clone(),
208            is_local: provider.is_local,
209        })
210    }
211}
212
213/// Maximum SSE buffer size (10 MB). Streams exceeding this are terminated to
214/// prevent unbounded memory growth from a misbehaving provider.
215const MAX_SSE_BUFFER: usize = 10 * 1024 * 1024;
216
217/// A `Stream` adapter that converts raw SSE byte chunks from an LLM provider
218/// into parsed `StreamChunk` items. Handles buffering across chunk boundaries
219/// with proper incremental UTF-8 decoding.
220pub struct SseChunkStream {
221    inner: Pin<Box<dyn Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send>>,
222    format: ApiFormat,
223    /// Validated UTF-8 text ready for line parsing.
224    text_buffer: String,
225    /// Raw byte buffer holding trailing bytes from an incomplete UTF-8 sequence.
226    /// These bytes are prepended to the next incoming chunk before decoding.
227    raw_tail: Vec<u8>,
228    /// Chunks parsed from the buffer remainder when the inner stream ends.
229    /// Drained before returning `None` to avoid dropping trailing data.
230    pending: std::collections::VecDeque<format::StreamChunk>,
231    inner_done: bool,
232}
233
234impl SseChunkStream {
235    pub fn new(
236        inner: Pin<Box<dyn Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send>>,
237        format: ApiFormat,
238    ) -> Self {
239        Self {
240            inner,
241            format,
242            text_buffer: String::new(),
243            raw_tail: Vec::new(),
244            pending: std::collections::VecDeque::new(),
245            inner_done: false,
246        }
247    }
248}
249
250impl Stream for SseChunkStream {
251    type Item = Result<format::StreamChunk>;
252
253    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
254        let this = self.get_mut();
255
256        // Drain any chunks buffered from the final flush before signaling end-of-stream
257        if let Some(chunk) = this.pending.pop_front() {
258            return Poll::Ready(Some(Ok(chunk)));
259        }
260        if this.inner_done {
261            return Poll::Ready(None);
262        }
263
264        loop {
265            // First, try to parse a complete line from the text buffer
266            if let Some(newline_pos) = this.text_buffer.find('\n') {
267                let line = this.text_buffer[..newline_pos].trim().to_string();
268                this.text_buffer = this.text_buffer[newline_pos + 1..].to_string();
269
270                if line.is_empty() {
271                    continue;
272                }
273
274                if let Some(chunk) = format::parse_sse_chunk(&line, &this.format) {
275                    return Poll::Ready(Some(Ok(chunk)));
276                }
277                continue;
278            }
279
280            // No complete line in buffer -- poll for more bytes
281            match Pin::new(&mut this.inner).poll_next(cx) {
282                Poll::Ready(Some(Ok(bytes))) => {
283                    // Prepend any leftover incomplete UTF-8 bytes from the previous chunk
284                    let combined = if this.raw_tail.is_empty() {
285                        bytes.to_vec()
286                    } else {
287                        let mut buf = std::mem::take(&mut this.raw_tail);
288                        buf.extend_from_slice(&bytes);
289                        buf
290                    };
291
292                    // Decode as much valid UTF-8 as possible, keeping any
293                    // incomplete trailing sequence for the next chunk.
294                    match std::str::from_utf8(&combined) {
295                        Ok(valid) => {
296                            this.text_buffer.push_str(valid);
297                        }
298                        Err(e) => {
299                            let valid_up_to = e.valid_up_to();
300                            // valid_up_to is a confirmed UTF-8 boundary from Utf8Error.
301                            let valid = std::str::from_utf8(&combined[..valid_up_to])
302                                .expect("valid_up_to guarantees valid UTF-8");
303                            this.text_buffer.push_str(valid);
304                            this.raw_tail = combined[valid_up_to..].to_vec();
305                        }
306                    }
307
308                    // Guard against unbounded buffer growth
309                    if this.text_buffer.len() + this.raw_tail.len() > MAX_SSE_BUFFER {
310                        return Poll::Ready(Some(Err(ironclad_core::IroncladError::Llm(
311                            "SSE stream buffer exceeded 10 MB limit".into(),
312                        ))));
313                    }
314                }
315                Poll::Ready(Some(Err(e))) => {
316                    return Poll::Ready(Some(Err(ironclad_core::IroncladError::Network(format!(
317                        "stream error: {e}"
318                    )))));
319                }
320                Poll::Ready(None) => {
321                    this.inner_done = true;
322
323                    // Convert any remaining raw tail bytes lossily (stream ended
324                    // mid-character, so these are genuinely malformed).
325                    if !this.raw_tail.is_empty() {
326                        let tail = std::mem::take(&mut this.raw_tail);
327                        this.text_buffer.push_str(&String::from_utf8_lossy(&tail));
328                    }
329
330                    // Parse ALL remaining lines and queue them for delivery
331                    if !this.text_buffer.trim().is_empty() {
332                        let remaining = std::mem::take(&mut this.text_buffer);
333                        for line in remaining.lines() {
334                            let line = line.trim();
335                            if line.is_empty() {
336                                continue;
337                            }
338                            if let Some(chunk) = format::parse_sse_chunk(line, &this.format) {
339                                this.pending.push_back(chunk);
340                            }
341                        }
342                    }
343                    return match this.pending.pop_front() {
344                        Some(chunk) => Poll::Ready(Some(Ok(chunk))),
345                        None => Poll::Ready(None),
346                    };
347                }
348                Poll::Pending => return Poll::Pending,
349            }
350        }
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn llm_service_construction() {
360        let toml = r#"
361[agent]
362name = "TestBot"
363id = "test"
364
365[server]
366port = 9999
367
368[database]
369path = "/tmp/test.db"
370
371[models]
372primary = "ollama/qwen3:8b"
373fallbacks = ["openai/gpt-4o"]
374
375[providers.ollama]
376url = "http://localhost:11434"
377tier = "T1"
378
379[providers.openai]
380url = "https://api.openai.com"
381tier = "T3"
382"#;
383        let config = IroncladConfig::from_str(toml).unwrap();
384        let service = LlmService::new(&config).unwrap();
385
386        assert_eq!(service.router.select_model(), "ollama/qwen3:8b");
387        assert_eq!(service.cache.size(), 0);
388        assert!(service.providers.get("ollama").is_some());
389        assert!(service.providers.get("openai").is_some());
390        assert!(!service.embedding.has_provider());
391    }
392
393    #[test]
394    fn llm_service_with_embedding_provider() {
395        let toml = r#"
396[agent]
397name = "TestBot"
398id = "test"
399
400[server]
401port = 9999
402
403[database]
404path = "/tmp/test.db"
405
406[models]
407primary = "ollama/qwen3:8b"
408
409[memory]
410embedding_provider = "ollama"
411
412[providers.ollama]
413url = "http://localhost:11434"
414tier = "T1"
415embedding_path = "/api/embed"
416embedding_model = "nomic-embed-text"
417embedding_dimensions = 768
418"#;
419        let config = IroncladConfig::from_str(toml).unwrap();
420        let service = LlmService::new(&config).unwrap();
421        assert!(service.embedding.has_provider());
422        assert_eq!(service.embedding.dimensions(), 768);
423    }
424
425    #[test]
426    fn resolve_embedding_config_no_provider() {
427        let memory = ironclad_core::config::MemoryConfig::default();
428        let providers = ProviderRegistry::new();
429        let result = LlmService::resolve_embedding_config(&memory, &providers);
430        assert!(result.is_none());
431    }
432
433    #[test]
434    fn resolve_embedding_config_missing_provider() {
435        let memory = ironclad_core::config::MemoryConfig {
436            embedding_provider: Some("nonexistent".into()),
437            ..Default::default()
438        };
439        let providers = ProviderRegistry::new();
440        let result = LlmService::resolve_embedding_config(&memory, &providers);
441        assert!(result.is_none());
442    }
443
444    #[test]
445    fn resolve_embedding_config_provider_no_embedding_path() {
446        let memory = ironclad_core::config::MemoryConfig {
447            embedding_provider: Some("anthropic".into()),
448            ..Default::default()
449        };
450        let mut providers_cfg = std::collections::HashMap::new();
451        providers_cfg.insert(
452            "anthropic".to_string(),
453            ironclad_core::config::ProviderConfig::new("https://api.anthropic.com", "T3"),
454        );
455        let providers = ProviderRegistry::from_config(&providers_cfg);
456        let result = LlmService::resolve_embedding_config(&memory, &providers);
457        assert!(result.is_none());
458    }
459
460    #[test]
461    fn resolve_embedding_config_uses_memory_model_override() {
462        let memory = ironclad_core::config::MemoryConfig {
463            embedding_provider: Some("openai".into()),
464            embedding_model: Some("text-embedding-3-large".into()),
465            ..Default::default()
466        };
467        let mut cfg = ironclad_core::config::ProviderConfig::new("https://api.openai.com", "T3");
468        cfg.embedding_path = Some("/v1/embeddings".into());
469        cfg.embedding_model = Some("text-embedding-3-small".into());
470        cfg.embedding_dimensions = Some(1536);
471        let mut providers_cfg = std::collections::HashMap::new();
472        providers_cfg.insert("openai".to_string(), cfg);
473        let providers = ProviderRegistry::from_config(&providers_cfg);
474
475        let result = LlmService::resolve_embedding_config(&memory, &providers).unwrap();
476        assert_eq!(result.model, "text-embedding-3-large");
477        assert_eq!(result.dimensions, 1536);
478    }
479
480    #[test]
481    fn resolve_embedding_config_falls_back_to_provider_model() {
482        let memory = ironclad_core::config::MemoryConfig {
483            embedding_provider: Some("ollama".into()),
484            embedding_model: None,
485            ..Default::default()
486        };
487        let mut cfg = ironclad_core::config::ProviderConfig::new("http://localhost:11434", "T1");
488        cfg.embedding_path = Some("/api/embed".into());
489        cfg.embedding_model = Some("nomic-embed-text".into());
490        cfg.embedding_dimensions = Some(768);
491        let mut providers_cfg = std::collections::HashMap::new();
492        providers_cfg.insert("ollama".to_string(), cfg);
493        let providers = ProviderRegistry::from_config(&providers_cfg);
494
495        let result = LlmService::resolve_embedding_config(&memory, &providers).unwrap();
496        assert_eq!(result.model, "nomic-embed-text");
497        assert_eq!(result.dimensions, 768);
498        assert_eq!(result.base_url, "http://localhost:11434");
499        assert_eq!(result.embedding_path, "/api/embed");
500    }
501
502    // ── SseChunkStream tests ──────────────────────────────────
503
504    use futures::stream;
505
506    /// Helper: drive an `SseChunkStream` to completion and collect all chunks.
507    fn collect_sse_chunks(data: Vec<Vec<u8>>) -> Vec<format::StreamChunk> {
508        let byte_stream = stream::iter(
509            data.into_iter()
510                .map(|b| Ok::<_, reqwest::Error>(Bytes::from(b))),
511        );
512        let mut sse = SseChunkStream::new(Box::pin(byte_stream), ApiFormat::OpenAiCompletions);
513
514        let rt = tokio::runtime::Builder::new_current_thread()
515            .build()
516            .unwrap();
517        rt.block_on(async {
518            let mut chunks = vec![];
519            while let Some(item) = futures::StreamExt::next(&mut sse).await {
520                chunks.push(item.unwrap());
521            }
522            chunks
523        })
524    }
525
526    #[test]
527    fn sse_chunk_stream_multiple_trailing_chunks() {
528        let data = vec![
529            b"data: {\"choices\":[{\"delta\":{\"content\":\"A\"}}]}\ndata: {\"choices\":[{\"delta\":{\"content\":\"B\"}}]}\n".to_vec(),
530            b"data: {\"choices\":[{\"delta\":{\"content\":\"C\"}}]}\ndata: {\"choices\":[{\"delta\":{\"content\":\"D\"}}]}".to_vec(),
531        ];
532        let chunks = collect_sse_chunks(data);
533        let text: String = chunks.iter().map(|c| c.delta.as_str()).collect();
534        assert_eq!(text, "ABCD", "all four chunks should be yielded");
535    }
536
537    #[test]
538    fn sse_chunk_stream_trailing_done_not_lost() {
539        let data = vec![
540            b"data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\ndata: [DONE]".to_vec(),
541        ];
542        let chunks = collect_sse_chunks(data);
543        assert_eq!(chunks.len(), 1);
544        assert_eq!(chunks[0].delta, "hello");
545    }
546
547    #[test]
548    fn sse_chunk_stream_empty_buffer_at_end() {
549        let data = vec![b"data: {\"choices\":[{\"delta\":{\"content\":\"only\"}}]}\n".to_vec()];
550        let chunks = collect_sse_chunks(data);
551        assert_eq!(chunks.len(), 1);
552        assert_eq!(chunks[0].delta, "only");
553    }
554
555    #[test]
556    fn resolve_embedding_config_default_dimensions() {
557        let memory = ironclad_core::config::MemoryConfig {
558            embedding_provider: Some("custom".into()),
559            embedding_model: Some("my-model".into()),
560            ..Default::default()
561        };
562        let mut cfg = ironclad_core::config::ProviderConfig::new("http://localhost:8080", "T1");
563        cfg.embedding_path = Some("/embed".into());
564        cfg.embedding_model = Some("my-model".into());
565        // No dimensions set — should default to 768
566        let mut providers_cfg = std::collections::HashMap::new();
567        providers_cfg.insert("custom".to_string(), cfg);
568        let providers = ProviderRegistry::from_config(&providers_cfg);
569
570        let result = LlmService::resolve_embedding_config(&memory, &providers).unwrap();
571        assert_eq!(result.dimensions, 768);
572    }
573
574    // ── SseChunkStream additional edge cases ──────────────────────
575
576    #[test]
577    fn sse_chunk_stream_empty_input() {
578        let chunks = collect_sse_chunks(vec![]);
579        assert!(chunks.is_empty());
580    }
581
582    #[test]
583    fn sse_chunk_stream_empty_bytes() {
584        let chunks = collect_sse_chunks(vec![b"".to_vec()]);
585        assert!(chunks.is_empty());
586    }
587
588    #[test]
589    fn sse_chunk_stream_only_whitespace_lines() {
590        let data = vec![b"\n\n\n".to_vec()];
591        let chunks = collect_sse_chunks(data);
592        assert!(chunks.is_empty());
593    }
594
595    #[test]
596    fn sse_chunk_stream_non_data_lines_skipped() {
597        let data = vec![
598            b"event: message\nid: 123\ndata: {\"choices\":[{\"delta\":{\"content\":\"ok\"}}]}\n"
599                .to_vec(),
600        ];
601        let chunks = collect_sse_chunks(data);
602        assert_eq!(chunks.len(), 1);
603        assert_eq!(chunks[0].delta, "ok");
604    }
605
606    #[test]
607    fn sse_chunk_stream_split_across_boundaries() {
608        // Split a single SSE line across two byte chunks
609        let data = vec![
610            b"data: {\"choices\":[{\"del".to_vec(),
611            b"ta\":{\"content\":\"split\"}}]}\n".to_vec(),
612        ];
613        let chunks = collect_sse_chunks(data);
614        assert_eq!(chunks.len(), 1);
615        assert_eq!(chunks[0].delta, "split");
616    }
617
618    #[test]
619    fn sse_chunk_stream_split_utf8_boundary() {
620        // Multi-byte UTF-8 char split across chunk boundary
621        // "Hello\xC3" in chunk 1, "\xA9world" in chunk 2 (copyright sign = 0xC3 0xA9)
622        let data = vec![
623            b"data: {\"choices\":[{\"delta\":{\"content\":\"Hello\xC3".to_vec(),
624            b"\xA9world\"}}]}\n".to_vec(),
625        ];
626        let chunks = collect_sse_chunks(data);
627        assert_eq!(chunks.len(), 1);
628        // The content should contain the copyright symbol
629        assert!(chunks[0].delta.contains("Hello"));
630        assert!(chunks[0].delta.contains("world"));
631    }
632
633    #[test]
634    fn sse_chunk_stream_multiple_lines_in_one_chunk() {
635        let data = vec![
636            b"data: {\"choices\":[{\"delta\":{\"content\":\"A\"}}]}\ndata: {\"choices\":[{\"delta\":{\"content\":\"B\"}}]}\ndata: {\"choices\":[{\"delta\":{\"content\":\"C\"}}]}\n".to_vec(),
637        ];
638        let chunks = collect_sse_chunks(data);
639        assert_eq!(chunks.len(), 3);
640        assert_eq!(chunks[0].delta, "A");
641        assert_eq!(chunks[1].delta, "B");
642        assert_eq!(chunks[2].delta, "C");
643    }
644
645    /// Helper: drive an SseChunkStream and collect all items (including errors).
646    fn collect_sse_results(data: Vec<Vec<u8>>) -> Vec<Result<format::StreamChunk>> {
647        let byte_stream = stream::iter(
648            data.into_iter()
649                .map(|b| Ok::<_, reqwest::Error>(Bytes::from(b))),
650        );
651        let mut sse = SseChunkStream::new(Box::pin(byte_stream), ApiFormat::OpenAiCompletions);
652        let rt = tokio::runtime::Builder::new_current_thread()
653            .build()
654            .unwrap();
655        rt.block_on(async {
656            let mut items = vec![];
657            while let Some(item) = futures::StreamExt::next(&mut sse).await {
658                items.push(item);
659            }
660            items
661        })
662    }
663
664    #[test]
665    fn sse_chunk_stream_buffer_overflow_error() {
666        // Create a chunk large enough to exceed the 10 MB limit
667        let huge = vec![b'x'; 11 * 1024 * 1024];
668        let results = collect_sse_results(vec![huge]);
669        let last = results.last().unwrap();
670        assert!(last.is_err());
671        let err_msg = format!("{}", last.as_ref().unwrap_err());
672        assert!(
673            err_msg.contains("10 MB"),
674            "error should mention buffer limit: {err_msg}"
675        );
676    }
677
678    #[test]
679    fn sse_chunk_stream_anthropic_format() {
680        // Test with Anthropic format
681        let data = vec![
682            b"data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}\n".to_vec(),
683        ];
684        let byte_stream = stream::iter(
685            data.into_iter()
686                .map(|b| Ok::<_, reqwest::Error>(Bytes::from(b))),
687        );
688        let mut sse = SseChunkStream::new(Box::pin(byte_stream), ApiFormat::AnthropicMessages);
689        let rt = tokio::runtime::Builder::new_current_thread()
690            .build()
691            .unwrap();
692        let chunks: Vec<_> = rt.block_on(async {
693            let mut chunks = vec![];
694            while let Some(item) = futures::StreamExt::next(&mut sse).await {
695                chunks.push(item.unwrap());
696            }
697            chunks
698        });
699        assert_eq!(chunks.len(), 1);
700        assert_eq!(chunks[0].delta, "Hi");
701    }
702
703    #[test]
704    fn sse_chunk_stream_google_format() {
705        let data = vec![
706            b"data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Gemini\"}],\"role\":\"model\"}}]}\n".to_vec(),
707        ];
708        let byte_stream = stream::iter(
709            data.into_iter()
710                .map(|b| Ok::<_, reqwest::Error>(Bytes::from(b))),
711        );
712        let mut sse = SseChunkStream::new(Box::pin(byte_stream), ApiFormat::GoogleGenerativeAi);
713        let rt = tokio::runtime::Builder::new_current_thread()
714            .build()
715            .unwrap();
716        let chunks: Vec<_> = rt.block_on(async {
717            let mut chunks = vec![];
718            while let Some(item) = futures::StreamExt::next(&mut sse).await {
719                chunks.push(item.unwrap());
720            }
721            chunks
722        });
723        assert_eq!(chunks.len(), 1);
724        assert_eq!(chunks[0].delta, "Gemini");
725    }
726
727    #[test]
728    fn sse_chunk_stream_trailing_data_no_newline() {
729        // Data that doesn't end with a newline should still be parsed on stream end
730        let data = vec![b"data: {\"choices\":[{\"delta\":{\"content\":\"tail\"}}]}".to_vec()];
731        let chunks = collect_sse_chunks(data);
732        assert_eq!(chunks.len(), 1);
733        assert_eq!(chunks[0].delta, "tail");
734    }
735
736    #[test]
737    fn sse_chunk_stream_pending_queue_drains_correctly() {
738        // Multiple trailing lines with no final newline
739        let data = vec![
740            b"data: {\"choices\":[{\"delta\":{\"content\":\"X\"}}]}\ndata: {\"choices\":[{\"delta\":{\"content\":\"Y\"}}]}".to_vec(),
741        ];
742        let chunks = collect_sse_chunks(data);
743        let text: String = chunks.iter().map(|c| c.delta.as_str()).collect();
744        assert_eq!(text, "XY");
745    }
746}