1use crate::Result;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "lowercase")]
14pub enum LlmProvider {
15 OpenAI,
17 Anthropic,
19 OpenAICompatible,
21 Ollama,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum EmbeddingProvider {
29 OpenAI,
31 OpenAICompatible,
33 Ollama,
35}
36
37#[async_trait::async_trait]
39pub trait LlmProviderTrait: Send + Sync {
40 async fn generate_completion(
42 &self,
43 prompt: &str,
44 max_tokens: Option<usize>,
45 temperature: Option<f32>,
46 top_p: Option<f32>,
47 stop_sequences: Option<Vec<String>>,
48 ) -> Result<String>;
49
50 async fn generate_chat_completion(
52 &self,
53 messages: Vec<ChatMessage>,
54 max_tokens: Option<usize>,
55 temperature: Option<f32>,
56 top_p: Option<f32>,
57 stop_sequences: Option<Vec<String>>,
58 ) -> Result<String>;
59
60 async fn get_available_models(&self) -> Result<Vec<String>>;
62
63 async fn is_available(&self) -> bool;
65
66 fn name(&self) -> &'static str;
68
69 fn max_context_length(&self) -> usize;
71}
72
73#[async_trait::async_trait]
75pub trait EmbeddingProviderTrait: Send + Sync {
76 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>>;
78
79 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
81
82 fn embedding_dimensions(&self) -> usize;
84
85 fn max_tokens(&self) -> usize;
87
88 fn name(&self) -> &'static str;
90
91 async fn is_available(&self) -> bool;
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ChatMessage {
98 pub role: ChatRole,
100 pub content: String,
102 pub metadata: Option<HashMap<String, String>>,
104}
105
106impl ChatMessage {
107 pub fn system(content: String) -> Self {
109 Self {
110 role: ChatRole::System,
111 content,
112 metadata: None,
113 }
114 }
115
116 pub fn user(content: String) -> Self {
118 Self {
119 role: ChatRole::User,
120 content,
121 metadata: None,
122 }
123 }
124
125 pub fn assistant(content: String) -> Self {
127 Self {
128 role: ChatRole::Assistant,
129 content,
130 metadata: None,
131 }
132 }
133
134 pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
136 self.metadata = Some(metadata);
137 self
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143#[serde(rename_all = "lowercase")]
144pub enum ChatRole {
145 System,
147 User,
149 Assistant,
151}
152
153pub struct OpenAiProvider {
155 api_key: String,
156 client: reqwest::Client,
157 base_url: String,
158}
159
160impl OpenAiProvider {
161 pub fn new(api_key: String) -> Self {
163 Self {
164 api_key,
165 client: reqwest::Client::new(),
166 base_url: "https://api.openai.com/v1".to_string(),
167 }
168 }
169
170 pub fn new_with_base_url(api_key: String, base_url: String) -> Self {
172 Self {
173 api_key,
174 client: reqwest::Client::new(),
175 base_url,
176 }
177 }
178}
179
180#[async_trait::async_trait]
181impl LlmProviderTrait for OpenAiProvider {
182 async fn generate_completion(
183 &self,
184 prompt: &str,
185 max_tokens: Option<usize>,
186 temperature: Option<f32>,
187 top_p: Option<f32>,
188 stop_sequences: Option<Vec<String>>,
189 ) -> Result<String> {
190 let mut request_body = serde_json::json!({
191 "model": "gpt-3.5-turbo-instruct",
192 "prompt": prompt,
193 "max_tokens": max_tokens.unwrap_or(1024),
194 "temperature": temperature.unwrap_or(0.7),
195 });
196
197 if let Some(top_p) = top_p {
198 request_body["top_p"] = serde_json::json!(top_p);
199 }
200
201 if let Some(stop) = stop_sequences {
202 request_body["stop"] = serde_json::json!(stop);
203 }
204
205 let response = self
206 .client
207 .post(format!("{}/completions", self.base_url))
208 .header("Authorization", format!("Bearer {}", self.api_key))
209 .header("Content-Type", "application/json")
210 .json(&request_body)
211 .send()
212 .await?;
213
214 if !response.status().is_success() {
215 return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
216 }
217
218 let json: Value = response.json().await?;
219 let content = json["choices"][0]["text"]
220 .as_str()
221 .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
222
223 Ok(content.to_string())
224 }
225
226 async fn generate_chat_completion(
227 &self,
228 messages: Vec<ChatMessage>,
229 max_tokens: Option<usize>,
230 temperature: Option<f32>,
231 top_p: Option<f32>,
232 stop_sequences: Option<Vec<String>>,
233 ) -> Result<String> {
234 let openai_messages: Vec<Value> = messages
235 .iter()
236 .map(|msg| {
237 serde_json::json!({
238 "role": format!("{:?}", msg.role).to_lowercase(),
239 "content": msg.content
240 })
241 })
242 .collect();
243
244 let mut request_body = serde_json::json!({
245 "model": "gpt-3.5-turbo",
246 "messages": openai_messages,
247 "max_tokens": max_tokens.unwrap_or(1024),
248 "temperature": temperature.unwrap_or(0.7),
249 });
250
251 if let Some(top_p) = top_p {
252 request_body["top_p"] = serde_json::json!(top_p);
253 }
254
255 if let Some(stop) = stop_sequences {
256 request_body["stop"] = serde_json::json!(stop);
257 }
258
259 let response = self
260 .client
261 .post(format!("{}/chat/completions", self.base_url))
262 .header("Authorization", format!("Bearer {}", self.api_key))
263 .header("Content-Type", "application/json")
264 .json(&request_body)
265 .send()
266 .await?;
267
268 if !response.status().is_success() {
269 return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
270 }
271
272 let json: Value = response.json().await?;
273 let content = json["choices"][0]["message"]["content"]
274 .as_str()
275 .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
276
277 Ok(content.to_string())
278 }
279
280 async fn get_available_models(&self) -> Result<Vec<String>> {
281 let response = self
282 .client
283 .get(format!("{}/models", self.base_url))
284 .header("Authorization", format!("Bearer {}", self.api_key))
285 .send()
286 .await?;
287
288 if !response.status().is_success() {
289 return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
290 }
291
292 let json: Value = response.json().await?;
293 let models = json["data"]
294 .as_array()
295 .ok_or_else(|| crate::Error::generic("Invalid models response format"))?;
296
297 let model_names = models
298 .iter()
299 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
300 .collect();
301
302 Ok(model_names)
303 }
304
305 async fn is_available(&self) -> bool {
306 (self.get_available_models().await).is_ok()
307 }
308
309 fn name(&self) -> &'static str {
310 "OpenAI"
311 }
312
313 fn max_context_length(&self) -> usize {
314 4096 }
316}
317
318pub struct OpenAiEmbeddingProvider {
320 api_key: String,
321 client: reqwest::Client,
322 base_url: String,
323 model: String,
324}
325
326impl OpenAiEmbeddingProvider {
327 pub fn new(api_key: String) -> Self {
329 Self {
330 api_key,
331 client: reqwest::Client::new(),
332 base_url: "https://api.openai.com/v1".to_string(),
333 model: "text-embedding-ada-002".to_string(),
334 }
335 }
336
337 pub fn new_with_model(api_key: String, model: String) -> Self {
339 Self {
340 api_key,
341 client: reqwest::Client::new(),
342 base_url: "https://api.openai.com/v1".to_string(),
343 model,
344 }
345 }
346}
347
348#[async_trait::async_trait]
349impl EmbeddingProviderTrait for OpenAiEmbeddingProvider {
350 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
351 let response = self
352 .client
353 .post(format!("{}/embeddings", self.base_url))
354 .header("Authorization", format!("Bearer {}", self.api_key))
355 .header("Content-Type", "application/json")
356 .json(&serde_json::json!({
357 "input": text,
358 "model": self.model
359 }))
360 .send()
361 .await?;
362
363 if !response.status().is_success() {
364 return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
365 }
366
367 let json: Value = response.json().await?;
368 let embedding = json["data"][0]["embedding"]
369 .as_array()
370 .ok_or_else(|| crate::Error::generic("Invalid embedding response format"))?;
371
372 Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
373 }
374
375 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
376 let mut embeddings = Vec::new();
377
378 for text in texts {
379 let embedding = self.generate_embedding(&text).await?;
380 embeddings.push(embedding);
381 }
382
383 Ok(embeddings)
384 }
385
386 fn embedding_dimensions(&self) -> usize {
387 match self.model.as_str() {
388 "text-embedding-ada-002" => 1536,
389 "text-embedding-3-small" => 1536,
390 "text-embedding-3-large" => 3072,
391 _ => 1536, }
393 }
394
395 fn max_tokens(&self) -> usize {
396 match self.model.as_str() {
397 "text-embedding-ada-002" => 8191,
398 "text-embedding-3-small" => 8191,
399 "text-embedding-3-large" => 8191,
400 _ => 8191, }
402 }
403
404 fn name(&self) -> &'static str {
405 "OpenAI"
406 }
407
408 async fn is_available(&self) -> bool {
409 (self.generate_embedding("test").await).is_ok()
410 }
411}
412
413pub struct OpenAiCompatibleProvider {
415 api_key: String,
416 client: reqwest::Client,
417 base_url: String,
418 model: String,
419}
420
421impl OpenAiCompatibleProvider {
422 pub fn new(api_key: String, base_url: String, model: String) -> Self {
424 Self {
425 api_key,
426 client: reqwest::Client::new(),
427 base_url,
428 model,
429 }
430 }
431}
432
433#[async_trait::async_trait]
434impl LlmProviderTrait for OpenAiCompatibleProvider {
435 async fn generate_completion(
436 &self,
437 prompt: &str,
438 max_tokens: Option<usize>,
439 temperature: Option<f32>,
440 top_p: Option<f32>,
441 stop_sequences: Option<Vec<String>>,
442 ) -> Result<String> {
443 let mut request_body = serde_json::json!({
444 "model": self.model,
445 "prompt": prompt,
446 "max_tokens": max_tokens.unwrap_or(1024),
447 "temperature": temperature.unwrap_or(0.7),
448 });
449
450 if let Some(top_p) = top_p {
451 request_body["top_p"] = serde_json::json!(top_p);
452 }
453
454 if let Some(stop) = stop_sequences {
455 request_body["stop"] = serde_json::json!(stop);
456 }
457
458 let response = self
459 .client
460 .post(format!("{}/completions", self.base_url))
461 .header("Authorization", format!("Bearer {}", self.api_key))
462 .header("Content-Type", "application/json")
463 .json(&request_body)
464 .send()
465 .await?;
466
467 if !response.status().is_success() {
468 return Err(crate::Error::generic(format!("API error: {}", response.status())));
469 }
470
471 let json: Value = response.json().await?;
472 let content = json["choices"][0]["text"]
473 .as_str()
474 .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
475
476 Ok(content.to_string())
477 }
478
479 async fn generate_chat_completion(
480 &self,
481 messages: Vec<ChatMessage>,
482 max_tokens: Option<usize>,
483 temperature: Option<f32>,
484 top_p: Option<f32>,
485 stop_sequences: Option<Vec<String>>,
486 ) -> Result<String> {
487 let openai_messages: Vec<Value> = messages
488 .iter()
489 .map(|msg| {
490 serde_json::json!({
491 "role": format!("{:?}", msg.role).to_lowercase(),
492 "content": msg.content
493 })
494 })
495 .collect();
496
497 let mut request_body = serde_json::json!({
498 "model": self.model,
499 "messages": openai_messages,
500 "max_tokens": max_tokens.unwrap_or(1024),
501 "temperature": temperature.unwrap_or(0.7),
502 });
503
504 if let Some(top_p) = top_p {
505 request_body["top_p"] = serde_json::json!(top_p);
506 }
507
508 if let Some(stop) = stop_sequences {
509 request_body["stop"] = serde_json::json!(stop);
510 }
511
512 let response = self
513 .client
514 .post(format!("{}/chat/completions", self.base_url))
515 .header("Authorization", format!("Bearer {}", self.api_key))
516 .header("Content-Type", "application/json")
517 .json(&request_body)
518 .send()
519 .await?;
520
521 if !response.status().is_success() {
522 return Err(crate::Error::generic(format!("API error: {}", response.status())));
523 }
524
525 let json: Value = response.json().await?;
526 let content = json["choices"][0]["message"]["content"]
527 .as_str()
528 .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
529
530 Ok(content.to_string())
531 }
532
533 async fn get_available_models(&self) -> Result<Vec<String>> {
534 match self
536 .client
537 .get(format!("{}/models", self.base_url))
538 .header("Authorization", format!("Bearer {}", self.api_key))
539 .send()
540 .await
541 {
542 Ok(response) if response.status().is_success() => {
543 let json: Value = response.json().await?;
544 let models = json["data"]
545 .as_array()
546 .ok_or_else(|| crate::Error::generic("Invalid models response format"))?;
547 Ok(models
548 .iter()
549 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
550 .collect())
551 }
552 _ => Ok(vec![self.model.clone()]), }
554 }
555
556 async fn is_available(&self) -> bool {
557 (self.generate_completion("test", Some(1), None, None, None).await).is_ok()
558 }
559
560 fn name(&self) -> &'static str {
561 "OpenAI Compatible"
562 }
563
564 fn max_context_length(&self) -> usize {
565 4096 }
567}
568
569pub struct OpenAiCompatibleEmbeddingProvider {
571 api_key: String,
572 client: reqwest::Client,
573 base_url: String,
574 model: String,
575}
576
577impl OpenAiCompatibleEmbeddingProvider {
578 pub fn new(api_key: String, base_url: String, model: String) -> Self {
580 Self {
581 api_key,
582 client: reqwest::Client::new(),
583 base_url,
584 model,
585 }
586 }
587}
588
589#[async_trait::async_trait]
590impl EmbeddingProviderTrait for OpenAiCompatibleEmbeddingProvider {
591 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
592 let response = self
593 .client
594 .post(format!("{}/embeddings", self.base_url))
595 .header("Authorization", format!("Bearer {}", self.api_key))
596 .header("Content-Type", "application/json")
597 .json(&serde_json::json!({
598 "input": text,
599 "model": self.model
600 }))
601 .send()
602 .await?;
603
604 if !response.status().is_success() {
605 return Err(crate::Error::generic(format!("API error: {}", response.status())));
606 }
607
608 let json: Value = response.json().await?;
609 let embedding = json["data"][0]["embedding"]
610 .as_array()
611 .ok_or_else(|| crate::Error::generic("Invalid embedding response format"))?;
612
613 Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
614 }
615
616 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
617 let mut embeddings = Vec::new();
618
619 for text in texts {
620 let embedding = self.generate_embedding(&text).await?;
621 embeddings.push(embedding);
622 }
623
624 Ok(embeddings)
625 }
626
627 fn embedding_dimensions(&self) -> usize {
628 1536 }
630
631 fn max_tokens(&self) -> usize {
632 8191 }
634
635 fn name(&self) -> &'static str {
636 "OpenAI Compatible"
637 }
638
639 async fn is_available(&self) -> bool {
640 (self.generate_embedding("test").await).is_ok()
641 }
642}
643
644pub struct ProviderFactory;
646
647impl ProviderFactory {
648 pub fn create_llm_provider(
650 provider_type: LlmProvider,
651 api_key: String,
652 base_url: Option<String>,
653 model: String,
654 ) -> Result<Box<dyn LlmProviderTrait>> {
655 match provider_type {
656 LlmProvider::OpenAI => {
657 let base_url = base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
658 Ok(Box::new(OpenAiProvider::new_with_base_url(api_key, base_url)))
659 }
660 LlmProvider::OpenAICompatible => {
661 let base_url = base_url.ok_or_else(|| {
662 crate::Error::generic("Base URL required for OpenAI compatible provider")
663 })?;
664 Ok(Box::new(OpenAiCompatibleProvider::new(api_key, base_url, model)))
665 }
666 _ => Err(crate::Error::generic(format!(
667 "Provider type {:?} not yet implemented",
668 provider_type
669 ))),
670 }
671 }
672
673 pub fn create_embedding_provider(
675 provider_type: EmbeddingProvider,
676 api_key: String,
677 base_url: Option<String>,
678 model: String,
679 ) -> Result<Box<dyn EmbeddingProviderTrait>> {
680 match provider_type {
681 EmbeddingProvider::OpenAI => {
682 Ok(Box::new(OpenAiEmbeddingProvider::new_with_model(api_key, model)))
683 }
684 EmbeddingProvider::OpenAICompatible => {
685 let base_url = base_url.ok_or_else(|| {
686 crate::Error::generic(
687 "Base URL required for OpenAI compatible embedding provider",
688 )
689 })?;
690 Ok(Box::new(OpenAiCompatibleEmbeddingProvider::new(api_key, base_url, model)))
691 }
692 EmbeddingProvider::Ollama => {
693 let base_url = base_url.ok_or_else(|| {
695 crate::Error::generic("Base URL required for Ollama embedding provider")
696 })?;
697 Ok(Box::new(OpenAiCompatibleEmbeddingProvider::new(String::new(), base_url, model)))
699 }
700 }
701 }
702}
703
704#[cfg(test)]
705mod tests {
706
707 #[test]
708 fn test_module_compiles() {
709 }
711}