rexis_llm/
provider.rs

1//! # RSLLM Provider Abstraction
2//!
3//! Multi-provider support for different LLM APIs with unified interface.
4//! Supports OpenAI, Claude (Anthropic), Ollama, and custom providers.
5
6use crate::{ChatMessage, ChatResponse, RsllmError, RsllmResult, StreamChunk};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::fmt;
10use std::str::FromStr;
11use url::Url;
12
13/// Normalize URL to ensure it has a trailing slash for proper path joining
14/// This allows users to provide URLs with or without trailing slashes
15fn normalize_base_url(url: &Url) -> Url {
16    let url_str = url.as_str();
17    if url_str.ends_with('/') {
18        url.clone()
19    } else {
20        // Add trailing slash
21        format!("{}/", url_str)
22            .parse()
23            .unwrap_or_else(|_| url.clone())
24    }
25}
26
27/// Supported LLM providers
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
29pub enum Provider {
30    /// OpenAI (GPT models)
31    OpenAI,
32    /// Anthropic Claude
33    Claude,
34    /// Ollama (local models)
35    Ollama,
36}
37
38impl Provider {
39    /// Get the default base URL for this provider
40    pub fn default_base_url(&self) -> Url {
41        match self {
42            Provider::OpenAI => "https://api.openai.com/v1/".parse().unwrap(),
43            Provider::Claude => "https://api.anthropic.com/v1/".parse().unwrap(),
44            Provider::Ollama => "http://localhost:11434/api/".parse().unwrap(),
45        }
46    }
47
48    /// Get the default models for this provider
49    pub fn default_models(&self) -> Vec<&'static str> {
50        match self {
51            Provider::OpenAI => vec![
52                "gpt-4o",
53                "gpt-4o-mini",
54                "gpt-4-turbo",
55                "gpt-4",
56                "gpt-3.5-turbo",
57                "gpt-3.5-turbo-instruct",
58            ],
59            Provider::Claude => vec![
60                "claude-3-5-sonnet-20241022",
61                "claude-3-5-haiku-20241022",
62                "claude-3-opus-20240229",
63                "claude-3-sonnet-20240229",
64                "claude-3-haiku-20240307",
65            ],
66            Provider::Ollama => vec![
67                "llama3.1",
68                "llama3.1:70b",
69                "llama3.1:405b",
70                "mistral",
71                "codellama",
72                "vicuna",
73            ],
74        }
75    }
76
77    /// Get the recommended model for this provider
78    pub fn default_model(&self) -> &'static str {
79        match self {
80            Provider::OpenAI => "gpt-4o-mini",
81            Provider::Claude => "claude-3-5-haiku-20241022",
82            Provider::Ollama => "llama3.1",
83        }
84    }
85
86    /// Check if this provider supports streaming
87    pub fn supports_streaming(&self) -> bool {
88        match self {
89            Provider::OpenAI => true,
90            Provider::Claude => true,
91            Provider::Ollama => true,
92        }
93    }
94
95    /// Check if this provider requires authentication
96    pub fn requires_auth(&self) -> bool {
97        match self {
98            Provider::OpenAI => true,
99            Provider::Claude => true,
100            Provider::Ollama => false, // Local deployment typically doesn't need auth
101        }
102    }
103}
104
105impl fmt::Display for Provider {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        match self {
108            Provider::OpenAI => write!(f, "openai"),
109            Provider::Claude => write!(f, "claude"),
110            Provider::Ollama => write!(f, "ollama"),
111        }
112    }
113}
114
115impl FromStr for Provider {
116    type Err = RsllmError;
117
118    fn from_str(s: &str) -> Result<Self, Self::Err> {
119        match s.to_lowercase().as_str() {
120            "openai" | "gpt" => Ok(Provider::OpenAI),
121            "claude" | "anthropic" => Ok(Provider::Claude),
122            "ollama" => Ok(Provider::Ollama),
123            _ => Err(RsllmError::configuration(format!(
124                "Unknown provider: {}",
125                s
126            ))),
127        }
128    }
129}
130
131/// Provider configuration
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ProviderConfig {
134    /// Provider type
135    pub provider: Provider,
136
137    /// API key (if required)
138    pub api_key: Option<String>,
139
140    /// Base URL (if custom)
141    pub base_url: Option<Url>,
142
143    /// Organization ID (for providers that support it)
144    pub organization_id: Option<String>,
145}
146
147impl Default for ProviderConfig {
148    fn default() -> Self {
149        Self {
150            provider: Provider::OpenAI,
151            api_key: None,
152            base_url: None,
153            organization_id: None,
154        }
155    }
156}
157
158/// Core provider trait for LLM interactions
159#[async_trait]
160pub trait LLMProvider: Send + Sync {
161    /// Provider name/identifier
162    fn name(&self) -> &str;
163
164    /// Provider type
165    fn provider_type(&self) -> Provider;
166
167    /// Supported models
168    fn supported_models(&self) -> Vec<String>;
169
170    /// Health check
171    async fn health_check(&self) -> RsllmResult<bool>;
172
173    /// Chat completion (non-streaming)
174    async fn chat_completion(
175        &self,
176        messages: Vec<ChatMessage>,
177        model: Option<&str>,
178        temperature: Option<f32>,
179        max_tokens: Option<u32>,
180    ) -> RsllmResult<ChatResponse>;
181
182    /// Chat completion (streaming)
183    async fn chat_completion_stream(
184        &self,
185        messages: Vec<ChatMessage>,
186        model: Option<String>,
187        temperature: Option<f32>,
188        max_tokens: Option<u32>,
189    ) -> RsllmResult<Box<dyn futures_util::Stream<Item = RsllmResult<StreamChunk>> + Send + Unpin>>;
190
191    /// Chat completion with tool calling support
192    async fn chat_completion_with_tools(
193        &self,
194        messages: Vec<ChatMessage>,
195        tools: Vec<crate::tools::ToolDefinition>,
196        model: Option<&str>,
197        temperature: Option<f32>,
198        max_tokens: Option<u32>,
199    ) -> RsllmResult<ChatResponse> {
200        // Default implementation: call without tools (fallback for providers without tool support)
201        let _ = tools; // Suppress unused warning
202        self.chat_completion(messages, model, temperature, max_tokens)
203            .await
204    }
205}
206
207/// OpenAI provider implementation
208#[cfg(feature = "openai")]
209pub struct OpenAIProvider {
210    client: reqwest::Client,
211    api_key: String,
212    base_url: Url,
213    organization_id: Option<String>,
214}
215
216#[cfg(feature = "openai")]
217impl OpenAIProvider {
218    /// Create a new OpenAI provider
219    pub fn new(
220        api_key: String,
221        base_url: Option<Url>,
222        organization_id: Option<String>,
223    ) -> RsllmResult<Self> {
224        let client = reqwest::Client::builder()
225            .timeout(std::time::Duration::from_secs(30))
226            .build()
227            .map_err(|e| {
228                RsllmError::configuration_with_source("Failed to create HTTP client", e)
229            })?;
230
231        let base = base_url.unwrap_or_else(|| Provider::OpenAI.default_base_url());
232        let normalized_base_url = normalize_base_url(&base);
233
234        Ok(Self {
235            client,
236            api_key,
237            base_url: normalized_base_url,
238            organization_id,
239        })
240    }
241
242    /// Build request headers
243    fn build_headers(&self) -> reqwest::header::HeaderMap {
244        let mut headers = reqwest::header::HeaderMap::new();
245
246        headers.insert(
247            reqwest::header::AUTHORIZATION,
248            format!("Bearer {}", self.api_key).parse().unwrap(),
249        );
250
251        headers.insert(
252            reqwest::header::CONTENT_TYPE,
253            "application/json".parse().unwrap(),
254        );
255
256        if let Some(org_id) = &self.organization_id {
257            headers.insert("OpenAI-Organization", org_id.parse().unwrap());
258        }
259
260        headers
261    }
262}
263
264#[cfg(feature = "openai")]
265#[async_trait]
266impl LLMProvider for OpenAIProvider {
267    fn name(&self) -> &str {
268        "OpenAI"
269    }
270
271    fn provider_type(&self) -> Provider {
272        Provider::OpenAI
273    }
274
275    fn supported_models(&self) -> Vec<String> {
276        Provider::OpenAI
277            .default_models()
278            .iter()
279            .map(|s| s.to_string())
280            .collect()
281    }
282
283    async fn health_check(&self) -> RsllmResult<bool> {
284        let url = self.base_url.join("models")?;
285        let response = self
286            .client
287            .get(url)
288            .headers(self.build_headers())
289            .send()
290            .await?;
291
292        Ok(response.status().is_success())
293    }
294
295    async fn chat_completion(
296        &self,
297        messages: Vec<ChatMessage>,
298        model: Option<&str>,
299        temperature: Option<f32>,
300        max_tokens: Option<u32>,
301    ) -> RsllmResult<ChatResponse> {
302        let url = self.base_url.join("chat/completions")?;
303
304        let mut request_body = serde_json::json!({
305            "model": model.unwrap_or(Provider::OpenAI.default_model()),
306            "messages": messages,
307        });
308
309        if let Some(temp) = temperature {
310            request_body["temperature"] = temp.into();
311        }
312
313        if let Some(max_tokens) = max_tokens {
314            request_body["max_tokens"] = max_tokens.into();
315        }
316
317        let response = self
318            .client
319            .post(url)
320            .headers(self.build_headers())
321            .json(&request_body)
322            .send()
323            .await?;
324
325        if !response.status().is_success() {
326            let status = response.status();
327            let error_text = response
328                .text()
329                .await
330                .unwrap_or_else(|_| "Unknown error".to_string());
331            return Err(RsllmError::api(
332                "OpenAI",
333                format!("API request failed: {}", error_text),
334                status.as_str(),
335            ));
336        }
337
338        let response_data: serde_json::Value = response.json().await?;
339
340        // Extract the response content
341        let content = response_data["choices"][0]["message"]["content"]
342            .as_str()
343            .unwrap_or("")
344            .to_string();
345
346        Ok(
347            ChatResponse::new(content, model.unwrap_or(Provider::OpenAI.default_model()))
348                .with_finish_reason("stop"),
349        )
350    }
351
352    async fn chat_completion_stream(
353        &self,
354        messages: Vec<ChatMessage>,
355        model: Option<String>,
356        temperature: Option<f32>,
357        max_tokens: Option<u32>,
358    ) -> RsllmResult<Box<dyn futures_util::Stream<Item = RsllmResult<StreamChunk>> + Send + Unpin>>
359    {
360        use futures_util::stream;
361
362        // For now, implement a simple mock stream
363        // In production, this would handle Server-Sent Events (SSE) from OpenAI
364        let _url = self.base_url.join("chat/completions")?;
365
366        let model_name = model.unwrap_or_else(|| Provider::OpenAI.default_model().to_string());
367        let mut _request_body = serde_json::json!({
368            "model": &model_name,
369            "messages": messages,
370            "stream": true,
371        });
372
373        if let Some(temp) = temperature {
374            _request_body["temperature"] = temp.into();
375        }
376
377        if let Some(max_tokens) = max_tokens {
378            _request_body["max_tokens"] = max_tokens.into();
379        }
380
381        // Mock streaming response
382        let chunks = vec![
383            "Hello",
384            " there!",
385            " This",
386            " is",
387            " a",
388            " streaming",
389            " response",
390            " from",
391            " OpenAI.",
392        ];
393
394        let stream = stream::iter(chunks.into_iter().enumerate().map(move |(i, chunk)| {
395            let _ = tokio::time::sleep(std::time::Duration::from_millis(100));
396
397            if i == 8 {
398                // Last chunk
399                Ok(StreamChunk::done(&model_name).with_finish_reason("stop"))
400            } else {
401                Ok(StreamChunk::delta(chunk, &model_name))
402            }
403        }));
404
405        Ok(Box::new(stream))
406    }
407
408    async fn chat_completion_with_tools(
409        &self,
410        messages: Vec<ChatMessage>,
411        tools: Vec<crate::tools::ToolDefinition>,
412        model: Option<&str>,
413        temperature: Option<f32>,
414        max_tokens: Option<u32>,
415    ) -> RsllmResult<ChatResponse> {
416        let url = self.base_url.join("chat/completions")?;
417
418        // Build tools in OpenAI format
419        let tools_json: Vec<serde_json::Value> = tools
420            .iter()
421            .map(|tool| {
422                serde_json::json!({
423                    "type": "function",
424                    "function": {
425                        "name": tool.name,
426                        "description": tool.description,
427                        "parameters": tool.parameters
428                    }
429                })
430            })
431            .collect();
432
433        let mut request_body = serde_json::json!({
434            "model": model.unwrap_or(Provider::OpenAI.default_model()),
435            "messages": messages,
436            "tools": tools_json,
437        });
438
439        if let Some(temp) = temperature {
440            request_body["temperature"] = temp.into();
441        }
442
443        if let Some(max_tokens) = max_tokens {
444            request_body["max_tokens"] = max_tokens.into();
445        }
446
447        let response = self
448            .client
449            .post(url)
450            .headers(self.build_headers())
451            .json(&request_body)
452            .send()
453            .await?;
454
455        if !response.status().is_success() {
456            let status = response.status();
457            let error_text = response
458                .text()
459                .await
460                .unwrap_or_else(|_| "Unknown error".to_string());
461            return Err(RsllmError::api(
462                "OpenAI",
463                format!("API request failed: {}", error_text),
464                status.as_str(),
465            ));
466        }
467
468        let response_data: serde_json::Value = response.json().await?;
469
470        // Extract content
471        let content = response_data["choices"][0]["message"]["content"]
472            .as_str()
473            .unwrap_or("")
474            .to_string();
475
476        // Parse tool calls if present (OpenAI format)
477        let tool_calls = if let Some(calls_array) =
478            response_data["choices"][0]["message"]["tool_calls"].as_array()
479        {
480            let parsed_calls: Vec<crate::message::ToolCall> = calls_array
481                .iter()
482                .filter_map(|call| {
483                    Some(crate::message::ToolCall {
484                        id: call["id"].as_str()?.to_string(),
485                        call_type: crate::message::ToolCallType::Function,
486                        function: crate::message::ToolFunction {
487                            name: call["function"]["name"].as_str()?.to_string(),
488                            arguments: serde_json::from_str(
489                                call["function"]["arguments"].as_str()?,
490                            )
491                            .ok()?,
492                        },
493                    })
494                })
495                .collect();
496
497            if parsed_calls.is_empty() {
498                None
499            } else {
500                Some(parsed_calls)
501            }
502        } else {
503            None
504        };
505
506        let mut response =
507            ChatResponse::new(content, model.unwrap_or(Provider::OpenAI.default_model()))
508                .with_finish_reason("stop");
509
510        if let Some(calls) = tool_calls {
511            response = response.with_tool_calls(calls);
512        }
513
514        Ok(response)
515    }
516}
517
518/// Ollama provider implementation  
519#[cfg(feature = "ollama")]
520pub struct OllamaProvider {
521    client: reqwest::Client,
522    base_url: Url,
523}
524
525#[cfg(feature = "ollama")]
526impl OllamaProvider {
527    /// Create a new Ollama provider
528    pub fn new(base_url: Option<Url>) -> RsllmResult<Self> {
529        let client = reqwest::Client::builder()
530            .timeout(std::time::Duration::from_secs(60)) // Ollama can be slower
531            .build()
532            .map_err(|e| {
533                RsllmError::configuration_with_source("Failed to create HTTP client", e)
534            })?;
535
536        let base = base_url.unwrap_or_else(|| Provider::Ollama.default_base_url());
537        let normalized_base_url = normalize_base_url(&base);
538
539        Ok(Self {
540            client,
541            base_url: normalized_base_url,
542        })
543    }
544}
545
546#[cfg(feature = "ollama")]
547#[async_trait]
548impl LLMProvider for OllamaProvider {
549    fn name(&self) -> &str {
550        "Ollama"
551    }
552
553    fn provider_type(&self) -> Provider {
554        Provider::Ollama
555    }
556
557    fn supported_models(&self) -> Vec<String> {
558        Provider::Ollama
559            .default_models()
560            .iter()
561            .map(|s| s.to_string())
562            .collect()
563    }
564
565    async fn health_check(&self) -> RsllmResult<bool> {
566        let url = self.base_url.join("tags")?;
567        let response = self.client.get(url).send().await?;
568        Ok(response.status().is_success())
569    }
570
571    async fn chat_completion(
572        &self,
573        messages: Vec<ChatMessage>,
574        model: Option<&str>,
575        temperature: Option<f32>,
576        _max_tokens: Option<u32>,
577    ) -> RsllmResult<ChatResponse> {
578        let url = self.base_url.join("chat")?;
579
580        let mut request_body = serde_json::json!({
581            "model": model.unwrap_or(Provider::Ollama.default_model()),
582            "messages": messages,
583            "stream": false,
584        });
585
586        if let Some(temp) = temperature {
587            request_body["options"] = serde_json::json!({
588                "temperature": temp
589            });
590        }
591
592        let response = self.client.post(url).json(&request_body).send().await?;
593
594        if !response.status().is_success() {
595            let status = response.status();
596            let error_text = response
597                .text()
598                .await
599                .unwrap_or_else(|_| "Unknown error".to_string());
600            return Err(RsllmError::api(
601                "Ollama",
602                format!("API request failed: {}", error_text),
603                status.as_str(),
604            ));
605        }
606
607        let response_data: serde_json::Value = response.json().await?;
608
609        let content = response_data["message"]["content"]
610            .as_str()
611            .unwrap_or("")
612            .to_string();
613
614        Ok(
615            ChatResponse::new(content, model.unwrap_or(Provider::Ollama.default_model()))
616                .with_finish_reason("stop"),
617        )
618    }
619
620    async fn chat_completion_stream(
621        &self,
622        messages: Vec<ChatMessage>,
623        model: Option<String>,
624        temperature: Option<f32>,
625        _max_tokens: Option<u32>,
626    ) -> RsllmResult<Box<dyn futures_util::Stream<Item = RsllmResult<StreamChunk>> + Send + Unpin>>
627    {
628        use futures_util::stream;
629
630        // Mock streaming response for Ollama
631        let _url = self.base_url.join("chat")?;
632
633        let model_name = model.unwrap_or_else(|| Provider::Ollama.default_model().to_string());
634        let mut _request_body = serde_json::json!({
635            "model": &model_name,
636            "messages": messages,
637            "stream": true,
638        });
639
640        if let Some(temp) = temperature {
641            _request_body["options"] = serde_json::json!({
642                "temperature": temp
643            });
644        }
645
646        // Mock streaming response
647        let chunks = vec![
648            "This",
649            " is",
650            " a",
651            " response",
652            " from",
653            " Ollama",
654            " running",
655            " locally.",
656        ];
657
658        let stream = stream::iter(chunks.into_iter().enumerate().map(move |(i, chunk)| {
659            let _ = tokio::time::sleep(std::time::Duration::from_millis(150));
660
661            if i == 7 {
662                // Last chunk
663                Ok(StreamChunk::done(&model_name).with_finish_reason("stop"))
664            } else {
665                Ok(StreamChunk::delta(chunk, &model_name))
666            }
667        }));
668
669        Ok(Box::new(stream))
670    }
671
672    async fn chat_completion_with_tools(
673        &self,
674        messages: Vec<ChatMessage>,
675        tools: Vec<crate::tools::ToolDefinition>,
676        model: Option<&str>,
677        temperature: Option<f32>,
678        _max_tokens: Option<u32>,
679    ) -> RsllmResult<ChatResponse> {
680        let url = self.base_url.join("chat")?;
681
682        // Build tools in Ollama/OpenAI format
683        let tools_json: Vec<serde_json::Value> = tools
684            .iter()
685            .map(|tool| {
686                serde_json::json!({
687                    "type": "function",
688                    "function": {
689                        "name": tool.name,
690                        "description": tool.description,
691                        "parameters": tool.parameters
692                    }
693                })
694            })
695            .collect();
696
697        let mut request_body = serde_json::json!({
698            "model": model.unwrap_or(Provider::Ollama.default_model()),
699            "messages": messages,
700            "stream": false,
701            "tools": tools_json,
702        });
703
704        if let Some(temp) = temperature {
705            request_body["options"] = serde_json::json!({
706                "temperature": temp
707            });
708        }
709
710        let response = self.client.post(url).json(&request_body).send().await?;
711
712        if !response.status().is_success() {
713            let status = response.status();
714            let error_text = response
715                .text()
716                .await
717                .unwrap_or_else(|_| "Unknown error".to_string());
718            return Err(RsllmError::api(
719                "Ollama",
720                format!("API request failed: {}", error_text),
721                status.as_str(),
722            ));
723        }
724
725        let response_data: serde_json::Value = response.json().await?;
726
727        let content = response_data["message"]["content"]
728            .as_str()
729            .unwrap_or("")
730            .to_string();
731
732        // Parse tool calls if present
733        let tool_calls =
734            if let Some(calls_array) = response_data["message"]["tool_calls"].as_array() {
735                let parsed_calls: Vec<crate::message::ToolCall> = calls_array
736                    .iter()
737                    .enumerate()
738                    .filter_map(|(idx, call)| {
739                        let function_name = call["function"]["name"].as_str()?;
740
741                        // Ollama returns arguments as an object, sometimes with string values
742                        // Convert string numbers to actual numbers for compatibility
743                        let mut arguments = call["function"]["arguments"].clone();
744                        if let serde_json::Value::Object(ref mut args_obj) = arguments {
745                            for (_key, value) in args_obj.iter_mut() {
746                                if let serde_json::Value::String(s) = value {
747                                    // Try to parse as number
748                                    if let Ok(num) = s.parse::<f64>() {
749                                        *value = serde_json::json!(num);
750                                    } else if let Ok(int_num) = s.parse::<i64>() {
751                                        *value = serde_json::json!(int_num);
752                                    }
753                                }
754                            }
755                        }
756
757                        // Ollama doesn't provide an ID, so generate one
758                        let id = call["id"]
759                            .as_str()
760                            .map(|s| s.to_string())
761                            .unwrap_or_else(|| format!("call_{}", idx));
762
763                        Some(crate::message::ToolCall {
764                            id,
765                            call_type: crate::message::ToolCallType::Function,
766                            function: crate::message::ToolFunction {
767                                name: function_name.to_string(),
768                                arguments,
769                            },
770                        })
771                    })
772                    .collect();
773
774                if parsed_calls.is_empty() {
775                    None
776                } else {
777                    Some(parsed_calls)
778                }
779            } else {
780                None
781            };
782
783        let mut response =
784            ChatResponse::new(content, model.unwrap_or(Provider::Ollama.default_model()))
785                .with_finish_reason("stop");
786
787        if let Some(calls) = tool_calls {
788            response = response.with_tool_calls(calls);
789        }
790
791        Ok(response)
792    }
793}
794
795#[cfg(test)]
796mod tests {
797    use super::*;
798
799    #[test]
800    fn test_normalize_base_url_without_trailing_slash() {
801        let url = Url::parse("http://localhost:11434/api").unwrap();
802        let normalized = normalize_base_url(&url);
803        assert_eq!(normalized.as_str(), "http://localhost:11434/api/");
804    }
805
806    #[test]
807    fn test_normalize_base_url_with_trailing_slash() {
808        let url = Url::parse("http://localhost:11434/api/").unwrap();
809        let normalized = normalize_base_url(&url);
810        assert_eq!(normalized.as_str(), "http://localhost:11434/api/");
811    }
812
813    #[test]
814    fn test_normalize_base_url_complex() {
815        let url = Url::parse("https://api.openai.com/v1").unwrap();
816        let normalized = normalize_base_url(&url);
817        assert_eq!(normalized.as_str(), "https://api.openai.com/v1/");
818    }
819
820    #[test]
821    fn test_url_join_after_normalization() {
822        // Test that after normalization, joining works correctly
823        let url_without_slash = Url::parse("http://localhost:11434/api").unwrap();
824        let normalized = normalize_base_url(&url_without_slash);
825        let joined = normalized.join("chat").unwrap();
826        assert_eq!(joined.as_str(), "http://localhost:11434/api/chat");
827
828        let url_with_slash = Url::parse("http://localhost:11434/api/").unwrap();
829        let normalized2 = normalize_base_url(&url_with_slash);
830        let joined2 = normalized2.join("chat").unwrap();
831        assert_eq!(joined2.as_str(), "http://localhost:11434/api/chat");
832    }
833}