1use std::collections::HashMap;
4use std::env;
5use std::fmt::{Debug, Write};
6use std::hash::Hash;
7
8use cursive_core::theme::Effect;
9use cursive_core::utils::markup::StyledString;
10use indexmap::IndexMap;
11use itertools::Itertools;
12use lib::core::config::get_main_branch_name;
13use lib::core::dag::CommitSet;
14use lib::core::dag::Dag;
15use lib::core::effects::Effects;
16use lib::core::effects::OperationType;
17use lib::core::eventlog::EventLogDb;
18use lib::core::repo_ext::RepoExt;
19use lib::core::repo_ext::RepoReferencesSnapshot;
20use lib::git::CategorizedReferenceName;
21use lib::git::GitErrorCode;
22use lib::git::GitRunInfo;
23use lib::git::RepoError;
24use lib::git::{BranchType, ConfigRead};
25use lib::git::{NonZeroOid, Repo};
26use lib::try_exit_code;
27use lib::util::ExitCode;
28use lib::util::EyreExitOr;
29
30use tracing::debug;
31use tracing::instrument;
32use tracing::warn;
33
34use crate::branch_forge::BranchForge;
35use crate::SubmitStatus;
36use crate::{CommitStatus, CreateStatus, Forge, SubmitOptions};
37
38pub const MOCK_REMOTE_REPO_PATH_ENV_KEY: &str = "BRANCHLESS_SUBMIT_GITHUB_MOCK_REMOTE_REPO_PATH";
42
43fn commit_summary_slug(summary: &str) -> String {
44 let summary_slug: String = summary
45 .chars()
46 .map(|c| if c.is_alphanumeric() { c } else { '-' })
47 .flat_map(|c| c.to_lowercase())
48 .dedup_by(|lhs, rhs| {
49 *lhs == '-' && *rhs == '-'
51 })
52 .collect();
53 let summary_slug = summary_slug.trim_matches('-');
54 if summary_slug.is_empty() {
55 "to-review".to_string()
56 } else {
57 summary_slug.to_owned()
58 }
59}
60
61fn singleton<K: Debug + Eq + Hash, V: Clone>(
62 map: &HashMap<K, V>,
63 key: K,
64 f: impl Fn(V) -> V,
65) -> HashMap<K, V> {
66 let mut result = HashMap::new();
67 match map.get(&key) {
68 Some(value) => {
69 result.insert(key, f(value.clone()));
70 }
71 None => {
72 warn!(?key, "No match for key in map");
73 }
74 }
75 result
76}
77
78pub fn github_push_remote(repo: &Repo) -> eyre::Result<Option<String>> {
85 let config = repo.get_readonly_config()?;
86 for remote_name in repo.get_all_remote_names()? {
87 let gh_resolved: Option<String> =
98 config.get(format!("remote.{remote_name}.gh-resolved"))?;
99 if gh_resolved.as_deref() == Some("base") {
100 return Ok(Some(remote_name));
101 }
102 }
103 Ok(None)
104}
105
106#[allow(missing_docs)]
109#[derive(Debug)]
110pub struct GithubForge<'a> {
111 pub effects: &'a Effects,
112 pub git_run_info: &'a GitRunInfo,
113 pub repo: &'a Repo,
114 pub event_log_db: &'a EventLogDb<'a>,
115 pub dag: &'a Dag,
116 pub client: Box<dyn client::GithubClient>,
117}
118
119impl Forge for GithubForge<'_> {
120 #[instrument]
121 fn query_status(
122 &mut self,
123 commit_set: CommitSet,
124 ) -> EyreExitOr<HashMap<NonZeroOid, CommitStatus>> {
125 let effects = self.effects;
126 let pull_request_infos =
127 try_exit_code!(self.client.query_repo_pull_request_infos(effects)?);
128 let references_snapshot = self.repo.get_references_snapshot()?;
129
130 let mut result = HashMap::new();
131 for branch in self.repo.get_all_local_branches()? {
132 let local_branch_oid = match branch.get_oid()? {
133 Some(branch_oid) => branch_oid,
134 None => continue,
135 };
136 if !self.dag.set_contains(&commit_set, local_branch_oid)? {
137 continue;
138 }
139
140 let local_branch_name = branch.get_name()?;
141 let remote_name = branch.get_push_remote_name()?;
142 let remote_branch_name = branch.get_upstream_branch_name_without_push_remote_name()?;
143
144 let submit_status = match remote_branch_name
145 .as_ref()
146 .and_then(|remote_branch_name| pull_request_infos.get(remote_branch_name))
147 {
148 None => SubmitStatus::Unsubmitted,
149 Some(pull_request_info) => {
150 let updated_pull_request_info = try_exit_code!(self
151 .make_updated_pull_request_info(
152 effects,
153 &references_snapshot,
154 &pull_request_infos,
155 local_branch_oid
156 )?);
157 debug!(
158 ?pull_request_info,
159 ?updated_pull_request_info,
160 "Comparing pull request info"
161 );
162 if updated_pull_request_info
163 .fields_to_update(pull_request_info)
164 .is_empty()
165 {
166 SubmitStatus::UpToDate
167 } else {
168 SubmitStatus::NeedsUpdate
169 }
170 }
171 };
172 result.insert(
173 local_branch_oid,
174 CommitStatus {
175 submit_status,
176 remote_name,
177 local_commit_name: Some(local_branch_name.to_owned()),
178 remote_commit_name: remote_branch_name,
179 },
180 );
181 }
182
183 for commit_oid in self.dag.commit_set_to_vec(&commit_set)? {
184 result.entry(commit_oid).or_insert(CommitStatus {
185 submit_status: SubmitStatus::Unsubmitted,
186 remote_name: None,
187 local_commit_name: None,
188 remote_commit_name: None,
189 });
190 }
191
192 Ok(Ok(result))
193 }
194
195 #[instrument]
196 fn create(
197 &mut self,
198 commits: HashMap<NonZeroOid, CommitStatus>,
199 options: &SubmitOptions,
200 ) -> EyreExitOr<HashMap<NonZeroOid, CreateStatus>> {
201 let effects = self.effects;
202 let commit_oids = self.dag.sort(&commits.keys().copied().collect())?;
203
204 let references_snapshot = self.repo.get_references_snapshot()?;
205 let mut branch_forge = BranchForge {
206 effects,
207 git_run_info: self.git_run_info,
208 dag: self.dag,
209 repo: self.repo,
210 event_log_db: self.event_log_db,
211 references_snapshot: &references_snapshot,
212 };
213 let push_remote_name = match github_push_remote(self.repo)? {
214 Some(remote_name) => remote_name,
215 None => match self.repo.get_default_push_remote()? {
216 Some(remote_name) => remote_name,
217 None => {
218 writeln!(
219 effects.get_output_stream(),
220 "No default push repository configured. To configure, run: {}",
221 effects.get_glyphs().render(StyledString::styled(
222 "gh repo set-default <repo>",
223 Effect::Bold,
224 ))?
225 )?;
226 return Ok(Err(ExitCode(1)));
227 }
228 },
229 };
230 let github_username = try_exit_code!(self.client.query_github_username(effects)?);
231
232 let commits_to_create = commit_oids
234 .into_iter()
235 .map(|commit_oid| (commit_oid, commits.get(&commit_oid).unwrap()))
236 .filter_map(
237 |(commit_oid, commit_status)| match commit_status.submit_status {
238 SubmitStatus::Local
239 | SubmitStatus::Unknown
240 | SubmitStatus::NeedsUpdate
241 | SubmitStatus::UpToDate => None,
242 SubmitStatus::Unsubmitted => Some((commit_oid, commit_status)),
243 },
244 )
245 .collect_vec();
246 let mut created_branches = HashMap::new();
247 for (commit_oid, commit_status) in commits_to_create.iter().copied() {
248 let commit = self.repo.find_commit_or_fail(commit_oid)?;
249
250 let local_branch_name = match &commit_status.local_commit_name {
251 Some(local_branch_name) => local_branch_name.clone(),
252 None => {
253 let summary = commit.get_summary()?;
254 let summary = String::from_utf8_lossy(&summary);
255 let summary_slug = commit_summary_slug(&summary);
256 let new_branch_name_base = format!("{github_username}/{summary_slug}");
257 let mut new_branch_name = new_branch_name_base.clone();
258 for i in 2.. {
259 if i > 6 {
260 writeln!(
261 effects.get_output_stream(),
262 "Could not generate fresh branch name for commit: {}",
263 effects
264 .get_glyphs()
265 .render(commit.friendly_describe(effects.get_glyphs())?)?,
266 )?;
267 return Ok(Err(ExitCode(1)));
268 }
269 match self.repo.find_branch(&new_branch_name, BranchType::Local)? {
270 Some(_) => {
271 new_branch_name = format!("{new_branch_name_base}-{i}");
272 }
273 None => break,
274 }
275 }
276 match self.repo.create_branch(&new_branch_name, &commit, false) {
277 Ok(_branch) => {}
278 Err(RepoError::CreateBranch { source, name: _ })
279 if source.code() == GitErrorCode::Exists => {}
280 Err(err) => return Err(err.into()),
281 };
282 new_branch_name
283 }
284 };
285
286 let created_branch = try_exit_code!(branch_forge.create(
287 singleton(&commits, commit_oid, |commit_status| CommitStatus {
288 local_commit_name: Some(local_branch_name.clone()),
289 ..commit_status.clone()
290 }),
291 options
292 )?);
293 created_branches.extend(created_branch.into_iter());
294 }
295
296 let commit_statuses: HashMap<NonZeroOid, CommitStatus> = commits_to_create
297 .iter()
298 .copied()
299 .map(|(commit_oid, commit_status)| {
300 let commit_status = match created_branches.get(&commit_oid) {
301 Some(CreateStatus {
302 final_commit_oid: _,
303 local_commit_name,
304 }) => CommitStatus {
305 submit_status: SubmitStatus::NeedsUpdate,
307 remote_name: Some(push_remote_name.clone()),
308 local_commit_name: Some(local_commit_name.clone()),
309 remote_commit_name: Some(local_commit_name.clone()),
311 },
312 None => commit_status.clone(),
313 };
314 (commit_oid, commit_status)
315 })
316 .collect();
317
318 for (commit_oid, _) in commits_to_create {
323 let local_branch_name = match commit_statuses.get(&commit_oid) {
324 Some(CommitStatus {
325 local_commit_name: Some(local_commit_name),
326 ..
327 }) => local_commit_name,
328 Some(CommitStatus {
329 local_commit_name: None,
330 ..
331 })
332 | None => {
333 writeln!(
334 effects.get_output_stream(),
335 "Could not find local branch name for commit: {}",
336 effects.get_glyphs().render(
337 self.repo
338 .find_commit_or_fail(commit_oid)?
339 .friendly_describe(effects.get_glyphs())?
340 )?
341 )?;
342 return Ok(Err(ExitCode(1)));
343 }
344 };
345
346 let commit = self.repo.find_commit_or_fail(commit_oid)?;
347 let title = String::from_utf8_lossy(&commit.get_summary()?).into_owned();
348 let body = String::from_utf8_lossy(&commit.get_message_pretty()).into_owned();
349 try_exit_code!(self.client.create_pull_request(
350 effects,
351 client::CreatePullRequestArgs {
352 head_ref_oid: commit_oid,
353 head_ref_name: local_branch_name.clone(),
354 title,
355 body,
356 },
357 options
358 )?);
359 }
360
361 try_exit_code!(self.update(commit_statuses, options)?);
362
363 Ok(Ok(created_branches))
364 }
365
366 #[instrument]
367 fn update(
368 &mut self,
369 commit_statuses: HashMap<NonZeroOid, CommitStatus>,
370 options: &SubmitOptions,
371 ) -> EyreExitOr<()> {
372 let effects = self.effects;
373 let SubmitOptions {
374 create: _,
375 draft: _,
376 execution_strategy: _,
377 num_jobs: _,
378 message: _,
379 } = options;
380
381 let pull_request_infos =
382 try_exit_code!(self.client.query_repo_pull_request_infos(effects)?);
383 let references_snapshot = self.repo.get_references_snapshot()?;
384 let mut branch_forge = BranchForge {
385 effects,
386 git_run_info: self.git_run_info,
387 dag: self.dag,
388 repo: self.repo,
389 event_log_db: self.event_log_db,
390 references_snapshot: &references_snapshot,
391 };
392
393 let commit_set: CommitSet = commit_statuses.keys().copied().collect();
394 let commit_oids = self.dag.sort(&commit_set)?;
395 {
396 let (effects, progress) = effects.start_operation(OperationType::UpdateCommits);
397 progress.notify_progress(0, commit_oids.len());
398 for commit_oid in commit_oids {
399 let commit_status = match commit_statuses.get(&commit_oid) {
400 Some(commit_status) => commit_status,
401 None => {
402 warn!(
403 ?commit_oid,
404 ?commit_statuses,
405 "Commit not found in commit statuses"
406 );
407 continue;
408 }
409 };
410 let remote_branch_name = match &commit_status.remote_commit_name {
411 Some(remote_branch_name) => remote_branch_name,
412 None => {
413 warn!(
414 ?commit_oid,
415 ?commit_statuses,
416 "Commit does not have remote branch name"
417 );
418 continue;
419 }
420 };
421 let pull_request_info = match pull_request_infos.get(remote_branch_name) {
422 Some(pull_request_info) => pull_request_info,
423 None => {
424 warn!(
425 ?commit_oid,
426 ?commit_statuses,
427 "Commit does not have pull request"
428 );
429 continue;
430 }
431 };
432
433 let updated_pull_request_info = try_exit_code!(self
434 .make_updated_pull_request_info(
435 &effects,
436 &references_snapshot,
437 &pull_request_infos,
438 commit_oid
439 )?);
440 let updated_fields = {
441 let fields = updated_pull_request_info.fields_to_update(pull_request_info);
442 if fields.is_empty() {
443 "none (this should not happen)".to_owned()
444 } else {
445 fields.join(", ")
446 }
447 };
448 let client::UpdatePullRequestArgs {
449 head_ref_oid: _, base_ref_name,
451 title,
452 body,
453 } = updated_pull_request_info;
454 writeln!(
455 effects.get_output_stream(),
456 "Updating pull request ({updated_fields}) for commit {}",
457 effects.get_glyphs().render(
458 self.repo
459 .find_commit_or_fail(commit_oid)?
460 .friendly_describe(effects.get_glyphs())?
461 )?
462 )?;
463
464 try_exit_code!(
472 branch_forge.update(singleton(&commit_statuses, commit_oid, |x| x), options)?
473 );
474
475 try_exit_code!(self.client.update_pull_request(
477 &effects,
478 pull_request_info.number,
479 client::UpdatePullRequestArgs {
480 head_ref_oid: commit_oid,
481 base_ref_name,
482 title,
483 body,
484 },
485 options
486 )?);
487 progress.notify_progress_inc(1);
488 }
489 }
490
491 Ok(Ok(()))
492 }
493}
494
495impl GithubForge<'_> {
496 pub fn client(git_run_info: GitRunInfo) -> Box<dyn client::GithubClient> {
498 match env::var(MOCK_REMOTE_REPO_PATH_ENV_KEY) {
499 Ok(path) => Box::new(client::MockGithubClient {
500 remote_repo_path: path.into(),
501 }),
502 Err(_) => {
503 let GitRunInfo {
504 path_to_git: _,
505 working_directory,
506 env,
507 } = git_run_info;
508 let gh_run_info = GitRunInfo {
509 path_to_git: "gh".into(),
510 working_directory: working_directory.clone(),
511 env: env.clone(),
512 };
513 Box::new(client::RealGithubClient { gh_run_info })
514 }
515 }
516 }
517
518 #[instrument]
519 fn make_updated_pull_request_info(
520 &self,
521 effects: &Effects,
522 references_snapshot: &RepoReferencesSnapshot,
523 pull_request_infos: &HashMap<String, client::PullRequestInfo>,
524 commit_oid: NonZeroOid,
525 ) -> EyreExitOr<client::UpdatePullRequestArgs> {
526 let mut stack_index = None;
527 let mut stack_pull_request_infos: IndexMap<NonZeroOid, &client::PullRequestInfo> =
528 Default::default();
529
530 let stack_commit_oids = self
533 .dag
534 .sort(&self.dag.query_stack_commits(CommitSet::from(commit_oid))?)?;
535 let get_pull_request_info =
536 |commit_oid: NonZeroOid| -> eyre::Result<Option<&client::PullRequestInfo>> {
537 let commit = self.repo.find_commit_or_fail(commit_oid)?; debug!(?commit, "Checking commit for pull request info");
540 let stack_branch_names =
541 match references_snapshot.branch_oid_to_names.get(&commit_oid) {
542 Some(stack_branch_names) => stack_branch_names,
543 None => {
544 debug!(?commit, "Commit has no associated branches");
545 return Ok(None);
546 }
547 };
548
549 for stack_branch_name in stack_branch_names.iter().sorted() {
552 let stack_local_branch = match self.repo.find_branch(
553 &CategorizedReferenceName::new(stack_branch_name).render_suffix(),
554 BranchType::Local,
555 )? {
556 Some(stack_local_branch) => stack_local_branch,
557 None => {
558 debug!(
559 ?commit,
560 ?stack_branch_name,
561 "Skipping branch with no local branch"
562 );
563 continue;
564 }
565 };
566
567 let stack_remote_branch_name = match stack_local_branch
568 .get_upstream_branch_name_without_push_remote_name()?
569 {
570 Some(stack_remote_branch_name) => stack_remote_branch_name,
571 None => {
572 debug!(
573 ?commit,
574 ?stack_local_branch,
575 "Skipping local branch with no remote branch"
576 );
577 continue;
578 }
579 };
580
581 let pull_request_info = match pull_request_infos.get(&stack_remote_branch_name)
582 {
583 Some(pull_request_info) => pull_request_info,
584 None => {
585 debug!(
586 ?commit,
587 ?stack_local_branch,
588 ?stack_remote_branch_name,
589 "Skipping remote branch with no pull request info"
590 );
591 continue;
592 }
593 };
594
595 debug!(
596 ?commit,
597 ?pull_request_info,
598 "Found pull request info for commit"
599 );
600 return Ok(Some(pull_request_info));
601 }
602
603 debug!(
604 ?commit,
605 "Commit has no branches with associated pull request info"
606 );
607 Ok(None)
608 };
609 for stack_commit_oid in stack_commit_oids {
610 let pull_request_info = match get_pull_request_info(stack_commit_oid)? {
611 Some(info) => info,
612 None => continue,
613 };
614 stack_pull_request_infos.insert(stack_commit_oid, pull_request_info);
615 if stack_commit_oid == commit_oid {
616 stack_index = Some(stack_pull_request_infos.len());
617 }
618 }
619
620 let stack_size = stack_pull_request_infos.len();
621 if stack_size == 0 {
622 warn!(
623 ?commit_oid,
624 ?stack_pull_request_infos,
625 "No pull requests in stack for commit"
626 );
627 }
628 let stack_index = match stack_index {
629 Some(stack_index) => stack_index.to_string(),
630 None => {
631 warn!(
632 ?commit_oid,
633 ?stack_pull_request_infos,
634 "Could not determine index in stack for commit"
635 );
636 "?".to_string()
637 }
638 };
639
640 let stack_list = {
641 let mut result = String::new();
642 for stack_pull_request_info in stack_pull_request_infos.values() {
643 writeln!(result, "* {}", stack_pull_request_info.url)?;
646 }
647 result
648 };
649
650 let commit = self.repo.find_commit_or_fail(commit_oid)?;
651 let commit_summary = commit.get_summary()?;
652 let commit_summary = String::from_utf8_lossy(&commit_summary).into_owned();
653 let title = format!("[{stack_index}/{stack_size}] {commit_summary}");
654 let commit_message = commit.get_message_pretty();
655 let commit_message = String::from_utf8_lossy(&commit_message);
656 let body = format!(
657 "\
658**Stack:**
659
660{stack_list}
661
662---
663
664{commit_message}
665"
666 );
667
668 let stack_ancestor_oids = {
669 let main_branch_oid = CommitSet::from(references_snapshot.main_branch_oid);
670 let stack_ancestor_oids = self
671 .dag
672 .query_only(CommitSet::from(commit_oid), main_branch_oid)?
673 .difference(&CommitSet::from(commit_oid));
674 self.dag.commit_set_to_vec(&stack_ancestor_oids)?
675 };
676 let nearest_ancestor_with_pull_request_info = {
677 let mut result = None;
678 for stack_ancestor_oid in stack_ancestor_oids.into_iter().rev() {
679 if let Some(info) = get_pull_request_info(stack_ancestor_oid)? {
680 result = Some(info);
681 break;
682 }
683 }
684 result
685 };
686 let base_ref_name = match nearest_ancestor_with_pull_request_info {
687 Some(info) => info.head_ref_name.clone(),
688 None => get_main_branch_name(self.repo)?,
689 };
690
691 Ok(Ok(client::UpdatePullRequestArgs {
692 head_ref_oid: commit_oid,
693 base_ref_name,
694 title,
695 body,
696 }))
697 }
698}
699
700mod client {
701 use std::collections::{BTreeMap, HashMap};
702 use std::fmt::{Debug, Write};
703 use std::fs::{self, File};
704 use std::path::{Path, PathBuf};
705 use std::process::{Command, Stdio};
706 use std::sync::Arc;
707
708 use eyre::Context;
709 use itertools::Itertools;
710 use lib::core::dag::Dag;
711 use lib::core::effects::{Effects, OperationType};
712 use lib::core::eventlog::{EventLogDb, EventReplayer};
713 use lib::core::formatting::Glyphs;
714 use lib::core::repo_ext::RepoExt;
715 use lib::git::{GitRunInfo, NonZeroOid, Repo, SerializedNonZeroOid};
716 use lib::try_exit_code;
717 use lib::util::{ExitCode, EyreExitOr};
718 use serde::{Deserialize, Serialize};
719 use tempfile::NamedTempFile;
720 use tracing::{debug, instrument};
721
722 use crate::SubmitOptions;
723
724 #[derive(Clone, Debug, Deserialize, Serialize)]
725 pub struct PullRequestInfo {
726 #[serde(rename = "number")]
727 pub number: usize,
728 #[serde(rename = "url")]
729 pub url: String,
730 #[serde(rename = "headRefName")]
731 pub head_ref_name: String,
732 #[serde(rename = "headRefOid")]
733 pub head_ref_oid: SerializedNonZeroOid,
734 #[serde(rename = "baseRefName")]
735 pub base_ref_name: String,
736 #[serde(rename = "closed")]
737 pub closed: bool,
738 #[serde(rename = "isDraft")]
739 pub is_draft: bool,
740 #[serde(rename = "title")]
741 pub title: String,
742 #[serde(rename = "body")]
743 pub body: String,
744 }
745
746 #[derive(Debug)]
747 pub struct CreatePullRequestArgs {
748 pub head_ref_oid: NonZeroOid,
749 pub head_ref_name: String,
750 pub title: String,
751 pub body: String,
752 }
753
754 #[derive(Debug, Eq, PartialEq)]
755 pub struct UpdatePullRequestArgs {
756 pub head_ref_oid: NonZeroOid,
757 pub base_ref_name: String,
758 pub title: String,
759 pub body: String,
760 }
761
762 impl UpdatePullRequestArgs {
763 pub fn fields_to_update(&self, pull_request_info: &PullRequestInfo) -> Vec<&'static str> {
764 let PullRequestInfo {
765 number: _,
766 url: _,
767 head_ref_name: _,
768 head_ref_oid: SerializedNonZeroOid(old_head_ref_oid),
769 base_ref_name: old_base_ref_name,
770 closed: _,
771 is_draft: _,
772 title: old_title,
773 body: old_body,
774 } = pull_request_info;
775 let Self {
776 head_ref_oid: new_head_ref_oid,
777 base_ref_name: new_base_ref_name,
778 title: new_title,
779 body: new_body,
780 } = self;
781
782 let mut updated_fields = Vec::new();
783 if old_head_ref_oid != new_head_ref_oid {
784 updated_fields.push("commit");
785 }
786 if old_base_ref_name != new_base_ref_name {
787 updated_fields.push("base branch");
788 }
789 if old_title != new_title {
790 updated_fields.push("title");
791 }
792 if old_body != new_body {
793 updated_fields.push("body");
794 }
795 updated_fields
796 }
797 }
798
799 pub trait GithubClient: Debug {
800 fn query_github_username(&self, effects: &Effects) -> EyreExitOr<String>;
802
803 fn query_repo_pull_request_infos(
807 &self,
808 effects: &Effects,
809 ) -> EyreExitOr<HashMap<String, PullRequestInfo>>;
810
811 fn create_pull_request(
812 &self,
813 effects: &Effects,
814 args: CreatePullRequestArgs,
815 submit_options: &super::SubmitOptions,
816 ) -> EyreExitOr<String>;
817
818 fn update_pull_request(
819 &self,
820 effects: &Effects,
821 number: usize,
822 args: UpdatePullRequestArgs,
823 submit_options: &super::SubmitOptions,
824 ) -> EyreExitOr<()>;
825 }
826
827 #[derive(Debug)]
828 pub struct RealGithubClient {
829 #[allow(dead_code)] pub gh_run_info: GitRunInfo,
831 }
832
833 impl RealGithubClient {
834 #[instrument]
835 fn run_gh(&self, effects: &Effects, args: &[&str]) -> EyreExitOr<Vec<u8>> {
836 let exe = "gh";
837 let exe_invocation = format!("{exe} {}", args.join(" "));
838 debug!(?exe_invocation, "Invoking gh");
839 let (effects, progress) =
840 effects.start_operation(OperationType::RunTests(Arc::new(exe_invocation.clone())));
841 let _progress = progress;
842
843 let child = Command::new("gh")
844 .args(args)
845 .stdin(Stdio::piped())
846 .stdout(Stdio::piped())
847 .stderr(Stdio::piped())
848 .spawn()
849 .context("Invoking `gh` command-line executable")?;
850 let output = child
851 .wait_with_output()
852 .context("Waiting for `gh` invocation")?;
853 if !output.status.success() {
854 writeln!(
855 effects.get_output_stream(),
856 "Call to `{exe_invocation}` failed",
857 )?;
858 writeln!(effects.get_output_stream(), "Stdout:")?;
859 writeln!(
860 effects.get_output_stream(),
861 "{}",
862 String::from_utf8_lossy(&output.stdout)
863 )?;
864 writeln!(effects.get_output_stream(), "Stderr:")?;
865 writeln!(
866 effects.get_output_stream(),
867 "{}",
868 String::from_utf8_lossy(&output.stderr)
869 )?;
870 return Ok(Err(ExitCode::try_from(output.status)?));
871 }
872 Ok(Ok(output.stdout))
873 }
874
875 #[instrument]
876 fn write_body_file(&self, body: &str) -> eyre::Result<NamedTempFile> {
877 use std::io::Write;
878 let mut body_file = NamedTempFile::new()?;
879 body_file.write_all(body.as_bytes())?;
880 body_file.flush()?;
881 Ok(body_file)
882 }
883 }
884
885 impl GithubClient for RealGithubClient {
886 #[instrument]
888 fn query_github_username(&self, effects: &Effects) -> EyreExitOr<String> {
889 let username =
890 try_exit_code!(self.run_gh(effects, &["api", "user", "--jq", ".login"])?);
891 let username = String::from_utf8(username)?;
892 let username = username.trim().to_owned();
893 Ok(Ok(username))
894 }
895
896 #[instrument]
900 fn query_repo_pull_request_infos(
901 &self,
902 effects: &Effects,
903 ) -> EyreExitOr<HashMap<String, PullRequestInfo>> {
904 let output = try_exit_code!(self.run_gh(
905 effects,
906 &[
907 "pr",
908 "list",
909 "--author",
910 "@me",
911 "--json",
912 "number,url,headRefName,headRefOid,baseRefName,closed,isDraft,title,body",
913 ]
914 )?);
915 let pull_request_infos: Vec<PullRequestInfo> =
916 serde_json::from_slice(&output).wrap_err("Deserializing output from gh pr list")?;
917 let pull_request_infos = pull_request_infos
918 .into_iter()
919 .map(|item| (item.head_ref_name.clone(), item))
920 .collect();
921 Ok(Ok(pull_request_infos))
922 }
923
924 #[instrument]
925 fn create_pull_request(
926 &self,
927 effects: &Effects,
928 args: CreatePullRequestArgs,
929 submit_options: &SubmitOptions,
930 ) -> EyreExitOr<String> {
931 let CreatePullRequestArgs {
932 head_ref_oid: _,
933 head_ref_name,
934 title,
935 body,
936 } = args;
937 let body_file = self.write_body_file(&body)?;
938 let mut args = vec![
939 "pr",
940 "create",
941 "--head",
942 &head_ref_name,
943 "--title",
944 &title,
945 "--body-file",
946 body_file.path().to_str().unwrap(),
947 ];
948
949 let SubmitOptions {
950 create: _,
951 draft,
952 execution_strategy: _,
953 num_jobs: _,
954 message: _,
955 } = submit_options;
956 if *draft {
957 args.push("--draft");
958 }
959
960 let stdout = try_exit_code!(self.run_gh(effects, &args)?);
961 let pull_request_url = match std::str::from_utf8(&stdout) {
962 Ok(url) => url,
963 Err(err) => {
964 writeln!(
965 effects.get_output_stream(),
966 "Could not parse output from `gh pr create` as UTF-8: {err}",
967 )?;
968 return Ok(Err(ExitCode(1)));
969 }
970 };
971 let pull_request_url = pull_request_url.trim();
972 Ok(Ok(pull_request_url.to_owned()))
973 }
974
975 fn update_pull_request(
976 &self,
977 effects: &Effects,
978 number: usize,
979 args: UpdatePullRequestArgs,
980 _submit_options: &super::SubmitOptions,
981 ) -> EyreExitOr<()> {
982 let UpdatePullRequestArgs {
983 head_ref_oid: _, base_ref_name,
985 title,
986 body,
987 } = args;
988 let body_file = self.write_body_file(&body)?;
989 try_exit_code!(self.run_gh(
990 effects,
991 &[
992 "pr",
993 "edit",
994 &number.to_string(),
995 "--base",
996 &base_ref_name,
997 "--title",
998 &title,
999 "--body-file",
1000 (body_file.path().to_str().unwrap()),
1001 ],
1002 )?);
1003 Ok(Ok(()))
1004 }
1005 }
1006
1007 #[derive(Debug, Default, Deserialize, Serialize)]
1010 pub struct MockState {
1011 pub pull_request_index: usize,
1013
1014 pub pull_requests: BTreeMap<String, PullRequestInfo>,
1017 }
1018
1019 impl MockState {
1020 fn load(path: &Path) -> eyre::Result<Self> {
1021 let file = match File::open(path) {
1022 Ok(file) => file,
1023 Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
1024 return Ok(Default::default());
1025 }
1026 Err(err) => return Err(err).wrap_err("Opening mock GitHub client state file"),
1027 };
1028 let state = serde_json::from_reader(file)?;
1029 Ok(state)
1030 }
1031
1032 fn restore_invariants(&mut self, remote_repo: &Repo) -> eyre::Result<()> {
1033 let effects = Effects::new_suppress_for_test(Glyphs::text());
1034 let conn = remote_repo.get_db_conn()?;
1035 let event_log_db = EventLogDb::new(&conn)?;
1036 let event_replayer =
1037 EventReplayer::from_event_log_db(&effects, remote_repo, &event_log_db)?;
1038 let event_cursor = event_replayer.make_default_cursor();
1039 let references_snapshot = remote_repo.get_references_snapshot()?;
1040 let dag = Dag::open_and_sync(
1041 &effects,
1042 remote_repo,
1043 &event_replayer,
1044 event_cursor,
1045 &references_snapshot,
1046 )?;
1047
1048 let branches: HashMap<String, NonZeroOid> = remote_repo
1049 .get_all_local_branches()?
1050 .into_iter()
1051 .map(|branch| -> eyre::Result<_> {
1052 let branch_name = branch.get_name()?.to_owned();
1053 let branch_oid = branch.get_oid()?.unwrap();
1054 Ok((branch_name, branch_oid))
1055 })
1056 .try_collect()?;
1057 for (_, pull_request_info) in self.pull_requests.iter_mut() {
1058 let base_ref_name = &pull_request_info.base_ref_name;
1059 let base_branch_oid = match branches.get(base_ref_name) {
1060 Some(oid) => *oid,
1061 None => {
1062 eyre::bail!("Could not find base branch {base_ref_name:?} for pull request: {pull_request_info:?}");
1063 }
1064 };
1065 let SerializedNonZeroOid(head_ref_oid) = pull_request_info.head_ref_oid;
1066 if dag.query_is_ancestor(head_ref_oid, base_branch_oid)? {
1067 pull_request_info.closed = true;
1068 }
1069 }
1070 Ok(())
1071 }
1072
1073 fn save(&self, path: &Path) -> eyre::Result<()> {
1074 let state = serde_json::to_string_pretty(self)?;
1075 fs::write(path, state)?;
1076 Ok(())
1077 }
1078 }
1079
1080 #[derive(Debug)]
1082 pub struct MockGithubClient {
1083 pub remote_repo_path: PathBuf,
1085 }
1086
1087 impl GithubClient for MockGithubClient {
1088 fn query_github_username(&self, _effects: &Effects) -> EyreExitOr<String> {
1089 Ok(Ok(Self::username().to_owned()))
1090 }
1091
1092 fn query_repo_pull_request_infos(
1093 &self,
1094 _effects: &Effects,
1095 ) -> EyreExitOr<HashMap<String, PullRequestInfo>> {
1096 let pull_requests_infos = self.with_state_mut(|state| {
1097 let pull_request_infos = state
1098 .pull_requests
1099 .values()
1100 .cloned()
1101 .map(|pull_request_info| {
1102 (pull_request_info.head_ref_name.clone(), pull_request_info)
1103 })
1104 .collect();
1105 Ok(pull_request_infos)
1106 })?;
1107 Ok(Ok(pull_requests_infos))
1108 }
1109
1110 fn create_pull_request(
1111 &self,
1112 _effects: &Effects,
1113 args: CreatePullRequestArgs,
1114 submit_options: &super::SubmitOptions,
1115 ) -> EyreExitOr<String> {
1116 let url = self.with_state_mut(|state| {
1117 state.pull_request_index += 1;
1118 let CreatePullRequestArgs {
1119 head_ref_oid,
1120 head_ref_name,
1121 title,
1122 body,
1123 } = args;
1124 let SubmitOptions {
1125 create,
1126 draft,
1127 execution_strategy: _,
1128 num_jobs: _,
1129 message: _,
1130 } = submit_options;
1131 assert!(create);
1132 let url = format!(
1133 "https://example.com/{}/{}/pulls/{}",
1134 Self::username(),
1135 Self::repo_name(),
1136 state.pull_request_index
1137 );
1138 let pull_request_info = PullRequestInfo {
1139 number: state.pull_request_index,
1140 url: url.clone(),
1141 head_ref_name: head_ref_name.clone(),
1142 head_ref_oid: SerializedNonZeroOid(head_ref_oid),
1143 base_ref_name: Self::main_branch().to_owned(),
1144 closed: false,
1145 is_draft: *draft,
1146 title,
1147 body,
1148 };
1149 state.pull_requests.insert(head_ref_name, pull_request_info);
1150 Ok(url)
1151 })?;
1152 Ok(Ok(url))
1153 }
1154
1155 fn update_pull_request(
1156 &self,
1157 _effects: &Effects,
1158 number: usize,
1159 args: UpdatePullRequestArgs,
1160 _submit_options: &super::SubmitOptions,
1161 ) -> EyreExitOr<()> {
1162 self.with_state_mut(|state| -> eyre::Result<()> {
1163 let UpdatePullRequestArgs {
1164 head_ref_oid,
1165 base_ref_name,
1166 title,
1167 body,
1168 } = args;
1169 let pull_request_info = match state
1170 .pull_requests
1171 .values_mut()
1172 .find(|pull_request_info| pull_request_info.number == number)
1173 {
1174 Some(pull_request_info) => pull_request_info,
1175 None => {
1176 eyre::bail!("Could not find pull request with number {number}");
1177 }
1178 };
1179 pull_request_info.head_ref_oid = SerializedNonZeroOid(head_ref_oid);
1180 pull_request_info.base_ref_name = base_ref_name;
1181 pull_request_info.title = title;
1182 pull_request_info.body = body;
1183 Ok(())
1184 })?;
1185 Ok(Ok(()))
1186 }
1187 }
1188
1189 impl MockGithubClient {
1190 fn username() -> &'static str {
1191 "mock-github-username"
1192 }
1193
1194 fn repo_name() -> &'static str {
1195 "mock-github-repo"
1196 }
1197
1198 fn main_branch() -> &'static str {
1199 "master"
1200 }
1201
1202 pub fn state_path(&self) -> PathBuf {
1204 self.remote_repo_path.join("mock-github-client-state.json")
1205 }
1206
1207 pub fn with_state_mut<T>(
1211 &self,
1212 f: impl FnOnce(&mut MockState) -> eyre::Result<T>,
1213 ) -> eyre::Result<T> {
1214 let repo = Repo::from_dir(&self.remote_repo_path)?;
1215 let state_path = self.state_path();
1216 let mut state = MockState::load(&state_path)?;
1217 state.restore_invariants(&repo)?;
1218 let result = f(&mut state)?;
1219 state.restore_invariants(&repo)?;
1220 state.save(&state_path)?;
1221 Ok(result)
1222 }
1223 }
1224}
1225
1226pub mod testing {
1228 pub use super::client::MockGithubClient;
1229}
1230
1231#[cfg(test)]
1232mod tests {
1233 use super::*;
1234
1235 #[test]
1236 fn test_commit_summary_slug() {
1237 assert_eq!(commit_summary_slug("hello: foo bar"), "hello-foo-bar");
1238 assert_eq!(
1239 commit_summary_slug("category(topic): `foo` bar!"),
1240 "category-topic-foo-bar"
1241 );
1242 assert_eq!(commit_summary_slug("foo_~_bar"), "foo-bar");
1243 assert_eq!(commit_summary_slug("!!!"), "to-review")
1244 }
1245}