1use crate::{ChatMessage, ChatResponse, RsllmError, RsllmResult, StreamChunk};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::fmt;
10use std::str::FromStr;
11use url::Url;
12
13fn normalize_base_url(url: &Url) -> Url {
16 let url_str = url.as_str();
17 if url_str.ends_with('/') {
18 url.clone()
19 } else {
20 format!("{}/", url_str)
22 .parse()
23 .unwrap_or_else(|_| url.clone())
24 }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
29pub enum Provider {
30 OpenAI,
32 Claude,
34 Ollama,
36}
37
38impl Provider {
39 pub fn default_base_url(&self) -> Url {
41 match self {
42 Provider::OpenAI => "https://api.openai.com/v1/".parse().unwrap(),
43 Provider::Claude => "https://api.anthropic.com/v1/".parse().unwrap(),
44 Provider::Ollama => "http://localhost:11434/api/".parse().unwrap(),
45 }
46 }
47
48 pub fn default_models(&self) -> Vec<&'static str> {
50 match self {
51 Provider::OpenAI => vec![
52 "gpt-4o",
53 "gpt-4o-mini",
54 "gpt-4-turbo",
55 "gpt-4",
56 "gpt-3.5-turbo",
57 "gpt-3.5-turbo-instruct",
58 ],
59 Provider::Claude => vec![
60 "claude-3-5-sonnet-20241022",
61 "claude-3-5-haiku-20241022",
62 "claude-3-opus-20240229",
63 "claude-3-sonnet-20240229",
64 "claude-3-haiku-20240307",
65 ],
66 Provider::Ollama => vec![
67 "llama3.1",
68 "llama3.1:70b",
69 "llama3.1:405b",
70 "mistral",
71 "codellama",
72 "vicuna",
73 ],
74 }
75 }
76
77 pub fn default_model(&self) -> &'static str {
79 match self {
80 Provider::OpenAI => "gpt-4o-mini",
81 Provider::Claude => "claude-3-5-haiku-20241022",
82 Provider::Ollama => "llama3.1",
83 }
84 }
85
86 pub fn supports_streaming(&self) -> bool {
88 match self {
89 Provider::OpenAI => true,
90 Provider::Claude => true,
91 Provider::Ollama => true,
92 }
93 }
94
95 pub fn requires_auth(&self) -> bool {
97 match self {
98 Provider::OpenAI => true,
99 Provider::Claude => true,
100 Provider::Ollama => false, }
102 }
103}
104
105impl fmt::Display for Provider {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 match self {
108 Provider::OpenAI => write!(f, "openai"),
109 Provider::Claude => write!(f, "claude"),
110 Provider::Ollama => write!(f, "ollama"),
111 }
112 }
113}
114
115impl FromStr for Provider {
116 type Err = RsllmError;
117
118 fn from_str(s: &str) -> Result<Self, Self::Err> {
119 match s.to_lowercase().as_str() {
120 "openai" | "gpt" => Ok(Provider::OpenAI),
121 "claude" | "anthropic" => Ok(Provider::Claude),
122 "ollama" => Ok(Provider::Ollama),
123 _ => Err(RsllmError::configuration(format!(
124 "Unknown provider: {}",
125 s
126 ))),
127 }
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ProviderConfig {
134 pub provider: Provider,
136
137 pub api_key: Option<String>,
139
140 pub base_url: Option<Url>,
142
143 pub organization_id: Option<String>,
145}
146
147impl Default for ProviderConfig {
148 fn default() -> Self {
149 Self {
150 provider: Provider::OpenAI,
151 api_key: None,
152 base_url: None,
153 organization_id: None,
154 }
155 }
156}
157
158#[async_trait]
160pub trait LLMProvider: Send + Sync {
161 fn name(&self) -> &str;
163
164 fn provider_type(&self) -> Provider;
166
167 fn supported_models(&self) -> Vec<String>;
169
170 async fn health_check(&self) -> RsllmResult<bool>;
172
173 async fn chat_completion(
175 &self,
176 messages: Vec<ChatMessage>,
177 model: Option<&str>,
178 temperature: Option<f32>,
179 max_tokens: Option<u32>,
180 ) -> RsllmResult<ChatResponse>;
181
182 async fn chat_completion_stream(
184 &self,
185 messages: Vec<ChatMessage>,
186 model: Option<String>,
187 temperature: Option<f32>,
188 max_tokens: Option<u32>,
189 ) -> RsllmResult<Box<dyn futures_util::Stream<Item = RsllmResult<StreamChunk>> + Send + Unpin>>;
190
191 async fn chat_completion_with_tools(
193 &self,
194 messages: Vec<ChatMessage>,
195 tools: Vec<crate::tools::ToolDefinition>,
196 model: Option<&str>,
197 temperature: Option<f32>,
198 max_tokens: Option<u32>,
199 ) -> RsllmResult<ChatResponse> {
200 let _ = tools; self.chat_completion(messages, model, temperature, max_tokens)
203 .await
204 }
205}
206
207#[cfg(feature = "openai")]
209pub struct OpenAIProvider {
210 client: reqwest::Client,
211 api_key: String,
212 base_url: Url,
213 organization_id: Option<String>,
214}
215
216#[cfg(feature = "openai")]
217impl OpenAIProvider {
218 pub fn new(
220 api_key: String,
221 base_url: Option<Url>,
222 organization_id: Option<String>,
223 ) -> RsllmResult<Self> {
224 let client = reqwest::Client::builder()
225 .timeout(std::time::Duration::from_secs(30))
226 .build()
227 .map_err(|e| {
228 RsllmError::configuration_with_source("Failed to create HTTP client", e)
229 })?;
230
231 let base = base_url.unwrap_or_else(|| Provider::OpenAI.default_base_url());
232 let normalized_base_url = normalize_base_url(&base);
233
234 Ok(Self {
235 client,
236 api_key,
237 base_url: normalized_base_url,
238 organization_id,
239 })
240 }
241
242 fn build_headers(&self) -> reqwest::header::HeaderMap {
244 let mut headers = reqwest::header::HeaderMap::new();
245
246 headers.insert(
247 reqwest::header::AUTHORIZATION,
248 format!("Bearer {}", self.api_key).parse().unwrap(),
249 );
250
251 headers.insert(
252 reqwest::header::CONTENT_TYPE,
253 "application/json".parse().unwrap(),
254 );
255
256 if let Some(org_id) = &self.organization_id {
257 headers.insert("OpenAI-Organization", org_id.parse().unwrap());
258 }
259
260 headers
261 }
262}
263
264#[cfg(feature = "openai")]
265#[async_trait]
266impl LLMProvider for OpenAIProvider {
267 fn name(&self) -> &str {
268 "OpenAI"
269 }
270
271 fn provider_type(&self) -> Provider {
272 Provider::OpenAI
273 }
274
275 fn supported_models(&self) -> Vec<String> {
276 Provider::OpenAI
277 .default_models()
278 .iter()
279 .map(|s| s.to_string())
280 .collect()
281 }
282
283 async fn health_check(&self) -> RsllmResult<bool> {
284 let url = self.base_url.join("models")?;
285 let response = self
286 .client
287 .get(url)
288 .headers(self.build_headers())
289 .send()
290 .await?;
291
292 Ok(response.status().is_success())
293 }
294
295 async fn chat_completion(
296 &self,
297 messages: Vec<ChatMessage>,
298 model: Option<&str>,
299 temperature: Option<f32>,
300 max_tokens: Option<u32>,
301 ) -> RsllmResult<ChatResponse> {
302 let url = self.base_url.join("chat/completions")?;
303
304 let mut request_body = serde_json::json!({
305 "model": model.unwrap_or(Provider::OpenAI.default_model()),
306 "messages": messages,
307 });
308
309 if let Some(temp) = temperature {
310 request_body["temperature"] = temp.into();
311 }
312
313 if let Some(max_tokens) = max_tokens {
314 request_body["max_tokens"] = max_tokens.into();
315 }
316
317 let response = self
318 .client
319 .post(url)
320 .headers(self.build_headers())
321 .json(&request_body)
322 .send()
323 .await?;
324
325 if !response.status().is_success() {
326 let status = response.status();
327 let error_text = response
328 .text()
329 .await
330 .unwrap_or_else(|_| "Unknown error".to_string());
331 return Err(RsllmError::api(
332 "OpenAI",
333 format!("API request failed: {}", error_text),
334 status.as_str(),
335 ));
336 }
337
338 let response_data: serde_json::Value = response.json().await?;
339
340 let content = response_data["choices"][0]["message"]["content"]
342 .as_str()
343 .unwrap_or("")
344 .to_string();
345
346 Ok(
347 ChatResponse::new(content, model.unwrap_or(Provider::OpenAI.default_model()))
348 .with_finish_reason("stop"),
349 )
350 }
351
352 async fn chat_completion_stream(
353 &self,
354 messages: Vec<ChatMessage>,
355 model: Option<String>,
356 temperature: Option<f32>,
357 max_tokens: Option<u32>,
358 ) -> RsllmResult<Box<dyn futures_util::Stream<Item = RsllmResult<StreamChunk>> + Send + Unpin>>
359 {
360 use futures_util::stream;
361
362 let _url = self.base_url.join("chat/completions")?;
365
366 let model_name = model.unwrap_or_else(|| Provider::OpenAI.default_model().to_string());
367 let mut _request_body = serde_json::json!({
368 "model": &model_name,
369 "messages": messages,
370 "stream": true,
371 });
372
373 if let Some(temp) = temperature {
374 _request_body["temperature"] = temp.into();
375 }
376
377 if let Some(max_tokens) = max_tokens {
378 _request_body["max_tokens"] = max_tokens.into();
379 }
380
381 let chunks = vec![
383 "Hello",
384 " there!",
385 " This",
386 " is",
387 " a",
388 " streaming",
389 " response",
390 " from",
391 " OpenAI.",
392 ];
393
394 let stream = stream::iter(chunks.into_iter().enumerate().map(move |(i, chunk)| {
395 let _ = tokio::time::sleep(std::time::Duration::from_millis(100));
396
397 if i == 8 {
398 Ok(StreamChunk::done(&model_name).with_finish_reason("stop"))
400 } else {
401 Ok(StreamChunk::delta(chunk, &model_name))
402 }
403 }));
404
405 Ok(Box::new(stream))
406 }
407
408 async fn chat_completion_with_tools(
409 &self,
410 messages: Vec<ChatMessage>,
411 tools: Vec<crate::tools::ToolDefinition>,
412 model: Option<&str>,
413 temperature: Option<f32>,
414 max_tokens: Option<u32>,
415 ) -> RsllmResult<ChatResponse> {
416 let url = self.base_url.join("chat/completions")?;
417
418 let tools_json: Vec<serde_json::Value> = tools
420 .iter()
421 .map(|tool| {
422 serde_json::json!({
423 "type": "function",
424 "function": {
425 "name": tool.name,
426 "description": tool.description,
427 "parameters": tool.parameters
428 }
429 })
430 })
431 .collect();
432
433 let mut request_body = serde_json::json!({
434 "model": model.unwrap_or(Provider::OpenAI.default_model()),
435 "messages": messages,
436 "tools": tools_json,
437 });
438
439 if let Some(temp) = temperature {
440 request_body["temperature"] = temp.into();
441 }
442
443 if let Some(max_tokens) = max_tokens {
444 request_body["max_tokens"] = max_tokens.into();
445 }
446
447 let response = self
448 .client
449 .post(url)
450 .headers(self.build_headers())
451 .json(&request_body)
452 .send()
453 .await?;
454
455 if !response.status().is_success() {
456 let status = response.status();
457 let error_text = response
458 .text()
459 .await
460 .unwrap_or_else(|_| "Unknown error".to_string());
461 return Err(RsllmError::api(
462 "OpenAI",
463 format!("API request failed: {}", error_text),
464 status.as_str(),
465 ));
466 }
467
468 let response_data: serde_json::Value = response.json().await?;
469
470 let content = response_data["choices"][0]["message"]["content"]
472 .as_str()
473 .unwrap_or("")
474 .to_string();
475
476 let tool_calls = if let Some(calls_array) =
478 response_data["choices"][0]["message"]["tool_calls"].as_array()
479 {
480 let parsed_calls: Vec<crate::message::ToolCall> = calls_array
481 .iter()
482 .filter_map(|call| {
483 Some(crate::message::ToolCall {
484 id: call["id"].as_str()?.to_string(),
485 call_type: crate::message::ToolCallType::Function,
486 function: crate::message::ToolFunction {
487 name: call["function"]["name"].as_str()?.to_string(),
488 arguments: serde_json::from_str(
489 call["function"]["arguments"].as_str()?,
490 )
491 .ok()?,
492 },
493 })
494 })
495 .collect();
496
497 if parsed_calls.is_empty() {
498 None
499 } else {
500 Some(parsed_calls)
501 }
502 } else {
503 None
504 };
505
506 let mut response =
507 ChatResponse::new(content, model.unwrap_or(Provider::OpenAI.default_model()))
508 .with_finish_reason("stop");
509
510 if let Some(calls) = tool_calls {
511 response = response.with_tool_calls(calls);
512 }
513
514 Ok(response)
515 }
516}
517
518#[cfg(feature = "ollama")]
520pub struct OllamaProvider {
521 client: reqwest::Client,
522 base_url: Url,
523}
524
525#[cfg(feature = "ollama")]
526impl OllamaProvider {
527 pub fn new(base_url: Option<Url>) -> RsllmResult<Self> {
529 let client = reqwest::Client::builder()
530 .timeout(std::time::Duration::from_secs(60)) .build()
532 .map_err(|e| {
533 RsllmError::configuration_with_source("Failed to create HTTP client", e)
534 })?;
535
536 let base = base_url.unwrap_or_else(|| Provider::Ollama.default_base_url());
537 let normalized_base_url = normalize_base_url(&base);
538
539 Ok(Self {
540 client,
541 base_url: normalized_base_url,
542 })
543 }
544}
545
546#[cfg(feature = "ollama")]
547#[async_trait]
548impl LLMProvider for OllamaProvider {
549 fn name(&self) -> &str {
550 "Ollama"
551 }
552
553 fn provider_type(&self) -> Provider {
554 Provider::Ollama
555 }
556
557 fn supported_models(&self) -> Vec<String> {
558 Provider::Ollama
559 .default_models()
560 .iter()
561 .map(|s| s.to_string())
562 .collect()
563 }
564
565 async fn health_check(&self) -> RsllmResult<bool> {
566 let url = self.base_url.join("tags")?;
567 let response = self.client.get(url).send().await?;
568 Ok(response.status().is_success())
569 }
570
571 async fn chat_completion(
572 &self,
573 messages: Vec<ChatMessage>,
574 model: Option<&str>,
575 temperature: Option<f32>,
576 _max_tokens: Option<u32>,
577 ) -> RsllmResult<ChatResponse> {
578 let url = self.base_url.join("chat")?;
579
580 let mut request_body = serde_json::json!({
581 "model": model.unwrap_or(Provider::Ollama.default_model()),
582 "messages": messages,
583 "stream": false,
584 });
585
586 if let Some(temp) = temperature {
587 request_body["options"] = serde_json::json!({
588 "temperature": temp
589 });
590 }
591
592 let response = self.client.post(url).json(&request_body).send().await?;
593
594 if !response.status().is_success() {
595 let status = response.status();
596 let error_text = response
597 .text()
598 .await
599 .unwrap_or_else(|_| "Unknown error".to_string());
600 return Err(RsllmError::api(
601 "Ollama",
602 format!("API request failed: {}", error_text),
603 status.as_str(),
604 ));
605 }
606
607 let response_data: serde_json::Value = response.json().await?;
608
609 let content = response_data["message"]["content"]
610 .as_str()
611 .unwrap_or("")
612 .to_string();
613
614 Ok(
615 ChatResponse::new(content, model.unwrap_or(Provider::Ollama.default_model()))
616 .with_finish_reason("stop"),
617 )
618 }
619
620 async fn chat_completion_stream(
621 &self,
622 messages: Vec<ChatMessage>,
623 model: Option<String>,
624 temperature: Option<f32>,
625 _max_tokens: Option<u32>,
626 ) -> RsllmResult<Box<dyn futures_util::Stream<Item = RsllmResult<StreamChunk>> + Send + Unpin>>
627 {
628 use futures_util::stream;
629
630 let _url = self.base_url.join("chat")?;
632
633 let model_name = model.unwrap_or_else(|| Provider::Ollama.default_model().to_string());
634 let mut _request_body = serde_json::json!({
635 "model": &model_name,
636 "messages": messages,
637 "stream": true,
638 });
639
640 if let Some(temp) = temperature {
641 _request_body["options"] = serde_json::json!({
642 "temperature": temp
643 });
644 }
645
646 let chunks = vec![
648 "This",
649 " is",
650 " a",
651 " response",
652 " from",
653 " Ollama",
654 " running",
655 " locally.",
656 ];
657
658 let stream = stream::iter(chunks.into_iter().enumerate().map(move |(i, chunk)| {
659 let _ = tokio::time::sleep(std::time::Duration::from_millis(150));
660
661 if i == 7 {
662 Ok(StreamChunk::done(&model_name).with_finish_reason("stop"))
664 } else {
665 Ok(StreamChunk::delta(chunk, &model_name))
666 }
667 }));
668
669 Ok(Box::new(stream))
670 }
671
672 async fn chat_completion_with_tools(
673 &self,
674 messages: Vec<ChatMessage>,
675 tools: Vec<crate::tools::ToolDefinition>,
676 model: Option<&str>,
677 temperature: Option<f32>,
678 _max_tokens: Option<u32>,
679 ) -> RsllmResult<ChatResponse> {
680 let url = self.base_url.join("chat")?;
681
682 let tools_json: Vec<serde_json::Value> = tools
684 .iter()
685 .map(|tool| {
686 serde_json::json!({
687 "type": "function",
688 "function": {
689 "name": tool.name,
690 "description": tool.description,
691 "parameters": tool.parameters
692 }
693 })
694 })
695 .collect();
696
697 let mut request_body = serde_json::json!({
698 "model": model.unwrap_or(Provider::Ollama.default_model()),
699 "messages": messages,
700 "stream": false,
701 "tools": tools_json,
702 });
703
704 if let Some(temp) = temperature {
705 request_body["options"] = serde_json::json!({
706 "temperature": temp
707 });
708 }
709
710 let response = self.client.post(url).json(&request_body).send().await?;
711
712 if !response.status().is_success() {
713 let status = response.status();
714 let error_text = response
715 .text()
716 .await
717 .unwrap_or_else(|_| "Unknown error".to_string());
718 return Err(RsllmError::api(
719 "Ollama",
720 format!("API request failed: {}", error_text),
721 status.as_str(),
722 ));
723 }
724
725 let response_data: serde_json::Value = response.json().await?;
726
727 let content = response_data["message"]["content"]
728 .as_str()
729 .unwrap_or("")
730 .to_string();
731
732 let tool_calls =
734 if let Some(calls_array) = response_data["message"]["tool_calls"].as_array() {
735 let parsed_calls: Vec<crate::message::ToolCall> = calls_array
736 .iter()
737 .enumerate()
738 .filter_map(|(idx, call)| {
739 let function_name = call["function"]["name"].as_str()?;
740
741 let mut arguments = call["function"]["arguments"].clone();
744 if let serde_json::Value::Object(ref mut args_obj) = arguments {
745 for (_key, value) in args_obj.iter_mut() {
746 if let serde_json::Value::String(s) = value {
747 if let Ok(num) = s.parse::<f64>() {
749 *value = serde_json::json!(num);
750 } else if let Ok(int_num) = s.parse::<i64>() {
751 *value = serde_json::json!(int_num);
752 }
753 }
754 }
755 }
756
757 let id = call["id"]
759 .as_str()
760 .map(|s| s.to_string())
761 .unwrap_or_else(|| format!("call_{}", idx));
762
763 Some(crate::message::ToolCall {
764 id,
765 call_type: crate::message::ToolCallType::Function,
766 function: crate::message::ToolFunction {
767 name: function_name.to_string(),
768 arguments,
769 },
770 })
771 })
772 .collect();
773
774 if parsed_calls.is_empty() {
775 None
776 } else {
777 Some(parsed_calls)
778 }
779 } else {
780 None
781 };
782
783 let mut response =
784 ChatResponse::new(content, model.unwrap_or(Provider::Ollama.default_model()))
785 .with_finish_reason("stop");
786
787 if let Some(calls) = tool_calls {
788 response = response.with_tool_calls(calls);
789 }
790
791 Ok(response)
792 }
793}
794
795#[cfg(test)]
796mod tests {
797 use super::*;
798
799 #[test]
800 fn test_normalize_base_url_without_trailing_slash() {
801 let url = Url::parse("http://localhost:11434/api").unwrap();
802 let normalized = normalize_base_url(&url);
803 assert_eq!(normalized.as_str(), "http://localhost:11434/api/");
804 }
805
806 #[test]
807 fn test_normalize_base_url_with_trailing_slash() {
808 let url = Url::parse("http://localhost:11434/api/").unwrap();
809 let normalized = normalize_base_url(&url);
810 assert_eq!(normalized.as_str(), "http://localhost:11434/api/");
811 }
812
813 #[test]
814 fn test_normalize_base_url_complex() {
815 let url = Url::parse("https://api.openai.com/v1").unwrap();
816 let normalized = normalize_base_url(&url);
817 assert_eq!(normalized.as_str(), "https://api.openai.com/v1/");
818 }
819
820 #[test]
821 fn test_url_join_after_normalization() {
822 let url_without_slash = Url::parse("http://localhost:11434/api").unwrap();
824 let normalized = normalize_base_url(&url_without_slash);
825 let joined = normalized.join("chat").unwrap();
826 assert_eq!(joined.as_str(), "http://localhost:11434/api/chat");
827
828 let url_with_slash = Url::parse("http://localhost:11434/api/").unwrap();
829 let normalized2 = normalize_base_url(&url_with_slash);
830 let joined2 = normalized2.join("chat").unwrap();
831 assert_eq!(joined2.as_str(), "http://localhost:11434/api/chat");
832 }
833}