1use crate::Result;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "lowercase")]
14pub enum LlmProvider {
15 OpenAI,
17 Anthropic,
19 OpenAICompatible,
21 Ollama,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum EmbeddingProvider {
29 OpenAI,
31 OpenAICompatible,
33}
34
35#[async_trait::async_trait]
37pub trait LlmProviderTrait: Send + Sync {
38 async fn generate_completion(
40 &self,
41 prompt: &str,
42 max_tokens: Option<usize>,
43 temperature: Option<f32>,
44 top_p: Option<f32>,
45 stop_sequences: Option<Vec<String>>,
46 ) -> Result<String>;
47
48 async fn generate_chat_completion(
50 &self,
51 messages: Vec<ChatMessage>,
52 max_tokens: Option<usize>,
53 temperature: Option<f32>,
54 top_p: Option<f32>,
55 stop_sequences: Option<Vec<String>>,
56 ) -> Result<String>;
57
58 async fn get_available_models(&self) -> Result<Vec<String>>;
60
61 async fn is_available(&self) -> bool;
63
64 fn name(&self) -> &'static str;
66
67 fn max_context_length(&self) -> usize;
69}
70
71#[async_trait::async_trait]
73pub trait EmbeddingProviderTrait: Send + Sync {
74 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>>;
76
77 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
79
80 fn embedding_dimensions(&self) -> usize;
82
83 fn max_tokens(&self) -> usize;
85
86 fn name(&self) -> &'static str;
88
89 async fn is_available(&self) -> bool;
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ChatMessage {
96 pub role: ChatRole,
98 pub content: String,
100 pub metadata: Option<HashMap<String, String>>,
102}
103
104impl ChatMessage {
105 pub fn system(content: String) -> Self {
107 Self {
108 role: ChatRole::System,
109 content,
110 metadata: None,
111 }
112 }
113
114 pub fn user(content: String) -> Self {
116 Self {
117 role: ChatRole::User,
118 content,
119 metadata: None,
120 }
121 }
122
123 pub fn assistant(content: String) -> Self {
125 Self {
126 role: ChatRole::Assistant,
127 content,
128 metadata: None,
129 }
130 }
131
132 pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
134 self.metadata = Some(metadata);
135 self
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141#[serde(rename_all = "lowercase")]
142pub enum ChatRole {
143 System,
145 User,
147 Assistant,
149}
150
151pub struct OpenAiProvider {
153 api_key: String,
154 client: reqwest::Client,
155 base_url: String,
156}
157
158impl OpenAiProvider {
159 pub fn new(api_key: String) -> Self {
161 Self {
162 api_key,
163 client: reqwest::Client::new(),
164 base_url: "https://api.openai.com/v1".to_string(),
165 }
166 }
167
168 pub fn new_with_base_url(api_key: String, base_url: String) -> Self {
170 Self {
171 api_key,
172 client: reqwest::Client::new(),
173 base_url,
174 }
175 }
176}
177
178#[async_trait::async_trait]
179impl LlmProviderTrait for OpenAiProvider {
180 async fn generate_completion(
181 &self,
182 prompt: &str,
183 max_tokens: Option<usize>,
184 temperature: Option<f32>,
185 top_p: Option<f32>,
186 stop_sequences: Option<Vec<String>>,
187 ) -> Result<String> {
188 let mut request_body = serde_json::json!({
189 "model": "gpt-3.5-turbo-instruct",
190 "prompt": prompt,
191 "max_tokens": max_tokens.unwrap_or(1024),
192 "temperature": temperature.unwrap_or(0.7),
193 });
194
195 if let Some(top_p) = top_p {
196 request_body["top_p"] = serde_json::json!(top_p);
197 }
198
199 if let Some(stop) = stop_sequences {
200 request_body["stop"] = serde_json::json!(stop);
201 }
202
203 let response = self
204 .client
205 .post(format!("{}/completions", self.base_url))
206 .header("Authorization", format!("Bearer {}", self.api_key))
207 .header("Content-Type", "application/json")
208 .json(&request_body)
209 .send()
210 .await?;
211
212 if !response.status().is_success() {
213 return Err(mockforge_core::Error::generic(format!(
214 "OpenAI API error: {}",
215 response.status()
216 )));
217 }
218
219 let json: Value = response.json().await?;
220 let content = json["choices"][0]["text"]
221 .as_str()
222 .ok_or_else(|| mockforge_core::Error::generic("Invalid response format"))?;
223
224 Ok(content.to_string())
225 }
226
227 async fn generate_chat_completion(
228 &self,
229 messages: Vec<ChatMessage>,
230 max_tokens: Option<usize>,
231 temperature: Option<f32>,
232 top_p: Option<f32>,
233 stop_sequences: Option<Vec<String>>,
234 ) -> Result<String> {
235 let openai_messages: Vec<Value> = messages
236 .iter()
237 .map(|msg| {
238 serde_json::json!({
239 "role": format!("{:?}", msg.role).to_lowercase(),
240 "content": msg.content
241 })
242 })
243 .collect();
244
245 let mut request_body = serde_json::json!({
246 "model": "gpt-3.5-turbo",
247 "messages": openai_messages,
248 "max_tokens": max_tokens.unwrap_or(1024),
249 "temperature": temperature.unwrap_or(0.7),
250 });
251
252 if let Some(top_p) = top_p {
253 request_body["top_p"] = serde_json::json!(top_p);
254 }
255
256 if let Some(stop) = stop_sequences {
257 request_body["stop"] = serde_json::json!(stop);
258 }
259
260 let response = self
261 .client
262 .post(format!("{}/chat/completions", self.base_url))
263 .header("Authorization", format!("Bearer {}", self.api_key))
264 .header("Content-Type", "application/json")
265 .json(&request_body)
266 .send()
267 .await?;
268
269 if !response.status().is_success() {
270 return Err(mockforge_core::Error::generic(format!(
271 "OpenAI API error: {}",
272 response.status()
273 )));
274 }
275
276 let json: Value = response.json().await?;
277 let content = json["choices"][0]["message"]["content"]
278 .as_str()
279 .ok_or_else(|| mockforge_core::Error::generic("Invalid response format"))?;
280
281 Ok(content.to_string())
282 }
283
284 async fn get_available_models(&self) -> Result<Vec<String>> {
285 let response = self
286 .client
287 .get(format!("{}/models", self.base_url))
288 .header("Authorization", format!("Bearer {}", self.api_key))
289 .send()
290 .await?;
291
292 if !response.status().is_success() {
293 return Err(mockforge_core::Error::generic(format!(
294 "OpenAI API error: {}",
295 response.status()
296 )));
297 }
298
299 let json: Value = response.json().await?;
300 let models = json["data"]
301 .as_array()
302 .ok_or_else(|| mockforge_core::Error::generic("Invalid models response format"))?;
303
304 let model_names = models
305 .iter()
306 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
307 .collect();
308
309 Ok(model_names)
310 }
311
312 async fn is_available(&self) -> bool {
313 (self.get_available_models().await).is_ok()
314 }
315
316 fn name(&self) -> &'static str {
317 "OpenAI"
318 }
319
320 fn max_context_length(&self) -> usize {
321 4096 }
323}
324
325pub struct OpenAiEmbeddingProvider {
327 api_key: String,
328 client: reqwest::Client,
329 base_url: String,
330 model: String,
331}
332
333impl OpenAiEmbeddingProvider {
334 pub fn new(api_key: String) -> Self {
336 Self {
337 api_key,
338 client: reqwest::Client::new(),
339 base_url: "https://api.openai.com/v1".to_string(),
340 model: "text-embedding-ada-002".to_string(),
341 }
342 }
343
344 pub fn new_with_model(api_key: String, model: String) -> Self {
346 Self {
347 api_key,
348 client: reqwest::Client::new(),
349 base_url: "https://api.openai.com/v1".to_string(),
350 model,
351 }
352 }
353}
354
355#[async_trait::async_trait]
356impl EmbeddingProviderTrait for OpenAiEmbeddingProvider {
357 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
358 let response = self
359 .client
360 .post(format!("{}/embeddings", self.base_url))
361 .header("Authorization", format!("Bearer {}", self.api_key))
362 .header("Content-Type", "application/json")
363 .json(&serde_json::json!({
364 "input": text,
365 "model": self.model
366 }))
367 .send()
368 .await?;
369
370 if !response.status().is_success() {
371 return Err(mockforge_core::Error::generic(format!(
372 "OpenAI API error: {}",
373 response.status()
374 )));
375 }
376
377 let json: Value = response.json().await?;
378 let embedding = json["data"][0]["embedding"]
379 .as_array()
380 .ok_or_else(|| mockforge_core::Error::generic("Invalid embedding response format"))?;
381
382 Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
383 }
384
385 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
386 let mut embeddings = Vec::new();
387
388 for text in texts {
389 let embedding = self.generate_embedding(&text).await?;
390 embeddings.push(embedding);
391 }
392
393 Ok(embeddings)
394 }
395
396 fn embedding_dimensions(&self) -> usize {
397 match self.model.as_str() {
398 "text-embedding-ada-002" => 1536,
399 "text-embedding-3-small" => 1536,
400 "text-embedding-3-large" => 3072,
401 _ => 1536, }
403 }
404
405 fn max_tokens(&self) -> usize {
406 match self.model.as_str() {
407 "text-embedding-ada-002" => 8191,
408 "text-embedding-3-small" => 8191,
409 "text-embedding-3-large" => 8191,
410 _ => 8191, }
412 }
413
414 fn name(&self) -> &'static str {
415 "OpenAI"
416 }
417
418 async fn is_available(&self) -> bool {
419 (self.generate_embedding("test").await).is_ok()
420 }
421}
422
423pub struct OpenAiCompatibleProvider {
425 api_key: String,
426 client: reqwest::Client,
427 base_url: String,
428 model: String,
429}
430
431impl OpenAiCompatibleProvider {
432 pub fn new(api_key: String, base_url: String, model: String) -> Self {
434 Self {
435 api_key,
436 client: reqwest::Client::new(),
437 base_url,
438 model,
439 }
440 }
441}
442
443#[async_trait::async_trait]
444impl LlmProviderTrait for OpenAiCompatibleProvider {
445 async fn generate_completion(
446 &self,
447 prompt: &str,
448 max_tokens: Option<usize>,
449 temperature: Option<f32>,
450 top_p: Option<f32>,
451 stop_sequences: Option<Vec<String>>,
452 ) -> Result<String> {
453 let mut request_body = serde_json::json!({
454 "model": self.model,
455 "prompt": prompt,
456 "max_tokens": max_tokens.unwrap_or(1024),
457 "temperature": temperature.unwrap_or(0.7),
458 });
459
460 if let Some(top_p) = top_p {
461 request_body["top_p"] = serde_json::json!(top_p);
462 }
463
464 if let Some(stop) = stop_sequences {
465 request_body["stop"] = serde_json::json!(stop);
466 }
467
468 let response = self
469 .client
470 .post(format!("{}/completions", self.base_url))
471 .header("Authorization", format!("Bearer {}", self.api_key))
472 .header("Content-Type", "application/json")
473 .json(&request_body)
474 .send()
475 .await?;
476
477 if !response.status().is_success() {
478 return Err(mockforge_core::Error::generic(format!(
479 "API error: {}",
480 response.status()
481 )));
482 }
483
484 let json: Value = response.json().await?;
485 let content = json["choices"][0]["text"]
486 .as_str()
487 .ok_or_else(|| mockforge_core::Error::generic("Invalid response format"))?;
488
489 Ok(content.to_string())
490 }
491
492 async fn generate_chat_completion(
493 &self,
494 messages: Vec<ChatMessage>,
495 max_tokens: Option<usize>,
496 temperature: Option<f32>,
497 top_p: Option<f32>,
498 stop_sequences: Option<Vec<String>>,
499 ) -> Result<String> {
500 let openai_messages: Vec<Value> = messages
501 .iter()
502 .map(|msg| {
503 serde_json::json!({
504 "role": format!("{:?}", msg.role).to_lowercase(),
505 "content": msg.content
506 })
507 })
508 .collect();
509
510 let mut request_body = serde_json::json!({
511 "model": self.model,
512 "messages": openai_messages,
513 "max_tokens": max_tokens.unwrap_or(1024),
514 "temperature": temperature.unwrap_or(0.7),
515 });
516
517 if let Some(top_p) = top_p {
518 request_body["top_p"] = serde_json::json!(top_p);
519 }
520
521 if let Some(stop) = stop_sequences {
522 request_body["stop"] = serde_json::json!(stop);
523 }
524
525 let response = self
526 .client
527 .post(format!("{}/chat/completions", self.base_url))
528 .header("Authorization", format!("Bearer {}", self.api_key))
529 .header("Content-Type", "application/json")
530 .json(&request_body)
531 .send()
532 .await?;
533
534 if !response.status().is_success() {
535 return Err(mockforge_core::Error::generic(format!(
536 "API error: {}",
537 response.status()
538 )));
539 }
540
541 let json: Value = response.json().await?;
542 let content = json["choices"][0]["message"]["content"]
543 .as_str()
544 .ok_or_else(|| mockforge_core::Error::generic("Invalid response format"))?;
545
546 Ok(content.to_string())
547 }
548
549 async fn get_available_models(&self) -> Result<Vec<String>> {
550 match self
552 .client
553 .get(format!("{}/models", self.base_url))
554 .header("Authorization", format!("Bearer {}", self.api_key))
555 .send()
556 .await
557 {
558 Ok(response) if response.status().is_success() => {
559 let json: Value = response.json().await?;
560 let models = json["data"].as_array().ok_or_else(|| {
561 mockforge_core::Error::generic("Invalid models response format")
562 })?;
563 Ok(models
564 .iter()
565 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
566 .collect())
567 }
568 _ => Ok(vec![self.model.clone()]), }
570 }
571
572 async fn is_available(&self) -> bool {
573 (self.generate_completion("test", Some(1), None, None, None).await).is_ok()
574 }
575
576 fn name(&self) -> &'static str {
577 "OpenAI Compatible"
578 }
579
580 fn max_context_length(&self) -> usize {
581 4096 }
583}
584
585pub struct OpenAiCompatibleEmbeddingProvider {
587 api_key: String,
588 client: reqwest::Client,
589 base_url: String,
590 model: String,
591}
592
593impl OpenAiCompatibleEmbeddingProvider {
594 pub fn new(api_key: String, base_url: String, model: String) -> Self {
596 Self {
597 api_key,
598 client: reqwest::Client::new(),
599 base_url,
600 model,
601 }
602 }
603}
604
605#[async_trait::async_trait]
606impl EmbeddingProviderTrait for OpenAiCompatibleEmbeddingProvider {
607 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
608 let response = self
609 .client
610 .post(format!("{}/embeddings", self.base_url))
611 .header("Authorization", format!("Bearer {}", self.api_key))
612 .header("Content-Type", "application/json")
613 .json(&serde_json::json!({
614 "input": text,
615 "model": self.model
616 }))
617 .send()
618 .await?;
619
620 if !response.status().is_success() {
621 return Err(mockforge_core::Error::generic(format!(
622 "API error: {}",
623 response.status()
624 )));
625 }
626
627 let json: Value = response.json().await?;
628 let embedding = json["data"][0]["embedding"]
629 .as_array()
630 .ok_or_else(|| mockforge_core::Error::generic("Invalid embedding response format"))?;
631
632 Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
633 }
634
635 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
636 let mut embeddings = Vec::new();
637
638 for text in texts {
639 let embedding = self.generate_embedding(&text).await?;
640 embeddings.push(embedding);
641 }
642
643 Ok(embeddings)
644 }
645
646 fn embedding_dimensions(&self) -> usize {
647 1536 }
649
650 fn max_tokens(&self) -> usize {
651 8191 }
653
654 fn name(&self) -> &'static str {
655 "OpenAI Compatible"
656 }
657
658 async fn is_available(&self) -> bool {
659 (self.generate_embedding("test").await).is_ok()
660 }
661}
662
663pub struct ProviderFactory;
665
666impl ProviderFactory {
667 pub fn create_llm_provider(
669 provider_type: LlmProvider,
670 api_key: String,
671 base_url: Option<String>,
672 model: String,
673 ) -> Result<Box<dyn LlmProviderTrait>> {
674 match provider_type {
675 LlmProvider::OpenAI => {
676 let base_url = base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
677 Ok(Box::new(OpenAiProvider::new_with_base_url(api_key, base_url)))
678 }
679 LlmProvider::OpenAICompatible => {
680 let base_url = base_url.ok_or_else(|| {
681 mockforge_core::Error::generic(
682 "Base URL required for OpenAI compatible provider",
683 )
684 })?;
685 Ok(Box::new(OpenAiCompatibleProvider::new(api_key, base_url, model)))
686 }
687 _ => Err(mockforge_core::Error::generic(format!(
688 "Provider type {:?} not yet implemented",
689 provider_type
690 ))),
691 }
692 }
693
694 pub fn create_embedding_provider(
696 provider_type: EmbeddingProvider,
697 api_key: String,
698 base_url: Option<String>,
699 model: String,
700 ) -> Result<Box<dyn EmbeddingProviderTrait>> {
701 match provider_type {
702 EmbeddingProvider::OpenAI => {
703 Ok(Box::new(OpenAiEmbeddingProvider::new_with_model(api_key, model)))
704 }
705 EmbeddingProvider::OpenAICompatible => {
706 let base_url = base_url.ok_or_else(|| {
707 mockforge_core::Error::generic(
708 "Base URL required for OpenAI compatible embedding provider",
709 )
710 })?;
711 Ok(Box::new(OpenAiCompatibleEmbeddingProvider::new(api_key, base_url, model)))
712 }
713 }
714 }
715}
716
717#[cfg(test)]
718mod tests {
719
720 #[test]
721 fn test_module_compiles() {
722 }
724}