ai/
multi_step_analysis.rs

1use serde::{Deserialize, Serialize};
2use serde_json::json;
3use async_openai::types::{ChatCompletionTool, ChatCompletionToolType, FunctionObjectArgs};
4use anyhow::Result;
5
6/// File analysis result from the analyze function
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct FileAnalysisResult {
9  pub lines_added:   u32,
10  pub lines_removed: u32,
11  pub file_category: String,
12  pub summary:       String
13}
14
15/// File data with analysis results for scoring
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct FileDataForScoring {
18  pub file_path:      String,
19  pub operation_type: String,
20  pub lines_added:    u32,
21  pub lines_removed:  u32,
22  pub file_category:  String,
23  pub summary:        String
24}
25
26/// File data with calculated impact score
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct FileWithScore {
29  pub file_path:      String,
30  pub operation_type: String,
31  pub lines_added:    u32,
32  pub lines_removed:  u32,
33  pub file_category:  String,
34  pub summary:        String,
35  pub impact_score:   f32
36}
37
38/// Score calculation result
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ScoreResult {
41  pub files_with_scores: Vec<FileWithScore>
42}
43
44/// Commit message generation result
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct GenerateResult {
47  pub candidates: Vec<String>,
48  pub reasoning:  String
49}
50
51/// Creates the analyze function tool definition
52pub fn create_analyze_function_tool() -> Result<ChatCompletionTool> {
53  log::debug!("Creating analyze function tool");
54
55  let function = FunctionObjectArgs::default()
56    .name("analyze")
57    .description("Analyze a single file's changes from the git diff")
58    .parameters(json!({
59        "type": "object",
60        "properties": {
61            "file_path": {
62                "type": "string",
63                "description": "Relative path to the file"
64            },
65            "diff_content": {
66                "type": "string",
67                "description": "The git diff content for this specific file only"
68            },
69            "operation_type": {
70                "type": "string",
71                "enum": ["added", "modified", "deleted", "renamed", "binary"],
72                "description": "Type of operation performed on the file"
73            }
74        },
75        "required": ["file_path", "diff_content", "operation_type"]
76    }))
77    .build()?;
78
79  Ok(ChatCompletionTool { r#type: ChatCompletionToolType::Function, function })
80}
81
82/// Creates the score function tool definition
83pub fn create_score_function_tool() -> Result<ChatCompletionTool> {
84  log::debug!("Creating score function tool");
85
86  let function = FunctionObjectArgs::default()
87    .name("score")
88    .description("Calculate impact scores for all analyzed files")
89    .parameters(json!({
90        "type": "object",
91        "properties": {
92            "files_data": {
93                "type": "array",
94                "description": "Array of analyzed file data",
95                "items": {
96                    "type": "object",
97                    "properties": {
98                        "file_path": {
99                            "type": "string",
100                            "description": "Relative path to the file"
101                        },
102                        "operation_type": {
103                            "type": "string",
104                            "enum": ["added", "modified", "deleted", "renamed", "binary"],
105                            "description": "Type of operation performed on the file"
106                        },
107                        "lines_added": {
108                            "type": "integer",
109                            "description": "Number of lines added",
110                            "minimum": 0
111                        },
112                        "lines_removed": {
113                            "type": "integer",
114                            "description": "Number of lines removed",
115                            "minimum": 0
116                        },
117                        "file_category": {
118                            "type": "string",
119                            "enum": ["source", "test", "config", "docs", "binary", "build"],
120                            "description": "Category of the file"
121                        },
122                        "summary": {
123                            "type": "string",
124                            "description": "Brief description of changes"
125                        }
126                    },
127                    "required": ["file_path", "operation_type", "lines_added", "lines_removed", "file_category", "summary"]
128                }
129            }
130        },
131        "required": ["files_data"]
132    }))
133    .build()?;
134
135  Ok(ChatCompletionTool { r#type: ChatCompletionToolType::Function, function })
136}
137
138/// Creates the generate function tool definition
139pub fn create_generate_function_tool() -> Result<ChatCompletionTool> {
140  log::debug!("Creating generate function tool");
141
142  let function = FunctionObjectArgs::default()
143    .name("generate")
144    .description("Generate commit message candidates based on scored files")
145    .parameters(json!({
146        "type": "object",
147        "properties": {
148            "files_with_scores": {
149                "type": "array",
150                "description": "All files with calculated impact scores",
151                "items": {
152                    "type": "object",
153                    "properties": {
154                        "file_path": {
155                            "type": "string"
156                        },
157                        "operation_type": {
158                            "type": "string"
159                        },
160                        "lines_added": {
161                            "type": "integer"
162                        },
163                        "lines_removed": {
164                            "type": "integer"
165                        },
166                        "file_category": {
167                            "type": "string"
168                        },
169                        "summary": {
170                            "type": "string"
171                        },
172                        "impact_score": {
173                            "type": "number",
174                            "minimum": 0.0,
175                            "maximum": 1.0
176                        }
177                    },
178                    "required": ["file_path", "operation_type", "lines_added", "lines_removed", "file_category", "summary", "impact_score"]
179                }
180            },
181            "max_length": {
182                "type": "integer",
183                "description": "Maximum character length for commit message",
184                "default": 72
185            }
186        },
187        "required": ["files_with_scores"]
188    }))
189    .build()?;
190
191  Ok(ChatCompletionTool { r#type: ChatCompletionToolType::Function, function })
192}
193
194/// Analyzes a single file's changes
195pub fn analyze_file(file_path: &str, diff_content: &str, operation_type: &str) -> FileAnalysisResult {
196  log::debug!("Analyzing file: {file_path} ({operation_type})");
197
198  // Count lines added and removed
199  let mut lines_added = 0u32;
200  let mut lines_removed = 0u32;
201
202  for line in diff_content.lines() {
203    if line.starts_with('+') && !line.starts_with("+++") {
204      lines_added += 1;
205    } else if line.starts_with('-') && !line.starts_with("---") {
206      lines_removed += 1;
207    }
208  }
209
210  // Determine file category
211  let file_category = categorize_file(file_path);
212
213  // Generate summary based on diff content
214  let summary = generate_file_summary(file_path, diff_content, operation_type);
215
216  log::debug!("File analysis complete: +{lines_added} -{lines_removed} lines, category: {file_category}");
217
218  FileAnalysisResult { lines_added, lines_removed, file_category, summary }
219}
220
221/// Calculates impact scores for all files
222pub fn calculate_impact_scores(files_data: Vec<FileDataForScoring>) -> ScoreResult {
223  log::debug!("Calculating impact scores for {} files", files_data.len());
224
225  let mut files_with_scores = Vec::new();
226
227  for file_data in files_data {
228    let impact_score = calculate_single_impact_score(&file_data);
229
230    files_with_scores.push(FileWithScore {
231      file_path: file_data.file_path,
232      operation_type: file_data.operation_type,
233      lines_added: file_data.lines_added,
234      lines_removed: file_data.lines_removed,
235      file_category: file_data.file_category,
236      summary: file_data.summary,
237      impact_score
238    });
239  }
240
241  // Sort by impact score descending
242  files_with_scores.sort_by(|a, b| {
243    b.impact_score
244      .partial_cmp(&a.impact_score)
245      .unwrap_or(std::cmp::Ordering::Equal)
246  });
247
248  ScoreResult { files_with_scores }
249}
250
251/// Generates commit message candidates
252pub fn generate_commit_messages(files_with_scores: Vec<FileWithScore>, max_length: usize) -> GenerateResult {
253  log::debug!("Generating commit messages (max length: {max_length})");
254
255  // Find the highest impact changes
256  let primary_change = files_with_scores.first();
257  let mut candidates = Vec::new();
258
259  if let Some(primary) = primary_change {
260    // Generate different styles of commit messages
261
262    // Style 1: Action-focused
263    let action_msg = generate_action_message(primary, &files_with_scores, max_length);
264    candidates.push(action_msg);
265
266    // Style 2: Component-focused
267    let component_msg = generate_component_message(primary, &files_with_scores, max_length);
268    candidates.push(component_msg);
269
270    // Style 3: Impact-focused
271    let impact_msg = generate_impact_message(primary, &files_with_scores, max_length);
272    candidates.push(impact_msg);
273  }
274
275  let reasoning = generate_reasoning(&files_with_scores);
276
277  GenerateResult { candidates, reasoning }
278}
279
280// Helper functions
281
282fn categorize_file(file_path: &str) -> String {
283  let path = file_path.to_lowercase();
284
285  if path.ends_with(".test.js")
286    || path.ends_with(".spec.js")
287    || path.ends_with("_test.go")
288    || path.ends_with("_test.rs")
289    || path.contains("/test/")
290    || path.contains("/tests/")
291  {
292    "test".to_string()
293  } else if path.ends_with(".md") || path.ends_with(".txt") || path.ends_with(".rst") || path.contains("/docs/") {
294    "docs".to_string()
295  } else if path == "package.json"
296    || path == "cargo.toml"
297    || path == "go.mod"
298    || path == "requirements.txt"
299    || path == "gemfile"
300    || path.ends_with(".lock")
301  {
302    "build".to_string()
303  } else if path.ends_with(".yml")
304    || path.ends_with(".yaml")
305    || path.ends_with(".json")
306    || path.ends_with(".toml")
307    || path.ends_with(".ini")
308    || path.ends_with(".conf")
309    || path.contains("config")
310    || path.contains(".github/")
311  {
312    "config".to_string()
313  } else if path.ends_with(".png")
314    || path.ends_with(".jpg")
315    || path.ends_with(".gif")
316    || path.ends_with(".ico")
317    || path.ends_with(".pdf")
318    || path.ends_with(".zip")
319  {
320    "binary".to_string()
321  } else {
322    "source".to_string()
323  }
324}
325
326fn generate_file_summary(file_path: &str, _diff_content: &str, operation_type: &str) -> String {
327  // This is a simplified version - in practice, you'd analyze the diff content
328  // more thoroughly to generate meaningful summaries
329  match operation_type {
330    "added" => format!("New {} file added", categorize_file(file_path)),
331    "deleted" => format!("Removed {} file", categorize_file(file_path)),
332    "renamed" => "File renamed".to_string(),
333    "binary" => "Binary file updated".to_string(),
334    _ => "File modified".to_string()
335  }
336}
337
338fn calculate_single_impact_score(file_data: &FileDataForScoring) -> f32 {
339  let mut score = 0.0f32;
340
341  // Base score from operation type
342  score += match file_data.operation_type.as_str() {
343    "added" => 0.3,
344    "modified" => 0.2,
345    "deleted" => 0.25,
346    "renamed" => 0.1,
347    "binary" => 0.05,
348    _ => 0.1
349  };
350
351  // Score from file category
352  score += match file_data.file_category.as_str() {
353    "source" => 0.4,
354    "test" => 0.2,
355    "config" => 0.25,
356    "build" => 0.3,
357    "docs" => 0.1,
358    "binary" => 0.05,
359    _ => 0.1
360  };
361
362  // Score from lines changed (normalized)
363  let total_lines = file_data.lines_added + file_data.lines_removed;
364  let line_score = (total_lines as f32 / 100.0).min(0.3);
365  score += line_score;
366
367  score.min(1.0) // Cap at 1.0
368}
369
370fn generate_action_message(primary: &FileWithScore, _all_files: &[FileWithScore], max_length: usize) -> String {
371  let base = match primary.operation_type.as_str() {
372    "added" => "Add",
373    "modified" => "Update",
374    "deleted" => "Remove",
375    "renamed" => "Rename",
376    _ => "Change"
377  };
378
379  let component = extract_component_name(&primary.file_path);
380  let message = format!("{base} {component}");
381
382  if message.len() > max_length {
383    message.chars().take(max_length).collect()
384  } else {
385    message
386  }
387}
388
389fn generate_component_message(primary: &FileWithScore, _all_files: &[FileWithScore], max_length: usize) -> String {
390  let component = extract_component_name(&primary.file_path);
391  let action = match primary.operation_type.as_str() {
392    "added" => "implementation",
393    "modified" => "updates",
394    "deleted" => "removal",
395    _ => "changes"
396  };
397
398  let message = format!("{component}: {action}");
399
400  if message.len() > max_length {
401    message.chars().take(max_length).collect()
402  } else {
403    message
404  }
405}
406
407fn generate_impact_message(primary: &FileWithScore, all_files: &[FileWithScore], max_length: usize) -> String {
408  let impact_type = if all_files
409    .iter()
410    .any(|f| f.file_category == "source" && f.operation_type == "added")
411  {
412    "feature"
413  } else if all_files.iter().any(|f| f.file_category == "test") {
414    "test"
415  } else if all_files.iter().any(|f| f.file_category == "config") {
416    "configuration"
417  } else {
418    "update"
419  };
420
421  let component = extract_component_name(&primary.file_path);
422  let message = format!(
423    "{} {} for {}",
424    if impact_type == "feature" {
425      "New"
426    } else {
427      "Update"
428    },
429    impact_type,
430    component
431  );
432
433  if message.len() > max_length {
434    message.chars().take(max_length).collect()
435  } else {
436    message
437  }
438}
439
440fn extract_component_name(file_path: &str) -> String {
441  let path_parts: Vec<&str> = file_path.split('/').collect();
442
443  if let Some(filename) = path_parts.last() {
444    // Remove extension
445    let name_parts: Vec<&str> = filename.split('.').collect();
446    if name_parts.len() > 1 {
447      name_parts[0].to_string()
448    } else {
449      filename.to_string()
450    }
451  } else {
452    "component".to_string()
453  }
454}
455
456fn generate_reasoning(files_with_scores: &[FileWithScore]) -> String {
457  if files_with_scores.is_empty() {
458    return "No files to analyze".to_string();
459  }
460
461  let primary = &files_with_scores[0];
462  let total_files = files_with_scores.len();
463  let total_lines: u32 = files_with_scores
464    .iter()
465    .map(|f| f.lines_added + f.lines_removed)
466    .sum();
467
468  format!(
469    "{} changes have highest impact ({:.2}) affecting {} functionality. \
470        Total {} files changed with {} lines modified.",
471    primary
472      .file_category
473      .chars()
474      .next()
475      .unwrap_or('u')
476      .to_uppercase()
477      .collect::<String>()
478      + primary.file_category.get(1..).unwrap_or(""),
479    primary.impact_score,
480    extract_component_name(&primary.file_path),
481    total_files,
482    total_lines
483  )
484}
485
486#[cfg(test)]
487mod tests {
488  use super::*;
489
490  #[test]
491  fn test_file_categorization() {
492    assert_eq!(categorize_file("src/main.rs"), "source");
493    assert_eq!(categorize_file("tests/integration_test.rs"), "test");
494    assert_eq!(categorize_file("package.json"), "build");
495    assert_eq!(categorize_file(".github/workflows/ci.yml"), "config");
496    assert_eq!(categorize_file("README.md"), "docs");
497    assert_eq!(categorize_file("logo.png"), "binary");
498  }
499
500  #[test]
501  fn test_impact_score_calculation() {
502    let file_data = FileDataForScoring {
503      file_path:      "src/auth.rs".to_string(),
504      operation_type: "modified".to_string(),
505      lines_added:    50,
506      lines_removed:  20,
507      file_category:  "source".to_string(),
508      summary:        "Updated authentication logic".to_string()
509    };
510
511    let score = calculate_single_impact_score(&file_data);
512    assert!(score > 0.0 && score <= 1.0);
513  }
514}