1use crate::cli::UpstreamAction;
2use crate::command::Command;
3use crate::core::git::{GitOperations, RemoteOperations};
4use crate::{GitXError, Result};
5use std::collections::HashMap;
6use std::process::Command as StdCommand;
7
8pub fn run(action: UpstreamAction) -> Result<()> {
9 let cmd = UpstreamCommand;
10 cmd.execute(action)
11}
12
13pub struct UpstreamCommand;
15
16impl Command for UpstreamCommand {
17 type Input = UpstreamAction;
18 type Output = ();
19
20 fn execute(&self, action: UpstreamAction) -> Result<()> {
21 run_upstream(action)
22 }
23
24 fn name(&self) -> &'static str {
25 "upstream"
26 }
27
28 fn description(&self) -> &'static str {
29 "Manage upstream tracking for branches"
30 }
31}
32
33fn run_upstream(action: UpstreamAction) -> Result<()> {
34 match action {
35 UpstreamAction::Set { upstream } => set_upstream(upstream),
36 UpstreamAction::Status => show_upstream_status(),
37 UpstreamAction::SyncAll { dry_run, merge } => sync_all_branches(dry_run, merge),
38 }
39}
40
41fn set_upstream(upstream: String) -> Result<()> {
42 validate_upstream_format(&upstream).map_err(|e| GitXError::GitCommand(e.to_string()))?;
44
45 validate_upstream_exists(&upstream).map_err(|e| GitXError::GitCommand(e.to_string()))?;
47
48 let current_branch = GitOperations::current_branch()
50 .map_err(|e| GitXError::GitCommand(format!("Failed to get current branch: {e}")))?;
51
52 println!(
53 "๐ Setting upstream for '{}' to '{}'...",
54 ¤t_branch, &upstream
55 );
56
57 RemoteOperations::set_upstream(¤t_branch, &upstream)
59 .map_err(|e| GitXError::GitCommand(e.to_string()))?;
60
61 println!(
62 "โ
Upstream for '{}' set to '{}'",
63 ¤t_branch, &upstream
64 );
65 Ok(())
66}
67
68fn show_upstream_status() -> Result<()> {
69 let branches = GitOperations::local_branches()
71 .map_err(|e| GitXError::GitCommand(format!("Failed to get local branches: {e}")))?;
72
73 if branches.is_empty() {
74 println!("โน๏ธ No local branches found");
75 return Ok(());
76 }
77
78 let mut branch_upstreams = HashMap::new();
80 for branch in &branches {
81 if let Ok(upstream) = get_branch_upstream(branch) {
82 branch_upstreams.insert(branch.clone(), Some(upstream));
83 } else {
84 branch_upstreams.insert(branch.clone(), None);
85 }
86 }
87
88 let current_branch = GitOperations::current_branch().unwrap_or_default();
90
91 println!("๐ Upstream status for all branches:\n");
92
93 for branch in &branches {
94 let is_current = branch == ¤t_branch;
95 let upstream = branch_upstreams.get(branch).unwrap();
96
97 match upstream {
98 Some(upstream_ref) => {
99 let sync_status =
101 get_branch_sync_status(branch, upstream_ref).unwrap_or(SyncStatus::Unknown);
102
103 println!(
104 "{}",
105 format_branch_with_upstream(branch, upstream_ref, &sync_status, is_current)
106 );
107 }
108 None => {
109 println!("{}", format_branch_without_upstream(branch, is_current));
110 }
111 }
112 }
113 Ok(())
114}
115
116fn sync_all_branches(dry_run: bool, merge: bool) -> Result<()> {
117 let branches =
119 get_branches_with_upstreams().map_err(|e| GitXError::GitCommand(e.to_string()))?;
120
121 if branches.is_empty() {
122 println!("{}", format_no_upstream_branches_message());
123 return Ok(());
124 }
125
126 println!(
127 "{}",
128 format_sync_all_start_message(branches.len(), dry_run, merge)
129 );
130
131 let mut sync_results = Vec::new();
132
133 for (branch, upstream) in &branches {
134 let sync_status = match get_branch_sync_status(branch, upstream) {
135 Ok(status) => status,
136 Err(_) => {
137 sync_results.push((
138 branch.clone(),
139 SyncResult::Error("Failed to get sync status".to_string()),
140 ));
141 continue;
142 }
143 };
144
145 match sync_status {
146 SyncStatus::UpToDate => {
147 sync_results.push((branch.clone(), SyncResult::UpToDate));
148 }
149 SyncStatus::Behind(_) | SyncStatus::Diverged(_, _) => {
150 if dry_run {
151 sync_results.push((branch.clone(), SyncResult::WouldSync));
152 } else {
153 match sync_branch_with_upstream(branch, upstream, merge) {
154 Ok(()) => sync_results.push((branch.clone(), SyncResult::Synced)),
155 Err(msg) => {
156 sync_results.push((branch.clone(), SyncResult::Error(msg.to_string())))
157 }
158 }
159 }
160 }
161 SyncStatus::Ahead(_) => {
162 sync_results.push((branch.clone(), SyncResult::Ahead));
163 }
164 SyncStatus::Unknown => {
165 sync_results.push((
166 branch.clone(),
167 SyncResult::Error("Unknown sync status".to_string()),
168 ));
169 }
170 }
171 }
172
173 println!("{}", format_sync_results_header());
175 for (branch, result) in &sync_results {
176 println!("{}", format_sync_result_line(branch, result));
177 }
178
179 let synced_count = sync_results
181 .iter()
182 .filter(|(_, r)| matches!(r, SyncResult::Synced | SyncResult::WouldSync))
183 .count();
184 println!("{}", format_sync_summary(synced_count, dry_run));
185 Ok(())
186}
187
188#[derive(Debug, Clone)]
189pub enum SyncStatus {
190 UpToDate,
191 Behind(u32),
192 Ahead(u32),
193 Diverged(u32, u32), Unknown,
195}
196
197#[derive(Debug, Clone)]
198pub enum SyncResult {
199 UpToDate,
200 Synced,
201 WouldSync,
202 Ahead,
203 Error(String),
204}
205
206pub fn validate_upstream_format(upstream: &str) -> Result<()> {
207 if upstream.is_empty() {
208 return Err(GitXError::GitCommand(
209 "Upstream cannot be empty".to_string(),
210 ));
211 }
212
213 if !upstream.contains('/') {
214 return Err(GitXError::GitCommand(
215 "Upstream must be in format 'remote/branch' (e.g., origin/main)".to_string(),
216 ));
217 }
218
219 let parts: Vec<&str> = upstream.split('/').collect();
220 if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
221 return Err(GitXError::GitCommand(
222 "Invalid upstream format. Use 'remote/branch' format".to_string(),
223 ));
224 }
225
226 Ok(())
227}
228
229pub fn validate_upstream_exists(upstream: &str) -> Result<()> {
230 let output = StdCommand::new("git")
231 .args(["rev-parse", "--verify", upstream])
232 .output()
233 .map_err(GitXError::Io)?;
234
235 if !output.status.success() {
236 return Err(GitXError::GitCommand(
237 "Upstream branch does not exist".to_string(),
238 ));
239 }
240
241 Ok(())
242}
243
244pub fn get_all_local_branches() -> Result<Vec<String>> {
245 let output = StdCommand::new("git")
246 .args(["branch", "--format=%(refname:short)"])
247 .output()
248 .map_err(GitXError::Io)?;
249
250 if !output.status.success() {
251 return Err(GitXError::GitCommand(
252 "Failed to list local branches".to_string(),
253 ));
254 }
255
256 let stdout = String::from_utf8_lossy(&output.stdout);
257 let branches: Vec<String> = stdout
258 .lines()
259 .map(|line| line.trim().to_string())
260 .filter(|line| !line.is_empty())
261 .collect();
262
263 Ok(branches)
264}
265
266pub fn get_branch_upstream(branch: &str) -> Result<String> {
267 let output = StdCommand::new("git")
268 .args(["rev-parse", "--abbrev-ref", &format!("{branch}@{{u}}")])
269 .output()
270 .map_err(GitXError::Io)?;
271
272 if !output.status.success() {
273 return Err(GitXError::GitCommand("No upstream configured".to_string()));
274 }
275
276 Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
277}
278
279pub fn get_branch_sync_status(branch: &str, upstream: &str) -> Result<SyncStatus> {
280 let output = StdCommand::new("git")
281 .args([
282 "rev-list",
283 "--left-right",
284 "--count",
285 &format!("{upstream}...{branch}"),
286 ])
287 .output()
288 .map_err(GitXError::Io)?;
289
290 if !output.status.success() {
291 return Err(GitXError::GitCommand(
292 "Failed to compare with upstream".to_string(),
293 ));
294 }
295
296 let counts = String::from_utf8_lossy(&output.stdout);
297 let mut parts = counts.split_whitespace();
298
299 let behind: u32 = parts
300 .next()
301 .and_then(|s| s.parse().ok())
302 .ok_or_else(|| GitXError::Parse("Invalid sync count format".to_string()))?;
303
304 let ahead: u32 = parts
305 .next()
306 .and_then(|s| s.parse().ok())
307 .ok_or_else(|| GitXError::Parse("Invalid sync count format".to_string()))?;
308
309 Ok(match (behind, ahead) {
310 (0, 0) => SyncStatus::UpToDate,
311 (b, 0) if b > 0 => SyncStatus::Behind(b),
312 (0, a) if a > 0 => SyncStatus::Ahead(a),
313 (b, a) if b > 0 && a > 0 => SyncStatus::Diverged(b, a),
314 _ => SyncStatus::Unknown,
315 })
316}
317
318pub fn get_branches_with_upstreams() -> Result<Vec<(String, String)>> {
319 let branches = get_all_local_branches()?;
320 let mut result = Vec::new();
321
322 for branch in branches {
323 if let Ok(upstream) = get_branch_upstream(&branch) {
324 result.push((branch, upstream));
325 }
326 }
327
328 Ok(result)
329}
330
331fn sync_branch_with_upstream(branch: &str, upstream: &str, merge: bool) -> Result<()> {
332 let status = StdCommand::new("git")
334 .args(["checkout", branch])
335 .status()
336 .map_err(GitXError::Io)?;
337
338 if !status.success() {
339 return Err(GitXError::GitCommand(
340 "Failed to checkout branch".to_string(),
341 ));
342 }
343
344 let args = if merge {
346 ["merge", upstream]
347 } else {
348 ["rebase", upstream]
349 };
350
351 let status = StdCommand::new("git")
352 .args(args)
353 .status()
354 .map_err(GitXError::Io)?;
355
356 if !status.success() {
357 return Err(if merge {
358 GitXError::GitCommand("Merge failed".to_string())
359 } else {
360 GitXError::GitCommand("Rebase failed".to_string())
361 });
362 }
363
364 Ok(())
365}
366
367pub fn format_upstream_set_message(branch: &str, upstream: &str) -> String {
368 format!("โ
Upstream for '{branch}' set to '{upstream}'")
369}
370
371pub fn format_branch_with_upstream(
372 branch: &str,
373 upstream: &str,
374 sync_status: &SyncStatus,
375 is_current: bool,
376) -> String {
377 let current_indicator = if is_current { "* " } else { " " };
378 let status_text = match sync_status {
379 SyncStatus::UpToDate => "โ
up-to-date",
380 SyncStatus::Behind(n) => &format!("โฌ๏ธ {n} behind"),
381 SyncStatus::Ahead(n) => &format!("โฌ๏ธ {n} ahead"),
382 SyncStatus::Diverged(b, a) => &format!("๐ {b} behind, {a} ahead"),
383 SyncStatus::Unknown => "โ unknown",
384 };
385
386 format!("{current_indicator}{branch} -> {upstream} ({status_text})")
387}
388
389pub fn format_branch_without_upstream(branch: &str, is_current: bool) -> String {
390 let current_indicator = if is_current { "* " } else { " " };
391 format!("{current_indicator}{branch} -> (no upstream)")
392}
393
394pub fn format_no_upstream_branches_message() -> &'static str {
395 "โน๏ธ No branches with upstream configuration found"
396}
397
398pub fn format_sync_all_start_message(count: usize, dry_run: bool, merge: bool) -> String {
399 let action = if merge { "merge" } else { "rebase" };
400 if dry_run {
401 format!("๐งช (dry run) Would sync {count} branch(es) with upstream using {action}:")
402 } else {
403 format!("๐ Syncing {count} branch(es) with upstream using {action}:")
404 }
405}
406
407pub fn format_sync_results_header() -> &'static str {
408 "\n๐ Sync results:"
409}
410
411pub fn format_sync_result_line(branch: &str, result: &SyncResult) -> String {
412 match result {
413 SyncResult::UpToDate => format!(" โ
{branch}: already up-to-date"),
414 SyncResult::Synced => format!(" โ
{branch}: synced successfully"),
415 SyncResult::WouldSync => format!(" ๐ {branch}: would be synced"),
416 SyncResult::Ahead => format!(" โฌ๏ธ {branch}: ahead of upstream (skipped)"),
417 SyncResult::Error(msg) => format!(" โ {branch}: {msg}"),
418 }
419}
420
421pub fn format_sync_summary(synced_count: usize, dry_run: bool) -> String {
422 if dry_run {
423 format!(
424 "\n๐ก Would sync {synced_count} branch(es). Run without --dry-run to apply changes."
425 )
426 } else {
427 format!("\nโ
Synced {synced_count} branch(es) successfully.")
428 }
429}