oxify_connect_llm/
lib.rs

1//! LLM provider connections for OxiFY
2
3mod batch;
4mod bedrock;
5mod cache;
6mod circuit_breaker;
7mod cohere;
8mod compression;
9mod dedup;
10mod errors;
11mod fallback;
12mod gemini;
13mod health_check;
14mod helpers;
15mod interceptor;
16mod llamacpp;
17mod load_balancer;
18mod mistral;
19mod observability;
20mod otel;
21mod priority_queue;
22mod prompt_engineering;
23mod rate_limit;
24mod recommender;
25#[cfg(feature = "redis-cache")]
26mod redis_budget;
27#[cfg(feature = "redis-cache")]
28mod redis_cache;
29mod response_utils;
30mod retry;
31mod selector;
32mod semantic_cache;
33mod streaming;
34mod templates;
35mod timeout;
36mod usage;
37mod validation;
38mod vllm;
39mod workflow;
40
41pub use batch::{BatchConfig, BatchProvider, BatchStats, EmbeddingBatchProvider};
42pub use bedrock::BedrockProvider;
43pub use cache::{CacheStats, CachedProvider, LlmCache};
44pub use circuit_breaker::{CircuitBreakerConfig, CircuitBreakerProvider, CircuitState};
45pub use cohere::CohereProvider;
46pub use compression::{CompressionStats, ModelLimits, PromptCompressor};
47pub use dedup::{DedupProvider, DedupStats};
48pub use errors::{ContextualError, ErrorContext, ErrorContextBuilder, ErrorContextExt};
49pub use fallback::FallbackProvider;
50pub use gemini::GeminiProvider;
51pub use health_check::{HealthCheckConfig, HealthCheckProvider, HealthStats, HealthStatus};
52pub use helpers::{LlmRequestBuilder, ModelUtils, QuickRequest, TokenUtils};
53pub use interceptor::{
54    ContentLengthInterceptor, EmbeddingInterceptorProvider, EmbeddingRequestInterceptor,
55    EmbeddingResponseInterceptor, InterceptorProvider, LoggingInterceptor, RequestInterceptor,
56    ResponseInterceptor, SanitizationInterceptor,
57};
58pub use llamacpp::LlamaCppProvider;
59pub use load_balancer::{LoadBalancer, LoadBalancerStats, LoadBalancingStrategy};
60pub use mistral::MistralProvider;
61pub use observability::{Metrics, MetricsProvider, ObservableProvider};
62pub use otel::{
63    OtelEmbeddingProvider, OtelProvider, ResponseAttributes, SpanAttributes, TraceEvent,
64};
65pub use priority_queue::{
66    PriorityQueueConfig, PriorityQueueProvider, PriorityQueueStats, RequestPriority,
67};
68pub use prompt_engineering::{
69    ChainOfThought, Example, FewShotPrompt, InstructionPrompt, Role, RolePrompt, SystemPrompts,
70};
71pub use rate_limit::{RateLimitConfig, RateLimitProvider, RateLimitStats};
72pub use recommender::{
73    AlternativeModel, BudgetConstraint, ModelRecommendation, ModelRecommender, OptimizationGoal,
74    RecommendationRequest, UseCase,
75};
76#[cfg(feature = "redis-cache")]
77pub use redis_budget::{RedisBudgetStats, RedisBudgetStore};
78#[cfg(feature = "redis-cache")]
79pub use redis_cache::{
80    RedisCache, RedisCacheStats, RedisCachedEmbeddingProvider, RedisCachedProvider,
81};
82pub use response_utils::{CodeBlock, ResponseUtils};
83pub use retry::{RetryConfig, RetryProvider};
84pub use selector::{ProviderMetadata, ProviderSelector, SelectionCriteria};
85pub use semantic_cache::{
86    SemanticCache, SemanticCacheStats, SemanticCachedProvider, SimilarityThreshold,
87};
88pub use streaming::{LlmChunk, LlmStream, StreamUsage, StreamingLlmProvider};
89pub use templates::{PromptTemplate, TemplateLibrary};
90pub use timeout::{TimeoutConfig, TimeoutProvider};
91pub use usage::{
92    BudgetLimit, BudgetProvider, ModelPricing, TrackedProvider, UsageStats, UsageTracker,
93};
94pub use validation::{RequestValidator, ValidationRules};
95pub use vllm::VllmProvider;
96pub use workflow::{WorkflowEmbeddingProvider, WorkflowProvider, WorkflowStats, WorkflowTracker};
97
98use async_trait::async_trait;
99use serde::{Deserialize, Serialize};
100use std::time::Duration;
101use thiserror::Error;
102
103pub type Result<T> = std::result::Result<T, LlmError>;
104
105#[derive(Error, Debug)]
106pub enum LlmError {
107    #[error("API error: {0}")]
108    ApiError(String),
109
110    #[error("Invalid configuration: {0}")]
111    ConfigError(String),
112
113    #[error("Serialization error: {0}")]
114    SerializationError(String),
115
116    #[error("Network error: {0}")]
117    NetworkError(#[from] reqwest::Error),
118
119    #[error("Rate limited (retry after {0:?})")]
120    RateLimited(Option<std::time::Duration>),
121
122    #[error("Invalid request: {0}")]
123    InvalidRequest(String),
124
125    #[error("Request timed out after {0:?}")]
126    Timeout(std::time::Duration),
127
128    #[error("Other error: {0}")]
129    Other(String),
130}
131
132impl Clone for LlmError {
133    fn clone(&self) -> Self {
134        match self {
135            Self::ApiError(s) => Self::ApiError(s.clone()),
136            Self::ConfigError(s) => Self::ConfigError(s.clone()),
137            Self::SerializationError(s) => Self::SerializationError(s.clone()),
138            Self::NetworkError(e) => Self::ApiError(format!("Network error: {}", e)),
139            Self::RateLimited(d) => Self::RateLimited(*d),
140            Self::InvalidRequest(s) => Self::InvalidRequest(s.clone()),
141            Self::Timeout(d) => Self::Timeout(*d),
142            Self::Other(s) => Self::Other(s.clone()),
143        }
144    }
145}
146
147impl LlmError {
148    /// Get the suggested retry delay if this is a rate limit error
149    pub fn retry_after(&self) -> Option<std::time::Duration> {
150        match self {
151            LlmError::RateLimited(retry_after) => *retry_after,
152            _ => None,
153        }
154    }
155}
156
157/// Tool/Function definition for function calling
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct Tool {
160    pub name: String,
161    pub description: String,
162    pub parameters: serde_json::Value,
163}
164
165/// Tool/Function call made by the LLM
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct ToolCall {
168    pub id: String,
169    pub name: String,
170    pub arguments: serde_json::Value,
171}
172
173/// Image input for vision models
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct ImageInput {
176    /// Image data (base64 encoded) or URL
177    pub data: String,
178    /// Image type: "url" or "base64"
179    pub source_type: ImageSourceType,
180    /// Media type (e.g., "image/png", "image/jpeg")
181    pub media_type: Option<String>,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
185pub enum ImageSourceType {
186    Url,
187    Base64,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct LlmRequest {
192    pub prompt: String,
193    pub system_prompt: Option<String>,
194    pub temperature: Option<f64>,
195    pub max_tokens: Option<u32>,
196    #[serde(default)]
197    pub tools: Vec<Tool>,
198    #[serde(default)]
199    pub images: Vec<ImageInput>,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct LlmResponse {
204    pub content: String,
205    pub model: String,
206    pub usage: Option<Usage>,
207    #[serde(default)]
208    pub tool_calls: Vec<ToolCall>,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct Usage {
213    pub prompt_tokens: u32,
214    pub completion_tokens: u32,
215    pub total_tokens: u32,
216}
217
218/// Trait for LLM providers
219#[async_trait]
220pub trait LlmProvider: Send + Sync {
221    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse>;
222}
223
224// ===== Embedding Support =====
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct EmbeddingRequest {
228    pub texts: Vec<String>,
229    pub model: Option<String>,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct EmbeddingResponse {
234    pub embeddings: Vec<Vec<f32>>,
235    pub model: String,
236    pub usage: Option<EmbeddingUsage>,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct EmbeddingUsage {
241    pub prompt_tokens: u32,
242    pub total_tokens: u32,
243}
244
245/// Trait for embedding providers
246#[async_trait]
247pub trait EmbeddingProvider: Send + Sync {
248    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse>;
249}
250
251// ===== OpenAI Provider =====
252
253/// OpenAI provider implementation
254pub struct OpenAIProvider {
255    api_key: String,
256    model: String,
257    client: reqwest::Client,
258    base_url: String,
259}
260
261#[derive(Serialize)]
262struct OpenAIRequest {
263    model: String,
264    messages: Vec<OpenAIMessage>,
265    #[serde(skip_serializing_if = "Option::is_none")]
266    temperature: Option<f64>,
267    #[serde(skip_serializing_if = "Option::is_none")]
268    max_tokens: Option<u32>,
269    #[serde(skip_serializing_if = "Vec::is_empty")]
270    tools: Vec<OpenAITool>,
271}
272
273#[derive(Serialize)]
274struct OpenAITool {
275    #[serde(rename = "type")]
276    tool_type: String,
277    function: OpenAIFunction,
278}
279
280#[derive(Serialize)]
281struct OpenAIFunction {
282    name: String,
283    description: String,
284    parameters: serde_json::Value,
285}
286
287#[derive(Serialize, Deserialize)]
288#[serde(untagged)]
289enum OpenAIMessageContent {
290    Text(String),
291    Parts(Vec<OpenAIContentPart>),
292}
293
294#[derive(Serialize, Deserialize)]
295#[serde(tag = "type")]
296enum OpenAIContentPart {
297    #[serde(rename = "text")]
298    Text { text: String },
299    #[serde(rename = "image_url")]
300    ImageUrl { image_url: OpenAIImageUrl },
301}
302
303#[derive(Serialize, Deserialize)]
304struct OpenAIImageUrl {
305    url: String,
306    #[serde(skip_serializing_if = "Option::is_none")]
307    detail: Option<String>,
308}
309
310#[derive(Serialize, Deserialize)]
311struct OpenAIMessage {
312    role: String,
313    content: OpenAIMessageContent,
314}
315
316#[derive(Deserialize)]
317struct OpenAIToolCall {
318    id: String,
319    #[serde(rename = "type")]
320    #[allow(dead_code)]
321    tool_type: String,
322    function: OpenAIFunctionCall,
323}
324
325#[derive(Deserialize)]
326struct OpenAIFunctionCall {
327    name: String,
328    arguments: String,
329}
330
331#[derive(Deserialize)]
332struct OpenAIResponse {
333    choices: Vec<Choice>,
334    usage: OpenAIUsage,
335    model: String,
336}
337
338#[derive(Deserialize)]
339struct Choice {
340    message: OpenAIResponseMessage,
341}
342
343#[derive(Deserialize)]
344struct OpenAIResponseMessage {
345    #[allow(dead_code)]
346    role: String,
347    #[serde(default)]
348    content: Option<String>,
349    #[serde(default)]
350    tool_calls: Vec<OpenAIToolCall>,
351}
352
353#[derive(Deserialize)]
354struct OpenAIUsage {
355    prompt_tokens: u32,
356    completion_tokens: u32,
357    total_tokens: u32,
358}
359
360#[derive(Deserialize)]
361struct OpenAIError {
362    error: OpenAIErrorDetail,
363}
364
365#[derive(Deserialize)]
366struct OpenAIErrorDetail {
367    message: String,
368    #[serde(rename = "type")]
369    error_type: String,
370}
371
372impl OpenAIProvider {
373    pub fn new(api_key: String, model: String) -> Self {
374        Self {
375            api_key,
376            model,
377            client: reqwest::Client::new(),
378            base_url: "https://api.openai.com/v1".to_string(),
379        }
380    }
381
382    pub fn with_base_url(mut self, base_url: String) -> Self {
383        self.base_url = base_url;
384        self
385    }
386
387    /// Create a provider specifically for embeddings
388    pub fn for_embeddings(api_key: String) -> Self {
389        Self::new(api_key, "text-embedding-ada-002".to_string())
390    }
391}
392
393#[async_trait]
394impl LlmProvider for OpenAIProvider {
395    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
396        let mut messages = Vec::new();
397
398        // Add system message if provided
399        if let Some(system_prompt) = &request.system_prompt {
400            messages.push(OpenAIMessage {
401                role: "system".to_string(),
402                content: OpenAIMessageContent::Text(system_prompt.clone()),
403            });
404        }
405
406        // Add user message (with images if provided)
407        let user_content = if request.images.is_empty() {
408            OpenAIMessageContent::Text(request.prompt.clone())
409        } else {
410            let mut parts = vec![OpenAIContentPart::Text {
411                text: request.prompt.clone(),
412            }];
413
414            for image in &request.images {
415                let url = match image.source_type {
416                    ImageSourceType::Url => image.data.clone(),
417                    ImageSourceType::Base64 => {
418                        let media_type = image.media_type.as_deref().unwrap_or("image/jpeg");
419                        format!("data:{};base64,{}", media_type, image.data)
420                    }
421                };
422
423                parts.push(OpenAIContentPart::ImageUrl {
424                    image_url: OpenAIImageUrl { url, detail: None },
425                });
426            }
427
428            OpenAIMessageContent::Parts(parts)
429        };
430
431        messages.push(OpenAIMessage {
432            role: "user".to_string(),
433            content: user_content,
434        });
435
436        // Convert tools to OpenAI format
437        let tools: Vec<OpenAITool> = request
438            .tools
439            .iter()
440            .map(|t| OpenAITool {
441                tool_type: "function".to_string(),
442                function: OpenAIFunction {
443                    name: t.name.clone(),
444                    description: t.description.clone(),
445                    parameters: t.parameters.clone(),
446                },
447            })
448            .collect();
449
450        let openai_request = OpenAIRequest {
451            model: self.model.clone(),
452            messages,
453            temperature: request.temperature,
454            max_tokens: request.max_tokens,
455            tools,
456        };
457
458        let response = self
459            .client
460            .post(format!("{}/chat/completions", self.base_url))
461            .header("Authorization", format!("Bearer {}", self.api_key))
462            .header("Content-Type", "application/json")
463            .json(&openai_request)
464            .send()
465            .await?;
466
467        let status = response.status();
468
469        if status == 429 {
470            // Extract Retry-After header if present
471            let retry_after = response
472                .headers()
473                .get("retry-after")
474                .and_then(|v| v.to_str().ok())
475                .and_then(|s| s.parse::<u64>().ok())
476                .map(Duration::from_secs);
477
478            return Err(LlmError::RateLimited(retry_after));
479        }
480
481        let body = response.text().await?;
482
483        if !status.is_success() {
484            // Try to parse error response
485            if let Ok(error) = serde_json::from_str::<OpenAIError>(&body) {
486                return Err(LlmError::ApiError(format!(
487                    "{}: {}",
488                    error.error.error_type, error.error.message
489                )));
490            }
491            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
492        }
493
494        let openai_response: OpenAIResponse =
495            serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
496
497        if openai_response.choices.is_empty() {
498            return Err(LlmError::ApiError("No choices in response".to_string()));
499        }
500
501        let message = &openai_response.choices[0].message;
502
503        // Parse tool calls if present
504        let tool_calls: Vec<ToolCall> = message
505            .tool_calls
506            .iter()
507            .filter_map(|tc| {
508                serde_json::from_str(&tc.function.arguments)
509                    .ok()
510                    .map(|args| ToolCall {
511                        id: tc.id.clone(),
512                        name: tc.function.name.clone(),
513                        arguments: args,
514                    })
515            })
516            .collect();
517
518        Ok(LlmResponse {
519            content: message.content.clone().unwrap_or_default(),
520            model: openai_response.model,
521            usage: Some(Usage {
522                prompt_tokens: openai_response.usage.prompt_tokens,
523                completion_tokens: openai_response.usage.completion_tokens,
524                total_tokens: openai_response.usage.total_tokens,
525            }),
526            tool_calls,
527        })
528    }
529}
530
531#[derive(Serialize)]
532struct OpenAIEmbeddingRequest {
533    input: Vec<String>,
534    model: String,
535}
536
537#[derive(Deserialize)]
538struct OpenAIEmbeddingResponse {
539    data: Vec<OpenAIEmbeddingData>,
540    model: String,
541    usage: OpenAIEmbeddingUsage,
542}
543
544#[derive(Deserialize)]
545struct OpenAIEmbeddingData {
546    embedding: Vec<f32>,
547    index: usize,
548}
549
550#[derive(Deserialize)]
551struct OpenAIEmbeddingUsage {
552    prompt_tokens: u32,
553    total_tokens: u32,
554}
555
556#[async_trait]
557impl EmbeddingProvider for OpenAIProvider {
558    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
559        let model = request.model.unwrap_or_else(|| self.model.clone());
560
561        let openai_request = OpenAIEmbeddingRequest {
562            input: request.texts,
563            model,
564        };
565
566        let response = self
567            .client
568            .post(format!("{}/embeddings", self.base_url))
569            .header("Authorization", format!("Bearer {}", self.api_key))
570            .header("Content-Type", "application/json")
571            .json(&openai_request)
572            .send()
573            .await?;
574
575        let status = response.status();
576
577        if status == 429 {
578            // Extract Retry-After header if present
579            let retry_after = response
580                .headers()
581                .get("retry-after")
582                .and_then(|v| v.to_str().ok())
583                .and_then(|s| s.parse::<u64>().ok())
584                .map(Duration::from_secs);
585
586            return Err(LlmError::RateLimited(retry_after));
587        }
588
589        let body = response.text().await?;
590
591        if !status.is_success() {
592            if let Ok(error) = serde_json::from_str::<OpenAIError>(&body) {
593                return Err(LlmError::ApiError(format!(
594                    "{}: {}",
595                    error.error.error_type, error.error.message
596                )));
597            }
598            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
599        }
600
601        let openai_response: OpenAIEmbeddingResponse =
602            serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
603
604        // Sort by index to ensure correct order
605        let mut data = openai_response.data;
606        data.sort_by_key(|d| d.index);
607
608        Ok(EmbeddingResponse {
609            embeddings: data.into_iter().map(|d| d.embedding).collect(),
610            model: openai_response.model,
611            usage: Some(EmbeddingUsage {
612                prompt_tokens: openai_response.usage.prompt_tokens,
613                total_tokens: openai_response.usage.total_tokens,
614            }),
615        })
616    }
617}
618
619// ===== OpenAI Streaming Implementation =====
620
621use futures::stream::StreamExt;
622
623#[derive(Deserialize)]
624struct OpenAIStreamChunk {
625    choices: Vec<OpenAIStreamChoice>,
626    #[serde(default)]
627    usage: Option<OpenAIUsage>,
628    model: String,
629}
630
631#[derive(Deserialize)]
632struct OpenAIStreamChoice {
633    delta: OpenAIDelta,
634    finish_reason: Option<String>,
635}
636
637#[derive(Deserialize)]
638struct OpenAIDelta {
639    #[serde(default)]
640    content: Option<String>,
641}
642
643#[async_trait]
644impl StreamingLlmProvider for OpenAIProvider {
645    async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
646        let mut messages = Vec::new();
647
648        if let Some(system_prompt) = &request.system_prompt {
649            messages.push(OpenAIMessage {
650                role: "system".to_string(),
651                content: OpenAIMessageContent::Text(system_prompt.clone()),
652            });
653        }
654
655        messages.push(OpenAIMessage {
656            role: "user".to_string(),
657            content: OpenAIMessageContent::Text(request.prompt.clone()),
658        });
659
660        // Note: Streaming with tools/vision is more complex, basic implementation without tool/vision support
661        let openai_request = serde_json::json!({
662            "model": self.model,
663            "messages": messages,
664            "temperature": request.temperature,
665            "max_tokens": request.max_tokens,
666            "stream": true
667        });
668
669        let response = self
670            .client
671            .post(format!("{}/chat/completions", self.base_url))
672            .header("Authorization", format!("Bearer {}", self.api_key))
673            .header("Content-Type", "application/json")
674            .json(&openai_request)
675            .send()
676            .await?;
677
678        let status = response.status();
679        if status == 429 {
680            // Extract Retry-After header if present
681            let retry_after = response
682                .headers()
683                .get("retry-after")
684                .and_then(|v| v.to_str().ok())
685                .and_then(|s| s.parse::<u64>().ok())
686                .map(Duration::from_secs);
687
688            return Err(LlmError::RateLimited(retry_after));
689        }
690
691        if !status.is_success() {
692            let body = response.text().await?;
693            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
694        }
695
696        let stream = response.bytes_stream();
697
698        let parsed_stream = stream.filter_map(|chunk_result| async move {
699            match chunk_result {
700                Ok(bytes) => {
701                    let text = String::from_utf8_lossy(&bytes);
702                    for line in text.lines() {
703                        if let Some(data) = line.strip_prefix("data: ") {
704                            if data == "[DONE]" {
705                                return Some(Ok(LlmChunk {
706                                    content: String::new(),
707                                    done: true,
708                                    model: None,
709                                    usage: None,
710                                }));
711                            }
712
713                            if let Ok(chunk) = serde_json::from_str::<OpenAIStreamChunk>(data) {
714                                if let Some(choice) = chunk.choices.first() {
715                                    let is_done = choice.finish_reason.is_some();
716                                    let content = choice.delta.content.clone().unwrap_or_default();
717
718                                    let usage = chunk.usage.as_ref().map(|u| StreamUsage {
719                                        prompt_tokens: Some(u.prompt_tokens),
720                                        completion_tokens: Some(u.completion_tokens),
721                                        total_tokens: Some(u.total_tokens),
722                                    });
723
724                                    if !content.is_empty() || is_done {
725                                        return Some(Ok(LlmChunk {
726                                            content,
727                                            done: is_done,
728                                            model: if is_done { Some(chunk.model) } else { None },
729                                            usage,
730                                        }));
731                                    }
732                                }
733                            }
734                        }
735                    }
736                    None
737                }
738                Err(e) => Some(Err(LlmError::NetworkError(e))),
739            }
740        });
741
742        Ok(Box::pin(parsed_stream))
743    }
744}
745
746// ===== Anthropic Provider =====
747
748/// Anthropic (Claude) provider implementation
749pub struct AnthropicProvider {
750    api_key: String,
751    model: String,
752    client: reqwest::Client,
753    base_url: String,
754}
755
756#[derive(Serialize)]
757struct AnthropicRequest {
758    model: String,
759    messages: Vec<AnthropicMessage>,
760    max_tokens: u32,
761    #[serde(skip_serializing_if = "Option::is_none")]
762    temperature: Option<f64>,
763    #[serde(skip_serializing_if = "Option::is_none")]
764    system: Option<String>,
765    #[serde(skip_serializing_if = "Vec::is_empty")]
766    tools: Vec<AnthropicTool>,
767}
768
769#[derive(Serialize)]
770struct AnthropicTool {
771    name: String,
772    description: String,
773    input_schema: serde_json::Value,
774}
775
776#[derive(Serialize, Deserialize)]
777#[serde(untagged)]
778enum AnthropicMessageContent {
779    Text(String),
780    Blocks(Vec<AnthropicInputBlock>),
781}
782
783#[derive(Serialize, Deserialize)]
784#[serde(tag = "type")]
785enum AnthropicInputBlock {
786    #[serde(rename = "text")]
787    Text { text: String },
788    #[serde(rename = "image")]
789    Image { source: AnthropicImageSource },
790}
791
792#[derive(Serialize, Deserialize)]
793struct AnthropicImageSource {
794    #[serde(rename = "type")]
795    source_type: String,
796    #[serde(skip_serializing_if = "Option::is_none")]
797    media_type: Option<String>,
798    #[serde(skip_serializing_if = "Option::is_none")]
799    data: Option<String>,
800    #[serde(skip_serializing_if = "Option::is_none")]
801    url: Option<String>,
802}
803
804#[derive(Serialize, Deserialize)]
805struct AnthropicMessage {
806    role: String,
807    content: AnthropicMessageContent,
808}
809
810#[derive(Deserialize)]
811struct AnthropicResponse {
812    content: Vec<AnthropicContentBlock>,
813    usage: AnthropicUsage,
814    model: String,
815}
816
817#[derive(Deserialize)]
818#[serde(tag = "type")]
819enum AnthropicContentBlock {
820    #[serde(rename = "text")]
821    Text { text: String },
822    #[serde(rename = "tool_use")]
823    ToolUse {
824        id: String,
825        name: String,
826        input: serde_json::Value,
827    },
828}
829
830#[derive(Deserialize)]
831struct AnthropicUsage {
832    input_tokens: u32,
833    output_tokens: u32,
834}
835
836impl AnthropicProvider {
837    pub fn new(api_key: String, model: String) -> Self {
838        Self {
839            api_key,
840            model,
841            client: reqwest::Client::new(),
842            base_url: "https://api.anthropic.com/v1".to_string(),
843        }
844    }
845
846    pub fn with_base_url(mut self, base_url: String) -> Self {
847        self.base_url = base_url;
848        self
849    }
850}
851
852#[async_trait]
853impl LlmProvider for AnthropicProvider {
854    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
855        // Convert tools to Anthropic format
856        let tools: Vec<AnthropicTool> = request
857            .tools
858            .iter()
859            .map(|t| AnthropicTool {
860                name: t.name.clone(),
861                description: t.description.clone(),
862                input_schema: t.parameters.clone(),
863            })
864            .collect();
865
866        // Build user message content (with images if provided)
867        let user_content = if request.images.is_empty() {
868            AnthropicMessageContent::Text(request.prompt.clone())
869        } else {
870            let mut blocks = vec![AnthropicInputBlock::Text {
871                text: request.prompt.clone(),
872            }];
873
874            for image in &request.images {
875                let source = match image.source_type {
876                    ImageSourceType::Url => AnthropicImageSource {
877                        source_type: "url".to_string(),
878                        media_type: None,
879                        data: None,
880                        url: Some(image.data.clone()),
881                    },
882                    ImageSourceType::Base64 => AnthropicImageSource {
883                        source_type: "base64".to_string(),
884                        media_type: image.media_type.clone(),
885                        data: Some(image.data.clone()),
886                        url: None,
887                    },
888                };
889
890                blocks.push(AnthropicInputBlock::Image { source });
891            }
892
893            AnthropicMessageContent::Blocks(blocks)
894        };
895
896        let anthropic_request = AnthropicRequest {
897            model: self.model.clone(),
898            messages: vec![AnthropicMessage {
899                role: "user".to_string(),
900                content: user_content,
901            }],
902            max_tokens: request.max_tokens.unwrap_or(4096),
903            temperature: request.temperature,
904            system: request.system_prompt,
905            tools,
906        };
907
908        let response = self
909            .client
910            .post(format!("{}/messages", self.base_url))
911            .header("x-api-key", &self.api_key)
912            .header("anthropic-version", "2023-06-01")
913            .header("Content-Type", "application/json")
914            .json(&anthropic_request)
915            .send()
916            .await?;
917
918        let status = response.status();
919
920        if status == 429 {
921            // Extract Retry-After header if present
922            let retry_after = response
923                .headers()
924                .get("retry-after")
925                .and_then(|v| v.to_str().ok())
926                .and_then(|s| s.parse::<u64>().ok())
927                .map(Duration::from_secs);
928
929            return Err(LlmError::RateLimited(retry_after));
930        }
931
932        let body = response.text().await?;
933
934        if !status.is_success() {
935            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
936        }
937
938        let anthropic_response: AnthropicResponse =
939            serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
940
941        if anthropic_response.content.is_empty() {
942            return Err(LlmError::ApiError("No content in response".to_string()));
943        }
944
945        // Extract text content and tool calls
946        let mut text_content = String::new();
947        let mut tool_calls = Vec::new();
948
949        for block in anthropic_response.content {
950            match block {
951                AnthropicContentBlock::Text { text } => {
952                    if !text_content.is_empty() {
953                        text_content.push('\n');
954                    }
955                    text_content.push_str(&text);
956                }
957                AnthropicContentBlock::ToolUse { id, name, input } => {
958                    tool_calls.push(ToolCall {
959                        id,
960                        name,
961                        arguments: input,
962                    });
963                }
964            }
965        }
966
967        Ok(LlmResponse {
968            content: text_content,
969            model: anthropic_response.model,
970            usage: Some(Usage {
971                prompt_tokens: anthropic_response.usage.input_tokens,
972                completion_tokens: anthropic_response.usage.output_tokens,
973                total_tokens: anthropic_response.usage.input_tokens
974                    + anthropic_response.usage.output_tokens,
975            }),
976            tool_calls,
977        })
978    }
979}
980
981// ===== Anthropic Streaming Implementation =====
982
983use std::sync::Arc;
984use tokio::sync::Mutex;
985
986#[derive(Deserialize)]
987#[serde(tag = "type")]
988enum AnthropicStreamEvent {
989    #[serde(rename = "message_start")]
990    MessageStart { message: AnthropicStreamMessage },
991    #[serde(rename = "content_block_delta")]
992    ContentBlockDelta { delta: AnthropicDelta },
993    #[serde(rename = "message_delta")]
994    MessageDelta {
995        #[allow(dead_code)]
996        delta: AnthropicStopDelta,
997        usage: AnthropicUsage,
998    },
999    #[serde(rename = "message_stop")]
1000    MessageStop,
1001    #[serde(other)]
1002    Other,
1003}
1004
1005#[derive(Deserialize)]
1006struct AnthropicStreamMessage {
1007    model: String,
1008    usage: AnthropicUsage,
1009}
1010
1011#[derive(Deserialize)]
1012#[serde(tag = "type")]
1013enum AnthropicDelta {
1014    #[serde(rename = "text_delta")]
1015    TextDelta { text: String },
1016    #[serde(other)]
1017    Other,
1018}
1019
1020#[derive(Deserialize)]
1021struct AnthropicStopDelta {
1022    #[allow(dead_code)]
1023    stop_reason: Option<String>,
1024}
1025
1026#[derive(Clone)]
1027struct AnthropicStreamState {
1028    model: Option<String>,
1029    input_tokens: Option<u32>,
1030    output_tokens: Option<u32>,
1031}
1032
1033#[async_trait]
1034impl StreamingLlmProvider for AnthropicProvider {
1035    async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
1036        let anthropic_request = serde_json::json!({
1037            "model": self.model,
1038            "messages": [{"role": "user", "content": request.prompt}],
1039            "max_tokens": request.max_tokens.unwrap_or(4096),
1040            "temperature": request.temperature,
1041            "system": request.system_prompt,
1042            "stream": true
1043        });
1044
1045        let response = self
1046            .client
1047            .post(format!("{}/messages", self.base_url))
1048            .header("x-api-key", &self.api_key)
1049            .header("anthropic-version", "2023-06-01")
1050            .header("Content-Type", "application/json")
1051            .json(&anthropic_request)
1052            .send()
1053            .await?;
1054
1055        let status = response.status();
1056        if status == 429 {
1057            // Extract Retry-After header if present
1058            let retry_after = response
1059                .headers()
1060                .get("retry-after")
1061                .and_then(|v| v.to_str().ok())
1062                .and_then(|s| s.parse::<u64>().ok())
1063                .map(Duration::from_secs);
1064
1065            return Err(LlmError::RateLimited(retry_after));
1066        }
1067
1068        if !status.is_success() {
1069            let body = response.text().await?;
1070            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
1071        }
1072
1073        let stream = response.bytes_stream();
1074        let state = Arc::new(Mutex::new(AnthropicStreamState {
1075            model: None,
1076            input_tokens: None,
1077            output_tokens: None,
1078        }));
1079
1080        let parsed_stream = stream.filter_map(move |chunk_result| {
1081            let state = Arc::clone(&state);
1082            async move {
1083                match chunk_result {
1084                    Ok(bytes) => {
1085                        let text = String::from_utf8_lossy(&bytes);
1086                        for line in text.lines() {
1087                            if let Some(data) = line.strip_prefix("data: ") {
1088                                if let Ok(event) =
1089                                    serde_json::from_str::<AnthropicStreamEvent>(data)
1090                                {
1091                                    match event {
1092                                        AnthropicStreamEvent::MessageStart { message } => {
1093                                            let mut s = state.lock().await;
1094                                            s.model = Some(message.model);
1095                                            s.input_tokens = Some(message.usage.input_tokens);
1096                                        }
1097                                        AnthropicStreamEvent::ContentBlockDelta { delta } => {
1098                                            if let AnthropicDelta::TextDelta { text } = delta {
1099                                                return Some(Ok(LlmChunk {
1100                                                    content: text,
1101                                                    done: false,
1102                                                    model: None,
1103                                                    usage: None,
1104                                                }));
1105                                            }
1106                                        }
1107                                        AnthropicStreamEvent::MessageDelta { usage, .. } => {
1108                                            let mut s = state.lock().await;
1109                                            s.output_tokens = Some(usage.output_tokens);
1110                                        }
1111                                        AnthropicStreamEvent::MessageStop => {
1112                                            let s = state.lock().await;
1113                                            let usage = match (s.input_tokens, s.output_tokens) {
1114                                                (Some(input), Some(output)) => Some(StreamUsage {
1115                                                    prompt_tokens: Some(input),
1116                                                    completion_tokens: Some(output),
1117                                                    total_tokens: Some(input + output),
1118                                                }),
1119                                                _ => None,
1120                                            };
1121
1122                                            return Some(Ok(LlmChunk {
1123                                                content: String::new(),
1124                                                done: true,
1125                                                model: s.model.clone(),
1126                                                usage,
1127                                            }));
1128                                        }
1129                                        AnthropicStreamEvent::Other => {}
1130                                    }
1131                                }
1132                            }
1133                        }
1134                        None
1135                    }
1136                    Err(e) => Some(Err(LlmError::NetworkError(e))),
1137                }
1138            }
1139        });
1140
1141        Ok(Box::pin(parsed_stream))
1142    }
1143}
1144
1145// ===== Ollama Provider =====
1146
1147/// Ollama (local model) provider implementation
1148pub struct OllamaProvider {
1149    model: String,
1150    client: reqwest::Client,
1151    base_url: String,
1152}
1153
1154#[derive(Serialize)]
1155struct OllamaRequest {
1156    model: String,
1157    prompt: String,
1158    #[serde(skip_serializing_if = "Option::is_none")]
1159    system: Option<String>,
1160    #[serde(skip_serializing_if = "Option::is_none")]
1161    temperature: Option<f64>,
1162    stream: bool,
1163}
1164
1165#[derive(Deserialize)]
1166struct OllamaResponse {
1167    response: String,
1168    model: String,
1169    #[serde(default)]
1170    #[allow(dead_code)]
1171    done: bool,
1172}
1173
1174#[derive(Deserialize)]
1175struct OllamaStreamResponse {
1176    response: String,
1177    #[serde(default)]
1178    model: Option<String>,
1179    #[serde(default)]
1180    done: bool,
1181}
1182
1183impl OllamaProvider {
1184    pub fn new(model: String) -> Self {
1185        Self {
1186            model,
1187            client: reqwest::Client::new(),
1188            base_url: "http://localhost:11434".to_string(),
1189        }
1190    }
1191
1192    pub fn with_base_url(mut self, base_url: String) -> Self {
1193        self.base_url = base_url;
1194        self
1195    }
1196
1197    /// Create a provider specifically for embeddings (e.g., "nomic-embed-text")
1198    pub fn for_embeddings(model: String) -> Self {
1199        Self::new(model)
1200    }
1201}
1202
1203#[async_trait]
1204impl LlmProvider for OllamaProvider {
1205    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
1206        let ollama_request = OllamaRequest {
1207            model: self.model.clone(),
1208            prompt: request.prompt.clone(),
1209            system: request.system_prompt,
1210            temperature: request.temperature,
1211            stream: false,
1212        };
1213
1214        let response = self
1215            .client
1216            .post(format!("{}/api/generate", self.base_url))
1217            .header("Content-Type", "application/json")
1218            .json(&ollama_request)
1219            .send()
1220            .await?;
1221
1222        let status = response.status();
1223
1224        if !status.is_success() {
1225            let body = response.text().await?;
1226            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
1227        }
1228
1229        let body = response.text().await?;
1230        let ollama_response: OllamaResponse =
1231            serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
1232
1233        Ok(LlmResponse {
1234            content: ollama_response.response,
1235            model: ollama_response.model,
1236            usage: None, // Ollama doesn't provide token usage by default
1237            tool_calls: Vec::new(),
1238        })
1239    }
1240}
1241
1242// ===== Ollama Streaming Implementation =====
1243
1244#[async_trait]
1245impl StreamingLlmProvider for OllamaProvider {
1246    async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
1247        let ollama_request = OllamaRequest {
1248            model: self.model.clone(),
1249            prompt: request.prompt.clone(),
1250            system: request.system_prompt,
1251            temperature: request.temperature,
1252            stream: true,
1253        };
1254
1255        let response = self
1256            .client
1257            .post(format!("{}/api/generate", self.base_url))
1258            .header("Content-Type", "application/json")
1259            .json(&ollama_request)
1260            .send()
1261            .await?;
1262
1263        let status = response.status();
1264
1265        if !status.is_success() {
1266            let body = response.text().await?;
1267            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
1268        }
1269
1270        let stream = response.bytes_stream();
1271
1272        let parsed_stream = stream.filter_map(|chunk_result| async move {
1273            match chunk_result {
1274                Ok(bytes) => {
1275                    let text = String::from_utf8_lossy(&bytes);
1276                    // Ollama returns newline-delimited JSON
1277                    for line in text.lines() {
1278                        if line.trim().is_empty() {
1279                            continue;
1280                        }
1281
1282                        if let Ok(chunk) = serde_json::from_str::<OllamaStreamResponse>(line) {
1283                            if chunk.done {
1284                                // Final chunk - includes model name
1285                                return Some(Ok(LlmChunk {
1286                                    content: chunk.response,
1287                                    done: true,
1288                                    model: chunk.model,
1289                                    usage: None, // Ollama doesn't provide token usage in stream
1290                                }));
1291                            } else {
1292                                // Content chunk
1293                                return Some(Ok(LlmChunk {
1294                                    content: chunk.response,
1295                                    done: false,
1296                                    model: None,
1297                                    usage: None,
1298                                }));
1299                            }
1300                        }
1301                    }
1302                    None
1303                }
1304                Err(e) => Some(Err(LlmError::NetworkError(e))),
1305            }
1306        });
1307
1308        Ok(Box::pin(parsed_stream))
1309    }
1310}
1311
1312// ===== Ollama Embeddings Implementation =====
1313
1314#[derive(Serialize)]
1315struct OllamaEmbeddingRequest {
1316    model: String,
1317    prompt: String,
1318}
1319
1320#[derive(Deserialize)]
1321struct OllamaEmbeddingResponse {
1322    embedding: Vec<f32>,
1323}
1324
1325#[async_trait]
1326impl EmbeddingProvider for OllamaProvider {
1327    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
1328        let model = request.model.unwrap_or_else(|| self.model.clone());
1329
1330        let mut embeddings = Vec::with_capacity(request.texts.len());
1331
1332        // Ollama embeddings API processes one text at a time
1333        for text in &request.texts {
1334            let ollama_request = OllamaEmbeddingRequest {
1335                model: model.clone(),
1336                prompt: text.clone(),
1337            };
1338
1339            let response = self
1340                .client
1341                .post(format!("{}/api/embeddings", self.base_url))
1342                .header("Content-Type", "application/json")
1343                .json(&ollama_request)
1344                .send()
1345                .await?;
1346
1347            let status = response.status();
1348
1349            if !status.is_success() {
1350                let body = response.text().await?;
1351                return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
1352            }
1353
1354            let body = response.text().await?;
1355            let ollama_response: OllamaEmbeddingResponse = serde_json::from_str(&body)
1356                .map_err(|e| LlmError::SerializationError(e.to_string()))?;
1357
1358            embeddings.push(ollama_response.embedding);
1359        }
1360
1361        Ok(EmbeddingResponse {
1362            embeddings,
1363            model,
1364            usage: None, // Ollama doesn't provide token usage for embeddings
1365        })
1366    }
1367}
1368
1369#[cfg(test)]
1370mod tests {
1371    use super::*;
1372
1373    #[test]
1374    fn test_provider_creation() {
1375        let provider = OpenAIProvider::new("test_key".to_string(), "gpt-4".to_string());
1376        assert_eq!(provider.model, "gpt-4");
1377        assert_eq!(provider.base_url, "https://api.openai.com/v1");
1378
1379        let provider =
1380            AnthropicProvider::new("test_key".to_string(), "claude-3-opus-20240229".to_string());
1381        assert_eq!(provider.model, "claude-3-opus-20240229");
1382
1383        let ollama_provider = OllamaProvider::new("llama2".to_string());
1384        assert_eq!(ollama_provider.model, "llama2");
1385        assert_eq!(ollama_provider.base_url, "http://localhost:11434");
1386    }
1387}