1use std::{
10 collections::HashMap,
11 path::{Path, PathBuf},
12 process::Command,
13 time::Duration,
14};
15
16use serde::{Deserialize, Serialize};
17
18use crate::{
19 config::CommitConfig,
20 diff::smart_truncate_diff,
21 error::{CommitGenError, Result},
22 patch::stage_files,
23 templates,
24 tokens::create_token_counter,
25 types::{
26 ChangelogBoundary, ChangelogCategory, Function, FunctionParameters, Tool, UnreleasedSection,
27 },
28};
29
30#[derive(Debug, Deserialize)]
32struct ChangelogResponse {
33 entries: HashMap<String, Vec<String>>,
34}
35
36#[derive(Debug, Serialize)]
38struct ApiRequest {
39 model: String,
40 max_tokens: u32,
41 temperature: f32,
42 tools: Vec<Tool>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 tool_choice: Option<serde_json::Value>,
45 messages: Vec<Message>,
46}
47
48#[derive(Debug, Serialize)]
49struct Message {
50 role: String,
51 content: String,
52}
53
54#[derive(Debug, Deserialize)]
55struct ApiResponse {
56 choices: Vec<Choice>,
57}
58
59#[derive(Debug, Deserialize)]
60struct Choice {
61 message: ResponseMessage,
62}
63
64#[derive(Debug, Deserialize)]
65struct ResponseMessage {
66 #[serde(default)]
67 tool_calls: Vec<ToolCall>,
68 #[serde(default)]
69 content: Option<String>,
70}
71
72#[derive(Debug, Deserialize)]
73struct ToolCall {
74 function: FunctionCall,
75}
76
77#[derive(Debug, Deserialize)]
78struct FunctionCall {
79 name: String,
80 arguments: String,
81}
82
83pub async fn run_changelog_flow(args: &crate::types::Args, config: &CommitConfig) -> Result<()> {
90 let token_counter = create_token_counter(config);
91
92 let staged_files = get_staged_files(&args.dir)?;
94 if staged_files.is_empty() {
95 return Ok(());
96 }
97
98 let non_changelog_files: Vec<_> = staged_files
100 .iter()
101 .filter(|f| !f.to_lowercase().ends_with("changelog.md"))
102 .cloned()
103 .collect();
104
105 if non_changelog_files.is_empty() {
106 return Ok(());
107 }
108
109 let changelogs = find_changelogs(&args.dir)?;
111 if changelogs.is_empty() {
112 return Ok(());
114 }
115
116 let boundaries = detect_boundaries(&non_changelog_files, &changelogs, &args.dir);
118 if boundaries.is_empty() {
119 return Ok(());
120 }
121
122 println!("{}", crate::style::info(&format!("Updating {} changelog(s)...", boundaries.len())));
123
124 let mut modified_changelogs = Vec::new();
125
126 for boundary in boundaries {
127 let diff = get_diff_for_files(&boundary.files, &args.dir)?;
129 let stat = get_stat_for_files(&boundary.files, &args.dir)?;
130
131 if diff.is_empty() {
132 continue;
133 }
134
135 let diff = if diff.len() > config.max_diff_length {
137 smart_truncate_diff(&diff, config.max_diff_length, config, &token_counter)
138 } else {
139 diff
140 };
141
142 let changelog_content = std::fs::read_to_string(&boundary.changelog_path).map_err(|e| {
144 CommitGenError::ChangelogParseError {
145 path: boundary.changelog_path.display().to_string(),
146 reason: e.to_string(),
147 }
148 })?;
149
150 let unreleased = match parse_unreleased_section(&changelog_content, &boundary.changelog_path)
151 {
152 Ok(u) => u,
153 Err(CommitGenError::NoUnreleasedSection { path }) => {
154 eprintln!(
155 "{} No [Unreleased] section in {}, skipping changelog update",
156 crate::style::icons::WARNING,
157 path
158 );
159 continue;
160 },
161 Err(e) => return Err(e),
162 };
163
164 let is_package_changelog = boundary
166 .changelog_path
167 .parent()
168 .is_some_and(|p| p != Path::new(&args.dir) && p != Path::new("."));
169
170 let existing_entries = format_existing_entries(&unreleased);
172
173 let new_entries = match generate_changelog_entries(
175 &boundary.changelog_path,
176 is_package_changelog,
177 &stat,
178 &diff,
179 existing_entries.as_deref(),
180 config,
181 )
182 .await
183 {
184 Ok(entries) => entries,
185 Err(e) => {
186 eprintln!(
187 "{}",
188 crate::style::warning(&format!("Failed to generate changelog entries: {e}"))
189 );
190 continue;
191 },
192 };
193
194 if new_entries.is_empty() {
195 continue;
196 }
197
198 if let Some(debug_dir) = &args.debug_output {
200 let _ = std::fs::create_dir_all(debug_dir);
201 let changelog_json: HashMap<String, Vec<String>> = new_entries
202 .iter()
203 .map(|(cat, entries)| (cat.as_str().to_string(), entries.clone()))
204 .collect();
205 if let Ok(json_str) = serde_json::to_string_pretty(&changelog_json) {
206 let _ = std::fs::write(debug_dir.join("changelog.json"), json_str);
207 }
208 }
209
210 let updated = write_entries(&changelog_content, &unreleased, &new_entries);
212 std::fs::write(&boundary.changelog_path, updated).map_err(|e| {
213 CommitGenError::ChangelogParseError {
214 path: boundary.changelog_path.display().to_string(),
215 reason: format!("Failed to write: {e}"),
216 }
217 })?;
218
219 let entry_count: usize = new_entries.values().map(|v| v.len()).sum();
220 modified_changelogs.push(boundary.changelog_path.display().to_string());
221 println!(
222 "{} Added {} entries to {}",
223 crate::style::icons::SUCCESS,
224 entry_count,
225 boundary.changelog_path.display()
226 );
227 }
228
229 if !modified_changelogs.is_empty() {
231 stage_files(&modified_changelogs, &args.dir)?;
232 }
233
234 Ok(())
235}
236
237async fn generate_changelog_entries(
239 changelog_path: &Path,
240 is_package_changelog: bool,
241 stat: &str,
242 diff: &str,
243 existing_entries: Option<&str>,
244 config: &CommitConfig,
245) -> Result<HashMap<ChangelogCategory, Vec<String>>> {
246 let parts = templates::render_changelog_prompt(
247 "default",
248 &changelog_path.display().to_string(),
249 is_package_changelog,
250 stat,
251 diff,
252 existing_entries,
253 )?;
254
255 let response = call_changelog_api(&parts, config).await?;
256
257 let mut result = HashMap::new();
259 for (key, entries) in response.entries {
260 if entries.is_empty() {
261 continue;
262 }
263 let category = ChangelogCategory::from_name(&key);
264 result.insert(category, entries);
265 }
266
267 Ok(result)
268}
269
270async fn call_changelog_api(
272 parts: &templates::PromptParts,
273 config: &CommitConfig,
274) -> Result<ChangelogResponse> {
275 let client = crate::api::get_client(config);
276
277 let model = config.model.clone();
278
279 let tool = Tool {
281 tool_type: "function".to_string(),
282 function: Function {
283 name: "create_changelog_entries".to_string(),
284 description: "Generate changelog entries grouped by category".to_string(),
285 parameters: FunctionParameters {
286 param_type: "object".to_string(),
287 properties: serde_json::json!({
288 "entries": {
289 "type": "object",
290 "description": "Changelog entries grouped by category",
291 "properties": {
292 "Added": {
293 "type": "array",
294 "items": { "type": "string" },
295 "description": "New features or capabilities"
296 },
297 "Changed": {
298 "type": "array",
299 "items": { "type": "string" },
300 "description": "Changes to existing functionality"
301 },
302 "Fixed": {
303 "type": "array",
304 "items": { "type": "string" },
305 "description": "Bug fixes"
306 },
307 "Deprecated": {
308 "type": "array",
309 "items": { "type": "string" },
310 "description": "Features marked for removal"
311 },
312 "Removed": {
313 "type": "array",
314 "items": { "type": "string" },
315 "description": "Removed features"
316 },
317 "Security": {
318 "type": "array",
319 "items": { "type": "string" },
320 "description": "Security-related changes"
321 },
322 "Breaking Changes": {
323 "type": "array",
324 "items": { "type": "string" },
325 "description": "Breaking API or behavior changes"
326 }
327 },
328 "additionalProperties": false
329 }
330 }),
331 required: vec!["entries".to_string()],
332 },
333 },
334 };
335
336 let mut attempt = 0;
337 loop {
338 attempt += 1;
339
340 let mut messages = Vec::new();
341 if !parts.system.is_empty() {
342 messages.push(Message { role: "system".to_string(), content: parts.system.clone() });
343 }
344 messages.push(Message { role: "user".to_string(), content: parts.user.clone() });
345
346 let request = ApiRequest {
347 model: model.clone(),
348 max_tokens: 2000,
349 temperature: config.temperature,
350 tools: vec![tool.clone()],
351 tool_choice: Some(
352 serde_json::json!({ "type": "function", "function": { "name": "create_changelog_entries" } }),
353 ),
354 messages,
355 };
356
357 let mut request_builder = client
358 .post(format!("{}/chat/completions", config.api_base_url))
359 .header("content-type", "application/json");
360
361 if let Some(api_key) = &config.api_key {
362 request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"));
363 }
364
365 let (status, response_text) =
366 crate::api::timed_send(request_builder.json(&request), "changelog", &model).await?;
367
368 if status.is_server_error() {
369 if attempt < config.max_retries {
370 let backoff_ms = config.initial_backoff_ms * (1 << (attempt - 1));
371 eprintln!(
372 "{}",
373 crate::style::warning(&format!(
374 "Server error {status}, retry {attempt}/{} after {backoff_ms}ms...",
375 config.max_retries
376 ))
377 );
378 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
379 continue;
380 }
381 return Err(CommitGenError::ApiError { status: status.as_u16(), body: response_text });
382 }
383
384 if !status.is_success() {
385 return Err(CommitGenError::ApiError { status: status.as_u16(), body: response_text });
386 }
387
388 if let Ok(api_response) = serde_json::from_str::<ApiResponse>(&response_text) {
390 let message = &api_response.choices[0].message;
391
392 if !message.tool_calls.is_empty() {
394 let tool_call = &message.tool_calls[0];
395 if tool_call
396 .function
397 .name
398 .ends_with("create_changelog_entries")
399 {
400 let changelog_response: ChangelogResponse =
401 serde_json::from_str(&tool_call.function.arguments).map_err(|e| {
402 CommitGenError::Other(format!(
403 "Failed to parse changelog tool arguments: {e}. Args: {}",
404 tool_call
405 .function
406 .arguments
407 .chars()
408 .take(500)
409 .collect::<String>()
410 ))
411 })?;
412 return Ok(changelog_response);
413 }
414 }
415
416 if let Some(content) = &message.content {
418 let json_str = extract_json_from_content(content);
419 if !json_str.is_empty() {
420 let changelog_response: ChangelogResponse = serde_json::from_str(&json_str)
421 .map_err(|e| {
422 CommitGenError::Other(format!(
423 "Failed to parse changelog response from content: {e}. Content: {}",
424 json_str.chars().take(500).collect::<String>()
425 ))
426 })?;
427 return Ok(changelog_response);
428 }
429 }
430 }
431
432 let json_str = extract_json_from_content(&response_text);
434 if json_str.is_empty() {
435 return Err(CommitGenError::Other(format!(
436 "Changelog API returned no tool calls or parseable content. Raw response: {}",
437 response_text.chars().take(1000).collect::<String>()
438 )));
439 }
440 let changelog_response: ChangelogResponse = serde_json::from_str(&json_str).map_err(|e| {
441 CommitGenError::Other(format!(
442 "Failed to parse changelog response: {e}. Content was: {}",
443 json_str.chars().take(500).collect::<String>()
444 ))
445 })?;
446
447 return Ok(changelog_response);
448 }
449}
450
451fn extract_json_from_content(content: &str) -> String {
453 let trimmed = content.trim();
454
455 if let Some(start) = trimmed.find("```json") {
457 let after_marker = &trimmed[start + 7..];
458 if let Some(end) = after_marker.find("```") {
459 return after_marker[..end].trim().to_string();
460 }
461 }
462
463 if let Some(start) = trimmed.find("```") {
465 let after_marker = &trimmed[start + 3..];
466 let content_start = after_marker.find('\n').map_or(0, |i| i + 1);
468 let after_newline = &after_marker[content_start..];
469 if let Some(end) = after_newline.find("```") {
470 return after_newline[..end].trim().to_string();
471 }
472 }
473
474 if let Some(start) = trimmed.find('{')
476 && let Some(end) = trimmed.rfind('}')
477 {
478 return trimmed[start..=end].to_string();
479 }
480
481 trimmed.to_string()
482}
483
484fn format_existing_entries(unreleased: &UnreleasedSection) -> Option<String> {
486 if unreleased.entries.is_empty() {
487 return None;
488 }
489
490 let mut lines = Vec::new();
491 for category in ChangelogCategory::render_order() {
492 if let Some(entries) = unreleased.entries.get(category) {
493 if entries.is_empty() {
494 continue;
495 }
496 lines.push(format!("### {}", category.as_str()));
497 for entry in entries {
498 lines.push(entry.clone());
499 }
500 lines.push(String::new());
501 }
502 }
503
504 if lines.is_empty() {
505 None
506 } else {
507 Some(lines.join("\n"))
508 }
509}
510
511fn get_staged_files(dir: &str) -> Result<Vec<String>> {
513 let output = Command::new("git")
514 .args(["diff", "--cached", "--name-only"])
515 .current_dir(dir)
516 .output()
517 .map_err(|e| CommitGenError::git(format!("Failed to get staged files: {e}")))?;
518
519 if !output.status.success() {
520 let stderr = String::from_utf8_lossy(&output.stderr);
521 return Err(CommitGenError::git(format!("git diff --cached --name-only failed: {stderr}")));
522 }
523
524 let files: Vec<String> = String::from_utf8_lossy(&output.stdout)
525 .lines()
526 .filter(|s| !s.is_empty())
527 .map(String::from)
528 .collect();
529
530 Ok(files)
531}
532
533fn find_changelogs(dir: &str) -> Result<Vec<PathBuf>> {
535 let output = Command::new("git")
536 .args(["ls-files", "--full-name", "**/CHANGELOG.md", "CHANGELOG.md"])
537 .current_dir(dir)
538 .output()
539 .map_err(|e| CommitGenError::git(format!("Failed to find changelogs: {e}")))?;
540
541 let files: Vec<PathBuf> = String::from_utf8_lossy(&output.stdout)
543 .lines()
544 .filter(|s| !s.is_empty())
545 .map(|s| PathBuf::from(dir).join(s))
546 .collect();
547
548 Ok(files)
549}
550
551fn detect_boundaries(
553 files: &[String],
554 changelogs: &[PathBuf],
555 dir: &str,
556) -> Vec<ChangelogBoundary> {
557 let mut file_to_changelog: HashMap<String, PathBuf> = HashMap::new();
558
559 let mut dir_to_changelog: HashMap<String, PathBuf> = HashMap::new();
563 let mut root_changelog: Option<PathBuf> = None;
564
565 for changelog in changelogs {
566 let rel_path = changelog
568 .strip_prefix(dir)
569 .unwrap_or(changelog)
570 .to_string_lossy();
571
572 if let Some(parent) = Path::new(&*rel_path).parent() {
574 let parent_str = parent.to_string_lossy().to_string();
575 if parent_str.is_empty() || parent_str == "." {
576 root_changelog = Some(changelog.clone());
577 } else {
578 dir_to_changelog.insert(parent_str, changelog.clone());
579 }
580 }
581 }
582
583 for file in files {
584 let mut current_path = Path::new(file)
586 .parent()
587 .map(|p| p.to_string_lossy().to_string());
588 let mut found = false;
589
590 while let Some(ref dir_path) = current_path {
591 if let Some(changelog) = dir_to_changelog.get(dir_path) {
592 file_to_changelog.insert(file.clone(), changelog.clone());
593 found = true;
594 break;
595 }
596
597 let path = Path::new(dir_path);
599 current_path = path.parent().and_then(|p| {
600 let s = p.to_string_lossy().to_string();
601 if s.is_empty() { None } else { Some(s) }
602 });
603 }
604
605 if !found && let Some(ref root) = root_changelog {
607 file_to_changelog.insert(file.clone(), root.clone());
608 }
609 }
611
612 let mut changelog_to_files: HashMap<PathBuf, Vec<String>> = HashMap::new();
614 for (file, changelog) in file_to_changelog {
615 changelog_to_files.entry(changelog).or_default().push(file);
616 }
617
618 let boundaries: Vec<ChangelogBoundary> = changelog_to_files
620 .into_iter()
621 .map(|(changelog_path, files)| ChangelogBoundary {
622 changelog_path,
623 files,
624 diff: String::new(), stat: String::new(), })
627 .collect();
628
629 boundaries
630}
631
632fn get_diff_for_files(files: &[String], dir: &str) -> Result<String> {
634 if files.is_empty() {
635 return Ok(String::new());
636 }
637
638 let output = Command::new("git")
639 .args(["diff", "--cached", "--"])
640 .args(files)
641 .current_dir(dir)
642 .output()
643 .map_err(|e| CommitGenError::git(format!("Failed to get diff for files: {e}")))?;
644
645 Ok(String::from_utf8_lossy(&output.stdout).to_string())
646}
647
648fn get_stat_for_files(files: &[String], dir: &str) -> Result<String> {
650 if files.is_empty() {
651 return Ok(String::new());
652 }
653
654 let output = Command::new("git")
655 .args(["diff", "--cached", "--stat", "--"])
656 .args(files)
657 .current_dir(dir)
658 .output()
659 .map_err(|e| CommitGenError::git(format!("Failed to get stat for files: {e}")))?;
660
661 Ok(String::from_utf8_lossy(&output.stdout).to_string())
662}
663
664fn parse_unreleased_section(content: &str, path: &Path) -> Result<UnreleasedSection> {
666 let lines: Vec<&str> = content.lines().collect();
667
668 let header_line = lines
670 .iter()
671 .position(|l| {
672 let trimmed = l.trim().to_lowercase();
673 trimmed.contains("[unreleased]") || trimmed == "## unreleased"
674 })
675 .ok_or_else(|| CommitGenError::NoUnreleasedSection { path: path.display().to_string() })?;
676
677 let end_line = lines
679 .iter()
680 .skip(header_line + 1)
681 .position(|l| {
682 let trimmed = l.trim();
683 trimmed.starts_with("## [") && trimmed.contains(']')
685 || (trimmed.starts_with("## ")
686 && trimmed.chars().nth(3).is_some_and(|c| c.is_ascii_digit()))
687 })
688 .map_or(lines.len(), |pos| header_line + 1 + pos);
689
690 let mut entries: HashMap<ChangelogCategory, Vec<String>> = HashMap::new();
692 let mut current_category: Option<ChangelogCategory> = None;
693
694 for line in &lines[header_line + 1..end_line] {
695 let trimmed = line.trim();
696
697 if trimmed.starts_with("### ") {
699 let cat_name = trimmed.trim_start_matches("### ").trim();
700 current_category = match cat_name.to_lowercase().as_str() {
701 "added" => Some(ChangelogCategory::Added),
702 "changed" => Some(ChangelogCategory::Changed),
703 "fixed" => Some(ChangelogCategory::Fixed),
704 "deprecated" => Some(ChangelogCategory::Deprecated),
705 "removed" => Some(ChangelogCategory::Removed),
706 "security" => Some(ChangelogCategory::Security),
707 "breaking changes" | "breaking" => Some(ChangelogCategory::Breaking),
708 _ => None,
709 };
710 } else if let Some(cat) = current_category {
711 if trimmed.starts_with("- ") || trimmed.starts_with("* ") {
713 entries.entry(cat).or_default().push(trimmed.to_string());
714 }
715 }
716 }
717
718 Ok(UnreleasedSection { header_line, end_line, entries })
719}
720
721fn write_entries(
723 content: &str,
724 unreleased: &UnreleasedSection,
725 new_entries: &HashMap<ChangelogCategory, Vec<String>>,
726) -> String {
727 let lines: Vec<&str> = content.lines().collect();
728
729 let mut result = Vec::new();
731
732 result.extend(
734 lines[..=unreleased.header_line]
735 .iter()
736 .map(|s| s.to_string()),
737 );
738
739 if unreleased.header_line + 1 < lines.len() && !lines[unreleased.header_line + 1].is_empty() {
741 result.push(String::new());
742 }
743
744 for category in ChangelogCategory::render_order() {
746 let new_in_category = new_entries.get(category);
747 let existing_in_category = unreleased.entries.get(category);
748
749 let has_new = new_in_category.is_some_and(|v| !v.is_empty());
750 let has_existing = existing_in_category.is_some_and(|v| !v.is_empty());
751
752 if !has_new && !has_existing {
753 continue;
754 }
755
756 result.push(format!("### {}", category.as_str()));
757 result.push(String::new());
758
759 if let Some(entries) = new_in_category {
761 for entry in entries {
762 if entry.starts_with("- ") || entry.starts_with("* ") {
764 result.push(entry.clone());
765 } else {
766 result.push(format!("- {entry}"));
767 }
768 }
769 }
770
771 if let Some(entries) = existing_in_category {
773 for entry in entries {
774 result.push(entry.clone());
775 }
776 }
777
778 result.push(String::new());
779 }
780
781 if unreleased.end_line < lines.len() {
783 result.extend(lines[unreleased.end_line..].iter().map(|s| s.to_string()));
784 }
785
786 result.join("\n")
787}
788
789#[cfg(test)]
790mod tests {
791 use super::*;
792
793 #[test]
794 fn test_extract_json_from_content_raw() {
795 let content = r#"{"entries": {"Added": ["entry 1"]}}"#;
796 let result = extract_json_from_content(content);
797 assert_eq!(result, r#"{"entries": {"Added": ["entry 1"]}}"#);
798 }
799
800 #[test]
801 fn test_extract_json_from_content_code_block() {
802 let content = r#"Here's the changelog:
803
804```json
805{"entries": {"Added": ["entry 1"]}}
806```
807
808That's all!"#;
809 let result = extract_json_from_content(content);
810 assert_eq!(result, r#"{"entries": {"Added": ["entry 1"]}}"#);
811 }
812
813 #[test]
814 fn test_extract_json_from_content_generic_block() {
815 let content = r#"```
816{"entries": {"Fixed": ["bug fix"]}}
817```"#;
818 let result = extract_json_from_content(content);
819 assert_eq!(result, r#"{"entries": {"Fixed": ["bug fix"]}}"#);
820 }
821
822 #[test]
823 fn test_parse_unreleased_section() {
824 let content = r"# Changelog
825
826## [Unreleased]
827
828### Added
829
830- Feature one
831- Feature two
832
833### Fixed
834
835- Bug fix
836
837## [1.0.0] - 2024-01-01
838
839### Added
840
841- Initial release
842";
843
844 let section = parse_unreleased_section(content, Path::new("CHANGELOG.md")).unwrap();
845 assert_eq!(section.header_line, 2);
846 assert_eq!(section.end_line, 13); assert_eq!(
848 section
849 .entries
850 .get(&ChangelogCategory::Added)
851 .unwrap()
852 .len(),
853 2
854 );
855 assert_eq!(
856 section
857 .entries
858 .get(&ChangelogCategory::Fixed)
859 .unwrap()
860 .len(),
861 1
862 );
863 }
864
865 #[test]
866 fn test_format_existing_entries() {
867 let mut entries = HashMap::new();
868 entries.insert(ChangelogCategory::Added, vec![
869 "- Feature one".to_string(),
870 "- Feature two".to_string(),
871 ]);
872 entries.insert(ChangelogCategory::Fixed, vec!["- Bug fix".to_string()]);
873
874 let unreleased = UnreleasedSection { header_line: 0, end_line: 10, entries };
875
876 let formatted = format_existing_entries(&unreleased).unwrap();
877 assert!(formatted.contains("### Added"));
878 assert!(formatted.contains("- Feature one"));
879 assert!(formatted.contains("### Fixed"));
880 assert!(formatted.contains("- Bug fix"));
881 }
882
883 #[test]
884 fn test_format_existing_entries_empty() {
885 let unreleased =
886 UnreleasedSection { header_line: 0, end_line: 10, entries: HashMap::new() };
887
888 assert!(format_existing_entries(&unreleased).is_none());
889 }
890}