mockforge_intelligence/ai_contract_diff/
recommendation_engine.rs1use super::types::{ContractDiffConfig, Mismatch, Recommendation};
7use crate::intelligent_behavior::config::BehaviorModelConfig;
8use crate::intelligent_behavior::llm_client::LlmClient;
9use crate::intelligent_behavior::types::LlmGenerationRequest;
10use mockforge_foundation::Result;
11use std::collections::HashMap;
12
13pub struct RecommendationEngine {
15 llm_client: Option<LlmClient>,
17
18 config: ContractDiffConfig,
20}
21
22impl RecommendationEngine {
23 pub fn new(config: ContractDiffConfig) -> Result<Self> {
25 let llm_client = if config.use_ai_recommendations {
26 let llm_config = BehaviorModelConfig {
28 llm_provider: config.llm_provider.clone(),
29 model: config.llm_model.clone(),
30 api_key: config.api_key.clone(),
31 api_endpoint: None,
32 temperature: 0.7, max_tokens: 2000,
34 rules: crate::intelligent_behavior::BehaviorRules::default(), };
36
37 Some(LlmClient::new(llm_config))
38 } else {
39 None
40 };
41
42 Ok(Self { llm_client, config })
43 }
44
45 pub async fn generate_recommendations(
47 &self,
48 mismatches: &[Mismatch],
49 request_context: &RequestContext,
50 ) -> Result<Vec<Recommendation>> {
51 if !self.config.use_ai_recommendations || self.llm_client.is_none() {
52 return Ok(self.generate_basic_recommendations(mismatches));
54 }
55
56 let mut recommendations = Vec::new();
57
58 let mut grouped: HashMap<String, Vec<&Mismatch>> = HashMap::new();
60 for mismatch in mismatches {
61 let key = format!("{:?}", mismatch.mismatch_type);
62 grouped.entry(key).or_default().push(mismatch);
63 }
64
65 for (_group_key, group_mismatches) in grouped {
67 if group_mismatches.len() > self.config.max_recommendations {
68 let limited = group_mismatches
70 .iter()
71 .take(self.config.max_recommendations)
72 .copied()
73 .collect::<Vec<_>>();
74 let group_recs =
75 self.generate_ai_recommendations_for_group(&limited, request_context).await?;
76 recommendations.extend(group_recs);
77 } else {
78 let group_recs = self
79 .generate_ai_recommendations_for_group(&group_mismatches, request_context)
80 .await?;
81 recommendations.extend(group_recs);
82 }
83 }
84
85 Ok(recommendations)
86 }
87
88 async fn generate_ai_recommendations_for_group(
90 &self,
91 mismatches: &[&Mismatch],
92 context: &RequestContext,
93 ) -> Result<Vec<Recommendation>> {
94 let llm_client = self
95 .llm_client
96 .as_ref()
97 .ok_or_else(|| mockforge_foundation::Error::internal("LLM client not initialized"))?;
98
99 let prompt = self.build_recommendation_prompt(mismatches, context);
101
102 let request = LlmGenerationRequest::new(self.get_system_prompt(), prompt)
104 .with_temperature(0.7)
105 .with_max_tokens(2000);
106
107 let response = llm_client.generate(&request).await?;
108
109 self.parse_llm_recommendations(response, mismatches)
111 }
112
113 fn build_recommendation_prompt(
115 &self,
116 mismatches: &[&Mismatch],
117 context: &RequestContext,
118 ) -> String {
119 let mut prompt = String::from(
120 "You are analyzing API contract mismatches between front-end requests and backend specifications.\n\n",
121 );
122
123 prompt.push_str("## Request Context\n");
124 prompt.push_str(&format!("Endpoint: {} {}\n", context.method, context.path));
125 if let Some(body) = &context.request_body {
126 prompt.push_str(&format!(
127 "Request Body: {}\n",
128 serde_json::to_string(body).unwrap_or_default()
129 ));
130 }
131 prompt.push_str(&format!("Contract Format: {}\n\n", context.contract_format));
132
133 prompt.push_str("## Detected Mismatches\n\n");
134 for (idx, mismatch) in mismatches.iter().enumerate() {
135 prompt.push_str(&format!("### Mismatch {}: {:?}\n", idx + 1, mismatch.mismatch_type));
136 prompt.push_str(&format!("Path: {}\n", mismatch.path));
137 prompt.push_str(&format!("Description: {}\n", mismatch.description));
138 if let Some(expected) = &mismatch.expected {
139 prompt.push_str(&format!("Expected: {}\n", expected));
140 }
141 if let Some(actual) = &mismatch.actual {
142 prompt.push_str(&format!("Actual: {}\n", actual));
143 }
144 prompt.push_str(&format!("Severity: {:?}\n\n", mismatch.severity));
145 }
146
147 prompt.push_str("## Task\n\n");
148 prompt.push_str("For each mismatch, provide:\n");
149 prompt.push_str("1. A clear, actionable recommendation for fixing the issue\n");
150 prompt.push_str("2. A suggested fix (code or configuration change)\n");
151 prompt.push_str("3. Reasoning explaining why this fix is appropriate\n");
152 if self.config.include_examples {
153 prompt.push_str("4. An example showing the fix applied\n");
154 }
155 prompt.push_str(
156 "\nReturn your response as a JSON array of recommendation objects with the following structure:\n",
157 );
158 prompt.push_str(
159 r#"[
160 {
161 "mismatch_index": 0,
162 "recommendation": "Clear recommendation text",
163 "suggested_fix": "Specific fix or action",
164 "reasoning": "Why this fix is appropriate",
165 "example": { "before": "...", "after": "..." }
166 }
167]"#,
168 );
169
170 prompt
171 }
172
173 fn get_system_prompt(&self) -> String {
175 String::from(
176 "You are an expert API contract analyst. Your role is to analyze mismatches between \
177 front-end API requests and backend contract specifications, and provide clear, \
178 actionable recommendations for fixing these issues. Your recommendations should be \
179 practical, well-reasoned, and include specific examples when helpful. Always consider \
180 the context of the API and the severity of the mismatch when making recommendations.",
181 )
182 }
183
184 fn parse_llm_recommendations(
186 &self,
187 response: serde_json::Value,
188 mismatches: &[&Mismatch],
189 ) -> Result<Vec<Recommendation>> {
190 let mut recommendations = Vec::new();
191
192 let recommendations_array = if response.is_array() {
194 Some(response.as_array().unwrap())
195 } else if let Some(arr) = response.get("recommendations") {
196 arr.as_array()
197 } else if let Some(arr) = response.get("data") {
198 arr.as_array()
199 } else {
200 None
201 };
202
203 if let Some(recs) = recommendations_array {
204 for (idx, rec_json) in recs.iter().enumerate() {
205 let mismatch_index =
206 rec_json.get("mismatch_index").and_then(|v| v.as_u64()).unwrap_or(idx as u64)
207 as usize;
208
209 if mismatch_index < mismatches.len() {
210 let mismatch = mismatches[mismatch_index];
211 let recommendation = Recommendation {
212 id: format!("rec_{}_{}", mismatch.path, idx),
213 mismatch_id: format!("mismatch_{}", mismatch_index),
214 recommendation: rec_json
215 .get("recommendation")
216 .and_then(|v| v.as_str())
217 .unwrap_or("No recommendation provided")
218 .to_string(),
219 suggested_fix: rec_json
220 .get("suggested_fix")
221 .and_then(|v| v.as_str())
222 .map(|s| s.to_string()),
223 confidence: mismatch.confidence, reasoning: rec_json
225 .get("reasoning")
226 .and_then(|v| v.as_str())
227 .map(|s| s.to_string()),
228 example: rec_json.get("example").cloned(),
229 };
230
231 recommendations.push(recommendation);
232 }
233 }
234 } else {
235 if let Some(text) = response.as_str() {
237 if let Some(start) = text.find('[') {
239 if let Some(end) = text.rfind(']') {
240 let json_str = &text[start..=end];
241 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
242 return self.parse_llm_recommendations(parsed, mismatches);
243 }
244 }
245 }
246 }
247
248 return Ok(self.generate_basic_recommendations(
250 &mismatches.iter().map(|m| (*m).clone()).collect::<Vec<_>>(),
251 ));
252 }
253
254 Ok(recommendations)
255 }
256
257 fn generate_basic_recommendations(&self, mismatches: &[Mismatch]) -> Vec<Recommendation> {
259 mismatches
260 .iter()
261 .enumerate()
262 .map(|(idx, mismatch)| {
263 let (recommendation, suggested_fix) = match mismatch.mismatch_type {
264 super::types::MismatchType::MissingRequiredField => (
265 format!("Add the required field '{}' to the request", mismatch.path),
266 format!("Add field: {}", mismatch.path),
267 ),
268 super::types::MismatchType::TypeMismatch => (
269 format!(
270 "Change the type of '{}' from {} to {}",
271 mismatch.path,
272 mismatch.actual.as_ref().unwrap_or(&"unknown".to_string()),
273 mismatch.expected.as_ref().unwrap_or(&"unknown".to_string())
274 ),
275 format!(
276 "Update field type: {} -> {}",
277 mismatch.path,
278 mismatch.expected.as_ref().unwrap_or(&"unknown".to_string())
279 ),
280 ),
281 super::types::MismatchType::UnexpectedField => (
282 format!("Remove the unexpected field '{}' from the request", mismatch.path),
283 format!("Remove field: {}", mismatch.path),
284 ),
285 _ => (mismatch.description.clone(), "Review and fix the mismatch".to_string()),
286 };
287
288 Recommendation {
289 id: format!("rec_{}_{}", mismatch.path, idx),
290 mismatch_id: format!("mismatch_{}", idx),
291 recommendation,
292 suggested_fix: Some(suggested_fix),
293 confidence: mismatch.confidence,
294 reasoning: Some(format!(
295 "Based on mismatch type: {:?}",
296 mismatch.mismatch_type
297 )),
298 example: None,
299 }
300 })
301 .collect()
302 }
303}
304
305#[derive(Debug, Clone)]
307pub struct RequestContext {
308 pub method: String,
310
311 pub path: String,
313
314 pub request_body: Option<serde_json::Value>,
316
317 pub contract_format: String,
319
320 pub additional_context: HashMap<String, serde_json::Value>,
322}
323
324impl RequestContext {
325 pub fn new(method: impl Into<String>, path: impl Into<String>) -> Self {
327 Self {
328 method: method.into(),
329 path: path.into(),
330 request_body: None,
331 contract_format: "openapi-3.0".to_string(),
332 additional_context: HashMap::new(),
333 }
334 }
335
336 pub fn with_body(mut self, body: serde_json::Value) -> Self {
338 self.request_body = Some(body);
339 self
340 }
341
342 pub fn with_contract_format(mut self, format: impl Into<String>) -> Self {
344 self.contract_format = format.into();
345 self
346 }
347}