pijul_remote/
ssh.rs

1use std::collections::HashSet;
2use std::convert::TryInto;
3use std::io::Write;
4use std::path::PathBuf;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::time::{Duration, SystemTime};
8
9use anyhow::bail;
10use byteorder::{BigEndian, ReadBytesExt};
11use lazy_static::lazy_static;
12use libpijul::pristine::Position;
13use libpijul::{Base32, Hash, Merkle};
14use log::{debug, error, info, trace, warn};
15use regex::Regex;
16use thrussh::client::Session;
17use tokio::sync::Mutex;
18
19use super::parse_line;
20use crate::CS;
21use pijul_interaction::ProgressBar;
22
23pub struct Ssh {
24    pub h: thrussh::client::Handle<SshClient>,
25    pub c: thrussh::client::Channel,
26    pub channel: String,
27    pub remote_cmd: String,
28    pub path: String,
29    pub is_running: bool,
30    pub name: String,
31    state: Arc<Mutex<State>>,
32    has_errors: Arc<Mutex<bool>>,
33}
34
35lazy_static! {
36    static ref ADDRESS: Regex = Regex::new(
37        r#"(ssh://)?((?P<user>[^@]+)@)?((?P<host>(\[([^\]]+)\])|([^:/]+)))((:(?P<port>\d+)(?P<path0>(/.+)))|(:(?P<path1>.+))|(?P<path2>(/.+)))"#
38    )
39        .unwrap();
40
41    static ref ADDRESS_NOPATH: Regex = Regex::new(
42        r#"(ssh://)?((?P<user>[^@]+)@)?((?P<host>(\[([^\]]+)\])|([^:/]+)))(:(?P<port>\d+))?"#
43    )
44        .unwrap();
45}
46
47#[derive(Debug)]
48pub struct Remote<'a> {
49    path: &'a str,
50    config: thrussh_config::Config,
51}
52
53pub fn ssh_remote<'a>(user: Option<&str>, addr: &'a str, with_path: bool) -> Option<Remote<'a>> {
54    let cap = if with_path {
55        ADDRESS.captures(addr)?
56    } else {
57        ADDRESS_NOPATH.captures(addr)?
58    };
59    debug!("ssh_remote: {:?}", cap);
60    let host = cap.name("host").unwrap().as_str();
61
62    let mut config =
63        thrussh_config::parse_home(&host).unwrap_or(thrussh_config::Config::default(host));
64    if let Some(port) = cap.name("port").map(|x| x.as_str().parse().unwrap()) {
65        config.port = port
66    }
67    if let Some(u) = cap.name("user") {
68        config.user.clear();
69        config.user.push_str(u.as_str());
70    }
71    if let Some(user) = user {
72        if !user.is_empty() {
73            config.user.clear();
74            config.user.push_str(user)
75        }
76    }
77    let path = if with_path {
78        let p = cap
79            .name("path0")
80            .unwrap_or_else(|| {
81                cap.name("path1")
82                    .unwrap_or_else(|| cap.name("path2").unwrap())
83            })
84            .as_str();
85        if p.starts_with("/~") {
86            p.split_at(1).1
87        } else {
88            p
89        }
90    } else {
91        ""
92    };
93    Some(Remote { path, config })
94}
95
96impl<'a> Remote<'a> {
97    pub async fn connect(
98        &mut self,
99        name: &str,
100        channel: &str,
101    ) -> Result<Option<Ssh>, anyhow::Error> {
102        let mut home = dirs_next::home_dir().unwrap();
103        home.push(".ssh");
104        home.push("known_hosts");
105        let state = Arc::new(Mutex::new(State::None));
106        let has_errors = Arc::new(Mutex::new(false));
107        let client = SshClient {
108            addr: self.config.host_name.clone(),
109            port: self.config.port,
110            known_hosts: home,
111            last_window_adjustment: SystemTime::now(),
112            state: state.clone(),
113            has_errors: has_errors.clone(),
114        };
115        let stream = match self.config.stream().await {
116            Ok(stream) => stream,
117            Err(e) => {
118                info!("remote connect error: {:?}", e);
119                return Ok(None);
120            }
121        };
122        let config = Arc::new(thrussh::client::Config::default());
123        let mut h = thrussh::client::connect_stream(config, stream, client).await?;
124
125        let mut key_path = dirs_next::home_dir().unwrap().join(".ssh");
126
127        // First try agent auth
128        let authenticated = match self.auth_agent(&mut h, &mut key_path).await {
129            Ok(true) => true,
130            Ok(false) => {
131                if self.auth_pk(&mut h, &mut key_path).await? {
132                    true
133                } else {
134                    let mut stderr = std::io::stderr();
135                    writeln!(stderr, "Warning: Unable to automatically authenticate with server. Please make sure your SSH keys have been uploaded to the Nest.")?;
136                    writeln!(stderr, "For more information, please visit https://pijul.org/manual/the_nest/public_keys.html#ssh-public-keys")?;
137                    self.auth_password(&mut h).await?
138                }
139            }
140            Err(e) => return Err(e.into()),
141        };
142
143        if !authenticated {
144            bail!("Not authenticated. Please check your credentials and try again.");
145        }
146
147        let c = h.channel_open_session().await?;
148        let remote_cmd = if let Ok(cmd) = std::env::var("REMOTE_PIJUL") {
149            cmd
150        } else {
151            "pijul".to_string()
152        };
153        Ok(Some(Ssh {
154            h,
155            c,
156            channel: channel.to_string(),
157            remote_cmd,
158            path: self.path.to_string(),
159            is_running: false,
160            name: name.to_string(),
161            state,
162            has_errors,
163        }))
164    }
165
166    async fn auth_agent(
167        &self,
168        h: &mut thrussh::client::Handle<SshClient>,
169        key_path: &mut PathBuf,
170    ) -> Result<bool, thrussh::Error> {
171        let mut authenticated = false;
172        let mut agent = match thrussh_keys::agent::client::AgentClient::connect_env().await {
173            Ok(agent) => agent,
174            Err(thrussh_keys::Error::EnvVar(_)) => return Ok(false),
175            Err(thrussh_keys::Error::AgentFailure) => return Ok(false),
176            Err(e) => return Err(e.into()),
177        };
178        let identities = if let Some(ref file) = self.config.identity_file {
179            key_path.push(file);
180            key_path.set_extension("pub");
181            let k = thrussh_keys::load_public_key(&key_path);
182            key_path.pop();
183            if let Ok(k) = k {
184                vec![k]
185            } else {
186                return Ok(false);
187            }
188        } else {
189            agent.request_identities().await?
190        };
191        debug!("identities = {:?}", identities);
192        let mut agent = Some(agent);
193        for key in identities {
194            debug!("Trying key {:?}", key);
195            debug!("fingerprint = {:?}", key.fingerprint());
196            if let Some(a) = agent.take() {
197                debug!("authenticate future");
198                match h.authenticate_future(&self.config.user, key, a).await {
199                    (a, Ok(auth)) => {
200                        authenticated = auth;
201                        agent = Some(a);
202                    }
203                    (_, Err(thrussh::AgentAuthError::Send(e))) => {
204                        debug!("send error {:?}", e);
205                        return Err(thrussh::Error::SendError);
206                    }
207                    (a, Err(e)) => {
208                        agent = Some(a);
209                        debug!("not auth {:?}", e);
210                        if let thrussh::AgentAuthError::Key(e) = e {
211                            debug!("error: {:?}", e);
212                            writeln!(std::io::stderr(), "Failed to sign with agent")?;
213                        }
214                    }
215                }
216            }
217            if authenticated {
218                return Ok(true);
219            }
220        }
221        Ok(false)
222    }
223
224    async fn auth_pk(
225        &self,
226        h: &mut thrussh::client::Handle<SshClient>,
227        key_path: &mut PathBuf,
228    ) -> Result<bool, anyhow::Error> {
229        if h.is_closed() {
230            return Ok(false);
231        }
232        let mut authenticated = false;
233        let mut keys = Vec::new();
234        if let Some(ref file) = self.config.identity_file {
235            keys.push(file.as_str())
236        } else {
237            keys.push("id_ed25519");
238            keys.push("id_rsa");
239        }
240        for k in keys.iter() {
241            key_path.push(k);
242            let k = match thrussh_keys::load_secret_key(&key_path, None) {
243                Ok(k) => k,
244                Err(thrussh_keys::Error::KeyIsEncrypted) => {
245                    let password = pijul_interaction::Password::new()?
246                        .with_prompt(format!("Password for encrypted private key"))
247                        .with_allow_empty(false)
248                        .interact()?;
249                    if let Ok(k) = thrussh_keys::load_secret_key(&key_path, Some(&password)) {
250                        k
251                    } else {
252                        continue;
253                    }
254                }
255                Err(_) => {
256                    key_path.pop();
257                    continue;
258                }
259            };
260            if let Ok(auth) = h
261                .authenticate_publickey(&self.config.user, Arc::new(k))
262                .await
263            {
264                authenticated = auth
265            }
266            key_path.pop();
267            if authenticated {
268                return Ok(true);
269            }
270        }
271
272        Ok(false)
273    }
274
275    async fn auth_password(
276        &self,
277        h: &mut thrussh::client::Handle<SshClient>,
278    ) -> Result<bool, anyhow::Error> {
279        if h.is_closed() {
280            return Ok(false);
281        }
282
283        // Authentication can be attempted multiple times
284        let mut authenticated = false;
285        let username = format!("{}@{}", self.config.user, self.config.host_name);
286
287        // Try authenticate using the user's keyring
288        if let Ok(password) = keyring::Entry::new("pijul", &username).and_then(|x| x.get_password())
289        {
290            authenticated = h
291                .authenticate_password(self.config.user.to_string(), &password)
292                .await?;
293        }
294
295        // Try authenticate using user's password
296        if !authenticated {
297            let password = pijul_interaction::Password::new()?
298                .with_prompt(format!("Password for {username}"))
299                .with_allow_empty(true)
300                .interact()?;
301
302            authenticated = h
303                .authenticate_password(self.config.user.to_string(), &password)
304                .await?;
305
306            // If the new password is valid, update the keyring to match
307            if authenticated {
308                if let Err(e) =
309                    keyring::Entry::new("pijul", &username).and_then(|x| x.set_password(&password))
310                {
311                    warn!("Unable to set password: {e:?}");
312                }
313            }
314        }
315
316        Ok(authenticated)
317    }
318}
319
320pub struct SshClient {
321    addr: String,
322    port: u16,
323    known_hosts: PathBuf,
324    last_window_adjustment: SystemTime,
325    state: Arc<Mutex<State>>,
326    has_errors: Arc<Mutex<bool>>,
327}
328
329enum State {
330    None,
331    State {
332        sender: Option<tokio::sync::oneshot::Sender<Option<(u64, Merkle, Merkle)>>>,
333    },
334    Id {
335        sender: Option<tokio::sync::oneshot::Sender<Option<libpijul::pristine::RemoteId>>>,
336    },
337    Changes {
338        sender: Option<tokio::sync::mpsc::Sender<CS>>,
339        remaining_len: usize,
340        file: std::fs::File,
341        path: PathBuf,
342        final_path: PathBuf,
343        hashes: Vec<CS>,
344        current: usize,
345    },
346    Changelist {
347        sender: tokio::sync::mpsc::Sender<Option<super::ListLine>>,
348        pending: Vec<u8>,
349    },
350    Archive {
351        sender: Option<tokio::sync::oneshot::Sender<u64>>,
352        len: u64,
353        conflicts: u64,
354        len_n: u64,
355        w: Box<dyn Write + Send>,
356    },
357    Prove {
358        key: libpijul::key::SKey,
359        sender: Option<tokio::sync::oneshot::Sender<()>>,
360        signed: bool,
361    },
362    Identities {
363        sender: Option<tokio::sync::mpsc::Sender<pijul_identity::Complete>>,
364        buf: Vec<u8>,
365    },
366}
367
368type BoxFuture<T> = Pin<Box<dyn futures::future::Future<Output = T> + Send>>;
369
370impl thrussh::client::Handler for SshClient {
371    type Error = anyhow::Error;
372    type FutureBool = futures::future::Ready<Result<(Self, bool), anyhow::Error>>;
373    type FutureUnit = BoxFuture<Result<(Self, Session), anyhow::Error>>;
374
375    fn finished_bool(self, b: bool) -> Self::FutureBool {
376        futures::future::ready(Ok((self, b)))
377    }
378    fn finished(self, session: Session) -> Self::FutureUnit {
379        Box::pin(async move { Ok((self, session)) })
380    }
381    fn check_server_key(
382        self,
383        server_public_key: &thrussh_keys::key::PublicKey,
384    ) -> Self::FutureBool {
385        debug!("addr = {:?} port = {:?}", self.addr, self.port);
386        match thrussh_keys::check_known_hosts_path(
387            &self.addr,
388            self.port,
389            server_public_key,
390            &self.known_hosts,
391        ) {
392            Ok(e) => {
393                if e {
394                    futures::future::ready(Ok((self, true)))
395                } else {
396                    match learn(&self.addr, self.port, server_public_key) {
397                        Ok(x) => futures::future::ready(Ok((self, x))),
398                        Err(e) => futures::future::ready(Err(e)),
399                    }
400                }
401            }
402            Err(e) => {
403                writeln!(std::io::stderr(), "Key changed for {:?}", self.addr).unwrap_or(());
404
405                futures::future::ready(Err(e.into()))
406            }
407        }
408    }
409
410    fn adjust_window(&mut self, _channel: thrussh::ChannelId, target: u32) -> u32 {
411        let elapsed = self.last_window_adjustment.elapsed().unwrap();
412        self.last_window_adjustment = SystemTime::now();
413        if target >= 10_000_000 {
414            return target;
415        }
416        if elapsed < Duration::from_secs(2) {
417            target * 2
418        } else if elapsed > Duration::from_secs(8) {
419            target / 2
420        } else {
421            target
422        }
423    }
424
425    fn channel_eof(
426        self,
427        _channel: thrussh::ChannelId,
428        session: thrussh::client::Session,
429    ) -> Self::FutureUnit {
430        Box::pin(async move {
431            *self.state.lock().await = State::None;
432            Ok((self, session))
433        })
434    }
435
436    fn exit_status(
437        self,
438        channel: thrussh::ChannelId,
439        exit_status: u32,
440        session: thrussh::client::Session,
441    ) -> Self::FutureUnit {
442        session.send_channel_msg(channel, thrussh::ChannelMsg::ExitStatus { exit_status });
443        Box::pin(async move {
444            *self.state.lock().await = State::None;
445            *self.has_errors.lock().await = true;
446            Ok((self, session))
447        })
448    }
449
450    fn extended_data(
451        self,
452        channel: thrussh::ChannelId,
453        ext: u32,
454        data: &[u8],
455        session: thrussh::client::Session,
456    ) -> Self::FutureUnit {
457        debug!("extended data {:?}, {:?}", std::str::from_utf8(data), ext);
458        if ext == 0 {
459            self.data(channel, data, session)
460        } else {
461            let data = data.to_vec();
462            Box::pin(async move {
463                *self.has_errors.lock().await = true;
464                let stderr = std::io::stderr();
465                let mut handle = stderr.lock();
466                handle.write_all(&data)?;
467                Ok((self, session))
468            })
469        }
470    }
471
472    fn data(
473        self,
474        channel: thrussh::ChannelId,
475        data: &[u8],
476        mut session: thrussh::client::Session,
477    ) -> Self::FutureUnit {
478        trace!("data {:?} {:?}", channel, data.len());
479        let data = data.to_vec();
480        Box::pin(async move {
481            match *self.state.lock().await {
482                State::State { ref mut sender } => {
483                    debug!("state: State");
484                    if let Some(sender) = sender.take() {
485                        // If we can't parse `data` (for example if the
486                        // remote returns the standard "-\n"), this
487                        // returns None.
488                        let mut s = std::str::from_utf8(&data).unwrap().split(' ');
489                        if let (Some(n), Some(m), Some(m2)) = (s.next(), s.next(), s.next()) {
490                            let n = n.parse().unwrap();
491                            sender
492                                .send(Some((
493                                    n,
494                                    Merkle::from_base32(m.trim().as_bytes()).unwrap(),
495                                    Merkle::from_base32(m2.trim().as_bytes()).unwrap(),
496                                )))
497                                .unwrap_or(());
498                        } else {
499                            sender.send(None).unwrap_or(());
500                        }
501                    }
502                }
503                State::Id { ref mut sender } => {
504                    debug!("state: Id {:?}", std::str::from_utf8(&data));
505                    if let Some(sender) = sender.take() {
506                        let line = if data.len() >= 16 && data.last() == Some(&10) {
507                            libpijul::pristine::RemoteId::from_base32(&data[..data.len() - 1])
508                        } else {
509                            None
510                        };
511                        if let Some(b) = line {
512                            sender.send(Some(b)).unwrap_or(());
513                        } else {
514                            sender.send(None).unwrap_or(());
515                        }
516                    }
517                }
518                State::Changes {
519                    ref mut sender,
520                    ref mut remaining_len,
521                    ref mut file,
522                    ref mut path,
523                    ref mut final_path,
524                    ref hashes,
525                    ref mut current,
526                } => {
527                    trace!("state changes");
528                    let mut p = 0;
529                    while p < data.len() {
530                        if *remaining_len == 0 {
531                            *remaining_len = (&data[p..]).read_u64::<BigEndian>().unwrap() as usize;
532                            p += 8;
533                            debug!("remaining_len = {:?}", remaining_len);
534                        }
535                        if data.len() >= p + *remaining_len {
536                            debug!("writing {:?} bytes", *remaining_len);
537                            file.write_all(&data[p..p + *remaining_len])?;
538                            // We have enough data to write the
539                            // file, write it and move to the next
540                            // file.
541                            p += *remaining_len;
542                            *remaining_len = 0;
543                            file.flush()?;
544
545                            match hashes[*current] {
546                                CS::Change(ref h) => {
547                                    libpijul::changestore::filesystem::push_filename(final_path, h);
548                                    debug!("moving {:?} to {:?}", path, final_path);
549                                    std::fs::create_dir_all(&final_path.parent().unwrap())?;
550                                    let r = std::fs::rename(&path, &final_path);
551                                    libpijul::changestore::filesystem::pop_filename(final_path);
552                                    r?;
553                                }
554                                CS::State(h) => {
555                                    libpijul::changestore::filesystem::push_tag_filename(
556                                        final_path, &h,
557                                    );
558                                    debug!("moving {:?} to {:?}", path, final_path);
559                                    std::fs::create_dir_all(&final_path.parent().unwrap())?;
560                                    let r = std::fs::rename(&path, &final_path);
561                                    libpijul::changestore::filesystem::pop_filename(final_path);
562                                    r?;
563                                }
564                            }
565                            debug!("sending {:?}", hashes[*current]);
566                            if let Some(ref mut sender) = sender {
567                                if sender.send(hashes[*current]).await.is_err() {
568                                    break;
569                                }
570                            }
571                            debug!("sent");
572                            *current += 1;
573                            if *current < hashes.len() {
574                                // If we're still waiting for another
575                                // change.
576                                *file = std::fs::File::create(&path)?;
577                            } else {
578                                // Else, just finish.
579                                debug!("dropping channel");
580                                std::mem::drop(sender.take());
581                                break;
582                            }
583                        } else {
584                            // not enough data, we need more.
585                            trace!(
586                                "writing to {:?} {:?} {:?}",
587                                path,
588                                final_path,
589                                hashes[*current]
590                            );
591
592                            file.write_all(&data[p..])?;
593                            file.flush()?;
594                            *remaining_len -= data.len() - p;
595                            trace!("need more data");
596                            break;
597                        }
598                    }
599                    trace!("finished, {:?} {:?}", p, data.len());
600                }
601                State::Changelist {
602                    ref mut sender,
603                    ref mut pending,
604                } => {
605                    debug!("state changelist");
606                    if &data[..] == b"\n" {
607                        debug!("log done");
608                        sender.send(None).await.unwrap_or(())
609                    } else {
610                        trace!("{:?}", data);
611                        let mut p = 0;
612                        while let Some(i) = (&data[p..]).iter().position(|i| *i == b'\n') {
613                            let line = if !pending.is_empty() {
614                                pending.extend(&data[p..p + i]);
615                                &pending
616                            } else {
617                                &data[p..p + i]
618                            };
619                            let l = std::str::from_utf8(line)?;
620                            if !l.is_empty() {
621                                debug!("line = {:?}", l);
622                                sender.send(parse_line(l).ok()).await.unwrap_or(())
623                            } else {
624                                sender.send(None).await.unwrap_or(());
625                            }
626                            pending.clear();
627                            p += i + 1;
628                        }
629                        pending.extend(&data[p..]);
630                    }
631                }
632                State::Archive {
633                    ref mut sender,
634                    ref mut w,
635                    ref mut len,
636                    ref mut len_n,
637                    ref mut conflicts,
638                } => {
639                    debug!("state archive");
640                    let mut off = 0;
641                    while *len_n < 16 && off < data.len() {
642                        if *len_n < 8 {
643                            *len = (*len << 8) | (data[off] as u64);
644                        } else {
645                            *conflicts = (*conflicts << 8) | (data[off] as u64);
646                        }
647                        *len_n += 1;
648                        off += 1;
649                    }
650                    if *len_n >= 16 {
651                        w.write_all(&data[off..])?;
652                        *len -= (data.len() - off) as u64;
653                        if *len == 0 {
654                            if let Some(sender) = sender.take() {
655                                sender.send(*conflicts).unwrap_or(())
656                            }
657                        }
658                    }
659                }
660                State::Prove {
661                    ref mut key,
662                    ref mut sender,
663                    ref mut signed,
664                } => {
665                    if let Ok(data) = std::str::from_utf8(&data) {
666                        if *signed && !data.trim().is_empty() {
667                            std::io::stderr().write_all(data.as_bytes())?;
668                        } else {
669                            let data = data.trim();
670                            debug!("signing {:?}", data);
671                            let s = key.sign_raw(data.as_bytes())?;
672                            session.data(
673                                channel,
674                                thrussh::CryptoVec::from_slice(format!("prove {}\n", s).as_bytes()),
675                            );
676                            if let Some(sender) = sender.take() {
677                                sender.send(()).unwrap_or(());
678                            }
679                            *signed = true;
680                        }
681                    }
682                }
683                State::Identities {
684                    ref mut sender,
685                    ref mut buf,
686                } => {
687                    debug!("data = {:?}", data);
688                    if data.ends_with(&[10]) {
689                        let buf_ = if buf.is_empty() {
690                            &data
691                        } else {
692                            buf.extend(&data);
693                            &buf
694                        };
695                        for data in buf_.split(|c| *c == 10) {
696                            if let Ok(p) = serde_json::from_slice(data) {
697                                debug!("p = {:?}", p);
698                                if let Some(ref mut sender) = sender {
699                                    sender.send(p).await?;
700                                }
701                            } else {
702                                debug!("could not parse {:?}", std::str::from_utf8(&data));
703                                *sender = None;
704                                break;
705                            }
706                        }
707                        buf.clear()
708                    } else {
709                        buf.extend(&data);
710                    }
711                }
712                State::None => {
713                    debug!("None state");
714                }
715            }
716            Ok((self, session))
717        })
718    }
719}
720
721fn learn(addr: &str, port: u16, pk: &thrussh_keys::key::PublicKey) -> Result<bool, anyhow::Error> {
722    if port == 22 {
723        print!(
724            "Unknown key for {:?}, fingerprint {:?}. Learn it (y/N)? ",
725            addr,
726            pk.fingerprint()
727        );
728    } else {
729        print!(
730            "Unknown key for {:?}:{}, fingerprint {:?}. Learn it (y/N)? ",
731            addr,
732            port,
733            pk.fingerprint()
734        );
735    }
736    std::io::stdout().flush()?;
737    let mut buffer = String::new();
738    std::io::stdin().read_line(&mut buffer)?;
739    let buffer = buffer.trim();
740    if buffer == "Y" || buffer == "y" {
741        thrussh_keys::learn_known_hosts(addr, port, pk)?;
742        Ok(true)
743    } else {
744        Ok(false)
745    }
746}
747
748impl Ssh {
749    pub async fn finish(&mut self) -> Result<(), anyhow::Error> {
750        self.c.eof().await?;
751        while let Some(msg) = self.c.wait().await {
752            debug!("msg = {:?}", msg);
753            match msg {
754                thrussh::ChannelMsg::WindowAdjusted { .. } => {}
755                thrussh::ChannelMsg::Eof => {}
756                thrussh::ChannelMsg::ExitStatus { exit_status } => {
757                    if exit_status != 0 {
758                        bail!("Remote exited with status {:?}", exit_status)
759                    }
760                }
761                msg => error!("wrong message {:?}", msg),
762            }
763        }
764        Ok(())
765    }
766
767    pub async fn get_state(
768        &mut self,
769        mid: Option<u64>,
770    ) -> Result<Option<(u64, Merkle, Merkle)>, anyhow::Error> {
771        debug!("get_state");
772        let (sender, receiver) = tokio::sync::oneshot::channel();
773        *self.state.lock().await = State::State {
774            sender: Some(sender),
775        };
776        self.run_protocol().await?;
777        if let Some(mid) = mid {
778            self.c
779                .data(format!("state {} {}\n", self.channel, mid).as_bytes())
780                .await?;
781        } else {
782            self.c
783                .data(format!("state {}\n", self.channel).as_bytes())
784                .await?;
785        }
786        Ok(receiver.await?)
787    }
788
789    pub async fn get_id(&mut self) -> Result<Option<libpijul::pristine::RemoteId>, anyhow::Error> {
790        let (sender, receiver) = tokio::sync::oneshot::channel();
791        *self.state.lock().await = State::Id {
792            sender: Some(sender),
793        };
794        self.run_protocol().await?;
795        self.c
796            .data(format!("id {}\n", self.channel).as_bytes())
797            .await?;
798        Ok(receiver.await?)
799    }
800
801    pub async fn prove(&mut self, key: libpijul::key::SKey) -> Result<(), anyhow::Error> {
802        debug!("get_state");
803        let (sender, receiver) = tokio::sync::oneshot::channel();
804        let k = serde_json::to_string(&key.public_key())?;
805        *self.state.lock().await = State::Prove {
806            key,
807            sender: Some(sender),
808            signed: false,
809        };
810        self.run_protocol().await?;
811        self.c.data(format!("challenge {}\n", k).as_bytes()).await?;
812        Ok(receiver.await?)
813    }
814
815    pub async fn archive<W: std::io::Write + Send + 'static>(
816        &mut self,
817        prefix: Option<String>,
818        state: Option<(Merkle, &[Hash])>,
819        w: W,
820    ) -> Result<u64, anyhow::Error> {
821        debug!("archive");
822        let (sender, receiver) = tokio::sync::oneshot::channel();
823        *self.state.lock().await = State::Archive {
824            sender: Some(sender),
825            len: 0,
826            conflicts: 0,
827            len_n: 0,
828            w: Box::new(w),
829        };
830        self.run_protocol().await?;
831        if let Some((ref state, ref extra)) = state {
832            let mut cmd = format!("archive {} {}", self.channel, state.to_base32(),);
833            for e in extra.iter() {
834                cmd.push_str(&format!(" {}", e.to_base32()));
835            }
836            if let Some(ref p) = prefix {
837                cmd.push_str(" :");
838                cmd.push_str(p)
839            }
840            cmd.push('\n');
841            self.c.data(cmd.as_bytes()).await?;
842        } else {
843            self.c
844                .data(
845                    format!(
846                        "archive {}{}{}\n",
847                        self.channel,
848                        if prefix.is_some() { " :" } else { "" },
849                        prefix.unwrap_or_else(String::new)
850                    )
851                    .as_bytes(),
852                )
853                .await?;
854        }
855        let conflicts = receiver.await.unwrap_or(0);
856        Ok(conflicts)
857    }
858
859    pub async fn run_protocol(&mut self) -> Result<(), anyhow::Error> {
860        if !self.is_running {
861            self.is_running = true;
862            debug!("run_protocol");
863            self.c
864                .exec(
865                    true,
866                    format!(
867                        "{} protocol --version {} --repository {}",
868                        self.remote_cmd,
869                        crate::PROTOCOL_VERSION,
870                        self.path
871                    ),
872                )
873                .await?;
874            debug!("waiting for a message");
875            while let Some(msg) = self.c.wait().await {
876                debug!("msg = {:?}", msg);
877                match msg {
878                    thrussh::ChannelMsg::Success => break,
879                    thrussh::ChannelMsg::WindowAdjusted { .. } => {}
880                    thrussh::ChannelMsg::Eof => {}
881                    thrussh::ChannelMsg::ExitStatus { exit_status } => {
882                        if exit_status != 0 {
883                            bail!("Remote exited with status {:?}", exit_status)
884                        }
885                    }
886                    _ => {}
887                }
888            }
889            debug!("run_protocol done");
890        }
891        Ok(())
892    }
893
894    pub async fn download_changelist<
895        A,
896        F: FnMut(&mut A, u64, Hash, libpijul::Merkle, bool) -> Result<(), anyhow::Error>,
897    >(
898        &mut self,
899        mut f: F,
900        a: &mut A,
901        from: u64,
902        paths: &[String],
903    ) -> Result<HashSet<Position<Hash>>, anyhow::Error> {
904        let (sender, mut receiver) = tokio::sync::mpsc::channel(10);
905        *self.state.lock().await = State::Changelist {
906            sender,
907            pending: Vec::new(),
908        };
909        self.run_protocol().await?;
910        debug!("download_changelist");
911        let mut command = Vec::new();
912        write!(command, "changelist {} {}", self.channel, from).unwrap();
913        for p in paths {
914            write!(command, " {:?}", p).unwrap()
915        }
916        command.push(b'\n');
917        self.c.data(&command[..]).await?;
918        debug!("waiting ssh, command: {:?}", std::str::from_utf8(&command));
919        let mut result = HashSet::new();
920        while let Some(Some(m)) = receiver.recv().await {
921            match m {
922                super::ListLine::Change { n, h, m, tag } => f(a, n, h, m, tag)?,
923                super::ListLine::Position(pos) => {
924                    result.insert(pos);
925                }
926                super::ListLine::Error(err) => {
927                    bail!(err)
928                }
929            }
930        }
931        if *self.has_errors.lock().await {
932            bail!("Remote sent an error")
933        }
934        debug!("no msg, result = {:?}", result);
935        Ok(result)
936    }
937
938    pub async fn upload_changes(
939        &mut self,
940        progress_bar: ProgressBar,
941        mut local: PathBuf,
942        to_channel: Option<&str>,
943        changes: &[CS],
944    ) -> Result<(), anyhow::Error> {
945        self.run_protocol().await?;
946        debug!("upload_changes");
947        for c in changes {
948            debug!("{:?}", c);
949            let to_channel = if let Some(t) = to_channel {
950                t
951            } else {
952                self.channel.as_str()
953            };
954            match c {
955                CS::Change(c) => {
956                    libpijul::changestore::filesystem::push_filename(&mut local, &c);
957                    let mut change_file = std::fs::File::open(&local)?;
958                    let change_len = change_file.metadata()?.len();
959                    let mut change = thrussh::CryptoVec::new_zeroed(change_len as usize);
960                    use std::io::Read;
961                    change_file.read_exact(&mut change[..])?;
962                    self.c
963                        .data(
964                            format!("apply {} {} {}\n", to_channel, c.to_base32(), change_len)
965                                .as_bytes(),
966                        )
967                        .await?;
968                    self.c.data(&change[..]).await?;
969                    libpijul::changestore::filesystem::pop_filename(&mut local);
970                }
971                CS::State(c) => {
972                    libpijul::changestore::filesystem::push_tag_filename(&mut local, &c);
973                    let mut tag_file = libpijul::tag::OpenTagFile::open(&local, &c)?;
974                    let mut v = Vec::new();
975                    tag_file.short(&mut v)?;
976                    self.c
977                        .data(
978                            format!("tagup {} {} {}\n", c.to_base32(), to_channel, v.len())
979                                .as_bytes(),
980                        )
981                        .await?;
982                    self.c.data(&v[..]).await?;
983                    libpijul::changestore::filesystem::pop_filename(&mut local);
984                }
985            }
986            progress_bar.inc(1);
987        }
988        Ok(())
989    }
990
991    pub async fn download_changes(
992        &mut self,
993        progress_bar: ProgressBar,
994        c: &mut tokio::sync::mpsc::UnboundedReceiver<CS>,
995        sender: &mut tokio::sync::mpsc::Sender<(CS, bool)>,
996        changes_dir: &mut PathBuf,
997        full: bool,
998    ) -> Result<(), anyhow::Error> {
999        self.download_changes_(progress_bar, c, Some(sender), changes_dir, full)
1000            .await
1001    }
1002
1003    async fn download_changes_(
1004        &mut self,
1005        progress_bar: ProgressBar,
1006        c: &mut tokio::sync::mpsc::UnboundedReceiver<CS>,
1007        sender: Option<&mut tokio::sync::mpsc::Sender<(CS, bool)>>,
1008        changes_dir: &mut PathBuf,
1009        full: bool,
1010    ) -> Result<(), anyhow::Error> {
1011        let (sender_, mut recv) = tokio::sync::mpsc::channel(100);
1012        let path = changes_dir.join("tmp");
1013        std::fs::create_dir_all(&changes_dir)?;
1014        let file = std::fs::File::create(&path)?;
1015        *self.state.lock().await = State::Changes {
1016            sender: Some(sender_),
1017            remaining_len: 0,
1018            path,
1019            final_path: changes_dir.clone(),
1020            file,
1021            hashes: Vec::new(),
1022            current: 0,
1023        };
1024        self.run_protocol().await?;
1025        let mut sender = sender.map(|x| x.clone());
1026        let t = tokio::spawn(async move {
1027            while let Some(hash) = recv.recv().await {
1028                debug!("received hash {:?}", hash);
1029                progress_bar.inc(1);
1030                debug!("received");
1031                if let Some(ref mut sender) = sender {
1032                    sender.send((hash, true)).await.unwrap_or(());
1033                }
1034            }
1035        });
1036        let mut received = false;
1037        while let Some(h) = c.recv().await {
1038            received = true;
1039            if let State::Changes { ref mut hashes, .. } = *self.state.lock().await {
1040                hashes.push(h);
1041            }
1042            debug!("download_change {:?} {:?}", h, full);
1043            match h {
1044                CS::Change(h) if full => {
1045                    self.c
1046                        .data(format!("change {}\n", h.to_base32()).as_bytes())
1047                        .await?;
1048                }
1049                CS::Change(h) => {
1050                    self.c
1051                        .data(format!("partial {}\n", h.to_base32()).as_bytes())
1052                        .await?;
1053                }
1054                CS::State(h) => {
1055                    self.c
1056                        .data(format!("tag {}\n", h.to_base32()).as_bytes())
1057                        .await?;
1058                }
1059            }
1060        }
1061        if !received {
1062            *self.state.lock().await = State::None;
1063        };
1064        t.await?;
1065        debug!("done downloading {:?}", changes_dir);
1066        Ok(())
1067    }
1068
1069    pub async fn update_identities(
1070        &mut self,
1071        rev: Option<u64>,
1072        mut path: PathBuf,
1073    ) -> Result<u64, anyhow::Error> {
1074        let (sender_, mut recv) = tokio::sync::mpsc::channel(100);
1075        *self.state.lock().await = State::Identities {
1076            sender: Some(sender_),
1077            buf: Vec::new(),
1078        };
1079        self.run_protocol().await?;
1080        if let Some(rev) = rev {
1081            self.c
1082                .data(format!("identities {}\n", rev).as_bytes())
1083                .await?;
1084        } else {
1085            self.c.data("identities\n".as_bytes()).await?;
1086        }
1087        let mut revision = 0;
1088        std::fs::create_dir_all(&path)?;
1089        while let Some(id) = recv.recv().await {
1090            path.push(&id.public_key.key);
1091            debug!("recv identity: {:?} {:?}", id, path);
1092            let mut id_file = std::fs::File::create(&path)?;
1093            serde_json::to_writer_pretty(&mut id_file, &id)?;
1094            path.pop();
1095            revision = revision.max(id.last_modified.timestamp());
1096        }
1097        debug!("done receiving");
1098        Ok(revision.try_into().unwrap())
1099    }
1100}