gnostr_asyncgit/sync/remotes/
callbacks.rs

1use std::sync::{
2    atomic::{AtomicBool, Ordering},
3    Arc, Mutex,
4};
5
6use crossbeam_channel::Sender;
7use git2::{Cred, Error as GitError, RemoteCallbacks};
8
9use super::push::ProgressNotification;
10use crate::{error::Result, sync::cred::BasicAuthCredential};
11
12///
13#[derive(Default, Clone)]
14pub struct CallbackStats {
15    pub push_rejected_msg: Option<(String, String)>,
16}
17
18///
19#[derive(Clone)]
20pub struct Callbacks {
21    sender: Option<Sender<ProgressNotification>>,
22    basic_credential: Option<BasicAuthCredential>,
23    stats: Arc<Mutex<CallbackStats>>,
24    first_call_to_credentials: Arc<AtomicBool>,
25}
26
27impl Callbacks {
28    ///
29    pub fn new(
30        sender: Option<Sender<ProgressNotification>>,
31        basic_credential: Option<BasicAuthCredential>,
32    ) -> Self {
33        let stats = Arc::new(Mutex::new(CallbackStats::default()));
34
35        Self {
36            sender,
37            basic_credential,
38            stats,
39            first_call_to_credentials: Arc::new(AtomicBool::new(true)),
40        }
41    }
42
43    ///
44    pub fn get_stats(&self) -> Result<CallbackStats> {
45        let stats = self.stats.lock()?;
46        Ok(stats.clone())
47    }
48
49    ///
50    pub fn callbacks<'a>(&self) -> RemoteCallbacks<'a> {
51        let mut callbacks = RemoteCallbacks::new();
52
53        let this = self.clone();
54        callbacks.push_transfer_progress(move |current, total, bytes| {
55            this.push_transfer_progress(current, total, bytes);
56        });
57
58        let this = self.clone();
59        callbacks.update_tips(move |name, a, b| {
60            this.update_tips(name, a, b);
61            true
62        });
63
64        let this = self.clone();
65        callbacks.transfer_progress(move |p| {
66            this.transfer_progress(&p);
67            true
68        });
69
70        let this = self.clone();
71        callbacks.pack_progress(move |stage, current, total| {
72            this.pack_progress(stage, total, current);
73        });
74
75        let this = self.clone();
76        callbacks.push_update_reference(move |reference, msg| {
77            this.push_update_reference(reference, msg);
78            Ok(())
79        });
80
81        let this = self.clone();
82        callbacks.credentials(move |url, username_from_url, allowed_types| {
83            this.credentials(url, username_from_url, allowed_types)
84        });
85
86        callbacks.sideband_progress(move |data| {
87            log::debug!(
88                "sideband transfer: '{}'",
89                String::from_utf8_lossy(data).trim()
90            );
91            true
92        });
93
94        callbacks
95    }
96
97    fn push_update_reference(&self, reference: &str, msg: Option<&str>) {
98        log::debug!("push_update_reference: '{}' {:?}", reference, msg);
99
100        if let Ok(mut stats) = self.stats.lock() {
101            stats.push_rejected_msg = msg.map(|msg| (reference.to_string(), msg.to_string()));
102        }
103    }
104
105    fn pack_progress(&self, stage: git2::PackBuilderStage, total: usize, current: usize) {
106        log::debug!("packing: {:?} - {}/{}", stage, current, total);
107        self.sender.clone().map(|sender| {
108            sender.send(ProgressNotification::Packing {
109                stage,
110                total,
111                current,
112            })
113        });
114    }
115
116    fn transfer_progress(&self, p: &git2::Progress) {
117        log::debug!("transfer: {}/{}", p.received_objects(), p.total_objects());
118        self.sender.clone().map(|sender| {
119            sender.send(ProgressNotification::Transfer {
120                objects: p.received_objects(),
121                total_objects: p.total_objects(),
122            })
123        });
124    }
125
126    fn update_tips(&self, name: &str, a: git2::Oid, b: git2::Oid) {
127        log::debug!("update tips: '{}' [{}] [{}]", name, a, b);
128        self.sender.clone().map(|sender| {
129            sender.send(ProgressNotification::UpdateTips {
130                name: name.to_string(),
131                a: a.into(),
132                b: b.into(),
133            })
134        });
135    }
136
137    fn push_transfer_progress(&self, current: usize, total: usize, bytes: usize) {
138        log::debug!("progress: {}/{} ({} B)", current, total, bytes,);
139        self.sender.clone().map(|sender| {
140            sender.send(ProgressNotification::PushTransfer {
141                current,
142                total,
143                bytes,
144            })
145        });
146    }
147
148    // If credentials are bad, we don't ask the user to re-fill their
149    // creds. We push an error and they will be able to restart their
150    // action (for example a push) and retype their creds. This behavior is explained in a issue on git2-rs project : https://github.com/rust-lang/git2-rs/issues/347
151    // An implementation reference is done in cargo : https://github.com/rust-lang/cargo/blob/9fb208dddb12a3081230a5fd8f470e01df8faa25/src/cargo/sources/git/utils.rs#L588
152    // There is also a guide about libgit2 authentication : https://libgit2.org/docs/guides/authentication/
153    fn credentials(
154        &self,
155        url: &str,
156        username_from_url: Option<&str>,
157        allowed_types: git2::CredentialType,
158    ) -> std::result::Result<Cred, GitError> {
159        log::debug!(
160            "creds: '{}' {:?} ({:?})",
161            url,
162            username_from_url,
163            allowed_types
164        );
165
166        // This boolean is used to avoid multiple calls to credentials
167        // callback.
168        if self.first_call_to_credentials.load(Ordering::Relaxed) {
169            self.first_call_to_credentials
170                .store(false, Ordering::Relaxed);
171        } else {
172            return Err(GitError::from_str("Bad credentials."));
173        }
174
175        match &self.basic_credential {
176            _ if allowed_types.is_ssh_key() => username_from_url.map_or_else(
177                || Err(GitError::from_str(" Couldn't extract username from url.")),
178                Cred::ssh_key_from_agent,
179            ),
180            Some(BasicAuthCredential {
181                username: Some(user),
182                password: Some(pwd),
183            }) if allowed_types.is_user_pass_plaintext() => Cred::userpass_plaintext(user, pwd),
184            Some(BasicAuthCredential {
185                username: Some(user),
186                password: _,
187            }) if allowed_types.is_username() => Cred::username(user),
188            _ if allowed_types.is_default() => Cred::default(),
189            _ => Err(GitError::from_str("Couldn't find credentials")),
190        }
191    }
192}