1pub mod accuracy;
38pub mod cache;
39pub mod capacity;
40pub mod cascade;
41pub mod circuit;
42pub mod client;
43pub mod compression;
44pub mod dedup;
45pub mod embedding;
46pub mod eval_harness;
48pub mod format;
49pub mod ml_router;
50pub mod oauth;
51pub mod profile;
52pub mod provider;
53pub mod router;
54pub mod tier;
55pub mod tiered;
56
57pub use accuracy::QualityTracker;
58pub use cache::{CachedResponse, ExportedCacheEntry, SemanticCache};
59pub use capacity::CapacityTracker;
60pub use cascade::{CascadeOptimizer, CascadeOutcome, CascadeStrategy};
61pub use circuit::{CircuitBreakerRegistry, CircuitState};
62pub use client::LlmClient;
63pub use compression::{CompressionEstimate, PromptCompressor};
64pub use dedup::DedupTracker;
65pub use embedding::{EmbeddingClient, EmbeddingConfig};
66pub use ml_router::{LogisticBackend, PreferenceCollector, PreferenceRecord};
67pub use oauth::OAuthManager;
68pub use profile::{MetascoreBreakdown, ModelProfile, build_model_profiles, select_by_metascore};
69pub use provider::{Provider, ProviderRegistry};
70pub use router::{ModelRouter, classify_complexity, extract_features};
71pub use tiered::{ConfidenceEvaluator, EscalationTracker, InferenceTier};
72
73pub use format::StreamChunk;
74
75use std::collections::HashMap;
76use std::pin::Pin;
77use std::task::{Context, Poll};
78
79use bytes::Bytes;
80use futures::Stream;
81use std::sync::Arc;
82
83use ironclad_core::{ApiFormat, IroncladConfig, PaymentHandler, Result};
84use router::HeuristicBackend;
85
86pub struct LlmService {
87 pub cache: SemanticCache,
88 pub breakers: CircuitBreakerRegistry,
89 pub dedup: DedupTracker,
90 pub router: ModelRouter,
91 pub client: LlmClient,
92 pub providers: ProviderRegistry,
93 pub capacity: CapacityTracker,
94 pub quality: QualityTracker,
95 pub confidence: ConfidenceEvaluator,
96 pub escalation: EscalationTracker,
97 pub embedding: EmbeddingClient,
98}
99
100impl LlmService {
101 pub fn new(config: &IroncladConfig) -> Result<Self> {
102 let cache = SemanticCache::with_threshold(
103 config.cache.enabled,
104 config.cache.exact_match_ttl_seconds,
105 config.cache.max_entries,
106 config.cache.semantic_threshold as f32,
107 );
108
109 let breakers = CircuitBreakerRegistry::new(&config.circuit_breaker);
110
111 let dedup = DedupTracker::default();
112
113 let routing_config = config.models.routing.clone();
114
115 let router = ModelRouter::new(
116 config.models.primary.clone(),
117 config.models.fallbacks.clone(),
118 routing_config,
119 Box::new(HeuristicBackend),
120 );
121
122 let client = LlmClient::new()?;
123
124 let providers = ProviderRegistry::from_config(&config.providers);
125
126 let capacity = CapacityTracker::new(60);
127 for provider in providers.list() {
128 capacity.register(&provider.name, provider.tpm_limit, provider.rpm_limit);
129 }
130
131 let quality = QualityTracker::new(100);
132 let confidence = ConfidenceEvaluator::new(config.models.tiered_inference.confidence_floor);
133 let escalation = EscalationTracker::default();
134
135 let embedding_config = Self::resolve_embedding_config(&config.memory, &providers);
136 let embedding = EmbeddingClient::new(embedding_config)?;
137
138 Ok(Self {
139 cache,
140 breakers,
141 dedup,
142 router,
143 client,
144 providers,
145 capacity,
146 quality,
147 confidence,
148 escalation,
149 embedding,
150 })
151 }
152
153 pub fn set_payment_handler(&mut self, handler: Arc<dyn PaymentHandler>) {
157 self.client = self.client.clone().with_payment_handler(handler);
158 }
159
160 pub async fn stream_to_provider(
166 &self,
167 url: String,
168 api_key: String,
169 mut body: serde_json::Value,
170 auth_header: String,
171 extra_headers: HashMap<String, String>,
172 api_format: ApiFormat,
173 ) -> Result<SseChunkStream> {
174 body["stream"] = serde_json::json!(true);
175
176 let raw_stream = self
177 .client
178 .forward_stream(&url, &api_key, body, &auth_header, &extra_headers)
179 .await?;
180
181 Ok(SseChunkStream::new(raw_stream, api_format))
182 }
183
184 fn resolve_embedding_config(
185 memory: &ironclad_core::config::MemoryConfig,
186 providers: &ProviderRegistry,
187 ) -> Option<EmbeddingConfig> {
188 let provider_name = memory.embedding_provider.as_deref()?;
189 let provider = providers.get(provider_name)?;
190 let embedding_path = provider.embedding_path.as_deref()?;
191
192 let model = memory
193 .embedding_model
194 .clone()
195 .or_else(|| provider.embedding_model.clone())?;
196
197 let dimensions = provider.embedding_dimensions.unwrap_or(768);
198
199 Some(EmbeddingConfig {
200 base_url: provider.url.clone(),
201 embedding_path: embedding_path.to_string(),
202 model,
203 dimensions,
204 format: provider.format,
205 api_key_env: provider.api_key_env.clone(),
206 auth_header: provider.auth_header.clone(),
207 extra_headers: provider.extra_headers.clone(),
208 is_local: provider.is_local,
209 })
210 }
211}
212
213const MAX_SSE_BUFFER: usize = 10 * 1024 * 1024;
216
217pub struct SseChunkStream {
221 inner: Pin<Box<dyn Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send>>,
222 format: ApiFormat,
223 text_buffer: String,
225 raw_tail: Vec<u8>,
228 pending: std::collections::VecDeque<format::StreamChunk>,
231 inner_done: bool,
232}
233
234impl SseChunkStream {
235 pub fn new(
236 inner: Pin<Box<dyn Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send>>,
237 format: ApiFormat,
238 ) -> Self {
239 Self {
240 inner,
241 format,
242 text_buffer: String::new(),
243 raw_tail: Vec::new(),
244 pending: std::collections::VecDeque::new(),
245 inner_done: false,
246 }
247 }
248}
249
250impl Stream for SseChunkStream {
251 type Item = Result<format::StreamChunk>;
252
253 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
254 let this = self.get_mut();
255
256 if let Some(chunk) = this.pending.pop_front() {
258 return Poll::Ready(Some(Ok(chunk)));
259 }
260 if this.inner_done {
261 return Poll::Ready(None);
262 }
263
264 loop {
265 if let Some(newline_pos) = this.text_buffer.find('\n') {
267 let line = this.text_buffer[..newline_pos].trim().to_string();
268 this.text_buffer = this.text_buffer[newline_pos + 1..].to_string();
269
270 if line.is_empty() {
271 continue;
272 }
273
274 if let Some(chunk) = format::parse_sse_chunk(&line, &this.format) {
275 return Poll::Ready(Some(Ok(chunk)));
276 }
277 continue;
278 }
279
280 match Pin::new(&mut this.inner).poll_next(cx) {
282 Poll::Ready(Some(Ok(bytes))) => {
283 let combined = if this.raw_tail.is_empty() {
285 bytes.to_vec()
286 } else {
287 let mut buf = std::mem::take(&mut this.raw_tail);
288 buf.extend_from_slice(&bytes);
289 buf
290 };
291
292 match std::str::from_utf8(&combined) {
295 Ok(valid) => {
296 this.text_buffer.push_str(valid);
297 }
298 Err(e) => {
299 let valid_up_to = e.valid_up_to();
300 let valid = std::str::from_utf8(&combined[..valid_up_to])
302 .expect("valid_up_to guarantees valid UTF-8");
303 this.text_buffer.push_str(valid);
304 this.raw_tail = combined[valid_up_to..].to_vec();
305 }
306 }
307
308 if this.text_buffer.len() + this.raw_tail.len() > MAX_SSE_BUFFER {
310 return Poll::Ready(Some(Err(ironclad_core::IroncladError::Llm(
311 "SSE stream buffer exceeded 10 MB limit".into(),
312 ))));
313 }
314 }
315 Poll::Ready(Some(Err(e))) => {
316 return Poll::Ready(Some(Err(ironclad_core::IroncladError::Network(format!(
317 "stream error: {e}"
318 )))));
319 }
320 Poll::Ready(None) => {
321 this.inner_done = true;
322
323 if !this.raw_tail.is_empty() {
326 let tail = std::mem::take(&mut this.raw_tail);
327 this.text_buffer.push_str(&String::from_utf8_lossy(&tail));
328 }
329
330 if !this.text_buffer.trim().is_empty() {
332 let remaining = std::mem::take(&mut this.text_buffer);
333 for line in remaining.lines() {
334 let line = line.trim();
335 if line.is_empty() {
336 continue;
337 }
338 if let Some(chunk) = format::parse_sse_chunk(line, &this.format) {
339 this.pending.push_back(chunk);
340 }
341 }
342 }
343 return match this.pending.pop_front() {
344 Some(chunk) => Poll::Ready(Some(Ok(chunk))),
345 None => Poll::Ready(None),
346 };
347 }
348 Poll::Pending => return Poll::Pending,
349 }
350 }
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn llm_service_construction() {
360 let toml = r#"
361[agent]
362name = "TestBot"
363id = "test"
364
365[server]
366port = 9999
367
368[database]
369path = "/tmp/test.db"
370
371[models]
372primary = "ollama/qwen3:8b"
373fallbacks = ["openai/gpt-4o"]
374
375[providers.ollama]
376url = "http://localhost:11434"
377tier = "T1"
378
379[providers.openai]
380url = "https://api.openai.com"
381tier = "T3"
382"#;
383 let config = IroncladConfig::from_str(toml).unwrap();
384 let service = LlmService::new(&config).unwrap();
385
386 assert_eq!(service.router.select_model(), "ollama/qwen3:8b");
387 assert_eq!(service.cache.size(), 0);
388 assert!(service.providers.get("ollama").is_some());
389 assert!(service.providers.get("openai").is_some());
390 assert!(!service.embedding.has_provider());
391 }
392
393 #[test]
394 fn llm_service_with_embedding_provider() {
395 let toml = r#"
396[agent]
397name = "TestBot"
398id = "test"
399
400[server]
401port = 9999
402
403[database]
404path = "/tmp/test.db"
405
406[models]
407primary = "ollama/qwen3:8b"
408
409[memory]
410embedding_provider = "ollama"
411
412[providers.ollama]
413url = "http://localhost:11434"
414tier = "T1"
415embedding_path = "/api/embed"
416embedding_model = "nomic-embed-text"
417embedding_dimensions = 768
418"#;
419 let config = IroncladConfig::from_str(toml).unwrap();
420 let service = LlmService::new(&config).unwrap();
421 assert!(service.embedding.has_provider());
422 assert_eq!(service.embedding.dimensions(), 768);
423 }
424
425 #[test]
426 fn resolve_embedding_config_no_provider() {
427 let memory = ironclad_core::config::MemoryConfig::default();
428 let providers = ProviderRegistry::new();
429 let result = LlmService::resolve_embedding_config(&memory, &providers);
430 assert!(result.is_none());
431 }
432
433 #[test]
434 fn resolve_embedding_config_missing_provider() {
435 let memory = ironclad_core::config::MemoryConfig {
436 embedding_provider: Some("nonexistent".into()),
437 ..Default::default()
438 };
439 let providers = ProviderRegistry::new();
440 let result = LlmService::resolve_embedding_config(&memory, &providers);
441 assert!(result.is_none());
442 }
443
444 #[test]
445 fn resolve_embedding_config_provider_no_embedding_path() {
446 let memory = ironclad_core::config::MemoryConfig {
447 embedding_provider: Some("anthropic".into()),
448 ..Default::default()
449 };
450 let mut providers_cfg = std::collections::HashMap::new();
451 providers_cfg.insert(
452 "anthropic".to_string(),
453 ironclad_core::config::ProviderConfig::new("https://api.anthropic.com", "T3"),
454 );
455 let providers = ProviderRegistry::from_config(&providers_cfg);
456 let result = LlmService::resolve_embedding_config(&memory, &providers);
457 assert!(result.is_none());
458 }
459
460 #[test]
461 fn resolve_embedding_config_uses_memory_model_override() {
462 let memory = ironclad_core::config::MemoryConfig {
463 embedding_provider: Some("openai".into()),
464 embedding_model: Some("text-embedding-3-large".into()),
465 ..Default::default()
466 };
467 let mut cfg = ironclad_core::config::ProviderConfig::new("https://api.openai.com", "T3");
468 cfg.embedding_path = Some("/v1/embeddings".into());
469 cfg.embedding_model = Some("text-embedding-3-small".into());
470 cfg.embedding_dimensions = Some(1536);
471 let mut providers_cfg = std::collections::HashMap::new();
472 providers_cfg.insert("openai".to_string(), cfg);
473 let providers = ProviderRegistry::from_config(&providers_cfg);
474
475 let result = LlmService::resolve_embedding_config(&memory, &providers).unwrap();
476 assert_eq!(result.model, "text-embedding-3-large");
477 assert_eq!(result.dimensions, 1536);
478 }
479
480 #[test]
481 fn resolve_embedding_config_falls_back_to_provider_model() {
482 let memory = ironclad_core::config::MemoryConfig {
483 embedding_provider: Some("ollama".into()),
484 embedding_model: None,
485 ..Default::default()
486 };
487 let mut cfg = ironclad_core::config::ProviderConfig::new("http://localhost:11434", "T1");
488 cfg.embedding_path = Some("/api/embed".into());
489 cfg.embedding_model = Some("nomic-embed-text".into());
490 cfg.embedding_dimensions = Some(768);
491 let mut providers_cfg = std::collections::HashMap::new();
492 providers_cfg.insert("ollama".to_string(), cfg);
493 let providers = ProviderRegistry::from_config(&providers_cfg);
494
495 let result = LlmService::resolve_embedding_config(&memory, &providers).unwrap();
496 assert_eq!(result.model, "nomic-embed-text");
497 assert_eq!(result.dimensions, 768);
498 assert_eq!(result.base_url, "http://localhost:11434");
499 assert_eq!(result.embedding_path, "/api/embed");
500 }
501
502 use futures::stream;
505
506 fn collect_sse_chunks(data: Vec<Vec<u8>>) -> Vec<format::StreamChunk> {
508 let byte_stream = stream::iter(
509 data.into_iter()
510 .map(|b| Ok::<_, reqwest::Error>(Bytes::from(b))),
511 );
512 let mut sse = SseChunkStream::new(Box::pin(byte_stream), ApiFormat::OpenAiCompletions);
513
514 let rt = tokio::runtime::Builder::new_current_thread()
515 .build()
516 .unwrap();
517 rt.block_on(async {
518 let mut chunks = vec![];
519 while let Some(item) = futures::StreamExt::next(&mut sse).await {
520 chunks.push(item.unwrap());
521 }
522 chunks
523 })
524 }
525
526 #[test]
527 fn sse_chunk_stream_multiple_trailing_chunks() {
528 let data = vec![
529 b"data: {\"choices\":[{\"delta\":{\"content\":\"A\"}}]}\ndata: {\"choices\":[{\"delta\":{\"content\":\"B\"}}]}\n".to_vec(),
530 b"data: {\"choices\":[{\"delta\":{\"content\":\"C\"}}]}\ndata: {\"choices\":[{\"delta\":{\"content\":\"D\"}}]}".to_vec(),
531 ];
532 let chunks = collect_sse_chunks(data);
533 let text: String = chunks.iter().map(|c| c.delta.as_str()).collect();
534 assert_eq!(text, "ABCD", "all four chunks should be yielded");
535 }
536
537 #[test]
538 fn sse_chunk_stream_trailing_done_not_lost() {
539 let data = vec![
540 b"data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\ndata: [DONE]".to_vec(),
541 ];
542 let chunks = collect_sse_chunks(data);
543 assert_eq!(chunks.len(), 1);
544 assert_eq!(chunks[0].delta, "hello");
545 }
546
547 #[test]
548 fn sse_chunk_stream_empty_buffer_at_end() {
549 let data = vec![b"data: {\"choices\":[{\"delta\":{\"content\":\"only\"}}]}\n".to_vec()];
550 let chunks = collect_sse_chunks(data);
551 assert_eq!(chunks.len(), 1);
552 assert_eq!(chunks[0].delta, "only");
553 }
554
555 #[test]
556 fn resolve_embedding_config_default_dimensions() {
557 let memory = ironclad_core::config::MemoryConfig {
558 embedding_provider: Some("custom".into()),
559 embedding_model: Some("my-model".into()),
560 ..Default::default()
561 };
562 let mut cfg = ironclad_core::config::ProviderConfig::new("http://localhost:8080", "T1");
563 cfg.embedding_path = Some("/embed".into());
564 cfg.embedding_model = Some("my-model".into());
565 let mut providers_cfg = std::collections::HashMap::new();
567 providers_cfg.insert("custom".to_string(), cfg);
568 let providers = ProviderRegistry::from_config(&providers_cfg);
569
570 let result = LlmService::resolve_embedding_config(&memory, &providers).unwrap();
571 assert_eq!(result.dimensions, 768);
572 }
573
574 #[test]
577 fn sse_chunk_stream_empty_input() {
578 let chunks = collect_sse_chunks(vec![]);
579 assert!(chunks.is_empty());
580 }
581
582 #[test]
583 fn sse_chunk_stream_empty_bytes() {
584 let chunks = collect_sse_chunks(vec![b"".to_vec()]);
585 assert!(chunks.is_empty());
586 }
587
588 #[test]
589 fn sse_chunk_stream_only_whitespace_lines() {
590 let data = vec![b"\n\n\n".to_vec()];
591 let chunks = collect_sse_chunks(data);
592 assert!(chunks.is_empty());
593 }
594
595 #[test]
596 fn sse_chunk_stream_non_data_lines_skipped() {
597 let data = vec![
598 b"event: message\nid: 123\ndata: {\"choices\":[{\"delta\":{\"content\":\"ok\"}}]}\n"
599 .to_vec(),
600 ];
601 let chunks = collect_sse_chunks(data);
602 assert_eq!(chunks.len(), 1);
603 assert_eq!(chunks[0].delta, "ok");
604 }
605
606 #[test]
607 fn sse_chunk_stream_split_across_boundaries() {
608 let data = vec![
610 b"data: {\"choices\":[{\"del".to_vec(),
611 b"ta\":{\"content\":\"split\"}}]}\n".to_vec(),
612 ];
613 let chunks = collect_sse_chunks(data);
614 assert_eq!(chunks.len(), 1);
615 assert_eq!(chunks[0].delta, "split");
616 }
617
618 #[test]
619 fn sse_chunk_stream_split_utf8_boundary() {
620 let data = vec![
623 b"data: {\"choices\":[{\"delta\":{\"content\":\"Hello\xC3".to_vec(),
624 b"\xA9world\"}}]}\n".to_vec(),
625 ];
626 let chunks = collect_sse_chunks(data);
627 assert_eq!(chunks.len(), 1);
628 assert!(chunks[0].delta.contains("Hello"));
630 assert!(chunks[0].delta.contains("world"));
631 }
632
633 #[test]
634 fn sse_chunk_stream_multiple_lines_in_one_chunk() {
635 let data = vec![
636 b"data: {\"choices\":[{\"delta\":{\"content\":\"A\"}}]}\ndata: {\"choices\":[{\"delta\":{\"content\":\"B\"}}]}\ndata: {\"choices\":[{\"delta\":{\"content\":\"C\"}}]}\n".to_vec(),
637 ];
638 let chunks = collect_sse_chunks(data);
639 assert_eq!(chunks.len(), 3);
640 assert_eq!(chunks[0].delta, "A");
641 assert_eq!(chunks[1].delta, "B");
642 assert_eq!(chunks[2].delta, "C");
643 }
644
645 fn collect_sse_results(data: Vec<Vec<u8>>) -> Vec<Result<format::StreamChunk>> {
647 let byte_stream = stream::iter(
648 data.into_iter()
649 .map(|b| Ok::<_, reqwest::Error>(Bytes::from(b))),
650 );
651 let mut sse = SseChunkStream::new(Box::pin(byte_stream), ApiFormat::OpenAiCompletions);
652 let rt = tokio::runtime::Builder::new_current_thread()
653 .build()
654 .unwrap();
655 rt.block_on(async {
656 let mut items = vec![];
657 while let Some(item) = futures::StreamExt::next(&mut sse).await {
658 items.push(item);
659 }
660 items
661 })
662 }
663
664 #[test]
665 fn sse_chunk_stream_buffer_overflow_error() {
666 let huge = vec![b'x'; 11 * 1024 * 1024];
668 let results = collect_sse_results(vec![huge]);
669 let last = results.last().unwrap();
670 assert!(last.is_err());
671 let err_msg = format!("{}", last.as_ref().unwrap_err());
672 assert!(
673 err_msg.contains("10 MB"),
674 "error should mention buffer limit: {err_msg}"
675 );
676 }
677
678 #[test]
679 fn sse_chunk_stream_anthropic_format() {
680 let data = vec![
682 b"data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}\n".to_vec(),
683 ];
684 let byte_stream = stream::iter(
685 data.into_iter()
686 .map(|b| Ok::<_, reqwest::Error>(Bytes::from(b))),
687 );
688 let mut sse = SseChunkStream::new(Box::pin(byte_stream), ApiFormat::AnthropicMessages);
689 let rt = tokio::runtime::Builder::new_current_thread()
690 .build()
691 .unwrap();
692 let chunks: Vec<_> = rt.block_on(async {
693 let mut chunks = vec![];
694 while let Some(item) = futures::StreamExt::next(&mut sse).await {
695 chunks.push(item.unwrap());
696 }
697 chunks
698 });
699 assert_eq!(chunks.len(), 1);
700 assert_eq!(chunks[0].delta, "Hi");
701 }
702
703 #[test]
704 fn sse_chunk_stream_google_format() {
705 let data = vec![
706 b"data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Gemini\"}],\"role\":\"model\"}}]}\n".to_vec(),
707 ];
708 let byte_stream = stream::iter(
709 data.into_iter()
710 .map(|b| Ok::<_, reqwest::Error>(Bytes::from(b))),
711 );
712 let mut sse = SseChunkStream::new(Box::pin(byte_stream), ApiFormat::GoogleGenerativeAi);
713 let rt = tokio::runtime::Builder::new_current_thread()
714 .build()
715 .unwrap();
716 let chunks: Vec<_> = rt.block_on(async {
717 let mut chunks = vec![];
718 while let Some(item) = futures::StreamExt::next(&mut sse).await {
719 chunks.push(item.unwrap());
720 }
721 chunks
722 });
723 assert_eq!(chunks.len(), 1);
724 assert_eq!(chunks[0].delta, "Gemini");
725 }
726
727 #[test]
728 fn sse_chunk_stream_trailing_data_no_newline() {
729 let data = vec![b"data: {\"choices\":[{\"delta\":{\"content\":\"tail\"}}]}".to_vec()];
731 let chunks = collect_sse_chunks(data);
732 assert_eq!(chunks.len(), 1);
733 assert_eq!(chunks[0].delta, "tail");
734 }
735
736 #[test]
737 fn sse_chunk_stream_pending_queue_drains_correctly() {
738 let data = vec![
740 b"data: {\"choices\":[{\"delta\":{\"content\":\"X\"}}]}\ndata: {\"choices\":[{\"delta\":{\"content\":\"Y\"}}]}".to_vec(),
741 ];
742 let chunks = collect_sse_chunks(data);
743 let text: String = chunks.iter().map(|c| c.delta.as_str()).collect();
744 assert_eq!(text, "XY");
745 }
746}