git_branchless_submit/
github.rs

1//! GitHub backend for submitting patch stacks.
2
3use std::collections::HashMap;
4use std::env;
5use std::fmt::{Debug, Write};
6use std::hash::Hash;
7
8use cursive_core::theme::Effect;
9use cursive_core::utils::markup::StyledString;
10use indexmap::IndexMap;
11use itertools::Itertools;
12use lib::core::config::get_main_branch_name;
13use lib::core::dag::CommitSet;
14use lib::core::dag::Dag;
15use lib::core::effects::Effects;
16use lib::core::effects::OperationType;
17use lib::core::eventlog::EventLogDb;
18use lib::core::repo_ext::RepoExt;
19use lib::core::repo_ext::RepoReferencesSnapshot;
20use lib::git::CategorizedReferenceName;
21use lib::git::GitErrorCode;
22use lib::git::GitRunInfo;
23use lib::git::RepoError;
24use lib::git::{BranchType, ConfigRead};
25use lib::git::{NonZeroOid, Repo};
26use lib::try_exit_code;
27use lib::util::ExitCode;
28use lib::util::EyreExitOr;
29
30use tracing::debug;
31use tracing::instrument;
32use tracing::warn;
33
34use crate::branch_forge::BranchForge;
35use crate::SubmitStatus;
36use crate::{CommitStatus, CreateStatus, Forge, SubmitOptions};
37
38/// Testing environment variable. When this is set, the executable will use the
39/// mock Github implementation. This should be set to the path of an existing
40/// repository that represents the remote/Github.
41pub const MOCK_REMOTE_REPO_PATH_ENV_KEY: &str = "BRANCHLESS_SUBMIT_GITHUB_MOCK_REMOTE_REPO_PATH";
42
43fn commit_summary_slug(summary: &str) -> String {
44    let summary_slug: String = summary
45        .chars()
46        .map(|c| if c.is_alphanumeric() { c } else { '-' })
47        .flat_map(|c| c.to_lowercase())
48        .dedup_by(|lhs, rhs| {
49            // Deduplicate adjacent hyphens.
50            *lhs == '-' && *rhs == '-'
51        })
52        .collect();
53    let summary_slug = summary_slug.trim_matches('-');
54    if summary_slug.is_empty() {
55        "to-review".to_string()
56    } else {
57        summary_slug.to_owned()
58    }
59}
60
61fn singleton<K: Debug + Eq + Hash, V: Clone>(
62    map: &HashMap<K, V>,
63    key: K,
64    f: impl Fn(V) -> V,
65) -> HashMap<K, V> {
66    let mut result = HashMap::new();
67    match map.get(&key) {
68        Some(value) => {
69            result.insert(key, f(value.clone()));
70        }
71        None => {
72            warn!(?key, "No match for key in map");
73        }
74    }
75    result
76}
77
78/// Get the name of the remote repository to push to in the course of creating
79/// pull requests.
80///
81/// NOTE: The `gh` command-line utility might infer the push remote if only one
82/// remote is available. This function only returns a remote if it is explicitly
83/// set by the user using `gh repo set-default`.`
84pub fn github_push_remote(repo: &Repo) -> eyre::Result<Option<String>> {
85    let config = repo.get_readonly_config()?;
86    for remote_name in repo.get_all_remote_names()? {
87        // This is set by `gh repo set-default`. Note that `gh` can
88        // sometimes infer which repo to push to without invoking
89        // `gh repo set-default` explicitly. We could probably check
90        // the remote URL to see if it's associated with
91        // `github.com`, but we would still need `gh` to be
92        // installed on the system in that case. The presence of
93        // this value means that `gh` was actively used.
94        //
95        // Possible values seem to be `base` and `other`. See:
96        // https://github.com/search?q=repo%3Acli%2Fcli%20gh-resolved&type=code
97        let gh_resolved: Option<String> =
98            config.get(format!("remote.{remote_name}.gh-resolved"))?;
99        if gh_resolved.as_deref() == Some("base") {
100            return Ok(Some(remote_name));
101        }
102    }
103    Ok(None)
104}
105
106/// The [GitHub](https://en.wikipedia.org/wiki/GitHub) code hosting platform.
107/// This forge integrates specifically with the `gh` command-line utility.
108#[allow(missing_docs)]
109#[derive(Debug)]
110pub struct GithubForge<'a> {
111    pub effects: &'a Effects,
112    pub git_run_info: &'a GitRunInfo,
113    pub repo: &'a Repo,
114    pub event_log_db: &'a EventLogDb<'a>,
115    pub dag: &'a Dag,
116    pub client: Box<dyn client::GithubClient>,
117}
118
119impl Forge for GithubForge<'_> {
120    #[instrument]
121    fn query_status(
122        &mut self,
123        commit_set: CommitSet,
124    ) -> EyreExitOr<HashMap<NonZeroOid, CommitStatus>> {
125        let effects = self.effects;
126        let pull_request_infos =
127            try_exit_code!(self.client.query_repo_pull_request_infos(effects)?);
128        let references_snapshot = self.repo.get_references_snapshot()?;
129
130        let mut result = HashMap::new();
131        for branch in self.repo.get_all_local_branches()? {
132            let local_branch_oid = match branch.get_oid()? {
133                Some(branch_oid) => branch_oid,
134                None => continue,
135            };
136            if !self.dag.set_contains(&commit_set, local_branch_oid)? {
137                continue;
138            }
139
140            let local_branch_name = branch.get_name()?;
141            let remote_name = branch.get_push_remote_name()?;
142            let remote_branch_name = branch.get_upstream_branch_name_without_push_remote_name()?;
143
144            let submit_status = match remote_branch_name
145                .as_ref()
146                .and_then(|remote_branch_name| pull_request_infos.get(remote_branch_name))
147            {
148                None => SubmitStatus::Unsubmitted,
149                Some(pull_request_info) => {
150                    let updated_pull_request_info = try_exit_code!(self
151                        .make_updated_pull_request_info(
152                            effects,
153                            &references_snapshot,
154                            &pull_request_infos,
155                            local_branch_oid
156                        )?);
157                    debug!(
158                        ?pull_request_info,
159                        ?updated_pull_request_info,
160                        "Comparing pull request info"
161                    );
162                    if updated_pull_request_info
163                        .fields_to_update(pull_request_info)
164                        .is_empty()
165                    {
166                        SubmitStatus::UpToDate
167                    } else {
168                        SubmitStatus::NeedsUpdate
169                    }
170                }
171            };
172            result.insert(
173                local_branch_oid,
174                CommitStatus {
175                    submit_status,
176                    remote_name,
177                    local_commit_name: Some(local_branch_name.to_owned()),
178                    remote_commit_name: remote_branch_name,
179                },
180            );
181        }
182
183        for commit_oid in self.dag.commit_set_to_vec(&commit_set)? {
184            result.entry(commit_oid).or_insert(CommitStatus {
185                submit_status: SubmitStatus::Unsubmitted,
186                remote_name: None,
187                local_commit_name: None,
188                remote_commit_name: None,
189            });
190        }
191
192        Ok(Ok(result))
193    }
194
195    #[instrument]
196    fn create(
197        &mut self,
198        commits: HashMap<NonZeroOid, CommitStatus>,
199        options: &SubmitOptions,
200    ) -> EyreExitOr<HashMap<NonZeroOid, CreateStatus>> {
201        let effects = self.effects;
202        let commit_oids = self.dag.sort(&commits.keys().copied().collect())?;
203
204        let references_snapshot = self.repo.get_references_snapshot()?;
205        let mut branch_forge = BranchForge {
206            effects,
207            git_run_info: self.git_run_info,
208            dag: self.dag,
209            repo: self.repo,
210            event_log_db: self.event_log_db,
211            references_snapshot: &references_snapshot,
212        };
213        let push_remote_name = match github_push_remote(self.repo)? {
214            Some(remote_name) => remote_name,
215            None => match self.repo.get_default_push_remote()? {
216                Some(remote_name) => remote_name,
217                None => {
218                    writeln!(
219                        effects.get_output_stream(),
220                        "No default push repository configured. To configure, run: {}",
221                        effects.get_glyphs().render(StyledString::styled(
222                            "gh repo set-default <repo>",
223                            Effect::Bold,
224                        ))?
225                    )?;
226                    return Ok(Err(ExitCode(1)));
227                }
228            },
229        };
230        let github_username = try_exit_code!(self.client.query_github_username(effects)?);
231
232        // Generate branches for all the commits to create.
233        let commits_to_create = commit_oids
234            .into_iter()
235            .map(|commit_oid| (commit_oid, commits.get(&commit_oid).unwrap()))
236            .filter_map(
237                |(commit_oid, commit_status)| match commit_status.submit_status {
238                    SubmitStatus::Local
239                    | SubmitStatus::Unknown
240                    | SubmitStatus::NeedsUpdate
241                    | SubmitStatus::UpToDate => None,
242                    SubmitStatus::Unsubmitted => Some((commit_oid, commit_status)),
243                },
244            )
245            .collect_vec();
246        let mut created_branches = HashMap::new();
247        for (commit_oid, commit_status) in commits_to_create.iter().copied() {
248            let commit = self.repo.find_commit_or_fail(commit_oid)?;
249
250            let local_branch_name = match &commit_status.local_commit_name {
251                Some(local_branch_name) => local_branch_name.clone(),
252                None => {
253                    let summary = commit.get_summary()?;
254                    let summary = String::from_utf8_lossy(&summary);
255                    let summary_slug = commit_summary_slug(&summary);
256                    let new_branch_name_base = format!("{github_username}/{summary_slug}");
257                    let mut new_branch_name = new_branch_name_base.clone();
258                    for i in 2.. {
259                        if i > 6 {
260                            writeln!(
261                                effects.get_output_stream(),
262                                "Could not generate fresh branch name for commit: {}",
263                                effects
264                                    .get_glyphs()
265                                    .render(commit.friendly_describe(effects.get_glyphs())?)?,
266                            )?;
267                            return Ok(Err(ExitCode(1)));
268                        }
269                        match self.repo.find_branch(&new_branch_name, BranchType::Local)? {
270                            Some(_) => {
271                                new_branch_name = format!("{new_branch_name_base}-{i}");
272                            }
273                            None => break,
274                        }
275                    }
276                    match self.repo.create_branch(&new_branch_name, &commit, false) {
277                        Ok(_branch) => {}
278                        Err(RepoError::CreateBranch { source, name: _ })
279                            if source.code() == GitErrorCode::Exists => {}
280                        Err(err) => return Err(err.into()),
281                    };
282                    new_branch_name
283                }
284            };
285
286            let created_branch = try_exit_code!(branch_forge.create(
287                singleton(&commits, commit_oid, |commit_status| CommitStatus {
288                    local_commit_name: Some(local_branch_name.clone()),
289                    ..commit_status.clone()
290                }),
291                options
292            )?);
293            created_branches.extend(created_branch.into_iter());
294        }
295
296        let commit_statuses: HashMap<NonZeroOid, CommitStatus> = commits_to_create
297            .iter()
298            .copied()
299            .map(|(commit_oid, commit_status)| {
300                let commit_status = match created_branches.get(&commit_oid) {
301                    Some(CreateStatus {
302                        final_commit_oid: _,
303                        local_commit_name,
304                    }) => CommitStatus {
305                        // To be updated below:
306                        submit_status: SubmitStatus::NeedsUpdate,
307                        remote_name: Some(push_remote_name.clone()),
308                        local_commit_name: Some(local_commit_name.clone()),
309                        // Expecting this to be the same as the local branch name (for now):
310                        remote_commit_name: Some(local_commit_name.clone()),
311                    },
312                    None => commit_status.clone(),
313                };
314                (commit_oid, commit_status)
315            })
316            .collect();
317
318        // Create the pull requests only after creating all the branches because
319        // we rely on the presence of a branch on each commit in the stack to
320        // know that it should be included/linked in the pull request body.
321        // FIXME: is this actually necessary?
322        for (commit_oid, _) in commits_to_create {
323            let local_branch_name = match commit_statuses.get(&commit_oid) {
324                Some(CommitStatus {
325                    local_commit_name: Some(local_commit_name),
326                    ..
327                }) => local_commit_name,
328                Some(CommitStatus {
329                    local_commit_name: None,
330                    ..
331                })
332                | None => {
333                    writeln!(
334                        effects.get_output_stream(),
335                        "Could not find local branch name for commit: {}",
336                        effects.get_glyphs().render(
337                            self.repo
338                                .find_commit_or_fail(commit_oid)?
339                                .friendly_describe(effects.get_glyphs())?
340                        )?
341                    )?;
342                    return Ok(Err(ExitCode(1)));
343                }
344            };
345
346            let commit = self.repo.find_commit_or_fail(commit_oid)?;
347            let title = String::from_utf8_lossy(&commit.get_summary()?).into_owned();
348            let body = String::from_utf8_lossy(&commit.get_message_pretty()).into_owned();
349            try_exit_code!(self.client.create_pull_request(
350                effects,
351                client::CreatePullRequestArgs {
352                    head_ref_oid: commit_oid,
353                    head_ref_name: local_branch_name.clone(),
354                    title,
355                    body,
356                },
357                options
358            )?);
359        }
360
361        try_exit_code!(self.update(commit_statuses, options)?);
362
363        Ok(Ok(created_branches))
364    }
365
366    #[instrument]
367    fn update(
368        &mut self,
369        commit_statuses: HashMap<NonZeroOid, CommitStatus>,
370        options: &SubmitOptions,
371    ) -> EyreExitOr<()> {
372        let effects = self.effects;
373        let SubmitOptions {
374            create: _,
375            draft: _,
376            execution_strategy: _,
377            num_jobs: _,
378            message: _,
379        } = options;
380
381        let pull_request_infos =
382            try_exit_code!(self.client.query_repo_pull_request_infos(effects)?);
383        let references_snapshot = self.repo.get_references_snapshot()?;
384        let mut branch_forge = BranchForge {
385            effects,
386            git_run_info: self.git_run_info,
387            dag: self.dag,
388            repo: self.repo,
389            event_log_db: self.event_log_db,
390            references_snapshot: &references_snapshot,
391        };
392
393        let commit_set: CommitSet = commit_statuses.keys().copied().collect();
394        let commit_oids = self.dag.sort(&commit_set)?;
395        {
396            let (effects, progress) = effects.start_operation(OperationType::UpdateCommits);
397            progress.notify_progress(0, commit_oids.len());
398            for commit_oid in commit_oids {
399                let commit_status = match commit_statuses.get(&commit_oid) {
400                    Some(commit_status) => commit_status,
401                    None => {
402                        warn!(
403                            ?commit_oid,
404                            ?commit_statuses,
405                            "Commit not found in commit statuses"
406                        );
407                        continue;
408                    }
409                };
410                let remote_branch_name = match &commit_status.remote_commit_name {
411                    Some(remote_branch_name) => remote_branch_name,
412                    None => {
413                        warn!(
414                            ?commit_oid,
415                            ?commit_statuses,
416                            "Commit does not have remote branch name"
417                        );
418                        continue;
419                    }
420                };
421                let pull_request_info = match pull_request_infos.get(remote_branch_name) {
422                    Some(pull_request_info) => pull_request_info,
423                    None => {
424                        warn!(
425                            ?commit_oid,
426                            ?commit_statuses,
427                            "Commit does not have pull request"
428                        );
429                        continue;
430                    }
431                };
432
433                let updated_pull_request_info = try_exit_code!(self
434                    .make_updated_pull_request_info(
435                        &effects,
436                        &references_snapshot,
437                        &pull_request_infos,
438                        commit_oid
439                    )?);
440                let updated_fields = {
441                    let fields = updated_pull_request_info.fields_to_update(pull_request_info);
442                    if fields.is_empty() {
443                        "none (this should not happen)".to_owned()
444                    } else {
445                        fields.join(", ")
446                    }
447                };
448                let client::UpdatePullRequestArgs {
449                    head_ref_oid: _, // Updated by `branch_forge.update`.
450                    base_ref_name,
451                    title,
452                    body,
453                } = updated_pull_request_info;
454                writeln!(
455                    effects.get_output_stream(),
456                    "Updating pull request ({updated_fields}) for commit {}",
457                    effects.get_glyphs().render(
458                        self.repo
459                            .find_commit_or_fail(commit_oid)?
460                            .friendly_describe(effects.get_glyphs())?
461                    )?
462                )?;
463
464                // Make sure to update the branch and metadata at the same time,
465                // rather than all the branches at first. Otherwise, when
466                // reordering commits, GitHub may close one of the pull requests
467                // as it seems to have all the commits of its parent (or
468                // something like that).
469
470                // Push branch:
471                try_exit_code!(
472                    branch_forge.update(singleton(&commit_statuses, commit_oid, |x| x), options)?
473                );
474
475                // Update metdata:
476                try_exit_code!(self.client.update_pull_request(
477                    &effects,
478                    pull_request_info.number,
479                    client::UpdatePullRequestArgs {
480                        head_ref_oid: commit_oid,
481                        base_ref_name,
482                        title,
483                        body,
484                    },
485                    options
486                )?);
487                progress.notify_progress_inc(1);
488            }
489        }
490
491        Ok(Ok(()))
492    }
493}
494
495impl GithubForge<'_> {
496    /// Construct a real or mock GitHub client according to the environment.
497    pub fn client(git_run_info: GitRunInfo) -> Box<dyn client::GithubClient> {
498        match env::var(MOCK_REMOTE_REPO_PATH_ENV_KEY) {
499            Ok(path) => Box::new(client::MockGithubClient {
500                remote_repo_path: path.into(),
501            }),
502            Err(_) => {
503                let GitRunInfo {
504                    path_to_git: _,
505                    working_directory,
506                    env,
507                } = git_run_info;
508                let gh_run_info = GitRunInfo {
509                    path_to_git: "gh".into(),
510                    working_directory: working_directory.clone(),
511                    env: env.clone(),
512                };
513                Box::new(client::RealGithubClient { gh_run_info })
514            }
515        }
516    }
517
518    #[instrument]
519    fn make_updated_pull_request_info(
520        &self,
521        effects: &Effects,
522        references_snapshot: &RepoReferencesSnapshot,
523        pull_request_infos: &HashMap<String, client::PullRequestInfo>,
524        commit_oid: NonZeroOid,
525    ) -> EyreExitOr<client::UpdatePullRequestArgs> {
526        let mut stack_index = None;
527        let mut stack_pull_request_infos: IndexMap<NonZeroOid, &client::PullRequestInfo> =
528            Default::default();
529
530        // Ensure we iterate over the stack in topological order so that the
531        // stack indexes are correct.
532        let stack_commit_oids = self
533            .dag
534            .sort(&self.dag.query_stack_commits(CommitSet::from(commit_oid))?)?;
535        let get_pull_request_info =
536            |commit_oid: NonZeroOid| -> eyre::Result<Option<&client::PullRequestInfo>> {
537                let commit = self.repo.find_commit_or_fail(commit_oid)?; // for debug output
538
539                debug!(?commit, "Checking commit for pull request info");
540                let stack_branch_names =
541                    match references_snapshot.branch_oid_to_names.get(&commit_oid) {
542                        Some(stack_branch_names) => stack_branch_names,
543                        None => {
544                            debug!(?commit, "Commit has no associated branches");
545                            return Ok(None);
546                        }
547                    };
548
549                // The commit should have at most one associated branch with a pull
550                // request.
551                for stack_branch_name in stack_branch_names.iter().sorted() {
552                    let stack_local_branch = match self.repo.find_branch(
553                        &CategorizedReferenceName::new(stack_branch_name).render_suffix(),
554                        BranchType::Local,
555                    )? {
556                        Some(stack_local_branch) => stack_local_branch,
557                        None => {
558                            debug!(
559                                ?commit,
560                                ?stack_branch_name,
561                                "Skipping branch with no local branch"
562                            );
563                            continue;
564                        }
565                    };
566
567                    let stack_remote_branch_name = match stack_local_branch
568                        .get_upstream_branch_name_without_push_remote_name()?
569                    {
570                        Some(stack_remote_branch_name) => stack_remote_branch_name,
571                        None => {
572                            debug!(
573                                ?commit,
574                                ?stack_local_branch,
575                                "Skipping local branch with no remote branch"
576                            );
577                            continue;
578                        }
579                    };
580
581                    let pull_request_info = match pull_request_infos.get(&stack_remote_branch_name)
582                    {
583                        Some(pull_request_info) => pull_request_info,
584                        None => {
585                            debug!(
586                                ?commit,
587                                ?stack_local_branch,
588                                ?stack_remote_branch_name,
589                                "Skipping remote branch with no pull request info"
590                            );
591                            continue;
592                        }
593                    };
594
595                    debug!(
596                        ?commit,
597                        ?pull_request_info,
598                        "Found pull request info for commit"
599                    );
600                    return Ok(Some(pull_request_info));
601                }
602
603                debug!(
604                    ?commit,
605                    "Commit has no branches with associated pull request info"
606                );
607                Ok(None)
608            };
609        for stack_commit_oid in stack_commit_oids {
610            let pull_request_info = match get_pull_request_info(stack_commit_oid)? {
611                Some(info) => info,
612                None => continue,
613            };
614            stack_pull_request_infos.insert(stack_commit_oid, pull_request_info);
615            if stack_commit_oid == commit_oid {
616                stack_index = Some(stack_pull_request_infos.len());
617            }
618        }
619
620        let stack_size = stack_pull_request_infos.len();
621        if stack_size == 0 {
622            warn!(
623                ?commit_oid,
624                ?stack_pull_request_infos,
625                "No pull requests in stack for commit"
626            );
627        }
628        let stack_index = match stack_index {
629            Some(stack_index) => stack_index.to_string(),
630            None => {
631                warn!(
632                    ?commit_oid,
633                    ?stack_pull_request_infos,
634                    "Could not determine index in stack for commit"
635                );
636                "?".to_string()
637            }
638        };
639
640        let stack_list = {
641            let mut result = String::new();
642            for stack_pull_request_info in stack_pull_request_infos.values() {
643                // Github will render a lone pull request URL as a title and
644                // open/closed status.
645                writeln!(result, "* {}", stack_pull_request_info.url)?;
646            }
647            result
648        };
649
650        let commit = self.repo.find_commit_or_fail(commit_oid)?;
651        let commit_summary = commit.get_summary()?;
652        let commit_summary = String::from_utf8_lossy(&commit_summary).into_owned();
653        let title = format!("[{stack_index}/{stack_size}] {commit_summary}");
654        let commit_message = commit.get_message_pretty();
655        let commit_message = String::from_utf8_lossy(&commit_message);
656        let body = format!(
657            "\
658**Stack:**
659
660{stack_list}
661
662---
663
664{commit_message}
665"
666        );
667
668        let stack_ancestor_oids = {
669            let main_branch_oid = CommitSet::from(references_snapshot.main_branch_oid);
670            let stack_ancestor_oids = self
671                .dag
672                .query_only(CommitSet::from(commit_oid), main_branch_oid)?
673                .difference(&CommitSet::from(commit_oid));
674            self.dag.commit_set_to_vec(&stack_ancestor_oids)?
675        };
676        let nearest_ancestor_with_pull_request_info = {
677            let mut result = None;
678            for stack_ancestor_oid in stack_ancestor_oids.into_iter().rev() {
679                if let Some(info) = get_pull_request_info(stack_ancestor_oid)? {
680                    result = Some(info);
681                    break;
682                }
683            }
684            result
685        };
686        let base_ref_name = match nearest_ancestor_with_pull_request_info {
687            Some(info) => info.head_ref_name.clone(),
688            None => get_main_branch_name(self.repo)?,
689        };
690
691        Ok(Ok(client::UpdatePullRequestArgs {
692            head_ref_oid: commit_oid,
693            base_ref_name,
694            title,
695            body,
696        }))
697    }
698}
699
700mod client {
701    use std::collections::{BTreeMap, HashMap};
702    use std::fmt::{Debug, Write};
703    use std::fs::{self, File};
704    use std::path::{Path, PathBuf};
705    use std::process::{Command, Stdio};
706    use std::sync::Arc;
707
708    use eyre::Context;
709    use itertools::Itertools;
710    use lib::core::dag::Dag;
711    use lib::core::effects::{Effects, OperationType};
712    use lib::core::eventlog::{EventLogDb, EventReplayer};
713    use lib::core::formatting::Glyphs;
714    use lib::core::repo_ext::RepoExt;
715    use lib::git::{GitRunInfo, NonZeroOid, Repo, SerializedNonZeroOid};
716    use lib::try_exit_code;
717    use lib::util::{ExitCode, EyreExitOr};
718    use serde::{Deserialize, Serialize};
719    use tempfile::NamedTempFile;
720    use tracing::{debug, instrument};
721
722    use crate::SubmitOptions;
723
724    #[derive(Clone, Debug, Deserialize, Serialize)]
725    pub struct PullRequestInfo {
726        #[serde(rename = "number")]
727        pub number: usize,
728        #[serde(rename = "url")]
729        pub url: String,
730        #[serde(rename = "headRefName")]
731        pub head_ref_name: String,
732        #[serde(rename = "headRefOid")]
733        pub head_ref_oid: SerializedNonZeroOid,
734        #[serde(rename = "baseRefName")]
735        pub base_ref_name: String,
736        #[serde(rename = "closed")]
737        pub closed: bool,
738        #[serde(rename = "isDraft")]
739        pub is_draft: bool,
740        #[serde(rename = "title")]
741        pub title: String,
742        #[serde(rename = "body")]
743        pub body: String,
744    }
745
746    #[derive(Debug)]
747    pub struct CreatePullRequestArgs {
748        pub head_ref_oid: NonZeroOid,
749        pub head_ref_name: String,
750        pub title: String,
751        pub body: String,
752    }
753
754    #[derive(Debug, Eq, PartialEq)]
755    pub struct UpdatePullRequestArgs {
756        pub head_ref_oid: NonZeroOid,
757        pub base_ref_name: String,
758        pub title: String,
759        pub body: String,
760    }
761
762    impl UpdatePullRequestArgs {
763        pub fn fields_to_update(&self, pull_request_info: &PullRequestInfo) -> Vec<&'static str> {
764            let PullRequestInfo {
765                number: _,
766                url: _,
767                head_ref_name: _,
768                head_ref_oid: SerializedNonZeroOid(old_head_ref_oid),
769                base_ref_name: old_base_ref_name,
770                closed: _,
771                is_draft: _,
772                title: old_title,
773                body: old_body,
774            } = pull_request_info;
775            let Self {
776                head_ref_oid: new_head_ref_oid,
777                base_ref_name: new_base_ref_name,
778                title: new_title,
779                body: new_body,
780            } = self;
781
782            let mut updated_fields = Vec::new();
783            if old_head_ref_oid != new_head_ref_oid {
784                updated_fields.push("commit");
785            }
786            if old_base_ref_name != new_base_ref_name {
787                updated_fields.push("base branch");
788            }
789            if old_title != new_title {
790                updated_fields.push("title");
791            }
792            if old_body != new_body {
793                updated_fields.push("body");
794            }
795            updated_fields
796        }
797    }
798
799    pub trait GithubClient: Debug {
800        /// Get the username of the currently-logged-in user.
801        fn query_github_username(&self, effects: &Effects) -> EyreExitOr<String>;
802
803        /// Get the details of all pull requests for the currently-logged-in user in
804        /// the current repository. The resulting map is keyed by remote branch
805        /// name.
806        fn query_repo_pull_request_infos(
807            &self,
808            effects: &Effects,
809        ) -> EyreExitOr<HashMap<String, PullRequestInfo>>;
810
811        fn create_pull_request(
812            &self,
813            effects: &Effects,
814            args: CreatePullRequestArgs,
815            submit_options: &super::SubmitOptions,
816        ) -> EyreExitOr<String>;
817
818        fn update_pull_request(
819            &self,
820            effects: &Effects,
821            number: usize,
822            args: UpdatePullRequestArgs,
823            submit_options: &super::SubmitOptions,
824        ) -> EyreExitOr<()>;
825    }
826
827    #[derive(Debug)]
828    pub struct RealGithubClient {
829        #[allow(dead_code)] // FIXME: destructure and use in `run_gh`?
830        pub gh_run_info: GitRunInfo,
831    }
832
833    impl RealGithubClient {
834        #[instrument]
835        fn run_gh(&self, effects: &Effects, args: &[&str]) -> EyreExitOr<Vec<u8>> {
836            let exe = "gh";
837            let exe_invocation = format!("{exe} {}", args.join(" "));
838            debug!(?exe_invocation, "Invoking gh");
839            let (effects, progress) =
840                effects.start_operation(OperationType::RunTests(Arc::new(exe_invocation.clone())));
841            let _progress = progress;
842
843            let child = Command::new("gh")
844                .args(args)
845                .stdin(Stdio::piped())
846                .stdout(Stdio::piped())
847                .stderr(Stdio::piped())
848                .spawn()
849                .context("Invoking `gh` command-line executable")?;
850            let output = child
851                .wait_with_output()
852                .context("Waiting for `gh` invocation")?;
853            if !output.status.success() {
854                writeln!(
855                    effects.get_output_stream(),
856                    "Call to `{exe_invocation}` failed",
857                )?;
858                writeln!(effects.get_output_stream(), "Stdout:")?;
859                writeln!(
860                    effects.get_output_stream(),
861                    "{}",
862                    String::from_utf8_lossy(&output.stdout)
863                )?;
864                writeln!(effects.get_output_stream(), "Stderr:")?;
865                writeln!(
866                    effects.get_output_stream(),
867                    "{}",
868                    String::from_utf8_lossy(&output.stderr)
869                )?;
870                return Ok(Err(ExitCode::try_from(output.status)?));
871            }
872            Ok(Ok(output.stdout))
873        }
874
875        #[instrument]
876        fn write_body_file(&self, body: &str) -> eyre::Result<NamedTempFile> {
877            use std::io::Write;
878            let mut body_file = NamedTempFile::new()?;
879            body_file.write_all(body.as_bytes())?;
880            body_file.flush()?;
881            Ok(body_file)
882        }
883    }
884
885    impl GithubClient for RealGithubClient {
886        /// Get the username of the currently-logged-in user.
887        #[instrument]
888        fn query_github_username(&self, effects: &Effects) -> EyreExitOr<String> {
889            let username =
890                try_exit_code!(self.run_gh(effects, &["api", "user", "--jq", ".login"])?);
891            let username = String::from_utf8(username)?;
892            let username = username.trim().to_owned();
893            Ok(Ok(username))
894        }
895
896        /// Get the details of all pull requests for the currently-logged-in user in
897        /// the current repository. The resulting map is keyed by remote branch
898        /// name.
899        #[instrument]
900        fn query_repo_pull_request_infos(
901            &self,
902            effects: &Effects,
903        ) -> EyreExitOr<HashMap<String, PullRequestInfo>> {
904            let output = try_exit_code!(self.run_gh(
905                effects,
906                &[
907                    "pr",
908                    "list",
909                    "--author",
910                    "@me",
911                    "--json",
912                    "number,url,headRefName,headRefOid,baseRefName,closed,isDraft,title,body",
913                ]
914            )?);
915            let pull_request_infos: Vec<PullRequestInfo> =
916                serde_json::from_slice(&output).wrap_err("Deserializing output from gh pr list")?;
917            let pull_request_infos = pull_request_infos
918                .into_iter()
919                .map(|item| (item.head_ref_name.clone(), item))
920                .collect();
921            Ok(Ok(pull_request_infos))
922        }
923
924        #[instrument]
925        fn create_pull_request(
926            &self,
927            effects: &Effects,
928            args: CreatePullRequestArgs,
929            submit_options: &SubmitOptions,
930        ) -> EyreExitOr<String> {
931            let CreatePullRequestArgs {
932                head_ref_oid: _,
933                head_ref_name,
934                title,
935                body,
936            } = args;
937            let body_file = self.write_body_file(&body)?;
938            let mut args = vec![
939                "pr",
940                "create",
941                "--head",
942                &head_ref_name,
943                "--title",
944                &title,
945                "--body-file",
946                body_file.path().to_str().unwrap(),
947            ];
948
949            let SubmitOptions {
950                create: _,
951                draft,
952                execution_strategy: _,
953                num_jobs: _,
954                message: _,
955            } = submit_options;
956            if *draft {
957                args.push("--draft");
958            }
959
960            let stdout = try_exit_code!(self.run_gh(effects, &args)?);
961            let pull_request_url = match std::str::from_utf8(&stdout) {
962                Ok(url) => url,
963                Err(err) => {
964                    writeln!(
965                        effects.get_output_stream(),
966                        "Could not parse output from `gh pr create` as UTF-8: {err}",
967                    )?;
968                    return Ok(Err(ExitCode(1)));
969                }
970            };
971            let pull_request_url = pull_request_url.trim();
972            Ok(Ok(pull_request_url.to_owned()))
973        }
974
975        fn update_pull_request(
976            &self,
977            effects: &Effects,
978            number: usize,
979            args: UpdatePullRequestArgs,
980            _submit_options: &super::SubmitOptions,
981        ) -> EyreExitOr<()> {
982            let UpdatePullRequestArgs {
983                head_ref_oid: _, // branch should have been pushed by caller
984                base_ref_name,
985                title,
986                body,
987            } = args;
988            let body_file = self.write_body_file(&body)?;
989            try_exit_code!(self.run_gh(
990                effects,
991                &[
992                    "pr",
993                    "edit",
994                    &number.to_string(),
995                    "--base",
996                    &base_ref_name,
997                    "--title",
998                    &title,
999                    "--body-file",
1000                    (body_file.path().to_str().unwrap()),
1001                ],
1002            )?);
1003            Ok(Ok(()))
1004        }
1005    }
1006
1007    /// The mock state on disk, representing the remote Github repository and
1008    /// server.
1009    #[derive(Debug, Default, Deserialize, Serialize)]
1010    pub struct MockState {
1011        /// The next index to assign a newly-created pull request.
1012        pub pull_request_index: usize,
1013
1014        /// Information about all pull requests open for the repository. Sorted
1015        /// for determinism when dumping state for testing.
1016        pub pull_requests: BTreeMap<String, PullRequestInfo>,
1017    }
1018
1019    impl MockState {
1020        fn load(path: &Path) -> eyre::Result<Self> {
1021            let file = match File::open(path) {
1022                Ok(file) => file,
1023                Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
1024                    return Ok(Default::default());
1025                }
1026                Err(err) => return Err(err).wrap_err("Opening mock GitHub client state file"),
1027            };
1028            let state = serde_json::from_reader(file)?;
1029            Ok(state)
1030        }
1031
1032        fn restore_invariants(&mut self, remote_repo: &Repo) -> eyre::Result<()> {
1033            let effects = Effects::new_suppress_for_test(Glyphs::text());
1034            let conn = remote_repo.get_db_conn()?;
1035            let event_log_db = EventLogDb::new(&conn)?;
1036            let event_replayer =
1037                EventReplayer::from_event_log_db(&effects, remote_repo, &event_log_db)?;
1038            let event_cursor = event_replayer.make_default_cursor();
1039            let references_snapshot = remote_repo.get_references_snapshot()?;
1040            let dag = Dag::open_and_sync(
1041                &effects,
1042                remote_repo,
1043                &event_replayer,
1044                event_cursor,
1045                &references_snapshot,
1046            )?;
1047
1048            let branches: HashMap<String, NonZeroOid> = remote_repo
1049                .get_all_local_branches()?
1050                .into_iter()
1051                .map(|branch| -> eyre::Result<_> {
1052                    let branch_name = branch.get_name()?.to_owned();
1053                    let branch_oid = branch.get_oid()?.unwrap();
1054                    Ok((branch_name, branch_oid))
1055                })
1056                .try_collect()?;
1057            for (_, pull_request_info) in self.pull_requests.iter_mut() {
1058                let base_ref_name = &pull_request_info.base_ref_name;
1059                let base_branch_oid = match branches.get(base_ref_name) {
1060                    Some(oid) => *oid,
1061                    None => {
1062                        eyre::bail!("Could not find base branch {base_ref_name:?} for pull request: {pull_request_info:?}");
1063                    }
1064                };
1065                let SerializedNonZeroOid(head_ref_oid) = pull_request_info.head_ref_oid;
1066                if dag.query_is_ancestor(head_ref_oid, base_branch_oid)? {
1067                    pull_request_info.closed = true;
1068                }
1069            }
1070            Ok(())
1071        }
1072
1073        fn save(&self, path: &Path) -> eyre::Result<()> {
1074            let state = serde_json::to_string_pretty(self)?;
1075            fs::write(path, state)?;
1076            Ok(())
1077        }
1078    }
1079
1080    /// A mock client representing the remote Github repository and server.
1081    #[derive(Debug)]
1082    pub struct MockGithubClient {
1083        /// The path to the remote repository on disk.
1084        pub remote_repo_path: PathBuf,
1085    }
1086
1087    impl GithubClient for MockGithubClient {
1088        fn query_github_username(&self, _effects: &Effects) -> EyreExitOr<String> {
1089            Ok(Ok(Self::username().to_owned()))
1090        }
1091
1092        fn query_repo_pull_request_infos(
1093            &self,
1094            _effects: &Effects,
1095        ) -> EyreExitOr<HashMap<String, PullRequestInfo>> {
1096            let pull_requests_infos = self.with_state_mut(|state| {
1097                let pull_request_infos = state
1098                    .pull_requests
1099                    .values()
1100                    .cloned()
1101                    .map(|pull_request_info| {
1102                        (pull_request_info.head_ref_name.clone(), pull_request_info)
1103                    })
1104                    .collect();
1105                Ok(pull_request_infos)
1106            })?;
1107            Ok(Ok(pull_requests_infos))
1108        }
1109
1110        fn create_pull_request(
1111            &self,
1112            _effects: &Effects,
1113            args: CreatePullRequestArgs,
1114            submit_options: &super::SubmitOptions,
1115        ) -> EyreExitOr<String> {
1116            let url = self.with_state_mut(|state| {
1117                state.pull_request_index += 1;
1118                let CreatePullRequestArgs {
1119                    head_ref_oid,
1120                    head_ref_name,
1121                    title,
1122                    body,
1123                } = args;
1124                let SubmitOptions {
1125                    create,
1126                    draft,
1127                    execution_strategy: _,
1128                    num_jobs: _,
1129                    message: _,
1130                } = submit_options;
1131                assert!(create);
1132                let url = format!(
1133                    "https://example.com/{}/{}/pulls/{}",
1134                    Self::username(),
1135                    Self::repo_name(),
1136                    state.pull_request_index
1137                );
1138                let pull_request_info = PullRequestInfo {
1139                    number: state.pull_request_index,
1140                    url: url.clone(),
1141                    head_ref_name: head_ref_name.clone(),
1142                    head_ref_oid: SerializedNonZeroOid(head_ref_oid),
1143                    base_ref_name: Self::main_branch().to_owned(),
1144                    closed: false,
1145                    is_draft: *draft,
1146                    title,
1147                    body,
1148                };
1149                state.pull_requests.insert(head_ref_name, pull_request_info);
1150                Ok(url)
1151            })?;
1152            Ok(Ok(url))
1153        }
1154
1155        fn update_pull_request(
1156            &self,
1157            _effects: &Effects,
1158            number: usize,
1159            args: UpdatePullRequestArgs,
1160            _submit_options: &super::SubmitOptions,
1161        ) -> EyreExitOr<()> {
1162            self.with_state_mut(|state| -> eyre::Result<()> {
1163                let UpdatePullRequestArgs {
1164                    head_ref_oid,
1165                    base_ref_name,
1166                    title,
1167                    body,
1168                } = args;
1169                let pull_request_info = match state
1170                    .pull_requests
1171                    .values_mut()
1172                    .find(|pull_request_info| pull_request_info.number == number)
1173                {
1174                    Some(pull_request_info) => pull_request_info,
1175                    None => {
1176                        eyre::bail!("Could not find pull request with number {number}");
1177                    }
1178                };
1179                pull_request_info.head_ref_oid = SerializedNonZeroOid(head_ref_oid);
1180                pull_request_info.base_ref_name = base_ref_name;
1181                pull_request_info.title = title;
1182                pull_request_info.body = body;
1183                Ok(())
1184            })?;
1185            Ok(Ok(()))
1186        }
1187    }
1188
1189    impl MockGithubClient {
1190        fn username() -> &'static str {
1191            "mock-github-username"
1192        }
1193
1194        fn repo_name() -> &'static str {
1195            "mock-github-repo"
1196        }
1197
1198        fn main_branch() -> &'static str {
1199            "master"
1200        }
1201
1202        /// Get the path on disk where the mock state is stored.
1203        pub fn state_path(&self) -> PathBuf {
1204            self.remote_repo_path.join("mock-github-client-state.json")
1205        }
1206
1207        /// Load the mock state from disk, run the given function, and then save
1208        /// the state back to disk. Github-specific pull request invariants are
1209        /// restored before and after running the function.
1210        pub fn with_state_mut<T>(
1211            &self,
1212            f: impl FnOnce(&mut MockState) -> eyre::Result<T>,
1213        ) -> eyre::Result<T> {
1214            let repo = Repo::from_dir(&self.remote_repo_path)?;
1215            let state_path = self.state_path();
1216            let mut state = MockState::load(&state_path)?;
1217            state.restore_invariants(&repo)?;
1218            let result = f(&mut state)?;
1219            state.restore_invariants(&repo)?;
1220            state.save(&state_path)?;
1221            Ok(result)
1222        }
1223    }
1224}
1225
1226/// Testing utilities.
1227pub mod testing {
1228    pub use super::client::MockGithubClient;
1229}
1230
1231#[cfg(test)]
1232mod tests {
1233    use super::*;
1234
1235    #[test]
1236    fn test_commit_summary_slug() {
1237        assert_eq!(commit_summary_slug("hello: foo bar"), "hello-foo-bar");
1238        assert_eq!(
1239            commit_summary_slug("category(topic): `foo` bar!"),
1240            "category-topic-foo-bar"
1241        );
1242        assert_eq!(commit_summary_slug("foo_~_bar"), "foo-bar");
1243        assert_eq!(commit_summary_slug("!!!"), "to-review")
1244    }
1245}