1use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use super::config::BehaviorModelConfig;
10use super::types::LlmGenerationRequest;
11use mockforge_foundation::Result;
12
13pub struct LlmClient {
15 rag_engine: Arc<RwLock<Option<Box<dyn LlmProvider>>>>,
17 config: BehaviorModelConfig,
19}
20
21impl LlmClient {
22 pub fn new(config: BehaviorModelConfig) -> Self {
24 Self {
25 rag_engine: Arc::new(RwLock::new(None)),
26 config,
27 }
28 }
29
30 async fn ensure_initialized(&self) -> Result<()> {
32 let mut engine = self.rag_engine.write().await;
33
34 if engine.is_none() {
35 let provider = self.create_provider()?;
37 *engine = Some(provider);
38 }
39
40 Ok(())
41 }
42
43 fn create_provider(&self) -> Result<Box<dyn LlmProvider>> {
45 match self.config.llm_provider.to_lowercase().as_str() {
46 "openai" => Ok(Box::new(OpenAIProvider::new(&self.config)?)),
47 "anthropic" => Ok(Box::new(AnthropicProvider::new(&self.config)?)),
48 "ollama" => Ok(Box::new(OllamaProvider::new(&self.config)?)),
49 "openai-compatible" => Ok(Box::new(OpenAICompatibleProvider::new(&self.config)?)),
50 _ => Err(mockforge_foundation::Error::internal(format!(
51 "Unsupported LLM provider: {}",
52 self.config.llm_provider
53 ))),
54 }
55 }
56
57 pub async fn generate(&self, request: &LlmGenerationRequest) -> Result<serde_json::Value> {
59 self.ensure_initialized().await?;
60
61 let engine = self.rag_engine.read().await;
62 let provider = engine
63 .as_ref()
64 .ok_or_else(|| mockforge_foundation::Error::internal("LLM provider not initialized"))?;
65
66 let messages = vec![
68 ChatMessage {
69 role: "system".to_string(),
70 content: request.system_prompt.clone(),
71 },
72 ChatMessage {
73 role: "user".to_string(),
74 content: request.user_prompt.clone(),
75 },
76 ];
77
78 let response_text = provider
80 .generate_chat(messages, request.temperature, request.max_tokens)
81 .await?;
82
83 match serde_json::from_str::<serde_json::Value>(&response_text) {
85 Ok(json) => Ok(json),
86 Err(_) => {
87 if let Some(start) = response_text.find('{') {
89 if let Some(end) = response_text.rfind('}') {
90 let json_str = &response_text[start..=end];
91 if let Ok(json) = serde_json::from_str::<serde_json::Value>(json_str) {
92 return Ok(json);
93 }
94 }
95 }
96
97 Ok(serde_json::json!({
99 "response": response_text,
100 "note": "Response was not valid JSON, wrapped in object"
101 }))
102 }
103 }
104 }
105
106 pub async fn generate_with_usage(
108 &self,
109 request: &LlmGenerationRequest,
110 ) -> Result<(serde_json::Value, LlmUsage)> {
111 self.ensure_initialized().await?;
112
113 let engine = self.rag_engine.read().await;
114 let provider = engine
115 .as_ref()
116 .ok_or_else(|| mockforge_foundation::Error::internal("LLM provider not initialized"))?;
117
118 let messages = vec![
120 ChatMessage {
121 role: "system".to_string(),
122 content: request.system_prompt.clone(),
123 },
124 ChatMessage {
125 role: "user".to_string(),
126 content: request.user_prompt.clone(),
127 },
128 ];
129
130 let (response_text, usage) = provider
132 .generate_chat_with_usage(messages, request.temperature, request.max_tokens)
133 .await?;
134
135 let json_value = match serde_json::from_str::<serde_json::Value>(&response_text) {
137 Ok(json) => json,
138 Err(_) => {
139 if let Some(start) = response_text.find('{') {
141 if let Some(end) = response_text.rfind('}') {
142 let json_str = &response_text[start..=end];
143 if let Ok(json) = serde_json::from_str::<serde_json::Value>(json_str) {
144 json
145 } else {
146 serde_json::json!({
147 "response": response_text,
148 "note": "Response was not valid JSON, wrapped in object"
149 })
150 }
151 } else {
152 serde_json::json!({
153 "response": response_text,
154 "note": "Response was not valid JSON, wrapped in object"
155 })
156 }
157 } else {
158 serde_json::json!({
159 "response": response_text,
160 "note": "Response was not valid JSON, wrapped in object"
161 })
162 }
163 }
164 };
165
166 Ok((json_value, usage))
167 }
168
169 pub fn config(&self) -> &BehaviorModelConfig {
171 &self.config
172 }
173}
174
175#[derive(Debug, Clone)]
177struct ChatMessage {
178 role: String,
179 content: String,
180}
181
182#[derive(Debug, Clone, Default)]
184pub struct LlmUsage {
185 pub prompt_tokens: u64,
187 pub completion_tokens: u64,
189 pub total_tokens: u64,
191}
192
193impl LlmUsage {
194 pub fn new(prompt_tokens: u64, completion_tokens: u64) -> Self {
196 Self {
197 prompt_tokens,
198 completion_tokens,
199 total_tokens: prompt_tokens + completion_tokens,
200 }
201 }
202}
203
204#[async_trait::async_trait]
206trait LlmProvider: Send + Sync {
207 async fn generate_chat(
209 &self,
210 messages: Vec<ChatMessage>,
211 temperature: f64,
212 max_tokens: usize,
213 ) -> Result<String>;
214
215 async fn generate_chat_with_usage(
217 &self,
218 messages: Vec<ChatMessage>,
219 temperature: f64,
220 max_tokens: usize,
221 ) -> Result<(String, LlmUsage)> {
222 let response = self.generate_chat(messages, temperature, max_tokens).await?;
224 let estimated_tokens = (response.len() as f64 / 4.0) as u64;
226 Ok((response, LlmUsage::new(estimated_tokens, estimated_tokens)))
227 }
228}
229
230struct OpenAIProvider {
232 client: reqwest::Client,
233 api_key: String,
234 model: String,
235 endpoint: String,
236}
237
238impl OpenAIProvider {
239 fn new(config: &BehaviorModelConfig) -> Result<Self> {
240 let api_key = config
241 .api_key
242 .clone()
243 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
244 .ok_or_else(|| mockforge_foundation::Error::internal("OpenAI API key not found"))?;
245
246 let endpoint = config
247 .api_endpoint
248 .clone()
249 .unwrap_or_else(|| "https://api.openai.com/v1/chat/completions".to_string());
250
251 Ok(Self {
252 client: reqwest::Client::new(),
253 api_key,
254 model: config.model.clone(),
255 endpoint,
256 })
257 }
258}
259
260#[async_trait::async_trait]
261impl LlmProvider for OpenAIProvider {
262 async fn generate_chat(
263 &self,
264 messages: Vec<ChatMessage>,
265 temperature: f64,
266 max_tokens: usize,
267 ) -> Result<String> {
268 let request_body = serde_json::json!({
269 "model": self.model,
270 "messages": messages.iter().map(|m| {
271 serde_json::json!({
272 "role": m.role,
273 "content": m.content
274 })
275 }).collect::<Vec<_>>(),
276 "temperature": temperature,
277 "max_tokens": max_tokens,
278 });
279
280 let response = self
281 .client
282 .post(&self.endpoint)
283 .header("Authorization", format!("Bearer {}", self.api_key))
284 .header("Content-Type", "application/json")
285 .json(&request_body)
286 .send()
287 .await
288 .map_err(|e| {
289 mockforge_foundation::Error::internal(format!("OpenAI API request failed: {}", e))
290 })?;
291
292 if !response.status().is_success() {
293 let error_text = response.text().await.unwrap_or_default();
294 return Err(mockforge_foundation::Error::internal(format!(
295 "OpenAI API error: {}",
296 error_text
297 )));
298 }
299
300 let response_json: serde_json::Value = response.json().await.map_err(|e| {
301 mockforge_foundation::Error::internal(format!("Failed to parse OpenAI response: {}", e))
302 })?;
303
304 let content = response_json["choices"][0]["message"]["content"]
306 .as_str()
307 .ok_or_else(|| mockforge_foundation::Error::internal("Invalid OpenAI response format"))?
308 .to_string();
309
310 Ok(content)
311 }
312
313 async fn generate_chat_with_usage(
314 &self,
315 messages: Vec<ChatMessage>,
316 temperature: f64,
317 max_tokens: usize,
318 ) -> Result<(String, LlmUsage)> {
319 let request_body = serde_json::json!({
320 "model": self.model,
321 "messages": messages.iter().map(|m| {
322 serde_json::json!({
323 "role": m.role,
324 "content": m.content
325 })
326 }).collect::<Vec<_>>(),
327 "temperature": temperature,
328 "max_tokens": max_tokens,
329 });
330
331 let response = self
332 .client
333 .post(&self.endpoint)
334 .header("Authorization", format!("Bearer {}", self.api_key))
335 .header("Content-Type", "application/json")
336 .json(&request_body)
337 .send()
338 .await
339 .map_err(|e| {
340 mockforge_foundation::Error::internal(format!("OpenAI API request failed: {}", e))
341 })?;
342
343 if !response.status().is_success() {
344 let error_text = response.text().await.unwrap_or_default();
345 return Err(mockforge_foundation::Error::internal(format!(
346 "OpenAI API error: {}",
347 error_text
348 )));
349 }
350
351 let response_json: serde_json::Value = response.json().await.map_err(|e| {
352 mockforge_foundation::Error::internal(format!("Failed to parse OpenAI response: {}", e))
353 })?;
354
355 let content = response_json["choices"][0]["message"]["content"]
357 .as_str()
358 .ok_or_else(|| mockforge_foundation::Error::internal("Invalid OpenAI response format"))?
359 .to_string();
360
361 let usage = if let Some(usage_obj) = response_json.get("usage") {
363 LlmUsage::new(
364 usage_obj["prompt_tokens"].as_u64().unwrap_or(0),
365 usage_obj["completion_tokens"].as_u64().unwrap_or(0),
366 )
367 } else {
368 let estimated = (content.len() as f64 / 4.0) as u64;
370 LlmUsage::new(estimated, estimated)
371 };
372
373 Ok((content, usage))
374 }
375}
376
377struct OllamaProvider {
379 client: reqwest::Client,
380 model: String,
381 endpoint: String,
382}
383
384impl OllamaProvider {
385 fn new(config: &BehaviorModelConfig) -> Result<Self> {
386 let endpoint = config
387 .api_endpoint
388 .clone()
389 .unwrap_or_else(|| "http://localhost:11434/api/chat".to_string());
390
391 Ok(Self {
392 client: reqwest::Client::new(),
393 model: config.model.clone(),
394 endpoint,
395 })
396 }
397}
398
399#[async_trait::async_trait]
400impl LlmProvider for OllamaProvider {
401 async fn generate_chat(
402 &self,
403 messages: Vec<ChatMessage>,
404 temperature: f64,
405 max_tokens: usize,
406 ) -> Result<String> {
407 let request_body = serde_json::json!({
408 "model": self.model,
409 "messages": messages.iter().map(|m| {
410 serde_json::json!({
411 "role": m.role,
412 "content": m.content
413 })
414 }).collect::<Vec<_>>(),
415 "options": {
416 "temperature": temperature,
417 "num_predict": max_tokens,
418 },
419 "stream": false,
420 });
421
422 let response = self
423 .client
424 .post(&self.endpoint)
425 .header("Content-Type", "application/json")
426 .json(&request_body)
427 .send()
428 .await
429 .map_err(|e| {
430 mockforge_foundation::Error::internal(format!("Ollama API request failed: {}", e))
431 })?;
432
433 if !response.status().is_success() {
434 let error_text = response.text().await.unwrap_or_default();
435 return Err(mockforge_foundation::Error::internal(format!(
436 "Ollama API error: {}",
437 error_text
438 )));
439 }
440
441 let response_json: serde_json::Value = response.json().await.map_err(|e| {
442 mockforge_foundation::Error::internal(format!("Failed to parse Ollama response: {}", e))
443 })?;
444
445 let content = response_json["message"]["content"]
447 .as_str()
448 .ok_or_else(|| mockforge_foundation::Error::internal("Invalid Ollama response format"))?
449 .to_string();
450
451 Ok(content)
452 }
453}
454
455struct AnthropicProvider {
457 client: reqwest::Client,
458 api_key: String,
459 model: String,
460 endpoint: String,
461}
462
463impl AnthropicProvider {
464 fn new(config: &BehaviorModelConfig) -> Result<Self> {
465 let api_key = config
466 .api_key
467 .clone()
468 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
469 .ok_or_else(|| mockforge_foundation::Error::internal("Anthropic API key not found"))?;
470
471 let endpoint = config
472 .api_endpoint
473 .clone()
474 .unwrap_or_else(|| "https://api.anthropic.com/v1/messages".to_string());
475
476 Ok(Self {
477 client: reqwest::Client::new(),
478 api_key,
479 model: config.model.clone(),
480 endpoint,
481 })
482 }
483}
484
485#[async_trait::async_trait]
486impl LlmProvider for AnthropicProvider {
487 async fn generate_chat(
488 &self,
489 messages: Vec<ChatMessage>,
490 temperature: f64,
491 max_tokens: usize,
492 ) -> Result<String> {
493 let system_message =
495 messages.iter().find(|m| m.role == "system").map(|m| m.content.clone());
496
497 let chat_messages: Vec<_> = messages
498 .iter()
499 .filter(|m| m.role != "system")
500 .map(|m| {
501 serde_json::json!({
502 "role": m.role,
503 "content": m.content
504 })
505 })
506 .collect();
507
508 let mut request_body = serde_json::json!({
509 "model": self.model,
510 "messages": chat_messages,
511 "temperature": temperature,
512 "max_tokens": max_tokens,
513 });
514
515 if let Some(system) = system_message {
516 request_body["system"] = serde_json::Value::String(system);
517 }
518
519 let response = self
520 .client
521 .post(&self.endpoint)
522 .header("x-api-key", &self.api_key)
523 .header("anthropic-version", "2023-06-01")
524 .header("Content-Type", "application/json")
525 .json(&request_body)
526 .send()
527 .await
528 .map_err(|e| {
529 mockforge_foundation::Error::internal(format!(
530 "Anthropic API request failed: {}",
531 e
532 ))
533 })?;
534
535 if !response.status().is_success() {
536 let error_text = response.text().await.unwrap_or_default();
537 return Err(mockforge_foundation::Error::internal(format!(
538 "Anthropic API error: {}",
539 error_text
540 )));
541 }
542
543 let response_json: serde_json::Value = response.json().await.map_err(|e| {
544 mockforge_foundation::Error::internal(format!(
545 "Failed to parse Anthropic response: {}",
546 e
547 ))
548 })?;
549
550 let content = response_json["content"][0]["text"]
552 .as_str()
553 .ok_or_else(|| {
554 mockforge_foundation::Error::internal("Invalid Anthropic response format")
555 })?
556 .to_string();
557
558 Ok(content)
559 }
560}
561
562struct OpenAICompatibleProvider {
564 client: reqwest::Client,
565 api_key: Option<String>,
566 model: String,
567 endpoint: String,
568}
569
570impl OpenAICompatibleProvider {
571 fn new(config: &BehaviorModelConfig) -> Result<Self> {
572 let endpoint = config.api_endpoint.clone().ok_or_else(|| {
573 mockforge_foundation::Error::internal(
574 "API endpoint required for OpenAI-compatible provider",
575 )
576 })?;
577
578 Ok(Self {
579 client: reqwest::Client::new(),
580 api_key: config.api_key.clone(),
581 model: config.model.clone(),
582 endpoint,
583 })
584 }
585}
586
587#[async_trait::async_trait]
588impl LlmProvider for OpenAICompatibleProvider {
589 async fn generate_chat(
590 &self,
591 messages: Vec<ChatMessage>,
592 temperature: f64,
593 max_tokens: usize,
594 ) -> Result<String> {
595 let request_body = serde_json::json!({
596 "model": self.model,
597 "messages": messages.iter().map(|m| {
598 serde_json::json!({
599 "role": m.role,
600 "content": m.content
601 })
602 }).collect::<Vec<_>>(),
603 "temperature": temperature,
604 "max_tokens": max_tokens,
605 });
606
607 let mut request =
608 self.client.post(&self.endpoint).header("Content-Type", "application/json");
609
610 if let Some(api_key) = &self.api_key {
611 request = request.header("Authorization", format!("Bearer {}", api_key));
612 }
613
614 let response = request.json(&request_body).send().await.map_err(|e| {
615 mockforge_foundation::Error::internal(format!("API request failed: {}", e))
616 })?;
617
618 if !response.status().is_success() {
619 let error_text = response.text().await.unwrap_or_default();
620 return Err(mockforge_foundation::Error::internal(format!(
621 "API error: {}",
622 error_text
623 )));
624 }
625
626 let response_json: serde_json::Value = response.json().await.map_err(|e| {
627 mockforge_foundation::Error::internal(format!("Failed to parse API response: {}", e))
628 })?;
629
630 let content = response_json["choices"][0]["message"]["content"]
632 .as_str()
633 .or_else(|| response_json["message"]["content"].as_str())
634 .ok_or_else(|| mockforge_foundation::Error::internal("Invalid API response format"))?
635 .to_string();
636
637 Ok(content)
638 }
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644
645 #[test]
646 fn test_llm_client_creation() {
647 let config = BehaviorModelConfig::default();
648 let client = LlmClient::new(config);
649 assert_eq!(client.config().llm_provider, "openai");
650 }
651}