Skip to main content

mockforge_intelligence/intelligent_behavior/
llm_client.rs

1//! LLM client wrapper for intelligent behavior
2//!
3//! This module provides a simplified interface to the RAG engine for
4//! intelligent mock behavior generation.
5
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use super::config::BehaviorModelConfig;
10use super::types::LlmGenerationRequest;
11use mockforge_foundation::Result;
12
13/// LLM client for generating intelligent responses
14pub struct LlmClient {
15    /// RAG engine (lazily initialized)
16    rag_engine: Arc<RwLock<Option<Box<dyn LlmProvider>>>>,
17    /// Configuration
18    config: BehaviorModelConfig,
19}
20
21impl LlmClient {
22    /// Create a new LLM client
23    pub fn new(config: BehaviorModelConfig) -> Self {
24        Self {
25            rag_engine: Arc::new(RwLock::new(None)),
26            config,
27        }
28    }
29
30    /// Initialize the RAG engine (lazy initialization)
31    async fn ensure_initialized(&self) -> Result<()> {
32        let mut engine = self.rag_engine.write().await;
33
34        if engine.is_none() {
35            // Create provider based on configuration
36            let provider = self.create_provider()?;
37            *engine = Some(provider);
38        }
39
40        Ok(())
41    }
42
43    /// Create LLM provider based on configuration
44    fn create_provider(&self) -> Result<Box<dyn LlmProvider>> {
45        match self.config.llm_provider.to_lowercase().as_str() {
46            "openai" => Ok(Box::new(OpenAIProvider::new(&self.config)?)),
47            "anthropic" => Ok(Box::new(AnthropicProvider::new(&self.config)?)),
48            "ollama" => Ok(Box::new(OllamaProvider::new(&self.config)?)),
49            "openai-compatible" => Ok(Box::new(OpenAICompatibleProvider::new(&self.config)?)),
50            _ => Err(mockforge_foundation::Error::internal(format!(
51                "Unsupported LLM provider: {}",
52                self.config.llm_provider
53            ))),
54        }
55    }
56
57    /// Generate a response from a prompt
58    pub async fn generate(&self, request: &LlmGenerationRequest) -> Result<serde_json::Value> {
59        self.ensure_initialized().await?;
60
61        let engine = self.rag_engine.read().await;
62        let provider = engine
63            .as_ref()
64            .ok_or_else(|| mockforge_foundation::Error::internal("LLM provider not initialized"))?;
65
66        // Build messages
67        let messages = vec![
68            ChatMessage {
69                role: "system".to_string(),
70                content: request.system_prompt.clone(),
71            },
72            ChatMessage {
73                role: "user".to_string(),
74                content: request.user_prompt.clone(),
75            },
76        ];
77
78        // Generate response
79        let response_text = provider
80            .generate_chat(messages, request.temperature, request.max_tokens)
81            .await?;
82
83        // Try to parse as JSON
84        match serde_json::from_str::<serde_json::Value>(&response_text) {
85            Ok(json) => Ok(json),
86            Err(_) => {
87                // Try to extract JSON from response
88                if let Some(start) = response_text.find('{') {
89                    if let Some(end) = response_text.rfind('}') {
90                        let json_str = &response_text[start..=end];
91                        if let Ok(json) = serde_json::from_str::<serde_json::Value>(json_str) {
92                            return Ok(json);
93                        }
94                    }
95                }
96
97                // Fallback: wrap in object
98                Ok(serde_json::json!({
99                    "response": response_text,
100                    "note": "Response was not valid JSON, wrapped in object"
101                }))
102            }
103        }
104    }
105
106    /// Generate a response and return usage information
107    pub async fn generate_with_usage(
108        &self,
109        request: &LlmGenerationRequest,
110    ) -> Result<(serde_json::Value, LlmUsage)> {
111        self.ensure_initialized().await?;
112
113        let engine = self.rag_engine.read().await;
114        let provider = engine
115            .as_ref()
116            .ok_or_else(|| mockforge_foundation::Error::internal("LLM provider not initialized"))?;
117
118        // Build messages
119        let messages = vec![
120            ChatMessage {
121                role: "system".to_string(),
122                content: request.system_prompt.clone(),
123            },
124            ChatMessage {
125                role: "user".to_string(),
126                content: request.user_prompt.clone(),
127            },
128        ];
129
130        // Generate response with usage tracking
131        let (response_text, usage) = provider
132            .generate_chat_with_usage(messages, request.temperature, request.max_tokens)
133            .await?;
134
135        // Try to parse as JSON
136        let json_value = match serde_json::from_str::<serde_json::Value>(&response_text) {
137            Ok(json) => json,
138            Err(_) => {
139                // Try to extract JSON from response
140                if let Some(start) = response_text.find('{') {
141                    if let Some(end) = response_text.rfind('}') {
142                        let json_str = &response_text[start..=end];
143                        if let Ok(json) = serde_json::from_str::<serde_json::Value>(json_str) {
144                            json
145                        } else {
146                            serde_json::json!({
147                                "response": response_text,
148                                "note": "Response was not valid JSON, wrapped in object"
149                            })
150                        }
151                    } else {
152                        serde_json::json!({
153                            "response": response_text,
154                            "note": "Response was not valid JSON, wrapped in object"
155                        })
156                    }
157                } else {
158                    serde_json::json!({
159                        "response": response_text,
160                        "note": "Response was not valid JSON, wrapped in object"
161                    })
162                }
163            }
164        };
165
166        Ok((json_value, usage))
167    }
168
169    /// Get configuration
170    pub fn config(&self) -> &BehaviorModelConfig {
171        &self.config
172    }
173}
174
175/// Chat message for LLM
176#[derive(Debug, Clone)]
177struct ChatMessage {
178    role: String,
179    content: String,
180}
181
182/// LLM usage information
183#[derive(Debug, Clone, Default)]
184pub struct LlmUsage {
185    /// Prompt tokens used
186    pub prompt_tokens: u64,
187    /// Completion tokens used
188    pub completion_tokens: u64,
189    /// Total tokens used
190    pub total_tokens: u64,
191}
192
193impl LlmUsage {
194    /// Create new usage info
195    pub fn new(prompt_tokens: u64, completion_tokens: u64) -> Self {
196        Self {
197            prompt_tokens,
198            completion_tokens,
199            total_tokens: prompt_tokens + completion_tokens,
200        }
201    }
202}
203
204/// LLM provider trait
205#[async_trait::async_trait]
206trait LlmProvider: Send + Sync {
207    /// Generate chat completion
208    async fn generate_chat(
209        &self,
210        messages: Vec<ChatMessage>,
211        temperature: f64,
212        max_tokens: usize,
213    ) -> Result<String>;
214
215    /// Generate chat completion with usage tracking
216    async fn generate_chat_with_usage(
217        &self,
218        messages: Vec<ChatMessage>,
219        temperature: f64,
220        max_tokens: usize,
221    ) -> Result<(String, LlmUsage)> {
222        // Default implementation: call generate_chat and estimate tokens
223        let response = self.generate_chat(messages, temperature, max_tokens).await?;
224        // Rough estimation: ~4 characters per token
225        let estimated_tokens = (response.len() as f64 / 4.0) as u64;
226        Ok((response, LlmUsage::new(estimated_tokens, estimated_tokens)))
227    }
228}
229
230/// OpenAI provider implementation
231struct OpenAIProvider {
232    client: reqwest::Client,
233    api_key: String,
234    model: String,
235    endpoint: String,
236}
237
238impl OpenAIProvider {
239    fn new(config: &BehaviorModelConfig) -> Result<Self> {
240        let api_key = config
241            .api_key
242            .clone()
243            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
244            .ok_or_else(|| mockforge_foundation::Error::internal("OpenAI API key not found"))?;
245
246        let endpoint = config
247            .api_endpoint
248            .clone()
249            .unwrap_or_else(|| "https://api.openai.com/v1/chat/completions".to_string());
250
251        Ok(Self {
252            client: reqwest::Client::new(),
253            api_key,
254            model: config.model.clone(),
255            endpoint,
256        })
257    }
258}
259
260#[async_trait::async_trait]
261impl LlmProvider for OpenAIProvider {
262    async fn generate_chat(
263        &self,
264        messages: Vec<ChatMessage>,
265        temperature: f64,
266        max_tokens: usize,
267    ) -> Result<String> {
268        let request_body = serde_json::json!({
269            "model": self.model,
270            "messages": messages.iter().map(|m| {
271                serde_json::json!({
272                    "role": m.role,
273                    "content": m.content
274                })
275            }).collect::<Vec<_>>(),
276            "temperature": temperature,
277            "max_tokens": max_tokens,
278        });
279
280        let response = self
281            .client
282            .post(&self.endpoint)
283            .header("Authorization", format!("Bearer {}", self.api_key))
284            .header("Content-Type", "application/json")
285            .json(&request_body)
286            .send()
287            .await
288            .map_err(|e| {
289                mockforge_foundation::Error::internal(format!("OpenAI API request failed: {}", e))
290            })?;
291
292        if !response.status().is_success() {
293            let error_text = response.text().await.unwrap_or_default();
294            return Err(mockforge_foundation::Error::internal(format!(
295                "OpenAI API error: {}",
296                error_text
297            )));
298        }
299
300        let response_json: serde_json::Value = response.json().await.map_err(|e| {
301            mockforge_foundation::Error::internal(format!("Failed to parse OpenAI response: {}", e))
302        })?;
303
304        // Extract content from response
305        let content = response_json["choices"][0]["message"]["content"]
306            .as_str()
307            .ok_or_else(|| mockforge_foundation::Error::internal("Invalid OpenAI response format"))?
308            .to_string();
309
310        Ok(content)
311    }
312
313    async fn generate_chat_with_usage(
314        &self,
315        messages: Vec<ChatMessage>,
316        temperature: f64,
317        max_tokens: usize,
318    ) -> Result<(String, LlmUsage)> {
319        let request_body = serde_json::json!({
320            "model": self.model,
321            "messages": messages.iter().map(|m| {
322                serde_json::json!({
323                    "role": m.role,
324                    "content": m.content
325                })
326            }).collect::<Vec<_>>(),
327            "temperature": temperature,
328            "max_tokens": max_tokens,
329        });
330
331        let response = self
332            .client
333            .post(&self.endpoint)
334            .header("Authorization", format!("Bearer {}", self.api_key))
335            .header("Content-Type", "application/json")
336            .json(&request_body)
337            .send()
338            .await
339            .map_err(|e| {
340                mockforge_foundation::Error::internal(format!("OpenAI API request failed: {}", e))
341            })?;
342
343        if !response.status().is_success() {
344            let error_text = response.text().await.unwrap_or_default();
345            return Err(mockforge_foundation::Error::internal(format!(
346                "OpenAI API error: {}",
347                error_text
348            )));
349        }
350
351        let response_json: serde_json::Value = response.json().await.map_err(|e| {
352            mockforge_foundation::Error::internal(format!("Failed to parse OpenAI response: {}", e))
353        })?;
354
355        // Extract content from response
356        let content = response_json["choices"][0]["message"]["content"]
357            .as_str()
358            .ok_or_else(|| mockforge_foundation::Error::internal("Invalid OpenAI response format"))?
359            .to_string();
360
361        // Extract usage information
362        let usage = if let Some(usage_obj) = response_json.get("usage") {
363            LlmUsage::new(
364                usage_obj["prompt_tokens"].as_u64().unwrap_or(0),
365                usage_obj["completion_tokens"].as_u64().unwrap_or(0),
366            )
367        } else {
368            // Fallback: estimate tokens
369            let estimated = (content.len() as f64 / 4.0) as u64;
370            LlmUsage::new(estimated, estimated)
371        };
372
373        Ok((content, usage))
374    }
375}
376
377/// Ollama provider implementation
378struct OllamaProvider {
379    client: reqwest::Client,
380    model: String,
381    endpoint: String,
382}
383
384impl OllamaProvider {
385    fn new(config: &BehaviorModelConfig) -> Result<Self> {
386        let endpoint = config
387            .api_endpoint
388            .clone()
389            .unwrap_or_else(|| "http://localhost:11434/api/chat".to_string());
390
391        Ok(Self {
392            client: reqwest::Client::new(),
393            model: config.model.clone(),
394            endpoint,
395        })
396    }
397}
398
399#[async_trait::async_trait]
400impl LlmProvider for OllamaProvider {
401    async fn generate_chat(
402        &self,
403        messages: Vec<ChatMessage>,
404        temperature: f64,
405        max_tokens: usize,
406    ) -> Result<String> {
407        let request_body = serde_json::json!({
408            "model": self.model,
409            "messages": messages.iter().map(|m| {
410                serde_json::json!({
411                    "role": m.role,
412                    "content": m.content
413                })
414            }).collect::<Vec<_>>(),
415            "options": {
416                "temperature": temperature,
417                "num_predict": max_tokens,
418            },
419            "stream": false,
420        });
421
422        let response = self
423            .client
424            .post(&self.endpoint)
425            .header("Content-Type", "application/json")
426            .json(&request_body)
427            .send()
428            .await
429            .map_err(|e| {
430                mockforge_foundation::Error::internal(format!("Ollama API request failed: {}", e))
431            })?;
432
433        if !response.status().is_success() {
434            let error_text = response.text().await.unwrap_or_default();
435            return Err(mockforge_foundation::Error::internal(format!(
436                "Ollama API error: {}",
437                error_text
438            )));
439        }
440
441        let response_json: serde_json::Value = response.json().await.map_err(|e| {
442            mockforge_foundation::Error::internal(format!("Failed to parse Ollama response: {}", e))
443        })?;
444
445        // Extract content from response
446        let content = response_json["message"]["content"]
447            .as_str()
448            .ok_or_else(|| mockforge_foundation::Error::internal("Invalid Ollama response format"))?
449            .to_string();
450
451        Ok(content)
452    }
453}
454
455/// Anthropic provider implementation
456struct AnthropicProvider {
457    client: reqwest::Client,
458    api_key: String,
459    model: String,
460    endpoint: String,
461}
462
463impl AnthropicProvider {
464    fn new(config: &BehaviorModelConfig) -> Result<Self> {
465        let api_key = config
466            .api_key
467            .clone()
468            .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
469            .ok_or_else(|| mockforge_foundation::Error::internal("Anthropic API key not found"))?;
470
471        let endpoint = config
472            .api_endpoint
473            .clone()
474            .unwrap_or_else(|| "https://api.anthropic.com/v1/messages".to_string());
475
476        Ok(Self {
477            client: reqwest::Client::new(),
478            api_key,
479            model: config.model.clone(),
480            endpoint,
481        })
482    }
483}
484
485#[async_trait::async_trait]
486impl LlmProvider for AnthropicProvider {
487    async fn generate_chat(
488        &self,
489        messages: Vec<ChatMessage>,
490        temperature: f64,
491        max_tokens: usize,
492    ) -> Result<String> {
493        // Separate system message from other messages
494        let system_message =
495            messages.iter().find(|m| m.role == "system").map(|m| m.content.clone());
496
497        let chat_messages: Vec<_> = messages
498            .iter()
499            .filter(|m| m.role != "system")
500            .map(|m| {
501                serde_json::json!({
502                    "role": m.role,
503                    "content": m.content
504                })
505            })
506            .collect();
507
508        let mut request_body = serde_json::json!({
509            "model": self.model,
510            "messages": chat_messages,
511            "temperature": temperature,
512            "max_tokens": max_tokens,
513        });
514
515        if let Some(system) = system_message {
516            request_body["system"] = serde_json::Value::String(system);
517        }
518
519        let response = self
520            .client
521            .post(&self.endpoint)
522            .header("x-api-key", &self.api_key)
523            .header("anthropic-version", "2023-06-01")
524            .header("Content-Type", "application/json")
525            .json(&request_body)
526            .send()
527            .await
528            .map_err(|e| {
529                mockforge_foundation::Error::internal(format!(
530                    "Anthropic API request failed: {}",
531                    e
532                ))
533            })?;
534
535        if !response.status().is_success() {
536            let error_text = response.text().await.unwrap_or_default();
537            return Err(mockforge_foundation::Error::internal(format!(
538                "Anthropic API error: {}",
539                error_text
540            )));
541        }
542
543        let response_json: serde_json::Value = response.json().await.map_err(|e| {
544            mockforge_foundation::Error::internal(format!(
545                "Failed to parse Anthropic response: {}",
546                e
547            ))
548        })?;
549
550        // Extract content from response
551        let content = response_json["content"][0]["text"]
552            .as_str()
553            .ok_or_else(|| {
554                mockforge_foundation::Error::internal("Invalid Anthropic response format")
555            })?
556            .to_string();
557
558        Ok(content)
559    }
560}
561
562/// OpenAI-compatible provider (generic)
563struct OpenAICompatibleProvider {
564    client: reqwest::Client,
565    api_key: Option<String>,
566    model: String,
567    endpoint: String,
568}
569
570impl OpenAICompatibleProvider {
571    fn new(config: &BehaviorModelConfig) -> Result<Self> {
572        let endpoint = config.api_endpoint.clone().ok_or_else(|| {
573            mockforge_foundation::Error::internal(
574                "API endpoint required for OpenAI-compatible provider",
575            )
576        })?;
577
578        Ok(Self {
579            client: reqwest::Client::new(),
580            api_key: config.api_key.clone(),
581            model: config.model.clone(),
582            endpoint,
583        })
584    }
585}
586
587#[async_trait::async_trait]
588impl LlmProvider for OpenAICompatibleProvider {
589    async fn generate_chat(
590        &self,
591        messages: Vec<ChatMessage>,
592        temperature: f64,
593        max_tokens: usize,
594    ) -> Result<String> {
595        let request_body = serde_json::json!({
596            "model": self.model,
597            "messages": messages.iter().map(|m| {
598                serde_json::json!({
599                    "role": m.role,
600                    "content": m.content
601                })
602            }).collect::<Vec<_>>(),
603            "temperature": temperature,
604            "max_tokens": max_tokens,
605        });
606
607        let mut request =
608            self.client.post(&self.endpoint).header("Content-Type", "application/json");
609
610        if let Some(api_key) = &self.api_key {
611            request = request.header("Authorization", format!("Bearer {}", api_key));
612        }
613
614        let response = request.json(&request_body).send().await.map_err(|e| {
615            mockforge_foundation::Error::internal(format!("API request failed: {}", e))
616        })?;
617
618        if !response.status().is_success() {
619            let error_text = response.text().await.unwrap_or_default();
620            return Err(mockforge_foundation::Error::internal(format!(
621                "API error: {}",
622                error_text
623            )));
624        }
625
626        let response_json: serde_json::Value = response.json().await.map_err(|e| {
627            mockforge_foundation::Error::internal(format!("Failed to parse API response: {}", e))
628        })?;
629
630        // Extract content (try both OpenAI and Ollama formats)
631        let content = response_json["choices"][0]["message"]["content"]
632            .as_str()
633            .or_else(|| response_json["message"]["content"].as_str())
634            .ok_or_else(|| mockforge_foundation::Error::internal("Invalid API response format"))?
635            .to_string();
636
637        Ok(content)
638    }
639}
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644
645    #[test]
646    fn test_llm_client_creation() {
647        let config = BehaviorModelConfig::default();
648        let client = LlmClient::new(config);
649        assert_eq!(client.config().llm_provider, "openai");
650    }
651}