1mod batch;
4mod bedrock;
5mod cache;
6mod circuit_breaker;
7mod cohere;
8mod compression;
9mod dedup;
10mod errors;
11mod fallback;
12mod gemini;
13mod health_check;
14mod helpers;
15mod interceptor;
16mod llamacpp;
17mod load_balancer;
18mod mistral;
19mod observability;
20mod otel;
21mod priority_queue;
22mod prompt_engineering;
23mod rate_limit;
24mod recommender;
25#[cfg(feature = "redis-cache")]
26mod redis_budget;
27#[cfg(feature = "redis-cache")]
28mod redis_cache;
29mod response_utils;
30mod retry;
31mod selector;
32mod semantic_cache;
33mod streaming;
34mod templates;
35mod timeout;
36mod usage;
37mod validation;
38mod vllm;
39mod workflow;
40
41pub use batch::{BatchConfig, BatchProvider, BatchStats, EmbeddingBatchProvider};
42pub use bedrock::BedrockProvider;
43pub use cache::{CacheStats, CachedProvider, LlmCache};
44pub use circuit_breaker::{CircuitBreakerConfig, CircuitBreakerProvider, CircuitState};
45pub use cohere::CohereProvider;
46pub use compression::{CompressionStats, ModelLimits, PromptCompressor};
47pub use dedup::{DedupProvider, DedupStats};
48pub use errors::{ContextualError, ErrorContext, ErrorContextBuilder, ErrorContextExt};
49pub use fallback::FallbackProvider;
50pub use gemini::GeminiProvider;
51pub use health_check::{HealthCheckConfig, HealthCheckProvider, HealthStats, HealthStatus};
52pub use helpers::{LlmRequestBuilder, ModelUtils, QuickRequest, TokenUtils};
53pub use interceptor::{
54 ContentLengthInterceptor, EmbeddingInterceptorProvider, EmbeddingRequestInterceptor,
55 EmbeddingResponseInterceptor, InterceptorProvider, LoggingInterceptor, RequestInterceptor,
56 ResponseInterceptor, SanitizationInterceptor,
57};
58pub use llamacpp::LlamaCppProvider;
59pub use load_balancer::{LoadBalancer, LoadBalancerStats, LoadBalancingStrategy};
60pub use mistral::MistralProvider;
61pub use observability::{Metrics, MetricsProvider, ObservableProvider};
62pub use otel::{
63 OtelEmbeddingProvider, OtelProvider, ResponseAttributes, SpanAttributes, TraceEvent,
64};
65pub use priority_queue::{
66 PriorityQueueConfig, PriorityQueueProvider, PriorityQueueStats, RequestPriority,
67};
68pub use prompt_engineering::{
69 ChainOfThought, Example, FewShotPrompt, InstructionPrompt, Role, RolePrompt, SystemPrompts,
70};
71pub use rate_limit::{RateLimitConfig, RateLimitProvider, RateLimitStats};
72pub use recommender::{
73 AlternativeModel, BudgetConstraint, ModelRecommendation, ModelRecommender, OptimizationGoal,
74 RecommendationRequest, UseCase,
75};
76#[cfg(feature = "redis-cache")]
77pub use redis_budget::{RedisBudgetStats, RedisBudgetStore};
78#[cfg(feature = "redis-cache")]
79pub use redis_cache::{
80 RedisCache, RedisCacheStats, RedisCachedEmbeddingProvider, RedisCachedProvider,
81};
82pub use response_utils::{CodeBlock, ResponseUtils};
83pub use retry::{RetryConfig, RetryProvider};
84pub use selector::{ProviderMetadata, ProviderSelector, SelectionCriteria};
85pub use semantic_cache::{
86 SemanticCache, SemanticCacheStats, SemanticCachedProvider, SimilarityThreshold,
87};
88pub use streaming::{LlmChunk, LlmStream, StreamUsage, StreamingLlmProvider};
89pub use templates::{PromptTemplate, TemplateLibrary};
90pub use timeout::{TimeoutConfig, TimeoutProvider};
91pub use usage::{
92 BudgetLimit, BudgetProvider, ModelPricing, TrackedProvider, UsageStats, UsageTracker,
93};
94pub use validation::{RequestValidator, ValidationRules};
95pub use vllm::VllmProvider;
96pub use workflow::{WorkflowEmbeddingProvider, WorkflowProvider, WorkflowStats, WorkflowTracker};
97
98use async_trait::async_trait;
99use serde::{Deserialize, Serialize};
100use std::time::Duration;
101use thiserror::Error;
102
103pub type Result<T> = std::result::Result<T, LlmError>;
104
105#[derive(Error, Debug)]
106pub enum LlmError {
107 #[error("API error: {0}")]
108 ApiError(String),
109
110 #[error("Invalid configuration: {0}")]
111 ConfigError(String),
112
113 #[error("Serialization error: {0}")]
114 SerializationError(String),
115
116 #[error("Network error: {0}")]
117 NetworkError(#[from] reqwest::Error),
118
119 #[error("Rate limited (retry after {0:?})")]
120 RateLimited(Option<std::time::Duration>),
121
122 #[error("Invalid request: {0}")]
123 InvalidRequest(String),
124
125 #[error("Request timed out after {0:?}")]
126 Timeout(std::time::Duration),
127
128 #[error("Other error: {0}")]
129 Other(String),
130}
131
132impl Clone for LlmError {
133 fn clone(&self) -> Self {
134 match self {
135 Self::ApiError(s) => Self::ApiError(s.clone()),
136 Self::ConfigError(s) => Self::ConfigError(s.clone()),
137 Self::SerializationError(s) => Self::SerializationError(s.clone()),
138 Self::NetworkError(e) => Self::ApiError(format!("Network error: {}", e)),
139 Self::RateLimited(d) => Self::RateLimited(*d),
140 Self::InvalidRequest(s) => Self::InvalidRequest(s.clone()),
141 Self::Timeout(d) => Self::Timeout(*d),
142 Self::Other(s) => Self::Other(s.clone()),
143 }
144 }
145}
146
147impl LlmError {
148 pub fn retry_after(&self) -> Option<std::time::Duration> {
150 match self {
151 LlmError::RateLimited(retry_after) => *retry_after,
152 _ => None,
153 }
154 }
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct Tool {
160 pub name: String,
161 pub description: String,
162 pub parameters: serde_json::Value,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct ToolCall {
168 pub id: String,
169 pub name: String,
170 pub arguments: serde_json::Value,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct ImageInput {
176 pub data: String,
178 pub source_type: ImageSourceType,
180 pub media_type: Option<String>,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
185pub enum ImageSourceType {
186 Url,
187 Base64,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct LlmRequest {
192 pub prompt: String,
193 pub system_prompt: Option<String>,
194 pub temperature: Option<f64>,
195 pub max_tokens: Option<u32>,
196 #[serde(default)]
197 pub tools: Vec<Tool>,
198 #[serde(default)]
199 pub images: Vec<ImageInput>,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct LlmResponse {
204 pub content: String,
205 pub model: String,
206 pub usage: Option<Usage>,
207 #[serde(default)]
208 pub tool_calls: Vec<ToolCall>,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct Usage {
213 pub prompt_tokens: u32,
214 pub completion_tokens: u32,
215 pub total_tokens: u32,
216}
217
218#[async_trait]
220pub trait LlmProvider: Send + Sync {
221 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse>;
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct EmbeddingRequest {
228 pub texts: Vec<String>,
229 pub model: Option<String>,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct EmbeddingResponse {
234 pub embeddings: Vec<Vec<f32>>,
235 pub model: String,
236 pub usage: Option<EmbeddingUsage>,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct EmbeddingUsage {
241 pub prompt_tokens: u32,
242 pub total_tokens: u32,
243}
244
245#[async_trait]
247pub trait EmbeddingProvider: Send + Sync {
248 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse>;
249}
250
251pub struct OpenAIProvider {
255 api_key: String,
256 model: String,
257 client: reqwest::Client,
258 base_url: String,
259}
260
261#[derive(Serialize)]
262struct OpenAIRequest {
263 model: String,
264 messages: Vec<OpenAIMessage>,
265 #[serde(skip_serializing_if = "Option::is_none")]
266 temperature: Option<f64>,
267 #[serde(skip_serializing_if = "Option::is_none")]
268 max_tokens: Option<u32>,
269 #[serde(skip_serializing_if = "Vec::is_empty")]
270 tools: Vec<OpenAITool>,
271}
272
273#[derive(Serialize)]
274struct OpenAITool {
275 #[serde(rename = "type")]
276 tool_type: String,
277 function: OpenAIFunction,
278}
279
280#[derive(Serialize)]
281struct OpenAIFunction {
282 name: String,
283 description: String,
284 parameters: serde_json::Value,
285}
286
287#[derive(Serialize, Deserialize)]
288#[serde(untagged)]
289enum OpenAIMessageContent {
290 Text(String),
291 Parts(Vec<OpenAIContentPart>),
292}
293
294#[derive(Serialize, Deserialize)]
295#[serde(tag = "type")]
296enum OpenAIContentPart {
297 #[serde(rename = "text")]
298 Text { text: String },
299 #[serde(rename = "image_url")]
300 ImageUrl { image_url: OpenAIImageUrl },
301}
302
303#[derive(Serialize, Deserialize)]
304struct OpenAIImageUrl {
305 url: String,
306 #[serde(skip_serializing_if = "Option::is_none")]
307 detail: Option<String>,
308}
309
310#[derive(Serialize, Deserialize)]
311struct OpenAIMessage {
312 role: String,
313 content: OpenAIMessageContent,
314}
315
316#[derive(Deserialize)]
317struct OpenAIToolCall {
318 id: String,
319 #[serde(rename = "type")]
320 #[allow(dead_code)]
321 tool_type: String,
322 function: OpenAIFunctionCall,
323}
324
325#[derive(Deserialize)]
326struct OpenAIFunctionCall {
327 name: String,
328 arguments: String,
329}
330
331#[derive(Deserialize)]
332struct OpenAIResponse {
333 choices: Vec<Choice>,
334 usage: OpenAIUsage,
335 model: String,
336}
337
338#[derive(Deserialize)]
339struct Choice {
340 message: OpenAIResponseMessage,
341}
342
343#[derive(Deserialize)]
344struct OpenAIResponseMessage {
345 #[allow(dead_code)]
346 role: String,
347 #[serde(default)]
348 content: Option<String>,
349 #[serde(default)]
350 tool_calls: Vec<OpenAIToolCall>,
351}
352
353#[derive(Deserialize)]
354struct OpenAIUsage {
355 prompt_tokens: u32,
356 completion_tokens: u32,
357 total_tokens: u32,
358}
359
360#[derive(Deserialize)]
361struct OpenAIError {
362 error: OpenAIErrorDetail,
363}
364
365#[derive(Deserialize)]
366struct OpenAIErrorDetail {
367 message: String,
368 #[serde(rename = "type")]
369 error_type: String,
370}
371
372impl OpenAIProvider {
373 pub fn new(api_key: String, model: String) -> Self {
374 Self {
375 api_key,
376 model,
377 client: reqwest::Client::new(),
378 base_url: "https://api.openai.com/v1".to_string(),
379 }
380 }
381
382 pub fn with_base_url(mut self, base_url: String) -> Self {
383 self.base_url = base_url;
384 self
385 }
386
387 pub fn for_embeddings(api_key: String) -> Self {
389 Self::new(api_key, "text-embedding-ada-002".to_string())
390 }
391}
392
393#[async_trait]
394impl LlmProvider for OpenAIProvider {
395 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
396 let mut messages = Vec::new();
397
398 if let Some(system_prompt) = &request.system_prompt {
400 messages.push(OpenAIMessage {
401 role: "system".to_string(),
402 content: OpenAIMessageContent::Text(system_prompt.clone()),
403 });
404 }
405
406 let user_content = if request.images.is_empty() {
408 OpenAIMessageContent::Text(request.prompt.clone())
409 } else {
410 let mut parts = vec![OpenAIContentPart::Text {
411 text: request.prompt.clone(),
412 }];
413
414 for image in &request.images {
415 let url = match image.source_type {
416 ImageSourceType::Url => image.data.clone(),
417 ImageSourceType::Base64 => {
418 let media_type = image.media_type.as_deref().unwrap_or("image/jpeg");
419 format!("data:{};base64,{}", media_type, image.data)
420 }
421 };
422
423 parts.push(OpenAIContentPart::ImageUrl {
424 image_url: OpenAIImageUrl { url, detail: None },
425 });
426 }
427
428 OpenAIMessageContent::Parts(parts)
429 };
430
431 messages.push(OpenAIMessage {
432 role: "user".to_string(),
433 content: user_content,
434 });
435
436 let tools: Vec<OpenAITool> = request
438 .tools
439 .iter()
440 .map(|t| OpenAITool {
441 tool_type: "function".to_string(),
442 function: OpenAIFunction {
443 name: t.name.clone(),
444 description: t.description.clone(),
445 parameters: t.parameters.clone(),
446 },
447 })
448 .collect();
449
450 let openai_request = OpenAIRequest {
451 model: self.model.clone(),
452 messages,
453 temperature: request.temperature,
454 max_tokens: request.max_tokens,
455 tools,
456 };
457
458 let response = self
459 .client
460 .post(format!("{}/chat/completions", self.base_url))
461 .header("Authorization", format!("Bearer {}", self.api_key))
462 .header("Content-Type", "application/json")
463 .json(&openai_request)
464 .send()
465 .await?;
466
467 let status = response.status();
468
469 if status == 429 {
470 let retry_after = response
472 .headers()
473 .get("retry-after")
474 .and_then(|v| v.to_str().ok())
475 .and_then(|s| s.parse::<u64>().ok())
476 .map(Duration::from_secs);
477
478 return Err(LlmError::RateLimited(retry_after));
479 }
480
481 let body = response.text().await?;
482
483 if !status.is_success() {
484 if let Ok(error) = serde_json::from_str::<OpenAIError>(&body) {
486 return Err(LlmError::ApiError(format!(
487 "{}: {}",
488 error.error.error_type, error.error.message
489 )));
490 }
491 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
492 }
493
494 let openai_response: OpenAIResponse =
495 serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
496
497 if openai_response.choices.is_empty() {
498 return Err(LlmError::ApiError("No choices in response".to_string()));
499 }
500
501 let message = &openai_response.choices[0].message;
502
503 let tool_calls: Vec<ToolCall> = message
505 .tool_calls
506 .iter()
507 .filter_map(|tc| {
508 serde_json::from_str(&tc.function.arguments)
509 .ok()
510 .map(|args| ToolCall {
511 id: tc.id.clone(),
512 name: tc.function.name.clone(),
513 arguments: args,
514 })
515 })
516 .collect();
517
518 Ok(LlmResponse {
519 content: message.content.clone().unwrap_or_default(),
520 model: openai_response.model,
521 usage: Some(Usage {
522 prompt_tokens: openai_response.usage.prompt_tokens,
523 completion_tokens: openai_response.usage.completion_tokens,
524 total_tokens: openai_response.usage.total_tokens,
525 }),
526 tool_calls,
527 })
528 }
529}
530
531#[derive(Serialize)]
532struct OpenAIEmbeddingRequest {
533 input: Vec<String>,
534 model: String,
535}
536
537#[derive(Deserialize)]
538struct OpenAIEmbeddingResponse {
539 data: Vec<OpenAIEmbeddingData>,
540 model: String,
541 usage: OpenAIEmbeddingUsage,
542}
543
544#[derive(Deserialize)]
545struct OpenAIEmbeddingData {
546 embedding: Vec<f32>,
547 index: usize,
548}
549
550#[derive(Deserialize)]
551struct OpenAIEmbeddingUsage {
552 prompt_tokens: u32,
553 total_tokens: u32,
554}
555
556#[async_trait]
557impl EmbeddingProvider for OpenAIProvider {
558 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
559 let model = request.model.unwrap_or_else(|| self.model.clone());
560
561 let openai_request = OpenAIEmbeddingRequest {
562 input: request.texts,
563 model,
564 };
565
566 let response = self
567 .client
568 .post(format!("{}/embeddings", self.base_url))
569 .header("Authorization", format!("Bearer {}", self.api_key))
570 .header("Content-Type", "application/json")
571 .json(&openai_request)
572 .send()
573 .await?;
574
575 let status = response.status();
576
577 if status == 429 {
578 let retry_after = response
580 .headers()
581 .get("retry-after")
582 .and_then(|v| v.to_str().ok())
583 .and_then(|s| s.parse::<u64>().ok())
584 .map(Duration::from_secs);
585
586 return Err(LlmError::RateLimited(retry_after));
587 }
588
589 let body = response.text().await?;
590
591 if !status.is_success() {
592 if let Ok(error) = serde_json::from_str::<OpenAIError>(&body) {
593 return Err(LlmError::ApiError(format!(
594 "{}: {}",
595 error.error.error_type, error.error.message
596 )));
597 }
598 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
599 }
600
601 let openai_response: OpenAIEmbeddingResponse =
602 serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
603
604 let mut data = openai_response.data;
606 data.sort_by_key(|d| d.index);
607
608 Ok(EmbeddingResponse {
609 embeddings: data.into_iter().map(|d| d.embedding).collect(),
610 model: openai_response.model,
611 usage: Some(EmbeddingUsage {
612 prompt_tokens: openai_response.usage.prompt_tokens,
613 total_tokens: openai_response.usage.total_tokens,
614 }),
615 })
616 }
617}
618
619use futures::stream::StreamExt;
622
623#[derive(Deserialize)]
624struct OpenAIStreamChunk {
625 choices: Vec<OpenAIStreamChoice>,
626 #[serde(default)]
627 usage: Option<OpenAIUsage>,
628 model: String,
629}
630
631#[derive(Deserialize)]
632struct OpenAIStreamChoice {
633 delta: OpenAIDelta,
634 finish_reason: Option<String>,
635}
636
637#[derive(Deserialize)]
638struct OpenAIDelta {
639 #[serde(default)]
640 content: Option<String>,
641}
642
643#[async_trait]
644impl StreamingLlmProvider for OpenAIProvider {
645 async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
646 let mut messages = Vec::new();
647
648 if let Some(system_prompt) = &request.system_prompt {
649 messages.push(OpenAIMessage {
650 role: "system".to_string(),
651 content: OpenAIMessageContent::Text(system_prompt.clone()),
652 });
653 }
654
655 messages.push(OpenAIMessage {
656 role: "user".to_string(),
657 content: OpenAIMessageContent::Text(request.prompt.clone()),
658 });
659
660 let openai_request = serde_json::json!({
662 "model": self.model,
663 "messages": messages,
664 "temperature": request.temperature,
665 "max_tokens": request.max_tokens,
666 "stream": true
667 });
668
669 let response = self
670 .client
671 .post(format!("{}/chat/completions", self.base_url))
672 .header("Authorization", format!("Bearer {}", self.api_key))
673 .header("Content-Type", "application/json")
674 .json(&openai_request)
675 .send()
676 .await?;
677
678 let status = response.status();
679 if status == 429 {
680 let retry_after = response
682 .headers()
683 .get("retry-after")
684 .and_then(|v| v.to_str().ok())
685 .and_then(|s| s.parse::<u64>().ok())
686 .map(Duration::from_secs);
687
688 return Err(LlmError::RateLimited(retry_after));
689 }
690
691 if !status.is_success() {
692 let body = response.text().await?;
693 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
694 }
695
696 let stream = response.bytes_stream();
697
698 let parsed_stream = stream.filter_map(|chunk_result| async move {
699 match chunk_result {
700 Ok(bytes) => {
701 let text = String::from_utf8_lossy(&bytes);
702 for line in text.lines() {
703 if let Some(data) = line.strip_prefix("data: ") {
704 if data == "[DONE]" {
705 return Some(Ok(LlmChunk {
706 content: String::new(),
707 done: true,
708 model: None,
709 usage: None,
710 }));
711 }
712
713 if let Ok(chunk) = serde_json::from_str::<OpenAIStreamChunk>(data) {
714 if let Some(choice) = chunk.choices.first() {
715 let is_done = choice.finish_reason.is_some();
716 let content = choice.delta.content.clone().unwrap_or_default();
717
718 let usage = chunk.usage.as_ref().map(|u| StreamUsage {
719 prompt_tokens: Some(u.prompt_tokens),
720 completion_tokens: Some(u.completion_tokens),
721 total_tokens: Some(u.total_tokens),
722 });
723
724 if !content.is_empty() || is_done {
725 return Some(Ok(LlmChunk {
726 content,
727 done: is_done,
728 model: if is_done { Some(chunk.model) } else { None },
729 usage,
730 }));
731 }
732 }
733 }
734 }
735 }
736 None
737 }
738 Err(e) => Some(Err(LlmError::NetworkError(e))),
739 }
740 });
741
742 Ok(Box::pin(parsed_stream))
743 }
744}
745
746pub struct AnthropicProvider {
750 api_key: String,
751 model: String,
752 client: reqwest::Client,
753 base_url: String,
754}
755
756#[derive(Serialize)]
757struct AnthropicRequest {
758 model: String,
759 messages: Vec<AnthropicMessage>,
760 max_tokens: u32,
761 #[serde(skip_serializing_if = "Option::is_none")]
762 temperature: Option<f64>,
763 #[serde(skip_serializing_if = "Option::is_none")]
764 system: Option<String>,
765 #[serde(skip_serializing_if = "Vec::is_empty")]
766 tools: Vec<AnthropicTool>,
767}
768
769#[derive(Serialize)]
770struct AnthropicTool {
771 name: String,
772 description: String,
773 input_schema: serde_json::Value,
774}
775
776#[derive(Serialize, Deserialize)]
777#[serde(untagged)]
778enum AnthropicMessageContent {
779 Text(String),
780 Blocks(Vec<AnthropicInputBlock>),
781}
782
783#[derive(Serialize, Deserialize)]
784#[serde(tag = "type")]
785enum AnthropicInputBlock {
786 #[serde(rename = "text")]
787 Text { text: String },
788 #[serde(rename = "image")]
789 Image { source: AnthropicImageSource },
790}
791
792#[derive(Serialize, Deserialize)]
793struct AnthropicImageSource {
794 #[serde(rename = "type")]
795 source_type: String,
796 #[serde(skip_serializing_if = "Option::is_none")]
797 media_type: Option<String>,
798 #[serde(skip_serializing_if = "Option::is_none")]
799 data: Option<String>,
800 #[serde(skip_serializing_if = "Option::is_none")]
801 url: Option<String>,
802}
803
804#[derive(Serialize, Deserialize)]
805struct AnthropicMessage {
806 role: String,
807 content: AnthropicMessageContent,
808}
809
810#[derive(Deserialize)]
811struct AnthropicResponse {
812 content: Vec<AnthropicContentBlock>,
813 usage: AnthropicUsage,
814 model: String,
815}
816
817#[derive(Deserialize)]
818#[serde(tag = "type")]
819enum AnthropicContentBlock {
820 #[serde(rename = "text")]
821 Text { text: String },
822 #[serde(rename = "tool_use")]
823 ToolUse {
824 id: String,
825 name: String,
826 input: serde_json::Value,
827 },
828}
829
830#[derive(Deserialize)]
831struct AnthropicUsage {
832 input_tokens: u32,
833 output_tokens: u32,
834}
835
836impl AnthropicProvider {
837 pub fn new(api_key: String, model: String) -> Self {
838 Self {
839 api_key,
840 model,
841 client: reqwest::Client::new(),
842 base_url: "https://api.anthropic.com/v1".to_string(),
843 }
844 }
845
846 pub fn with_base_url(mut self, base_url: String) -> Self {
847 self.base_url = base_url;
848 self
849 }
850}
851
852#[async_trait]
853impl LlmProvider for AnthropicProvider {
854 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
855 let tools: Vec<AnthropicTool> = request
857 .tools
858 .iter()
859 .map(|t| AnthropicTool {
860 name: t.name.clone(),
861 description: t.description.clone(),
862 input_schema: t.parameters.clone(),
863 })
864 .collect();
865
866 let user_content = if request.images.is_empty() {
868 AnthropicMessageContent::Text(request.prompt.clone())
869 } else {
870 let mut blocks = vec![AnthropicInputBlock::Text {
871 text: request.prompt.clone(),
872 }];
873
874 for image in &request.images {
875 let source = match image.source_type {
876 ImageSourceType::Url => AnthropicImageSource {
877 source_type: "url".to_string(),
878 media_type: None,
879 data: None,
880 url: Some(image.data.clone()),
881 },
882 ImageSourceType::Base64 => AnthropicImageSource {
883 source_type: "base64".to_string(),
884 media_type: image.media_type.clone(),
885 data: Some(image.data.clone()),
886 url: None,
887 },
888 };
889
890 blocks.push(AnthropicInputBlock::Image { source });
891 }
892
893 AnthropicMessageContent::Blocks(blocks)
894 };
895
896 let anthropic_request = AnthropicRequest {
897 model: self.model.clone(),
898 messages: vec![AnthropicMessage {
899 role: "user".to_string(),
900 content: user_content,
901 }],
902 max_tokens: request.max_tokens.unwrap_or(4096),
903 temperature: request.temperature,
904 system: request.system_prompt,
905 tools,
906 };
907
908 let response = self
909 .client
910 .post(format!("{}/messages", self.base_url))
911 .header("x-api-key", &self.api_key)
912 .header("anthropic-version", "2023-06-01")
913 .header("Content-Type", "application/json")
914 .json(&anthropic_request)
915 .send()
916 .await?;
917
918 let status = response.status();
919
920 if status == 429 {
921 let retry_after = response
923 .headers()
924 .get("retry-after")
925 .and_then(|v| v.to_str().ok())
926 .and_then(|s| s.parse::<u64>().ok())
927 .map(Duration::from_secs);
928
929 return Err(LlmError::RateLimited(retry_after));
930 }
931
932 let body = response.text().await?;
933
934 if !status.is_success() {
935 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
936 }
937
938 let anthropic_response: AnthropicResponse =
939 serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
940
941 if anthropic_response.content.is_empty() {
942 return Err(LlmError::ApiError("No content in response".to_string()));
943 }
944
945 let mut text_content = String::new();
947 let mut tool_calls = Vec::new();
948
949 for block in anthropic_response.content {
950 match block {
951 AnthropicContentBlock::Text { text } => {
952 if !text_content.is_empty() {
953 text_content.push('\n');
954 }
955 text_content.push_str(&text);
956 }
957 AnthropicContentBlock::ToolUse { id, name, input } => {
958 tool_calls.push(ToolCall {
959 id,
960 name,
961 arguments: input,
962 });
963 }
964 }
965 }
966
967 Ok(LlmResponse {
968 content: text_content,
969 model: anthropic_response.model,
970 usage: Some(Usage {
971 prompt_tokens: anthropic_response.usage.input_tokens,
972 completion_tokens: anthropic_response.usage.output_tokens,
973 total_tokens: anthropic_response.usage.input_tokens
974 + anthropic_response.usage.output_tokens,
975 }),
976 tool_calls,
977 })
978 }
979}
980
981use std::sync::Arc;
984use tokio::sync::Mutex;
985
986#[derive(Deserialize)]
987#[serde(tag = "type")]
988enum AnthropicStreamEvent {
989 #[serde(rename = "message_start")]
990 MessageStart { message: AnthropicStreamMessage },
991 #[serde(rename = "content_block_delta")]
992 ContentBlockDelta { delta: AnthropicDelta },
993 #[serde(rename = "message_delta")]
994 MessageDelta {
995 #[allow(dead_code)]
996 delta: AnthropicStopDelta,
997 usage: AnthropicUsage,
998 },
999 #[serde(rename = "message_stop")]
1000 MessageStop,
1001 #[serde(other)]
1002 Other,
1003}
1004
1005#[derive(Deserialize)]
1006struct AnthropicStreamMessage {
1007 model: String,
1008 usage: AnthropicUsage,
1009}
1010
1011#[derive(Deserialize)]
1012#[serde(tag = "type")]
1013enum AnthropicDelta {
1014 #[serde(rename = "text_delta")]
1015 TextDelta { text: String },
1016 #[serde(other)]
1017 Other,
1018}
1019
1020#[derive(Deserialize)]
1021struct AnthropicStopDelta {
1022 #[allow(dead_code)]
1023 stop_reason: Option<String>,
1024}
1025
1026#[derive(Clone)]
1027struct AnthropicStreamState {
1028 model: Option<String>,
1029 input_tokens: Option<u32>,
1030 output_tokens: Option<u32>,
1031}
1032
1033#[async_trait]
1034impl StreamingLlmProvider for AnthropicProvider {
1035 async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
1036 let anthropic_request = serde_json::json!({
1037 "model": self.model,
1038 "messages": [{"role": "user", "content": request.prompt}],
1039 "max_tokens": request.max_tokens.unwrap_or(4096),
1040 "temperature": request.temperature,
1041 "system": request.system_prompt,
1042 "stream": true
1043 });
1044
1045 let response = self
1046 .client
1047 .post(format!("{}/messages", self.base_url))
1048 .header("x-api-key", &self.api_key)
1049 .header("anthropic-version", "2023-06-01")
1050 .header("Content-Type", "application/json")
1051 .json(&anthropic_request)
1052 .send()
1053 .await?;
1054
1055 let status = response.status();
1056 if status == 429 {
1057 let retry_after = response
1059 .headers()
1060 .get("retry-after")
1061 .and_then(|v| v.to_str().ok())
1062 .and_then(|s| s.parse::<u64>().ok())
1063 .map(Duration::from_secs);
1064
1065 return Err(LlmError::RateLimited(retry_after));
1066 }
1067
1068 if !status.is_success() {
1069 let body = response.text().await?;
1070 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
1071 }
1072
1073 let stream = response.bytes_stream();
1074 let state = Arc::new(Mutex::new(AnthropicStreamState {
1075 model: None,
1076 input_tokens: None,
1077 output_tokens: None,
1078 }));
1079
1080 let parsed_stream = stream.filter_map(move |chunk_result| {
1081 let state = Arc::clone(&state);
1082 async move {
1083 match chunk_result {
1084 Ok(bytes) => {
1085 let text = String::from_utf8_lossy(&bytes);
1086 for line in text.lines() {
1087 if let Some(data) = line.strip_prefix("data: ") {
1088 if let Ok(event) =
1089 serde_json::from_str::<AnthropicStreamEvent>(data)
1090 {
1091 match event {
1092 AnthropicStreamEvent::MessageStart { message } => {
1093 let mut s = state.lock().await;
1094 s.model = Some(message.model);
1095 s.input_tokens = Some(message.usage.input_tokens);
1096 }
1097 AnthropicStreamEvent::ContentBlockDelta { delta } => {
1098 if let AnthropicDelta::TextDelta { text } = delta {
1099 return Some(Ok(LlmChunk {
1100 content: text,
1101 done: false,
1102 model: None,
1103 usage: None,
1104 }));
1105 }
1106 }
1107 AnthropicStreamEvent::MessageDelta { usage, .. } => {
1108 let mut s = state.lock().await;
1109 s.output_tokens = Some(usage.output_tokens);
1110 }
1111 AnthropicStreamEvent::MessageStop => {
1112 let s = state.lock().await;
1113 let usage = match (s.input_tokens, s.output_tokens) {
1114 (Some(input), Some(output)) => Some(StreamUsage {
1115 prompt_tokens: Some(input),
1116 completion_tokens: Some(output),
1117 total_tokens: Some(input + output),
1118 }),
1119 _ => None,
1120 };
1121
1122 return Some(Ok(LlmChunk {
1123 content: String::new(),
1124 done: true,
1125 model: s.model.clone(),
1126 usage,
1127 }));
1128 }
1129 AnthropicStreamEvent::Other => {}
1130 }
1131 }
1132 }
1133 }
1134 None
1135 }
1136 Err(e) => Some(Err(LlmError::NetworkError(e))),
1137 }
1138 }
1139 });
1140
1141 Ok(Box::pin(parsed_stream))
1142 }
1143}
1144
1145pub struct OllamaProvider {
1149 model: String,
1150 client: reqwest::Client,
1151 base_url: String,
1152}
1153
1154#[derive(Serialize)]
1155struct OllamaRequest {
1156 model: String,
1157 prompt: String,
1158 #[serde(skip_serializing_if = "Option::is_none")]
1159 system: Option<String>,
1160 #[serde(skip_serializing_if = "Option::is_none")]
1161 temperature: Option<f64>,
1162 stream: bool,
1163}
1164
1165#[derive(Deserialize)]
1166struct OllamaResponse {
1167 response: String,
1168 model: String,
1169 #[serde(default)]
1170 #[allow(dead_code)]
1171 done: bool,
1172}
1173
1174#[derive(Deserialize)]
1175struct OllamaStreamResponse {
1176 response: String,
1177 #[serde(default)]
1178 model: Option<String>,
1179 #[serde(default)]
1180 done: bool,
1181}
1182
1183impl OllamaProvider {
1184 pub fn new(model: String) -> Self {
1185 Self {
1186 model,
1187 client: reqwest::Client::new(),
1188 base_url: "http://localhost:11434".to_string(),
1189 }
1190 }
1191
1192 pub fn with_base_url(mut self, base_url: String) -> Self {
1193 self.base_url = base_url;
1194 self
1195 }
1196
1197 pub fn for_embeddings(model: String) -> Self {
1199 Self::new(model)
1200 }
1201}
1202
1203#[async_trait]
1204impl LlmProvider for OllamaProvider {
1205 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
1206 let ollama_request = OllamaRequest {
1207 model: self.model.clone(),
1208 prompt: request.prompt.clone(),
1209 system: request.system_prompt,
1210 temperature: request.temperature,
1211 stream: false,
1212 };
1213
1214 let response = self
1215 .client
1216 .post(format!("{}/api/generate", self.base_url))
1217 .header("Content-Type", "application/json")
1218 .json(&ollama_request)
1219 .send()
1220 .await?;
1221
1222 let status = response.status();
1223
1224 if !status.is_success() {
1225 let body = response.text().await?;
1226 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
1227 }
1228
1229 let body = response.text().await?;
1230 let ollama_response: OllamaResponse =
1231 serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
1232
1233 Ok(LlmResponse {
1234 content: ollama_response.response,
1235 model: ollama_response.model,
1236 usage: None, tool_calls: Vec::new(),
1238 })
1239 }
1240}
1241
1242#[async_trait]
1245impl StreamingLlmProvider for OllamaProvider {
1246 async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
1247 let ollama_request = OllamaRequest {
1248 model: self.model.clone(),
1249 prompt: request.prompt.clone(),
1250 system: request.system_prompt,
1251 temperature: request.temperature,
1252 stream: true,
1253 };
1254
1255 let response = self
1256 .client
1257 .post(format!("{}/api/generate", self.base_url))
1258 .header("Content-Type", "application/json")
1259 .json(&ollama_request)
1260 .send()
1261 .await?;
1262
1263 let status = response.status();
1264
1265 if !status.is_success() {
1266 let body = response.text().await?;
1267 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
1268 }
1269
1270 let stream = response.bytes_stream();
1271
1272 let parsed_stream = stream.filter_map(|chunk_result| async move {
1273 match chunk_result {
1274 Ok(bytes) => {
1275 let text = String::from_utf8_lossy(&bytes);
1276 for line in text.lines() {
1278 if line.trim().is_empty() {
1279 continue;
1280 }
1281
1282 if let Ok(chunk) = serde_json::from_str::<OllamaStreamResponse>(line) {
1283 if chunk.done {
1284 return Some(Ok(LlmChunk {
1286 content: chunk.response,
1287 done: true,
1288 model: chunk.model,
1289 usage: None, }));
1291 } else {
1292 return Some(Ok(LlmChunk {
1294 content: chunk.response,
1295 done: false,
1296 model: None,
1297 usage: None,
1298 }));
1299 }
1300 }
1301 }
1302 None
1303 }
1304 Err(e) => Some(Err(LlmError::NetworkError(e))),
1305 }
1306 });
1307
1308 Ok(Box::pin(parsed_stream))
1309 }
1310}
1311
1312#[derive(Serialize)]
1315struct OllamaEmbeddingRequest {
1316 model: String,
1317 prompt: String,
1318}
1319
1320#[derive(Deserialize)]
1321struct OllamaEmbeddingResponse {
1322 embedding: Vec<f32>,
1323}
1324
1325#[async_trait]
1326impl EmbeddingProvider for OllamaProvider {
1327 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
1328 let model = request.model.unwrap_or_else(|| self.model.clone());
1329
1330 let mut embeddings = Vec::with_capacity(request.texts.len());
1331
1332 for text in &request.texts {
1334 let ollama_request = OllamaEmbeddingRequest {
1335 model: model.clone(),
1336 prompt: text.clone(),
1337 };
1338
1339 let response = self
1340 .client
1341 .post(format!("{}/api/embeddings", self.base_url))
1342 .header("Content-Type", "application/json")
1343 .json(&ollama_request)
1344 .send()
1345 .await?;
1346
1347 let status = response.status();
1348
1349 if !status.is_success() {
1350 let body = response.text().await?;
1351 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
1352 }
1353
1354 let body = response.text().await?;
1355 let ollama_response: OllamaEmbeddingResponse = serde_json::from_str(&body)
1356 .map_err(|e| LlmError::SerializationError(e.to_string()))?;
1357
1358 embeddings.push(ollama_response.embedding);
1359 }
1360
1361 Ok(EmbeddingResponse {
1362 embeddings,
1363 model,
1364 usage: None, })
1366 }
1367}
1368
1369#[cfg(test)]
1370mod tests {
1371 use super::*;
1372
1373 #[test]
1374 fn test_provider_creation() {
1375 let provider = OpenAIProvider::new("test_key".to_string(), "gpt-4".to_string());
1376 assert_eq!(provider.model, "gpt-4");
1377 assert_eq!(provider.base_url, "https://api.openai.com/v1");
1378
1379 let provider =
1380 AnthropicProvider::new("test_key".to_string(), "claude-3-opus-20240229".to_string());
1381 assert_eq!(provider.model, "claude-3-opus-20240229");
1382
1383 let ollama_provider = OllamaProvider::new("llama2".to_string());
1384 assert_eq!(ollama_provider.model, "llama2");
1385 assert_eq!(ollama_provider.base_url, "http://localhost:11434");
1386 }
1387}