git_x/
upstream.rs

1use crate::cli::UpstreamAction;
2use crate::command::Command;
3use crate::core::git::{GitOperations, RemoteOperations};
4use crate::{GitXError, Result};
5use std::collections::HashMap;
6use std::process::Command as StdCommand;
7
8pub fn run(action: UpstreamAction) -> Result<()> {
9    let cmd = UpstreamCommand;
10    cmd.execute(action)
11}
12
13/// Command implementation for git upstream
14pub struct UpstreamCommand;
15
16impl Command for UpstreamCommand {
17    type Input = UpstreamAction;
18    type Output = ();
19
20    fn execute(&self, action: UpstreamAction) -> Result<()> {
21        run_upstream(action)
22    }
23
24    fn name(&self) -> &'static str {
25        "upstream"
26    }
27
28    fn description(&self) -> &'static str {
29        "Manage upstream tracking for branches"
30    }
31}
32
33fn run_upstream(action: UpstreamAction) -> Result<()> {
34    match action {
35        UpstreamAction::Set { upstream } => set_upstream(upstream),
36        UpstreamAction::Status => show_upstream_status(),
37        UpstreamAction::SyncAll { dry_run, merge } => sync_all_branches(dry_run, merge),
38    }
39}
40
41fn set_upstream(upstream: String) -> Result<()> {
42    // Validate upstream format
43    validate_upstream_format(&upstream).map_err(|e| GitXError::GitCommand(e.to_string()))?;
44
45    // Check if upstream exists
46    validate_upstream_exists(&upstream).map_err(|e| GitXError::GitCommand(e.to_string()))?;
47
48    // Get current branch
49    let current_branch = GitOperations::current_branch()
50        .map_err(|e| GitXError::GitCommand(format!("Failed to get current branch: {e}")))?;
51
52    println!(
53        "๐Ÿ”— Setting upstream for '{}' to '{}'...",
54        &current_branch, &upstream
55    );
56
57    // Set upstream
58    RemoteOperations::set_upstream(&current_branch, &upstream)
59        .map_err(|e| GitXError::GitCommand(e.to_string()))?;
60
61    println!(
62        "โœ… Upstream for '{}' set to '{}'",
63        &current_branch, &upstream
64    );
65    Ok(())
66}
67
68fn show_upstream_status() -> Result<()> {
69    // Get all local branches
70    let branches = GitOperations::local_branches()
71        .map_err(|e| GitXError::GitCommand(format!("Failed to get local branches: {e}")))?;
72
73    if branches.is_empty() {
74        println!("โ„น๏ธ No local branches found");
75        return Ok(());
76    }
77
78    // Get upstream info for each branch
79    let mut branch_upstreams = HashMap::new();
80    for branch in &branches {
81        if let Ok(upstream) = get_branch_upstream(branch) {
82            branch_upstreams.insert(branch.clone(), Some(upstream));
83        } else {
84            branch_upstreams.insert(branch.clone(), None);
85        }
86    }
87
88    // Get current branch for highlighting
89    let current_branch = GitOperations::current_branch().unwrap_or_default();
90
91    println!("๐Ÿ”— Upstream status for all branches:\n");
92
93    for branch in &branches {
94        let is_current = branch == &current_branch;
95        let upstream = branch_upstreams.get(branch).unwrap();
96
97        match upstream {
98            Some(upstream_ref) => {
99                // Check sync status
100                let sync_status =
101                    get_branch_sync_status(branch, upstream_ref).unwrap_or(SyncStatus::Unknown);
102
103                println!(
104                    "{}",
105                    format_branch_with_upstream(branch, upstream_ref, &sync_status, is_current)
106                );
107            }
108            None => {
109                println!("{}", format_branch_without_upstream(branch, is_current));
110            }
111        }
112    }
113    Ok(())
114}
115
116fn sync_all_branches(dry_run: bool, merge: bool) -> Result<()> {
117    // Get all branches with upstreams
118    let branches =
119        get_branches_with_upstreams().map_err(|e| GitXError::GitCommand(e.to_string()))?;
120
121    if branches.is_empty() {
122        println!("{}", format_no_upstream_branches_message());
123        return Ok(());
124    }
125
126    println!(
127        "{}",
128        format_sync_all_start_message(branches.len(), dry_run, merge)
129    );
130
131    let mut sync_results = Vec::new();
132
133    for (branch, upstream) in &branches {
134        let sync_status = match get_branch_sync_status(branch, upstream) {
135            Ok(status) => status,
136            Err(_) => {
137                sync_results.push((
138                    branch.clone(),
139                    SyncResult::Error("Failed to get sync status".to_string()),
140                ));
141                continue;
142            }
143        };
144
145        match sync_status {
146            SyncStatus::UpToDate => {
147                sync_results.push((branch.clone(), SyncResult::UpToDate));
148            }
149            SyncStatus::Behind(_) | SyncStatus::Diverged(_, _) => {
150                if dry_run {
151                    sync_results.push((branch.clone(), SyncResult::WouldSync));
152                } else {
153                    match sync_branch_with_upstream(branch, upstream, merge) {
154                        Ok(()) => sync_results.push((branch.clone(), SyncResult::Synced)),
155                        Err(msg) => {
156                            sync_results.push((branch.clone(), SyncResult::Error(msg.to_string())))
157                        }
158                    }
159                }
160            }
161            SyncStatus::Ahead(_) => {
162                sync_results.push((branch.clone(), SyncResult::Ahead));
163            }
164            SyncStatus::Unknown => {
165                sync_results.push((
166                    branch.clone(),
167                    SyncResult::Error("Unknown sync status".to_string()),
168                ));
169            }
170        }
171    }
172
173    // Print results
174    println!("{}", format_sync_results_header());
175    for (branch, result) in &sync_results {
176        println!("{}", format_sync_result_line(branch, result));
177    }
178
179    // Print summary
180    let synced_count = sync_results
181        .iter()
182        .filter(|(_, r)| matches!(r, SyncResult::Synced | SyncResult::WouldSync))
183        .count();
184    println!("{}", format_sync_summary(synced_count, dry_run));
185    Ok(())
186}
187
188#[derive(Debug, Clone)]
189pub enum SyncStatus {
190    UpToDate,
191    Behind(u32),
192    Ahead(u32),
193    Diverged(u32, u32), // behind, ahead
194    Unknown,
195}
196
197#[derive(Debug, Clone)]
198pub enum SyncResult {
199    UpToDate,
200    Synced,
201    WouldSync,
202    Ahead,
203    Error(String),
204}
205
206pub fn validate_upstream_format(upstream: &str) -> Result<()> {
207    if upstream.is_empty() {
208        return Err(GitXError::GitCommand(
209            "Upstream cannot be empty".to_string(),
210        ));
211    }
212
213    if !upstream.contains('/') {
214        return Err(GitXError::GitCommand(
215            "Upstream must be in format 'remote/branch' (e.g., origin/main)".to_string(),
216        ));
217    }
218
219    let parts: Vec<&str> = upstream.split('/').collect();
220    if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
221        return Err(GitXError::GitCommand(
222            "Invalid upstream format. Use 'remote/branch' format".to_string(),
223        ));
224    }
225
226    Ok(())
227}
228
229pub fn validate_upstream_exists(upstream: &str) -> Result<()> {
230    let output = StdCommand::new("git")
231        .args(["rev-parse", "--verify", upstream])
232        .output()
233        .map_err(GitXError::Io)?;
234
235    if !output.status.success() {
236        return Err(GitXError::GitCommand(
237            "Upstream branch does not exist".to_string(),
238        ));
239    }
240
241    Ok(())
242}
243
244pub fn get_all_local_branches() -> Result<Vec<String>> {
245    let output = StdCommand::new("git")
246        .args(["branch", "--format=%(refname:short)"])
247        .output()
248        .map_err(GitXError::Io)?;
249
250    if !output.status.success() {
251        return Err(GitXError::GitCommand(
252            "Failed to list local branches".to_string(),
253        ));
254    }
255
256    let stdout = String::from_utf8_lossy(&output.stdout);
257    let branches: Vec<String> = stdout
258        .lines()
259        .map(|line| line.trim().to_string())
260        .filter(|line| !line.is_empty())
261        .collect();
262
263    Ok(branches)
264}
265
266pub fn get_branch_upstream(branch: &str) -> Result<String> {
267    let output = StdCommand::new("git")
268        .args(["rev-parse", "--abbrev-ref", &format!("{branch}@{{u}}")])
269        .output()
270        .map_err(GitXError::Io)?;
271
272    if !output.status.success() {
273        return Err(GitXError::GitCommand("No upstream configured".to_string()));
274    }
275
276    Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
277}
278
279pub fn get_branch_sync_status(branch: &str, upstream: &str) -> Result<SyncStatus> {
280    let output = StdCommand::new("git")
281        .args([
282            "rev-list",
283            "--left-right",
284            "--count",
285            &format!("{upstream}...{branch}"),
286        ])
287        .output()
288        .map_err(GitXError::Io)?;
289
290    if !output.status.success() {
291        return Err(GitXError::GitCommand(
292            "Failed to compare with upstream".to_string(),
293        ));
294    }
295
296    let counts = String::from_utf8_lossy(&output.stdout);
297    let mut parts = counts.split_whitespace();
298
299    let behind: u32 = parts
300        .next()
301        .and_then(|s| s.parse().ok())
302        .ok_or_else(|| GitXError::Parse("Invalid sync count format".to_string()))?;
303
304    let ahead: u32 = parts
305        .next()
306        .and_then(|s| s.parse().ok())
307        .ok_or_else(|| GitXError::Parse("Invalid sync count format".to_string()))?;
308
309    Ok(match (behind, ahead) {
310        (0, 0) => SyncStatus::UpToDate,
311        (b, 0) if b > 0 => SyncStatus::Behind(b),
312        (0, a) if a > 0 => SyncStatus::Ahead(a),
313        (b, a) if b > 0 && a > 0 => SyncStatus::Diverged(b, a),
314        _ => SyncStatus::Unknown,
315    })
316}
317
318pub fn get_branches_with_upstreams() -> Result<Vec<(String, String)>> {
319    let branches = get_all_local_branches()?;
320    let mut result = Vec::new();
321
322    for branch in branches {
323        if let Ok(upstream) = get_branch_upstream(&branch) {
324            result.push((branch, upstream));
325        }
326    }
327
328    Ok(result)
329}
330
331fn sync_branch_with_upstream(branch: &str, upstream: &str, merge: bool) -> Result<()> {
332    // Switch to the branch first
333    let status = StdCommand::new("git")
334        .args(["checkout", branch])
335        .status()
336        .map_err(GitXError::Io)?;
337
338    if !status.success() {
339        return Err(GitXError::GitCommand(
340            "Failed to checkout branch".to_string(),
341        ));
342    }
343
344    // Sync with upstream
345    let args = if merge {
346        ["merge", upstream]
347    } else {
348        ["rebase", upstream]
349    };
350
351    let status = StdCommand::new("git")
352        .args(args)
353        .status()
354        .map_err(GitXError::Io)?;
355
356    if !status.success() {
357        return Err(if merge {
358            GitXError::GitCommand("Merge failed".to_string())
359        } else {
360            GitXError::GitCommand("Rebase failed".to_string())
361        });
362    }
363
364    Ok(())
365}
366
367pub fn format_upstream_set_message(branch: &str, upstream: &str) -> String {
368    format!("โœ… Upstream for '{branch}' set to '{upstream}'")
369}
370
371pub fn format_branch_with_upstream(
372    branch: &str,
373    upstream: &str,
374    sync_status: &SyncStatus,
375    is_current: bool,
376) -> String {
377    let current_indicator = if is_current { "* " } else { "  " };
378    let status_text = match sync_status {
379        SyncStatus::UpToDate => "โœ… up-to-date",
380        SyncStatus::Behind(n) => &format!("โฌ‡๏ธ {n} behind"),
381        SyncStatus::Ahead(n) => &format!("โฌ†๏ธ {n} ahead"),
382        SyncStatus::Diverged(b, a) => &format!("๐Ÿ”€ {b} behind, {a} ahead"),
383        SyncStatus::Unknown => "โ“ unknown",
384    };
385
386    format!("{current_indicator}{branch} -> {upstream} ({status_text})")
387}
388
389pub fn format_branch_without_upstream(branch: &str, is_current: bool) -> String {
390    let current_indicator = if is_current { "* " } else { "  " };
391    format!("{current_indicator}{branch} -> (no upstream)")
392}
393
394pub fn format_no_upstream_branches_message() -> &'static str {
395    "โ„น๏ธ No branches with upstream configuration found"
396}
397
398pub fn format_sync_all_start_message(count: usize, dry_run: bool, merge: bool) -> String {
399    let action = if merge { "merge" } else { "rebase" };
400    if dry_run {
401        format!("๐Ÿงช (dry run) Would sync {count} branch(es) with upstream using {action}:")
402    } else {
403        format!("๐Ÿ”„ Syncing {count} branch(es) with upstream using {action}:")
404    }
405}
406
407pub fn format_sync_results_header() -> &'static str {
408    "\n๐Ÿ“Š Sync results:"
409}
410
411pub fn format_sync_result_line(branch: &str, result: &SyncResult) -> String {
412    match result {
413        SyncResult::UpToDate => format!("  โœ… {branch}: already up-to-date"),
414        SyncResult::Synced => format!("  โœ… {branch}: synced successfully"),
415        SyncResult::WouldSync => format!("  ๐Ÿ”„ {branch}: would be synced"),
416        SyncResult::Ahead => format!("  โฌ†๏ธ {branch}: ahead of upstream (skipped)"),
417        SyncResult::Error(msg) => format!("  โŒ {branch}: {msg}"),
418    }
419}
420
421pub fn format_sync_summary(synced_count: usize, dry_run: bool) -> String {
422    if dry_run {
423        format!(
424            "\n๐Ÿ’ก Would sync {synced_count} branch(es). Run without --dry-run to apply changes."
425        )
426    } else {
427        format!("\nโœ… Synced {synced_count} branch(es) successfully.")
428    }
429}