Skip to main content

mockforge_intelligence/ai_contract_diff/
recommendation_engine.rs

1//! AI-powered recommendation engine for contract diff analysis
2//!
3//! This module uses LLM to generate contextual recommendations for fixing contract mismatches,
4//! going beyond structural diffs to provide intelligent suggestions.
5
6use super::types::{ContractDiffConfig, Mismatch, Recommendation};
7use crate::intelligent_behavior::config::BehaviorModelConfig;
8use crate::intelligent_behavior::llm_client::LlmClient;
9use crate::intelligent_behavior::types::LlmGenerationRequest;
10use mockforge_foundation::Result;
11use std::collections::HashMap;
12
13/// AI-powered recommendation engine
14pub struct RecommendationEngine {
15    /// LLM client for generating recommendations
16    llm_client: Option<LlmClient>,
17
18    /// Configuration
19    config: ContractDiffConfig,
20}
21
22impl RecommendationEngine {
23    /// Create a new recommendation engine
24    pub fn new(config: ContractDiffConfig) -> Result<Self> {
25        let llm_client = if config.use_ai_recommendations {
26            // Create LLM client configuration
27            let llm_config = BehaviorModelConfig {
28                llm_provider: config.llm_provider.clone(),
29                model: config.llm_model.clone(),
30                api_key: config.api_key.clone(),
31                api_endpoint: None,
32                temperature: 0.7, // Lower temperature for more focused recommendations
33                max_tokens: 2000,
34                rules: crate::intelligent_behavior::BehaviorRules::default(), // No specific rules for contract diff recommendations
35            };
36
37            Some(LlmClient::new(llm_config))
38        } else {
39            None
40        };
41
42        Ok(Self { llm_client, config })
43    }
44
45    /// Generate recommendations for mismatches
46    pub async fn generate_recommendations(
47        &self,
48        mismatches: &[Mismatch],
49        request_context: &RequestContext,
50    ) -> Result<Vec<Recommendation>> {
51        if !self.config.use_ai_recommendations || self.llm_client.is_none() {
52            // Return basic recommendations without AI
53            return Ok(self.generate_basic_recommendations(mismatches));
54        }
55
56        let mut recommendations = Vec::new();
57
58        // Group mismatches by type for batch processing
59        let mut grouped: HashMap<String, Vec<&Mismatch>> = HashMap::new();
60        for mismatch in mismatches {
61            let key = format!("{:?}", mismatch.mismatch_type);
62            grouped.entry(key).or_default().push(mismatch);
63        }
64
65        // Generate recommendations for each group
66        for (_group_key, group_mismatches) in grouped {
67            if group_mismatches.len() > self.config.max_recommendations {
68                // Limit to max_recommendations
69                let limited = group_mismatches
70                    .iter()
71                    .take(self.config.max_recommendations)
72                    .copied()
73                    .collect::<Vec<_>>();
74                let group_recs =
75                    self.generate_ai_recommendations_for_group(&limited, request_context).await?;
76                recommendations.extend(group_recs);
77            } else {
78                let group_recs = self
79                    .generate_ai_recommendations_for_group(&group_mismatches, request_context)
80                    .await?;
81                recommendations.extend(group_recs);
82            }
83        }
84
85        Ok(recommendations)
86    }
87
88    /// Generate AI-powered recommendations for a group of mismatches
89    async fn generate_ai_recommendations_for_group(
90        &self,
91        mismatches: &[&Mismatch],
92        context: &RequestContext,
93    ) -> Result<Vec<Recommendation>> {
94        let llm_client = self
95            .llm_client
96            .as_ref()
97            .ok_or_else(|| mockforge_foundation::Error::internal("LLM client not initialized"))?;
98
99        // Build prompt for LLM
100        let prompt = self.build_recommendation_prompt(mismatches, context);
101
102        // Generate recommendation using LLM
103        let request = LlmGenerationRequest::new(self.get_system_prompt(), prompt)
104            .with_temperature(0.7)
105            .with_max_tokens(2000);
106
107        let response = llm_client.generate(&request).await?;
108
109        // Parse LLM response into recommendations
110        self.parse_llm_recommendations(response, mismatches)
111    }
112
113    /// Build prompt for LLM recommendation generation
114    fn build_recommendation_prompt(
115        &self,
116        mismatches: &[&Mismatch],
117        context: &RequestContext,
118    ) -> String {
119        let mut prompt = String::from(
120            "You are analyzing API contract mismatches between front-end requests and backend specifications.\n\n",
121        );
122
123        prompt.push_str("## Request Context\n");
124        prompt.push_str(&format!("Endpoint: {} {}\n", context.method, context.path));
125        if let Some(body) = &context.request_body {
126            prompt.push_str(&format!(
127                "Request Body: {}\n",
128                serde_json::to_string(body).unwrap_or_default()
129            ));
130        }
131        prompt.push_str(&format!("Contract Format: {}\n\n", context.contract_format));
132
133        prompt.push_str("## Detected Mismatches\n\n");
134        for (idx, mismatch) in mismatches.iter().enumerate() {
135            prompt.push_str(&format!("### Mismatch {}: {:?}\n", idx + 1, mismatch.mismatch_type));
136            prompt.push_str(&format!("Path: {}\n", mismatch.path));
137            prompt.push_str(&format!("Description: {}\n", mismatch.description));
138            if let Some(expected) = &mismatch.expected {
139                prompt.push_str(&format!("Expected: {}\n", expected));
140            }
141            if let Some(actual) = &mismatch.actual {
142                prompt.push_str(&format!("Actual: {}\n", actual));
143            }
144            prompt.push_str(&format!("Severity: {:?}\n\n", mismatch.severity));
145        }
146
147        prompt.push_str("## Task\n\n");
148        prompt.push_str("For each mismatch, provide:\n");
149        prompt.push_str("1. A clear, actionable recommendation for fixing the issue\n");
150        prompt.push_str("2. A suggested fix (code or configuration change)\n");
151        prompt.push_str("3. Reasoning explaining why this fix is appropriate\n");
152        if self.config.include_examples {
153            prompt.push_str("4. An example showing the fix applied\n");
154        }
155        prompt.push_str(
156            "\nReturn your response as a JSON array of recommendation objects with the following structure:\n",
157        );
158        prompt.push_str(
159            r#"[
160  {
161    "mismatch_index": 0,
162    "recommendation": "Clear recommendation text",
163    "suggested_fix": "Specific fix or action",
164    "reasoning": "Why this fix is appropriate",
165    "example": { "before": "...", "after": "..." }
166  }
167]"#,
168        );
169
170        prompt
171    }
172
173    /// Get system prompt for LLM
174    fn get_system_prompt(&self) -> String {
175        String::from(
176            "You are an expert API contract analyst. Your role is to analyze mismatches between \
177            front-end API requests and backend contract specifications, and provide clear, \
178            actionable recommendations for fixing these issues. Your recommendations should be \
179            practical, well-reasoned, and include specific examples when helpful. Always consider \
180            the context of the API and the severity of the mismatch when making recommendations.",
181        )
182    }
183
184    /// Parse LLM response into recommendation objects
185    fn parse_llm_recommendations(
186        &self,
187        response: serde_json::Value,
188        mismatches: &[&Mismatch],
189    ) -> Result<Vec<Recommendation>> {
190        let mut recommendations = Vec::new();
191
192        // Try to extract recommendations array from response
193        let recommendations_array = if response.is_array() {
194            Some(response.as_array().unwrap())
195        } else if let Some(arr) = response.get("recommendations") {
196            arr.as_array()
197        } else if let Some(arr) = response.get("data") {
198            arr.as_array()
199        } else {
200            None
201        };
202
203        if let Some(recs) = recommendations_array {
204            for (idx, rec_json) in recs.iter().enumerate() {
205                let mismatch_index =
206                    rec_json.get("mismatch_index").and_then(|v| v.as_u64()).unwrap_or(idx as u64)
207                        as usize;
208
209                if mismatch_index < mismatches.len() {
210                    let mismatch = mismatches[mismatch_index];
211                    let recommendation = Recommendation {
212                        id: format!("rec_{}_{}", mismatch.path, idx),
213                        mismatch_id: format!("mismatch_{}", mismatch_index),
214                        recommendation: rec_json
215                            .get("recommendation")
216                            .and_then(|v| v.as_str())
217                            .unwrap_or("No recommendation provided")
218                            .to_string(),
219                        suggested_fix: rec_json
220                            .get("suggested_fix")
221                            .and_then(|v| v.as_str())
222                            .map(|s| s.to_string()),
223                        confidence: mismatch.confidence, // Use mismatch confidence as base
224                        reasoning: rec_json
225                            .get("reasoning")
226                            .and_then(|v| v.as_str())
227                            .map(|s| s.to_string()),
228                        example: rec_json.get("example").cloned(),
229                    };
230
231                    recommendations.push(recommendation);
232                }
233            }
234        } else {
235            // Fallback: try to parse as text and extract JSON
236            if let Some(text) = response.as_str() {
237                // Try to find JSON in text
238                if let Some(start) = text.find('[') {
239                    if let Some(end) = text.rfind(']') {
240                        let json_str = &text[start..=end];
241                        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
242                            return self.parse_llm_recommendations(parsed, mismatches);
243                        }
244                    }
245                }
246            }
247
248            // If all else fails, generate basic recommendations
249            return Ok(self.generate_basic_recommendations(
250                &mismatches.iter().map(|m| (*m).clone()).collect::<Vec<_>>(),
251            ));
252        }
253
254        Ok(recommendations)
255    }
256
257    /// Generate basic recommendations without AI
258    fn generate_basic_recommendations(&self, mismatches: &[Mismatch]) -> Vec<Recommendation> {
259        mismatches
260            .iter()
261            .enumerate()
262            .map(|(idx, mismatch)| {
263                let (recommendation, suggested_fix) = match mismatch.mismatch_type {
264                    super::types::MismatchType::MissingRequiredField => (
265                        format!("Add the required field '{}' to the request", mismatch.path),
266                        format!("Add field: {}", mismatch.path),
267                    ),
268                    super::types::MismatchType::TypeMismatch => (
269                        format!(
270                            "Change the type of '{}' from {} to {}",
271                            mismatch.path,
272                            mismatch.actual.as_ref().unwrap_or(&"unknown".to_string()),
273                            mismatch.expected.as_ref().unwrap_or(&"unknown".to_string())
274                        ),
275                        format!(
276                            "Update field type: {} -> {}",
277                            mismatch.path,
278                            mismatch.expected.as_ref().unwrap_or(&"unknown".to_string())
279                        ),
280                    ),
281                    super::types::MismatchType::UnexpectedField => (
282                        format!("Remove the unexpected field '{}' from the request", mismatch.path),
283                        format!("Remove field: {}", mismatch.path),
284                    ),
285                    _ => (mismatch.description.clone(), "Review and fix the mismatch".to_string()),
286                };
287
288                Recommendation {
289                    id: format!("rec_{}_{}", mismatch.path, idx),
290                    mismatch_id: format!("mismatch_{}", idx),
291                    recommendation,
292                    suggested_fix: Some(suggested_fix),
293                    confidence: mismatch.confidence,
294                    reasoning: Some(format!(
295                        "Based on mismatch type: {:?}",
296                        mismatch.mismatch_type
297                    )),
298                    example: None,
299                }
300            })
301            .collect()
302    }
303}
304
305/// Context for recommendation generation
306#[derive(Debug, Clone)]
307pub struct RequestContext {
308    /// HTTP method
309    pub method: String,
310
311    /// Request path
312    pub path: String,
313
314    /// Request body
315    pub request_body: Option<serde_json::Value>,
316
317    /// Contract format
318    pub contract_format: String,
319
320    /// Additional context
321    pub additional_context: HashMap<String, serde_json::Value>,
322}
323
324impl RequestContext {
325    /// Create a new request context
326    pub fn new(method: impl Into<String>, path: impl Into<String>) -> Self {
327        Self {
328            method: method.into(),
329            path: path.into(),
330            request_body: None,
331            contract_format: "openapi-3.0".to_string(),
332            additional_context: HashMap::new(),
333        }
334    }
335
336    /// Add request body
337    pub fn with_body(mut self, body: serde_json::Value) -> Self {
338        self.request_body = Some(body);
339        self
340    }
341
342    /// Set contract format
343    pub fn with_contract_format(mut self, format: impl Into<String>) -> Self {
344        self.contract_format = format.into();
345        self
346    }
347}