git_editor/rewrite/
rewrite_range.rs

1use crate::utils::types::CommitInfo;
2use crate::utils::types::Result;
3use crate::{args::Args, utils::commit_history::get_commit_history};
4use chrono::NaiveDateTime;
5use colored::Colorize;
6use git2::{Repository, Signature, Sort, Time};
7use std::collections::HashMap;
8use std::io::{self, Write};
9
10pub fn parse_range_input(input: &str) -> Result<(usize, usize)> {
11    let parts: Vec<&str> = input.trim().split('-').collect();
12
13    if parts.len() != 2 {
14        return Err("Invalid range format. Use format like '5-11'".into());
15    }
16
17    let start = parts[0]
18        .trim()
19        .parse::<usize>()
20        .map_err(|_| "Invalid start number in range")?;
21    let end = parts[1]
22        .trim()
23        .parse::<usize>()
24        .map_err(|_| "Invalid end number in range")?;
25
26    if start < 1 {
27        return Err("Start position must be 1 or greater".into());
28    }
29
30    if end < start {
31        return Err("End position must be greater than or equal to start position".into());
32    }
33
34    Ok((start, end))
35}
36
37pub fn select_commit_range(commits: &[CommitInfo]) -> Result<(usize, usize)> {
38    println!("\n{}", "Commit History:".bold().green());
39    println!("{}", "-".repeat(80).cyan());
40
41    for (i, commit) in commits.iter().enumerate() {
42        println!(
43            "{:3}. {} {} {} {}",
44            i + 1,
45            commit.short_hash.yellow().bold(),
46            commit
47                .timestamp
48                .format("%Y-%m-%d %H:%M:%S")
49                .to_string()
50                .blue(),
51            commit.author_name.magenta(),
52            commit.message.lines().next().unwrap_or("").white()
53        );
54    }
55
56    println!("{}", "-".repeat(80).cyan());
57    println!(
58        "\n{}",
59        "Enter range in format 'start-end' (e.g., '5-11'):"
60            .bold()
61            .green()
62    );
63    print!("{} ", "Range:".bold());
64    io::stdout().flush()?;
65
66    let mut input = String::new();
67    io::stdin().read_line(&mut input)?;
68
69    let (start, end) = parse_range_input(&input)?;
70
71    if start > commits.len() || end > commits.len() {
72        return Err(format!(
73            "Range out of bounds. Available commits: 1-{}",
74            commits.len()
75        )
76        .into());
77    }
78
79    Ok((start - 1, end - 1)) // Convert to 0-based indexing
80}
81
82pub fn show_range_details(commits: &[CommitInfo], start_idx: usize, end_idx: usize) -> Result<()> {
83    println!("\n{}", "Selected Commit Range:".bold().green());
84    println!("{}", "=".repeat(80).cyan());
85
86    for (idx, commit) in commits[start_idx..=end_idx].iter().enumerate() {
87        println!(
88            "\n{}: {} ({})",
89            format!("Commit {}", start_idx + idx + 1).bold(),
90            commit.short_hash.yellow(),
91            &commit.oid.to_string()[..8]
92        );
93        println!(
94            "{}: {}",
95            "Author".bold(),
96            format!("{} <{}>", commit.author_name, commit.author_email).magenta()
97        );
98        println!(
99            "{}: {}",
100            "Date".bold(),
101            commit
102                .timestamp
103                .format("%Y-%m-%d %H:%M:%S")
104                .to_string()
105                .blue()
106        );
107        println!(
108            "{}: {}",
109            "Message".bold(),
110            commit.message.lines().next().unwrap_or("").white()
111        );
112    }
113
114    println!("\n{}", "=".repeat(80).cyan());
115    println!(
116        "{} {} commits selected for editing",
117        "Total:".bold(),
118        (end_idx - start_idx + 1).to_string().green()
119    );
120
121    Ok(())
122}
123
124pub fn get_range_edit_info(args: &Args) -> Result<(String, String, NaiveDateTime, NaiveDateTime)> {
125    println!("\n{}", "Range Edit Configuration:".bold().green());
126
127    // Get author name
128    let author_name = if let Some(name) = &args.name {
129        name.clone()
130    } else {
131        print!("{} ", "New author name:".bold());
132        io::stdout().flush()?;
133        let mut input = String::new();
134        io::stdin().read_line(&mut input)?;
135        input.trim().to_string()
136    };
137
138    // Get author email
139    let author_email = if let Some(email) = &args.email {
140        email.clone()
141    } else {
142        print!("{} ", "New author email:".bold());
143        io::stdout().flush()?;
144        let mut input = String::new();
145        io::stdin().read_line(&mut input)?;
146        input.trim().to_string()
147    };
148
149    // Get start timestamp
150    let start_timestamp = if let Some(start) = &args.start {
151        NaiveDateTime::parse_from_str(start, "%Y-%m-%d %H:%M:%S")
152            .map_err(|_| "Invalid start timestamp format")?
153    } else {
154        print!("{} ", "Start timestamp (YYYY-MM-DD HH:MM:SS):".bold());
155        io::stdout().flush()?;
156        let mut input = String::new();
157        io::stdin().read_line(&mut input)?;
158        NaiveDateTime::parse_from_str(input.trim(), "%Y-%m-%d %H:%M:%S")
159            .map_err(|_| "Invalid start timestamp format")?
160    };
161
162    // Get end timestamp
163    let end_timestamp = if let Some(end) = &args.end {
164        NaiveDateTime::parse_from_str(end, "%Y-%m-%d %H:%M:%S")
165            .map_err(|_| "Invalid end timestamp format")?
166    } else {
167        print!("{} ", "End timestamp (YYYY-MM-DD HH:MM:SS):".bold());
168        io::stdout().flush()?;
169        let mut input = String::new();
170        io::stdin().read_line(&mut input)?;
171        NaiveDateTime::parse_from_str(input.trim(), "%Y-%m-%d %H:%M:%S")
172            .map_err(|_| "Invalid end timestamp format")?
173    };
174
175    if end_timestamp <= start_timestamp {
176        return Err("End timestamp must be after start timestamp".into());
177    }
178
179    Ok((author_name, author_email, start_timestamp, end_timestamp))
180}
181
182pub fn generate_range_timestamps(
183    start_time: NaiveDateTime,
184    end_time: NaiveDateTime,
185    count: usize,
186) -> Vec<NaiveDateTime> {
187    if count == 0 {
188        return vec![];
189    }
190
191    if count == 1 {
192        return vec![start_time];
193    }
194
195    let total_duration = end_time.signed_duration_since(start_time);
196    let step_duration = total_duration / (count - 1) as i32;
197
198    (0..count)
199        .map(|i| start_time + step_duration * i as i32)
200        .collect()
201}
202
203pub fn rewrite_range_commits(args: &Args) -> Result<()> {
204    let commits = get_commit_history(args, false)?;
205
206    if commits.is_empty() {
207        println!("{}", "No commits found!".red());
208        return Ok(());
209    }
210
211    let (start_idx, end_idx) = select_commit_range(&commits)?;
212    show_range_details(&commits, start_idx, end_idx)?;
213
214    let (author_name, author_email, start_time, end_time) = get_range_edit_info(args)?;
215
216    let range_size = end_idx - start_idx + 1;
217    let timestamps = generate_range_timestamps(start_time, end_time, range_size);
218
219    // Show planned changes
220    println!("\n{}", "Planned Changes:".bold().yellow());
221    for (i, (commit, timestamp)) in commits[start_idx..=end_idx]
222        .iter()
223        .zip(timestamps.iter())
224        .enumerate()
225    {
226        println!(
227            "  {}: {} -> {}",
228            format!("Commit {}", start_idx + i + 1).bold(),
229            format!(
230                "{} <{}> {}",
231                commit.author_name,
232                commit.author_email,
233                commit.timestamp.format("%Y-%m-%d %H:%M:%S")
234            )
235            .red(),
236            format!(
237                "{} <{}> {}",
238                author_name,
239                author_email,
240                timestamp.format("%Y-%m-%d %H:%M:%S")
241            )
242            .green()
243        );
244    }
245
246    print!("\n{} (y/n): ", "Proceed with changes?".bold());
247    io::stdout().flush()?;
248
249    let mut confirm = String::new();
250    io::stdin().read_line(&mut confirm)?;
251
252    if confirm.trim().to_lowercase() != "y" {
253        println!("{}", "Operation cancelled.".yellow());
254        return Ok(());
255    }
256
257    // Apply changes
258    apply_range_changes(
259        args,
260        &commits,
261        start_idx,
262        end_idx,
263        &author_name,
264        &author_email,
265        &timestamps,
266    )?;
267
268    println!("\n{}", "✓ Commit range successfully edited!".green().bold());
269
270    if args.show_history {
271        get_commit_history(args, true)?;
272    }
273
274    Ok(())
275}
276
277fn apply_range_changes(
278    args: &Args,
279    _commits: &[CommitInfo],
280    start_idx: usize,
281    end_idx: usize,
282    author_name: &str,
283    author_email: &str,
284    timestamps: &[NaiveDateTime],
285) -> Result<()> {
286    let repo = Repository::open(args.repo_path.as_ref().unwrap())?;
287    let head_ref = repo.head()?;
288    let branch_name = head_ref
289        .shorthand()
290        .ok_or("Detached HEAD or invalid branch")?;
291    let full_ref = format!("refs/heads/{branch_name}");
292
293    let mut revwalk = repo.revwalk()?;
294    revwalk.push_head()?;
295    revwalk.set_sorting(Sort::TOPOLOGICAL | Sort::TIME)?;
296    let mut orig_oids: Vec<_> = revwalk.filter_map(|id| id.ok()).collect();
297    orig_oids.reverse();
298
299    let mut new_map: HashMap<git2::Oid, git2::Oid> = HashMap::new();
300    let mut last_new_oid = None;
301    let mut range_timestamp_idx = 0;
302
303    for (commit_idx, &oid) in orig_oids.iter().enumerate() {
304        let orig = repo.find_commit(oid)?;
305        let tree = orig.tree()?;
306
307        let new_parents: Result<Vec<_>> = orig
308            .parent_ids()
309            .map(|pid| {
310                let new_pid = *new_map.get(&pid).unwrap_or(&pid);
311                repo.find_commit(new_pid).map_err(|e| e.into())
312            })
313            .collect();
314
315        let new_oid = if commit_idx >= start_idx && commit_idx <= end_idx {
316            // This commit is in our range - update it
317            let timestamp = timestamps[range_timestamp_idx];
318            range_timestamp_idx += 1;
319
320            let sig = Signature::new(
321                author_name,
322                author_email,
323                &Time::new(timestamp.and_utc().timestamp(), 0),
324            )?;
325
326            repo.commit(
327                None,
328                &sig,
329                &sig,
330                orig.message().unwrap_or_default(),
331                &tree,
332                &new_parents?.iter().collect::<Vec<_>>(),
333            )?
334        } else {
335            // Keep other commits as-is but update parent references
336            let author = orig.author();
337            let committer = orig.committer();
338
339            repo.commit(
340                None,
341                &author,
342                &committer,
343                orig.message().unwrap_or_default(),
344                &tree,
345                &new_parents?.iter().collect::<Vec<_>>(),
346            )?
347        };
348
349        new_map.insert(oid, new_oid);
350        last_new_oid = Some(new_oid);
351    }
352
353    if let Some(new_head) = last_new_oid {
354        repo.reference(&full_ref, new_head, true, "edited commit range")?;
355        println!(
356            "{} '{}' -> {}",
357            "Updated branch".green(),
358            branch_name.cyan(),
359            new_head.to_string()[..8].to_string().cyan()
360        );
361    }
362
363    Ok(())
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use std::fs;
370    use tempfile::TempDir;
371
372    fn create_test_repo_with_commits() -> (TempDir, String) {
373        let temp_dir = TempDir::new().unwrap();
374        let repo_path = temp_dir.path().to_str().unwrap().to_string();
375
376        // Initialize git repo
377        let repo = git2::Repository::init(&repo_path).unwrap();
378
379        // Create multiple commits
380        for i in 1..=5 {
381            let file_path = temp_dir.path().join(format!("test{i}.txt"));
382            fs::write(&file_path, format!("test content {i}")).unwrap();
383
384            let mut index = repo.index().unwrap();
385            index
386                .add_path(std::path::Path::new(&format!("test{i}.txt")))
387                .unwrap();
388            index.write().unwrap();
389
390            let tree_id = index.write_tree().unwrap();
391            let tree = repo.find_tree(tree_id).unwrap();
392
393            let sig = git2::Signature::new(
394                "Test User",
395                "test@example.com",
396                &git2::Time::new(1234567890 + i as i64 * 3600, 0),
397            )
398            .unwrap();
399
400            let parents = if i == 1 {
401                vec![]
402            } else {
403                let head = repo.head().unwrap();
404                let parent_commit = head.peel_to_commit().unwrap();
405                vec![parent_commit]
406            };
407
408            repo.commit(
409                Some("HEAD"),
410                &sig,
411                &sig,
412                &format!("Commit {i}"),
413                &tree,
414                &parents.iter().collect::<Vec<_>>(),
415            )
416            .unwrap();
417        }
418
419        (temp_dir, repo_path)
420    }
421
422    #[test]
423    fn test_parse_range_input_valid() {
424        let result = parse_range_input("5-11");
425        assert!(result.is_ok());
426        let (start, end) = result.unwrap();
427        assert_eq!(start, 5);
428        assert_eq!(end, 11);
429    }
430
431    #[test]
432    fn test_parse_range_input_with_spaces() {
433        let result = parse_range_input(" 3 - 8 ");
434        assert!(result.is_ok());
435        let (start, end) = result.unwrap();
436        assert_eq!(start, 3);
437        assert_eq!(end, 8);
438    }
439
440    #[test]
441    fn test_parse_range_input_invalid_format() {
442        let result = parse_range_input("5");
443        assert!(result.is_err());
444
445        let result = parse_range_input("5-11-15");
446        assert!(result.is_err());
447
448        let result = parse_range_input("abc-def");
449        assert!(result.is_err());
450    }
451
452    #[test]
453    fn test_parse_range_input_invalid_range() {
454        let result = parse_range_input("11-5");
455        assert!(result.is_err());
456
457        let result = parse_range_input("0-5");
458        assert!(result.is_err());
459    }
460
461    #[test]
462    fn test_generate_range_timestamps() {
463        let start =
464            NaiveDateTime::parse_from_str("2023-01-01 00:00:00", "%Y-%m-%d %H:%M:%S").unwrap();
465        let end =
466            NaiveDateTime::parse_from_str("2023-01-01 10:00:00", "%Y-%m-%d %H:%M:%S").unwrap();
467
468        let timestamps = generate_range_timestamps(start, end, 5);
469
470        assert_eq!(timestamps.len(), 5);
471        assert_eq!(timestamps[0], start);
472        assert_eq!(timestamps[4], end);
473
474        // Check that timestamps are evenly distributed
475        for i in 1..timestamps.len() {
476            assert!(timestamps[i] >= timestamps[i - 1]);
477        }
478    }
479
480    #[test]
481    fn test_generate_range_timestamps_edge_cases() {
482        let start =
483            NaiveDateTime::parse_from_str("2023-01-01 00:00:00", "%Y-%m-%d %H:%M:%S").unwrap();
484        let end =
485            NaiveDateTime::parse_from_str("2023-01-01 10:00:00", "%Y-%m-%d %H:%M:%S").unwrap();
486
487        // Zero count
488        let timestamps = generate_range_timestamps(start, end, 0);
489        assert_eq!(timestamps.len(), 0);
490
491        // Single timestamp
492        let timestamps = generate_range_timestamps(start, end, 1);
493        assert_eq!(timestamps.len(), 1);
494        assert_eq!(timestamps[0], start);
495    }
496
497    #[test]
498    fn test_rewrite_range_commits_with_repo() {
499        let (_temp_dir, repo_path) = create_test_repo_with_commits();
500        let args = Args {
501            repo_path: Some(repo_path),
502            email: Some("new@example.com".to_string()),
503            name: Some("New User".to_string()),
504            start: Some("2023-01-01 00:00:00".to_string()),
505            end: Some("2023-01-01 10:00:00".to_string()),
506            show_history: false,
507            pic_specific_commits: false,
508            range: false,
509        };
510
511        // Test that get_commit_history returns commits for this repo
512        let commits = get_commit_history(&args, false).unwrap();
513        assert_eq!(commits.len(), 5);
514
515        // Test range validation
516        let (start, end) = (0, 2); // 0-based indexing
517        assert!(start <= end);
518        assert!(end < commits.len());
519
520        // Test timestamp generation
521        let start_time =
522            NaiveDateTime::parse_from_str("2023-01-01 00:00:00", "%Y-%m-%d %H:%M:%S").unwrap();
523        let end_time =
524            NaiveDateTime::parse_from_str("2023-01-01 10:00:00", "%Y-%m-%d %H:%M:%S").unwrap();
525        let timestamps = generate_range_timestamps(start_time, end_time, 3);
526        assert_eq!(timestamps.len(), 3);
527    }
528}