1use std::io::{Read, Write};
10use std::path::PathBuf;
11
12use clap::Parser;
13use mkit_core::hash::hash;
14use mkit_core::protocol::{PackKey, RefWriteCondition, Transport};
15use mkit_rpc::mkit::rpc::v1::ssh::{
16 DownloadPackHeader, HelloResponse, ListRefsResponse, PackChunk, PackExistsResponse,
17 ReadRefResponse, RefExpectation, SshFrame, UploadPack, UploadPackResponse,
18 list_refs_response::RefEntry, ssh_frame,
19};
20use mkit_rpc::mkit::rpc::v1::{ErrorCode, ProtocolVersion};
21use mkit_rpc::{FrameError, read_frame, write_frame};
22use mkit_transport_file::FileTransport;
23
24use crate::clap_shim;
25use crate::cli::CLI_VERSION;
26use crate::exit;
27
28#[derive(Debug, Parser)]
29#[command(
30 name = "mkit serve",
31 about = "Speak the mkit-rpc protocol on stdin/stdout (default) or on \
32 an encrypted TCP socket (--listen-enc)."
33)]
34struct ServeOpts {
35 path: String,
37 #[arg(long = "listen-enc", value_name = "ADDR")]
51 listen_enc: Option<String>,
52
53 #[arg(long = "enc-authorized-peers", value_name = "PATH")]
61 enc_authorized_peers: Option<String>,
62
63 #[arg(long = "enc-server-key", value_name = "PATH")]
68 enc_server_key: Option<String>,
69
70 #[arg(long = "unsafe-allow-any-enc-peer", default_value_t = false)]
74 unsafe_allow_any_enc_peer: bool,
75
76 #[arg(
83 long = "enc-idle-timeout-secs",
84 value_name = "SECS",
85 default_value_t = 60
86 )]
87 enc_idle_timeout_secs: u64,
88
89 #[arg(
94 long = "enc-handshake-timeout-secs",
95 value_name = "SECS",
96 default_value_t = 60
97 )]
98 enc_handshake_timeout_secs: u64,
99}
100
101pub(crate) const MAX_FRAMES_PER_CONN: u32 = 10_000;
107pub(crate) const MAX_BYTES_PER_CONN: u64 = 1024 * 1024 * 1024; const PACK_CHUNK_DATA_MAX: usize = 800 * 1024;
113
114#[must_use]
115pub fn run(args: &[String]) -> u8 {
116 let opts = match clap_shim::parse::<ServeOpts>("mkit serve", args) {
117 Ok(o) => o,
118 Err(code) => return code,
119 };
120
121 let repo_root = match resolve_repo_path(&opts.path) {
122 Ok(p) => p,
123 Err(code) => return code,
124 };
125
126 if let Some(addr) = opts.listen_enc.as_deref() {
127 return run_listen_enc(
128 addr,
129 repo_root,
130 opts.enc_authorized_peers.as_deref(),
131 opts.enc_server_key.as_deref(),
132 opts.unsafe_allow_any_enc_peer,
133 opts.enc_idle_timeout_secs,
134 opts.enc_handshake_timeout_secs,
135 );
136 }
137
138 let tx = FileTransport::new(&repo_root);
139 let stdin = std::io::stdin();
140 let stdout = std::io::stdout();
141 let mut r = stdin.lock();
142 let mut w = stdout.lock();
143
144 serve_loop(&tx, &mut r, &mut w)
145}
146
147#[cfg(not(feature = "enc-transport"))]
152#[allow(clippy::too_many_arguments)]
153fn run_listen_enc(
154 _addr: &str,
155 _repo_root: PathBuf,
156 _authorized_peers: Option<&str>,
157 _server_key: Option<&str>,
158 _unsafe_allow_any: bool,
159 _idle_timeout_secs: u64,
160 _handshake_timeout_secs: u64,
161) -> u8 {
162 eprintln!(
163 "mkit serve --listen-enc requires the `enc-transport` cargo feature; \
164 rebuild with `--features enc-transport` to enable it."
165 );
166 exit::UNAVAILABLE
167}
168
169#[cfg(feature = "enc-transport")]
170#[allow(
171 clippy::needless_pass_by_value,
172 clippy::manual_let_else,
173 clippy::items_after_statements,
174 clippy::cast_possible_truncation,
175 clippy::box_default,
176 clippy::too_many_lines,
177 clippy::too_many_arguments
178)]
179fn run_listen_enc(
180 addr: &str,
181 repo_root: PathBuf,
182 authorized_peers: Option<&str>,
183 server_key: Option<&str>,
184 unsafe_allow_any: bool,
185 idle_timeout_secs: u64,
186 handshake_timeout_secs: u64,
187) -> u8 {
188 use commonware_cryptography::Signer as _;
189 use mkit_transport_enc::{EncHandshakeBounds, PeerPolicy};
190 use std::sync::Arc;
191 use std::time::Duration;
192
193 let policy = match (authorized_peers, unsafe_allow_any) {
202 (Some(_), true) => {
203 eprintln!(
204 "mkit serve --listen-enc: --enc-authorized-peers and \
205 --unsafe-allow-any-enc-peer are mutually exclusive"
206 );
207 return exit::USAGE;
208 }
209 (Some(path), false) => match load_authorized_peers(path) {
210 Ok(set) if set.is_empty() => {
211 eprintln!(
212 "mkit serve --listen-enc: --enc-authorized-peers '{path}' \
213 contained no valid peer keys; refusing to bind (fail-closed)"
214 );
215 return exit::CONFIG_ERROR;
216 }
217 Ok(set) => PeerPolicy::Allowlist(set),
218 Err(msg) => {
219 eprintln!("mkit serve --listen-enc: {msg}");
220 return exit::CONFIG_ERROR;
221 }
222 },
223 (None, true) => {
224 eprintln!(
225 "============================================================\n\
226 WARNING: mkit serve --listen-enc --unsafe-allow-any-enc-peer\n\
227 The encrypted listener will accept ANY client that completes\n\
228 the handshake. There is NO client authentication. Use this\n\
229 only for local development or testing, NEVER in production.\n\
230 ============================================================"
231 );
232 PeerPolicy::AllowAny
233 }
234 (None, false) => {
235 eprintln!(
236 "mkit serve --listen-enc: refusing to bind without peer authorization.\n\
237 Pass --enc-authorized-peers <PATH> with an allowlist of client public keys,\n\
238 or --unsafe-allow-any-enc-peer to accept any peer (development only)."
239 );
240 return exit::CONFIG_ERROR;
241 }
242 };
243
244 let sk = match resolve_server_key(server_key, &policy) {
252 Ok(sk) => sk,
253 Err(code) => return code,
254 };
255
256 let pk = sk.public_key().to_string();
257 eprintln!(
258 "mkit serve --listen-enc on {addr} (server pubkey = {pk}); \
259 clients dial mkit+enc://<host>:<port>?pubkey={pk}"
260 );
261
262 let tx = Arc::new(FileTransport::new(&repo_root));
263
264 let idle_timeout = (idle_timeout_secs != 0).then(|| Duration::from_secs(idle_timeout_secs));
266
267 let serve_fn = move |sess: mkit_transport_enc::EncSession<
268 mkit_transport_enc::tokio_io::TokioStream,
269 mkit_transport_enc::tokio_io::TokioSink,
270 >,
271 _peer: commonware_cryptography::ed25519::PublicKey| {
272 let tx = tx.clone();
273 async move { serve_enc_session(tx, sess, idle_timeout).await }
277 };
278
279 let bounds = EncHandshakeBounds {
283 handshake_timeout: Duration::from_secs(handshake_timeout_secs),
284 ..EncHandshakeBounds::default()
285 };
286
287 match mkit_transport_enc::serve_tcp_with_policy_and_bounds(addr, sk, policy, bounds, serve_fn) {
288 Ok(()) => exit::OK,
289 Err(e) => {
290 eprintln!("mkit serve --listen-enc: {e}");
291 exit::TEMPFAIL
292 }
293 }
294}
295
296#[cfg(feature = "enc-transport")]
302fn load_authorized_peers(path: &str) -> Result<std::collections::HashSet<[u8; 32]>, String> {
303 let contents = std::fs::read_to_string(path)
304 .map_err(|e| format!("failed to read authorized-peers file '{path}': {e}"))?;
305 let mut set = std::collections::HashSet::new();
306 for (lineno, raw) in contents.lines().enumerate() {
307 let line = raw.trim();
308 if line.is_empty() || line.starts_with('#') {
309 continue;
310 }
311 let key = decode_peer_pubkey_line(line)
312 .map_err(|msg| format!("authorized-peers '{path}' line {}: {msg}", lineno + 1))?;
313 set.insert(key);
314 }
315 Ok(set)
316}
317
318#[cfg(feature = "enc-transport")]
322fn decode_peer_pubkey_line(s: &str) -> Result<[u8; 32], String> {
323 if s.len() == 64 && s.bytes().all(|b| b.is_ascii_hexdigit()) {
324 let mut out = [0u8; 32];
325 for (i, byte) in out.iter_mut().enumerate() {
326 let hi = hex_nibble(s.as_bytes()[i * 2])?;
327 let lo = hex_nibble(s.as_bytes()[i * 2 + 1])?;
328 *byte = (hi << 4) | lo;
329 }
330 return Ok(out);
331 }
332 if s.len() == 43 && s.bytes().all(is_b64url_byte) {
333 return decode_b64url_pubkey(s);
334 }
335 Err("peer key must be 64 hex chars or 43 url-safe base64 chars".to_string())
336}
337
338#[cfg(feature = "enc-transport")]
339fn hex_nibble(b: u8) -> Result<u8, String> {
340 match b {
341 b'0'..=b'9' => Ok(b - b'0'),
342 b'a'..=b'f' => Ok(10 + b - b'a'),
343 b'A'..=b'F' => Ok(10 + b - b'A'),
344 _ => Err("invalid hex digit".to_string()),
345 }
346}
347
348#[cfg(feature = "enc-transport")]
349const fn is_b64url_byte(b: u8) -> bool {
350 matches!(b, b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_')
351}
352
353#[cfg(feature = "enc-transport")]
358#[allow(clippy::cast_possible_truncation)] fn decode_b64url_pubkey(s: &str) -> Result<[u8; 32], String> {
360 let bytes = s.as_bytes();
361 let mut buf = [0u8; 44];
363 buf[..43].copy_from_slice(bytes);
364 buf[43] = b'A';
365 let mut out = [0u8; 32];
366 let mut out_pos = 0usize;
367 for chunk in buf.chunks_exact(4) {
368 let v0 = b64url_nibble(chunk[0])?;
369 let v1 = b64url_nibble(chunk[1])?;
370 let v2 = b64url_nibble(chunk[2])?;
371 let v3 = b64url_nibble(chunk[3])?;
372 let triple =
373 (u32::from(v0) << 18) | (u32::from(v1) << 12) | (u32::from(v2) << 6) | u32::from(v3);
374 if out_pos < 32 {
375 out[out_pos] = (triple >> 16) as u8;
376 }
377 if out_pos + 1 < 32 {
378 out[out_pos + 1] = (triple >> 8) as u8;
379 }
380 if out_pos + 2 < 32 {
381 out[out_pos + 2] = triple as u8;
382 }
383 out_pos += 3;
384 }
385 let last = b64url_nibble(bytes[42])?;
388 if last & 0b0000_0011 != 0 {
389 return Err("base64 peer key has non-zero trailing bits".to_string());
390 }
391 Ok(out)
392}
393
394#[cfg(feature = "enc-transport")]
395fn b64url_nibble(b: u8) -> Result<u8, String> {
396 match b {
397 b'A'..=b'Z' => Ok(b - b'A'),
398 b'a'..=b'z' => Ok(26 + b - b'a'),
399 b'0'..=b'9' => Ok(52 + b - b'0'),
400 b'-' => Ok(62),
401 b'_' => Ok(63),
402 _ => Err("invalid base64 url-safe digit".to_string()),
403 }
404}
405
406#[cfg(feature = "enc-transport")]
415fn resolve_server_key(
416 server_key: Option<&str>,
417 policy: &mkit_transport_enc::PeerPolicy,
418) -> Result<commonware_cryptography::ed25519::PrivateKey, u8> {
419 use mkit_transport_enc::PeerPolicy;
420
421 match (server_key, policy) {
422 (Some(path), _) => load_or_create_server_key(std::path::Path::new(path)),
423 (None, PeerPolicy::Allowlist(_)) => {
424 let Some(home) = crate::config::home_dir_for_euid() else {
425 eprintln!(
426 "mkit serve --listen-enc: cannot resolve a user-scoped key path; \
427 pass --enc-server-key <PATH>"
428 );
429 return Err(exit::CONFIG_ERROR);
430 };
431 let path = home.join(".config/mkit/enc/server.key");
432 load_or_create_server_key(&path)
433 }
434 (None, PeerPolicy::AllowAny) => ephemeral_server_key(),
435 }
436}
437
438#[cfg(feature = "enc-transport")]
441fn load_or_create_server_key(
442 path: &std::path::Path,
443) -> Result<commonware_cryptography::ed25519::PrivateKey, u8> {
444 use commonware_codec::DecodeExt as _;
445 use commonware_cryptography::ed25519::PrivateKey;
446
447 if !path.exists() {
448 if let Some(parent) = path.parent() {
449 if let Err(e) = std::fs::create_dir_all(parent) {
450 eprintln!(
451 "mkit serve --listen-enc: create key dir '{}': {e}",
452 parent.display()
453 );
454 return Err(exit::CANTCREAT);
455 }
456 #[cfg(unix)]
457 {
458 use std::os::unix::fs::PermissionsExt as _;
459 let _ = std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700));
460 }
461 }
462 let mut secret = zeroize::Zeroizing::new([0u8; 32]);
463 if getrandom::fill(secret.as_mut()).is_err() {
464 eprintln!("mkit serve --listen-enc: failed to read system RNG for server key");
465 return Err(exit::TEMPFAIL);
466 }
467 match mkit_core::sign::save_raw_32_create_new(path, &secret) {
470 Ok(_created) => {}
471 Err(e) => {
472 eprintln!(
473 "mkit serve --listen-enc: write server key '{}': {e}",
474 path.display()
475 );
476 return Err(exit::CANTCREAT);
477 }
478 }
479 }
480
481 let seed = match mkit_core::sign::load_raw_32(path) {
482 Ok(s) => s,
483 Err(e) => {
484 eprintln!(
485 "mkit serve --listen-enc: load server key '{}': {e}",
486 path.display()
487 );
488 return Err(exit::NOPERM);
489 }
490 };
491 PrivateKey::decode(seed.as_ref()).map_err(|e| {
492 eprintln!("mkit serve --listen-enc: server key construction failed: {e}");
493 exit::DATAERR
494 })
495}
496
497#[cfg(feature = "enc-transport")]
499fn ephemeral_server_key() -> Result<commonware_cryptography::ed25519::PrivateKey, u8> {
500 use commonware_codec::DecodeExt as _;
501 use commonware_cryptography::ed25519::PrivateKey;
502
503 let mut secret = zeroize::Zeroizing::new([0u8; 32]);
504 if getrandom::fill(secret.as_mut()).is_err() {
505 eprintln!("mkit serve --listen-enc: failed to read system RNG for ephemeral key");
506 return Err(exit::TEMPFAIL);
507 }
508 PrivateKey::decode(secret.as_ref()).map_err(|e| {
509 eprintln!("mkit serve --listen-enc: ephemeral key construction failed: {e}");
510 exit::TEMPFAIL
511 })
512}
513
514#[cfg(feature = "enc-transport")]
515async fn serve_enc_session(
516 tx: std::sync::Arc<FileTransport>,
517 sess: mkit_transport_enc::EncSession<
518 mkit_transport_enc::tokio_io::TokioStream,
519 mkit_transport_enc::tokio_io::TokioSink,
520 >,
521 idle_timeout: Option<std::time::Duration>,
522) {
523 use mkit_transport_enc::send_frame;
524
525 let (mut sender, mut receiver) = sess.into_parts();
526 let Ok(frame) = recv_frame_idle(&mut receiver, idle_timeout).await else {
530 return;
531 };
532 let proto = match frame.body {
533 Some(ssh_frame::Body::Hello(h)) => h.proto.unwrap_or_default(),
534 _ => return,
535 };
536 if proto != ProtocolVersion::ProtocolVersion1 {
537 return;
538 }
539 let resp = SshFrame {
540 body: Some(ssh_frame::Body::HelloResponse(Box::new(HelloResponse {
541 proto: Some(ProtocolVersion::ProtocolVersion1.into()),
542 server_id: Some(format!("mkit serve-enc/{}", crate::cli::CLI_VERSION)),
543 ..Default::default()
544 }))),
545 ..Default::default()
546 };
547 if send_frame(&mut sender, &resp).await.is_err() {
548 return;
549 }
550
551 loop {
555 let Ok(frame) = recv_frame_idle(&mut receiver, idle_timeout).await else {
556 return;
557 };
558 if let Some(ssh_frame::Body::Close(_)) = frame.body {
559 return;
560 }
561 if dispatch_enc_one(&tx, frame, &mut sender, &mut receiver, idle_timeout)
562 .await
563 .is_err()
564 {
565 return;
566 }
567 }
568}
569
570#[cfg(feature = "enc-transport")]
574async fn recv_frame_idle(
575 receiver: &mut mkit_transport_enc::EncReceiver<mkit_transport_enc::tokio_io::TokioStream>,
576 idle_timeout: Option<std::time::Duration>,
577) -> Result<SshFrame, ()> {
578 match idle_timeout {
579 Some(d) => mkit_transport_enc::recv_frame_within(receiver, d)
580 .await
581 .map_err(|_| ()),
582 None => mkit_transport_enc::recv_frame(receiver)
583 .await
584 .map_err(|_| ()),
585 }
586}
587
588#[cfg(feature = "enc-transport")]
597#[allow(
598 clippy::too_many_lines,
599 clippy::items_after_statements,
600 clippy::cast_possible_truncation,
601 clippy::box_default,
602 clippy::manual_let_else
603)]
604async fn dispatch_enc_one(
605 tx: &FileTransport,
606 frame: SshFrame,
607 sender: &mut mkit_transport_enc::EncSender<mkit_transport_enc::tokio_io::TokioSink>,
608 receiver: &mut mkit_transport_enc::EncReceiver<mkit_transport_enc::tokio_io::TokioStream>,
609 idle_timeout: Option<std::time::Duration>,
610) -> Result<(), ()> {
611 use mkit_core::protocol::PackKey;
612 use mkit_rpc::mkit::rpc::v1::ssh::DownloadPackHeader;
613 use mkit_transport_enc::send_frame;
614
615 async fn send_body(
616 sender: &mut mkit_transport_enc::EncSender<mkit_transport_enc::tokio_io::TokioSink>,
617 body: ssh_frame::Body,
618 ) -> Result<(), ()> {
619 let frame = SshFrame {
620 body: Some(body),
621 ..Default::default()
622 };
623 send_frame(sender, &frame).await.map_err(|_| ())
624 }
625 async fn send_err(
626 sender: &mut mkit_transport_enc::EncSender<mkit_transport_enc::tokio_io::TokioSink>,
627 code: ErrorCode,
628 msg: &str,
629 ) -> Result<(), ()> {
630 send_frame(sender, &mkit_rpc::ssh_error_frame(code, msg))
631 .await
632 .map_err(|_| ())
633 }
634 fn pack_key_from(b: Option<&Vec<u8>>) -> Result<PackKey, ()> {
635 let v = b.ok_or(())?;
636 if v.len() != 32 {
637 return Err(());
638 }
639 let mut h = [0u8; 32];
640 h.copy_from_slice(v);
641 Ok(PackKey(h))
642 }
643
644 match frame.body {
645 Some(ssh_frame::Body::PackExists(req)) => {
646 let key = pack_key_from(req.pack_id.as_ref())?;
647 let exists = tx.pack_exists(&key).unwrap_or(false);
648 send_body(
649 sender,
650 ssh_frame::Body::PackExistsResponse(Box::new(PackExistsResponse {
651 exists: Some(exists),
652 ..Default::default()
653 })),
654 )
655 .await
656 }
657 Some(ssh_frame::Body::DownloadPack(req)) => {
658 let key = pack_key_from(req.pack_id.as_ref())?;
659 match tx.download_pack(&key) {
660 Ok(bytes) => {
661 send_body(
662 sender,
663 ssh_frame::Body::DownloadPackHeader(Box::new(DownloadPackHeader {
664 total_bytes: Some(bytes.len() as u64),
665 ..Default::default()
666 })),
667 )
668 .await?;
669 let mut iter_pos = 0usize;
670 let mut offset = 0u64;
671 let total = bytes.len();
672 if total == 0 {
673 return send_body(
674 sender,
675 ssh_frame::Body::PackChunk(Box::new(PackChunk {
676 pack_id: req.pack_id.clone(),
677 offset: Some(0),
678 data: Some(Vec::new()),
679 last: Some(true),
680 ..Default::default()
681 })),
682 )
683 .await;
684 }
685 const PACK_CHUNK_DATA_MAX: usize = 800 * 1024;
686 while iter_pos < total {
687 let end = core::cmp::min(iter_pos + PACK_CHUNK_DATA_MAX, total);
688 send_body(
689 sender,
690 ssh_frame::Body::PackChunk(Box::new(PackChunk {
691 pack_id: req.pack_id.clone(),
692 offset: Some(offset),
693 data: Some(bytes[iter_pos..end].to_vec()),
694 last: Some(end == total),
695 ..Default::default()
696 })),
697 )
698 .await?;
699 offset += (end - iter_pos) as u64;
700 iter_pos = end;
701 }
702 Ok(())
703 }
704 Err(_) => send_err(sender, ErrorCode::KeyNotFound, "pack not found").await,
705 }
706 }
707 Some(ssh_frame::Body::UploadPack(header)) => {
708 let mut upload = match UploadDrain::new(&header) {
709 Ok(upload) => upload,
710 Err(e) => {
711 return send_err(sender, ErrorCode::InvalidRequest, e.message()).await;
712 }
713 };
714 loop {
715 let f = recv_frame_idle(receiver, idle_timeout).await?;
716 let Some(ssh_frame::Body::PackChunk(chunk)) = f.body else {
717 return send_err(
718 sender,
719 ErrorCode::InvalidRequest,
720 "expected PackChunk after UploadPack",
721 )
722 .await;
723 };
724 let complete = match upload.push_chunk(&chunk) {
725 Ok(complete) => complete,
726 Err(e) => {
727 return send_err(sender, ErrorCode::InvalidRequest, e.message()).await;
728 }
729 };
730 if complete {
731 break;
732 }
733 }
734 let (bytes, key) = upload.into_parts();
735 match tx.upload_pack(&bytes, &key) {
736 Ok(()) => {
737 send_body(
738 sender,
739 ssh_frame::Body::UploadPackResponse(
740 Box::new(UploadPackResponse::default()),
741 ),
742 )
743 .await
744 }
745 Err(_) => send_err(sender, ErrorCode::Internal, "upload failed").await,
746 }
747 }
748 Some(ssh_frame::Body::ReadRef(req)) => {
749 let name = req.name.unwrap_or_default();
750 match tx.read_ref(&name) {
751 Ok(Some(h)) => {
752 send_body(
753 sender,
754 ssh_frame::Body::ReadRefResponse(Box::new(ReadRefResponse {
755 object_id: Some(h.to_vec()),
756 ..Default::default()
757 })),
758 )
759 .await
760 }
761 Ok(None) => {
762 send_body(
763 sender,
764 ssh_frame::Body::ReadRefResponse(Box::new(ReadRefResponse {
765 object_id: Some(Vec::new()),
766 ..Default::default()
767 })),
768 )
769 .await
770 }
771 Err(_) => send_err(sender, ErrorCode::Internal, "read ref failed").await,
772 }
773 }
774 Some(ssh_frame::Body::UpdateRef(req)) => {
775 use mkit_core::protocol::RefWriteCondition;
776 let name = req.name.unwrap_or_default();
777 let new_id = req.new_id.unwrap_or_default();
778 if new_id.len() != 32 {
779 return send_err(sender, ErrorCode::InvalidRequest, "new_id must be 32 bytes")
780 .await;
781 }
782 let mut new_h = [0u8; 32];
783 new_h.copy_from_slice(&new_id);
784 let expectation = req
785 .expectation
786 .as_ref()
787 .and_then(buffa::EnumValue::as_known)
788 .unwrap_or(RefExpectation::Unspecified);
789 let condition = match expectation {
790 RefExpectation::Any => RefWriteCondition::Any,
791 RefExpectation::Missing => RefWriteCondition::Missing,
792 RefExpectation::Match => {
793 let bytes = req.expected_id.as_deref().unwrap_or(&[]);
794 if bytes.len() != 32 {
795 return send_err(
796 sender,
797 ErrorCode::InvalidRequest,
798 "MATCH expectation requires a 32-byte expected_id",
799 )
800 .await;
801 }
802 let mut e = [0u8; 32];
803 e.copy_from_slice(bytes);
804 RefWriteCondition::Match(e)
805 }
806 RefExpectation::Unspecified => {
807 return send_err(
808 sender,
809 ErrorCode::InvalidRequest,
810 "UpdateRef.expectation is required",
811 )
812 .await;
813 }
814 };
815 match tx.update_ref(&name, condition, &new_h) {
816 Ok(()) => {
817 send_body(sender, ssh_frame::Body::UpdateRefResponse(Box::default())).await
818 }
819 Err(_) => send_err(sender, ErrorCode::InvalidRequest, "update ref failed").await,
820 }
821 }
822 Some(ssh_frame::Body::ListRefs(req)) => {
823 let prefix = req.prefix.unwrap_or_default();
824 match tx.list_refs(&prefix) {
825 Ok(refs) => {
826 let entries: Vec<RefEntry> = refs
827 .into_iter()
828 .map(|r| RefEntry {
829 name: Some(r.name),
830 object_id: r.hash.map(|h| h.to_vec()),
831 ..Default::default()
832 })
833 .collect();
834 send_body(
835 sender,
836 ssh_frame::Body::ListRefsResponse(Box::new(ListRefsResponse {
837 refs: entries,
838 ..Default::default()
839 })),
840 )
841 .await
842 }
843 Err(_) => send_err(sender, ErrorCode::Internal, "list refs failed").await,
844 }
845 }
846 _ => send_err(sender, ErrorCode::InvalidRequest, "unexpected frame").await,
847 }
848}
849
850pub(crate) fn resolve_repo_path(path: &str) -> Result<PathBuf, u8> {
852 let resolved = std::fs::canonicalize(path).map_err(|_| exit::NOINPUT)?;
853 if !resolved.is_dir() {
854 return Err(exit::DATAERR);
855 }
856 if !resolved.join(".mkit").is_dir() {
857 return Err(exit::DATAERR);
858 }
859 if let Ok(root) = std::env::var("MKIT_SERVE_ROOT") {
860 let pinned = std::fs::canonicalize(&root).map_err(|_| exit::NOPERM)?;
861 if !resolved.starts_with(&pinned) {
862 return Err(exit::NOPERM);
863 }
864 }
865 Ok(resolved)
866}
867
868pub(crate) fn serve_loop(tx: &FileTransport, r: &mut impl Read, w: &mut impl Write) -> u8 {
871 if !handshake(r, w) {
872 return exit::PROTOCOL_ERROR;
873 }
874 let mut frame_count: u32 = 0;
875 let mut byte_count: u64 = 0;
876
877 loop {
878 let frame: SshFrame = match read_frame(r) {
879 Ok(f) => f,
880 Err(FrameError::LengthTruncated) => return exit::OK,
881 Err(_) => {
882 let _ = emit_error(w, ErrorCode::InvalidRequest, "frame parse error");
883 return exit::PROTOCOL_ERROR;
884 }
885 };
886
887 frame_count = frame_count.saturating_add(1);
888 if frame_count > MAX_FRAMES_PER_CONN {
889 let _ = emit_error(
890 w,
891 ErrorCode::InvalidRequest,
892 "per-connection frame budget exceeded",
893 );
894 return exit::PROTOCOL_ERROR;
895 }
896
897 byte_count = byte_count.saturating_add(frame_byte_estimate(&frame));
902 if byte_count > MAX_BYTES_PER_CONN {
903 let _ = emit_error(
904 w,
905 ErrorCode::InvalidRequest,
906 "per-connection byte budget exceeded",
907 );
908 return exit::PROTOCOL_ERROR;
909 }
910
911 match frame.body {
912 Some(ssh_frame::Body::Close(_)) => return exit::OK,
913 body => {
914 if dispatch(tx, body, w, r).is_err() {
915 return exit::OK;
916 }
917 }
918 }
919 }
920}
921
922fn handshake(r: &mut impl Read, w: &mut impl Write) -> bool {
923 let frame: SshFrame = match read_frame(r) {
924 Ok(f) => f,
925 Err(_) => return false,
926 };
927 let Some(ssh_frame::Body::Hello(hello)) = frame.body else {
928 let _ = emit_error(w, ErrorCode::InvalidRequest, "first frame must be Hello");
929 return false;
930 };
931 let proto = hello.proto.unwrap_or_default();
932 if proto != ProtocolVersion::ProtocolVersion1 {
933 let _ = emit_error(
934 w,
935 ErrorCode::InvalidRequest,
936 &format!("unsupported proto_version {}", proto.to_i32()),
937 );
938 return false;
939 }
940 let resp = SshFrame {
941 body: Some(ssh_frame::Body::HelloResponse(Box::new(HelloResponse {
942 proto: Some(ProtocolVersion::ProtocolVersion1.into()),
943 server_id: Some(format!("mkit serve/{CLI_VERSION}")),
944 ..Default::default()
945 }))),
946 ..Default::default()
947 };
948 write_frame(w, &resp).is_ok()
949}
950
951#[allow(clippy::too_many_lines)]
952fn dispatch(
953 tx: &FileTransport,
954 body: Option<ssh_frame::Body>,
955 w: &mut impl Write,
956 r: &mut impl Read,
957) -> std::io::Result<()> {
958 match body {
959 Some(ssh_frame::Body::PackExists(req)) => {
960 let key = pack_key_from_bytes(req.pack_id.as_ref())?;
961 let exists = tx.pack_exists(&key).unwrap_or(false);
962 send(
963 w,
964 ssh_frame::Body::PackExistsResponse(Box::new(PackExistsResponse {
965 exists: Some(exists),
966 ..Default::default()
967 })),
968 )
969 }
970 Some(ssh_frame::Body::DownloadPack(req)) => {
971 let key = pack_key_from_bytes(req.pack_id.as_ref())?;
972 match tx.download_pack(&key) {
973 Ok(bytes) => {
974 send(
975 w,
976 ssh_frame::Body::DownloadPackHeader(Box::new(DownloadPackHeader {
977 total_bytes: Some(bytes.len() as u64),
978 ..Default::default()
979 })),
980 )?;
981 let mut iter_pos = 0usize;
982 let mut offset = 0u64;
983 let total = bytes.len();
984 if total == 0 {
985 send(
986 w,
987 ssh_frame::Body::PackChunk(Box::new(PackChunk {
988 pack_id: req.pack_id.clone(),
989 offset: Some(0),
990 data: Some(Vec::new()),
991 last: Some(true),
992 ..Default::default()
993 })),
994 )?;
995 } else {
996 while iter_pos < total {
997 let end = core::cmp::min(iter_pos + PACK_CHUNK_DATA_MAX, total);
998 send(
999 w,
1000 ssh_frame::Body::PackChunk(Box::new(PackChunk {
1001 pack_id: req.pack_id.clone(),
1002 offset: Some(offset),
1003 data: Some(bytes[iter_pos..end].to_vec()),
1004 last: Some(end == total),
1005 ..Default::default()
1006 })),
1007 )?;
1008 offset += (end - iter_pos) as u64;
1009 iter_pos = end;
1010 }
1011 }
1012 Ok(())
1013 }
1014 Err(_) => emit_error(w, ErrorCode::KeyNotFound, "pack not found"),
1015 }
1016 }
1017 Some(ssh_frame::Body::UploadPack(header)) => {
1018 let mut upload = match UploadDrain::new(&header) {
1019 Ok(upload) => upload,
1020 Err(e) => {
1021 return emit_error(w, ErrorCode::InvalidRequest, e.message());
1022 }
1023 };
1024 loop {
1025 let frame: SshFrame = match read_frame(r) {
1026 Ok(f) => f,
1027 Err(_) => {
1028 return emit_error(w, ErrorCode::InvalidRequest, "pack chunk read failed");
1029 }
1030 };
1031 let Some(ssh_frame::Body::PackChunk(chunk)) = frame.body else {
1032 return emit_error(
1033 w,
1034 ErrorCode::InvalidRequest,
1035 "expected PackChunk after UploadPack",
1036 );
1037 };
1038 let complete = match upload.push_chunk(&chunk) {
1039 Ok(complete) => complete,
1040 Err(e) => {
1041 return emit_error(w, ErrorCode::InvalidRequest, e.message());
1042 }
1043 };
1044 if complete {
1045 break;
1046 }
1047 }
1048 let (bytes, key) = upload.into_parts();
1049 match tx.upload_pack(&bytes, &key) {
1050 Ok(()) => send(
1051 w,
1052 ssh_frame::Body::UploadPackResponse(Box::new(UploadPackResponse {
1053 ..Default::default()
1054 })),
1055 ),
1056 Err(_) => emit_error(w, ErrorCode::Internal, "upload failed"),
1057 }
1058 }
1059 Some(ssh_frame::Body::ReadRef(req)) => {
1060 let name = req.name.unwrap_or_default();
1061 match tx.read_ref(&name) {
1062 Ok(Some(h)) => send(
1063 w,
1064 ssh_frame::Body::ReadRefResponse(Box::new(ReadRefResponse {
1065 object_id: Some(h.to_vec()),
1066 ..Default::default()
1067 })),
1068 ),
1069 Ok(None) => send(
1070 w,
1071 ssh_frame::Body::ReadRefResponse(Box::new(ReadRefResponse {
1072 object_id: Some(Vec::new()),
1073 ..Default::default()
1074 })),
1075 ),
1076 Err(_) => emit_error(w, ErrorCode::Internal, "read ref failed"),
1077 }
1078 }
1079 Some(ssh_frame::Body::UpdateRef(req)) => {
1080 let name = req.name.unwrap_or_default();
1081 let new_id = req.new_id.unwrap_or_default();
1082 if new_id.len() != 32 {
1083 return emit_error(w, ErrorCode::InvalidRequest, "new_id must be 32 bytes");
1084 }
1085 let mut new_h = [0u8; 32];
1086 new_h.copy_from_slice(&new_id);
1087 let expectation = req
1091 .expectation
1092 .as_ref()
1093 .and_then(buffa::EnumValue::as_known)
1094 .unwrap_or(RefExpectation::Unspecified);
1095 let condition = match expectation {
1096 RefExpectation::Any => RefWriteCondition::Any,
1097 RefExpectation::Missing => RefWriteCondition::Missing,
1098 RefExpectation::Match => {
1099 let bytes = req.expected_id.as_deref().unwrap_or(&[]);
1100 if bytes.len() != 32 {
1101 return emit_error(
1102 w,
1103 ErrorCode::InvalidRequest,
1104 "MATCH expectation requires a 32-byte expected_id",
1105 );
1106 }
1107 let mut e = [0u8; 32];
1108 e.copy_from_slice(bytes);
1109 RefWriteCondition::Match(e)
1110 }
1111 RefExpectation::Unspecified => {
1112 return emit_error(
1113 w,
1114 ErrorCode::InvalidRequest,
1115 "UpdateRef.expectation is required",
1116 );
1117 }
1118 };
1119 match tx.update_ref(&name, condition, &new_h) {
1120 Ok(()) => send(w, ssh_frame::Body::UpdateRefResponse(Box::default())),
1121 Err(_) => emit_error(w, ErrorCode::InvalidRequest, "update ref failed"),
1122 }
1123 }
1124 Some(ssh_frame::Body::ListRefs(req)) => {
1125 let prefix = req.prefix.unwrap_or_default();
1126 match tx.list_refs(&prefix) {
1127 Ok(refs) => {
1128 let entries: Vec<RefEntry> = refs
1129 .into_iter()
1130 .map(|r| RefEntry {
1131 name: Some(r.name),
1132 object_id: r.hash.map(|h| h.to_vec()),
1133 ..Default::default()
1134 })
1135 .collect();
1136 send(
1137 w,
1138 ssh_frame::Body::ListRefsResponse(Box::new(ListRefsResponse {
1139 refs: entries,
1140 ..Default::default()
1141 })),
1142 )
1143 }
1144 Err(_) => emit_error(w, ErrorCode::Internal, "list refs failed"),
1145 }
1146 }
1147 Some(ssh_frame::Body::PackChunk(_)) => emit_error(
1148 w,
1149 ErrorCode::InvalidRequest,
1150 "PackChunk arrived without UploadPack header",
1151 ),
1152 Some(ssh_frame::Body::Hello(_)) => {
1153 emit_error(w, ErrorCode::InvalidRequest, "Hello after handshake")
1154 }
1155 Some(_) => emit_error(w, ErrorCode::InvalidRequest, "unexpected request frame"),
1156 None => emit_error(w, ErrorCode::InvalidRequest, "empty frame"),
1157 }
1158}
1159
1160fn send(w: &mut impl Write, body: ssh_frame::Body) -> std::io::Result<()> {
1161 let frame = SshFrame {
1162 body: Some(body),
1163 ..Default::default()
1164 };
1165 write_frame(w, &frame).map_err(|_| std::io::Error::other("frame write"))
1166}
1167
1168struct UploadDrain {
1169 key: PackKey,
1170 expected_total: u64,
1171 next_offset: u64,
1172 chunks: u32,
1173 bytes: Vec<u8>,
1174}
1175
1176#[derive(Debug, Clone, Copy)]
1177struct UploadDrainError(&'static str);
1178
1179impl UploadDrainError {
1180 fn message(self) -> &'static str {
1181 self.0
1182 }
1183}
1184
1185impl UploadDrain {
1186 fn new(header: &UploadPack) -> Result<Self, UploadDrainError> {
1187 let key = pack_key_from_upload(header.pack_id.as_deref())?;
1188 let expected_total = header
1189 .total_bytes
1190 .ok_or(UploadDrainError("UploadPack.total_bytes is required"))?;
1191 if expected_total > MAX_BYTES_PER_CONN {
1192 return Err(UploadDrainError(
1193 "UploadPack.total_bytes exceeds server cap",
1194 ));
1195 }
1196 Ok(Self {
1197 key,
1198 expected_total,
1199 next_offset: 0,
1200 chunks: 0,
1201 bytes: Vec::new(),
1202 })
1203 }
1204
1205 fn push_chunk(&mut self, chunk: &PackChunk) -> Result<bool, UploadDrainError> {
1206 self.chunks = self.chunks.saturating_add(1);
1207 if self.chunks > MAX_FRAMES_PER_CONN {
1208 return Err(UploadDrainError(
1209 "too many PackChunk frames before last=true",
1210 ));
1211 }
1212
1213 let chunk_key = pack_key_from_upload(chunk.pack_id.as_deref())?;
1214 if chunk_key.as_bytes() != self.key.as_bytes() {
1215 return Err(UploadDrainError(
1216 "PackChunk.pack_id does not match UploadPack",
1217 ));
1218 }
1219
1220 let offset = chunk
1221 .offset
1222 .ok_or(UploadDrainError("PackChunk.offset is required"))?;
1223 if offset != self.next_offset {
1224 return Err(UploadDrainError(
1225 "PackChunk.offset is not the expected next offset",
1226 ));
1227 }
1228
1229 let data = chunk.data.as_deref().unwrap_or(&[]);
1230 let data_len = u64::try_from(data.len())
1231 .map_err(|_| UploadDrainError("PackChunk.data length overflows u64"))?;
1232 let new_total = self
1233 .next_offset
1234 .checked_add(data_len)
1235 .ok_or(UploadDrainError("PackChunk byte count overflow"))?;
1236 if new_total > self.expected_total {
1237 return Err(UploadDrainError(
1238 "PackChunk data exceeds declared total_bytes",
1239 ));
1240 }
1241
1242 self.bytes.extend_from_slice(data);
1243 self.next_offset = new_total;
1244
1245 if !chunk.last.unwrap_or(false) {
1246 return Ok(false);
1247 }
1248 if self.next_offset != self.expected_total {
1249 return Err(UploadDrainError(
1250 "PackChunk stream ended before declared total_bytes",
1251 ));
1252 }
1253 if hash(&self.bytes) != *self.key.as_bytes() {
1254 return Err(UploadDrainError(
1255 "uploaded pack bytes do not match UploadPack.pack_id",
1256 ));
1257 }
1258 Ok(true)
1259 }
1260
1261 fn into_parts(self) -> (Vec<u8>, PackKey) {
1262 (self.bytes, self.key)
1263 }
1264}
1265
1266fn emit_error(w: &mut impl Write, code: ErrorCode, message: &str) -> std::io::Result<()> {
1269 write_frame(w, &mkit_rpc::ssh_error_frame(code, message))
1270 .map_err(|_| std::io::Error::other("frame write"))
1271}
1272
1273fn pack_key_from_bytes(bytes: Option<&Vec<u8>>) -> std::io::Result<PackKey> {
1274 let b = bytes.ok_or_else(|| std::io::Error::other("pack_id missing"))?;
1275 if b.len() != 32 {
1276 return Err(std::io::Error::other("pack_id must be 32 bytes"));
1277 }
1278 let mut h = [0u8; 32];
1279 h.copy_from_slice(b);
1280 Ok(PackKey(h))
1281}
1282
1283fn pack_key_from_upload(bytes: Option<&[u8]>) -> Result<PackKey, UploadDrainError> {
1284 let b = bytes.ok_or(UploadDrainError("pack_id missing"))?;
1285 if b.len() != 32 {
1286 return Err(UploadDrainError("pack_id must be 32 bytes"));
1287 }
1288 let mut h = [0u8; 32];
1289 h.copy_from_slice(b);
1290 Ok(PackKey(h))
1291}
1292
1293fn frame_byte_estimate(f: &SshFrame) -> u64 {
1296 use ssh_frame::Body;
1297 match &f.body {
1298 Some(Body::PackChunk(c)) => c.data.as_ref().map_or(0, Vec::len) as u64,
1299 Some(Body::UploadPack(h)) => h.total_bytes.unwrap_or(0),
1300 Some(Body::DownloadPackHeader(h)) => h.total_bytes.unwrap_or(0),
1301 _ => 64, }
1303}
1304
1305#[cfg(feature = "sparse-checkout")]
1333#[derive(Debug, thiserror::Error)]
1334pub enum SparseServeError {
1335 #[error("sparse build: {0}")]
1338 Build(#[from] mkit_core::sparse::SparseError),
1339}
1340
1341#[cfg(feature = "sparse-checkout")]
1356pub fn build_sparse_response_from_tree(
1357 tree: &mkit_core::object::Tree,
1358 filter: &[std::path::PathBuf],
1359) -> Result<mkit_core::sparse::SparseResponse, SparseServeError> {
1360 let (entries, manifest, proof) = mkit_core::sparse::build_sparse(tree, filter)?;
1361 Ok(mkit_core::sparse::SparseResponse {
1362 manifest,
1363 entries,
1364 proof,
1365 })
1366}
1367
1368#[cfg(feature = "sparse-checkout")]
1382pub fn build_sparse_response_from_store(
1383 store: &mkit_core::store::ObjectStore,
1384 tree_hash: &mkit_core::hash::Hash,
1385 filter: &[std::path::PathBuf],
1386) -> Result<mkit_core::sparse::SparseResponse, String> {
1387 use mkit_core::object::Object;
1388 let tree = match store.read_object(tree_hash) {
1389 Ok(Object::Tree(t)) => t,
1390 Ok(_) => return Err("addressed object is not a tree".to_string()),
1391 Err(e) => return Err(format!("read tree: {e}")),
1392 };
1393 build_sparse_response_from_tree(&tree, filter).map_err(|e| e.to_string())
1394}
1395
1396#[cfg(test)]
1397mod tests {
1398 use super::*;
1399 use crate::exit;
1400 use std::fs;
1401 use std::io::Cursor;
1402
1403 fn upload_header(pack_id: Vec<u8>, total_bytes: Option<u64>) -> UploadPack {
1404 UploadPack {
1405 pack_id: Some(pack_id),
1406 total_bytes,
1407 ..Default::default()
1408 }
1409 }
1410
1411 fn upload_chunk(pack_id: Vec<u8>, offset: Option<u64>, data: &[u8], last: bool) -> PackChunk {
1412 PackChunk {
1413 pack_id: Some(pack_id),
1414 offset,
1415 data: Some(data.to_vec()),
1416 last: Some(last),
1417 ..Default::default()
1418 }
1419 }
1420
1421 fn valid_pack() -> (Vec<u8>, PackKey) {
1422 let bytes = b"valid pack bytes".to_vec();
1423 let key = PackKey::new(hash(&bytes));
1424 (bytes, key)
1425 }
1426
1427 #[test]
1428 fn resolve_repo_path_rejects_missing_path() {
1429 let err = resolve_repo_path("/definitely/does/not/exist/xyzzy").unwrap_err();
1430 assert_eq!(err, exit::NOINPUT);
1431 }
1432
1433 #[test]
1434 fn resolve_repo_path_rejects_non_repo_dir() {
1435 let td = tempfile::tempdir().unwrap();
1436 let err = resolve_repo_path(td.path().to_str().unwrap()).unwrap_err();
1437 assert_eq!(err, exit::DATAERR);
1438 }
1439
1440 #[test]
1441 fn resolve_repo_path_accepts_repo_dir() {
1442 let td = tempfile::tempdir().unwrap();
1443 fs::create_dir_all(td.path().join(".mkit")).unwrap();
1444 let resolved = resolve_repo_path(td.path().to_str().unwrap()).unwrap();
1445 assert!(resolved.join(".mkit").is_dir());
1446 }
1447
1448 #[test]
1449 fn upload_drain_accepts_valid_chunks() {
1450 let (bytes, key) = valid_pack();
1451 let mut drain = UploadDrain::new(&upload_header(
1452 key.as_bytes().to_vec(),
1453 Some(bytes.len() as u64),
1454 ))
1455 .unwrap();
1456 assert!(
1457 !drain
1458 .push_chunk(&upload_chunk(
1459 key.as_bytes().to_vec(),
1460 Some(0),
1461 &bytes[..5],
1462 false
1463 ))
1464 .unwrap()
1465 );
1466 assert!(
1467 drain
1468 .push_chunk(&upload_chunk(
1469 key.as_bytes().to_vec(),
1470 Some(5),
1471 &bytes[5..],
1472 true,
1473 ))
1474 .unwrap()
1475 );
1476 let (got, got_key) = drain.into_parts();
1477 assert_eq!(got, bytes);
1478 assert_eq!(got_key.as_bytes(), key.as_bytes());
1479 }
1480
1481 #[test]
1482 fn upload_drain_rejects_malformed_streams() {
1483 let (bytes, key) = valid_pack();
1484 assert!(UploadDrain::new(&upload_header(key.as_bytes().to_vec(), None)).is_err());
1485 assert!(
1486 UploadDrain::new(&upload_header(
1487 key.as_bytes().to_vec(),
1488 Some(MAX_BYTES_PER_CONN + 1),
1489 ))
1490 .is_err()
1491 );
1492
1493 let mut drain = UploadDrain::new(&upload_header(
1494 key.as_bytes().to_vec(),
1495 Some(bytes.len() as u64),
1496 ))
1497 .unwrap();
1498 assert!(
1499 drain
1500 .push_chunk(&upload_chunk(
1501 key.as_bytes().to_vec(),
1502 Some(1),
1503 &bytes,
1504 true
1505 ))
1506 .is_err()
1507 );
1508
1509 let mut drain = UploadDrain::new(&upload_header(
1510 key.as_bytes().to_vec(),
1511 Some(bytes.len() as u64),
1512 ))
1513 .unwrap();
1514 assert!(
1515 drain
1516 .push_chunk(&upload_chunk(vec![0xAA; 32], Some(0), &bytes, true))
1517 .is_err()
1518 );
1519
1520 let mut drain = UploadDrain::new(&upload_header(
1521 key.as_bytes().to_vec(),
1522 Some(bytes.len() as u64 - 1),
1523 ))
1524 .unwrap();
1525 assert!(
1526 drain
1527 .push_chunk(&upload_chunk(
1528 key.as_bytes().to_vec(),
1529 Some(0),
1530 &bytes,
1531 true
1532 ))
1533 .is_err()
1534 );
1535
1536 let mut drain = UploadDrain::new(&upload_header(
1537 key.as_bytes().to_vec(),
1538 Some(bytes.len() as u64),
1539 ))
1540 .unwrap();
1541 assert!(
1542 drain
1543 .push_chunk(&upload_chunk(
1544 key.as_bytes().to_vec(),
1545 Some(0),
1546 &bytes[..bytes.len() - 1],
1547 true,
1548 ))
1549 .is_err()
1550 );
1551
1552 let wrong_bytes = b"wrong pack bytes";
1553 let mut drain = UploadDrain::new(&upload_header(
1554 key.as_bytes().to_vec(),
1555 Some(wrong_bytes.len() as u64),
1556 ))
1557 .unwrap();
1558 assert!(
1559 drain
1560 .push_chunk(&upload_chunk(
1561 key.as_bytes().to_vec(),
1562 Some(0),
1563 wrong_bytes,
1564 true,
1565 ))
1566 .is_err()
1567 );
1568 }
1569
1570 fn write_body(buf: &mut Vec<u8>, body: ssh_frame::Body) {
1571 mkit_rpc::write_frame(
1572 buf,
1573 &SshFrame {
1574 body: Some(body),
1575 ..Default::default()
1576 },
1577 )
1578 .unwrap();
1579 }
1580
1581 #[test]
1582 fn serve_loop_rejects_invalid_upload_before_storage() {
1583 let td = tempfile::tempdir().unwrap();
1584 let tx = FileTransport::new(td.path());
1585 let bogus_key = PackKey::new([0x77; 32]);
1586
1587 let mut input = Vec::new();
1588 write_body(
1589 &mut input,
1590 ssh_frame::Body::Hello(Box::new(
1591 mkit_rpc::mkit::rpc::v1::ssh::Hello::default()
1592 .with_proto(ProtocolVersion::ProtocolVersion1),
1593 )),
1594 );
1595 write_body(
1596 &mut input,
1597 ssh_frame::Body::UploadPack(Box::new(upload_header(
1598 bogus_key.as_bytes().to_vec(),
1599 Some(5),
1600 ))),
1601 );
1602 write_body(
1603 &mut input,
1604 ssh_frame::Body::PackChunk(Box::new(upload_chunk(
1605 bogus_key.as_bytes().to_vec(),
1606 Some(0),
1607 b"wrong",
1608 true,
1609 ))),
1610 );
1611
1612 let mut reader = Cursor::new(input);
1613 let mut output = Vec::new();
1614 assert_eq!(serve_loop(&tx, &mut reader, &mut output), exit::OK);
1615 assert!(!tx.pack_exists(&bogus_key).unwrap());
1616
1617 let mut out = Cursor::new(output);
1618 let _hello: SshFrame = mkit_rpc::read_frame(&mut out).unwrap();
1619 let err: SshFrame = mkit_rpc::read_frame(&mut out).unwrap();
1620 assert!(matches!(err.body, Some(ssh_frame::Body::Error(_))));
1621 }
1622
1623 #[test]
1624 fn serve_loop_rejected_upload_does_not_overwrite_existing_pack() {
1625 let td = tempfile::tempdir().unwrap();
1626 let tx = FileTransport::new(td.path());
1627 let (bytes, key) = valid_pack();
1628 tx.upload_pack(&bytes, &key).unwrap();
1629
1630 let mut input = Vec::new();
1631 write_body(
1632 &mut input,
1633 ssh_frame::Body::Hello(Box::new(
1634 mkit_rpc::mkit::rpc::v1::ssh::Hello::default()
1635 .with_proto(ProtocolVersion::ProtocolVersion1),
1636 )),
1637 );
1638 write_body(
1639 &mut input,
1640 ssh_frame::Body::UploadPack(Box::new(upload_header(key.as_bytes().to_vec(), Some(5)))),
1641 );
1642 write_body(
1643 &mut input,
1644 ssh_frame::Body::PackChunk(Box::new(upload_chunk(
1645 key.as_bytes().to_vec(),
1646 Some(0),
1647 b"wrong",
1648 true,
1649 ))),
1650 );
1651
1652 let mut reader = Cursor::new(input);
1653 let mut output = Vec::new();
1654 assert_eq!(serve_loop(&tx, &mut reader, &mut output), exit::OK);
1655 assert_eq!(tx.download_pack(&key).unwrap(), bytes);
1656 }
1657
1658 #[cfg(feature = "enc-transport")]
1659 #[test]
1660 fn listen_enc_rejected_upload_does_not_overwrite_existing_pack() {
1661 use commonware_codec::Encode as _;
1662 use commonware_cryptography::Signer as _;
1663 use commonware_cryptography::ed25519::PrivateKey;
1664 use mkit_transport_enc::tcp::{
1665 TokioExecutor, connect_tcp_with_executor, serve_tcp_with_addr,
1666 };
1667 use std::sync::{Arc, mpsc};
1668 use std::thread;
1669 use std::time::Duration;
1670
1671 let td = tempfile::tempdir().unwrap();
1672 let tx = FileTransport::new(td.path());
1673 let (bytes, key) = valid_pack();
1674 tx.upload_pack(&bytes, &key).unwrap();
1675
1676 let exec = TokioExecutor::new().expect("tokio runtime");
1677 let server_key = PrivateKey::from_seed(1001);
1678 let server_pubkey = {
1679 let encoded = server_key.public_key().encode();
1680 let bytes = encoded.as_ref();
1681 assert_eq!(bytes.len(), 32);
1682 let mut out = [0u8; 32];
1683 out.copy_from_slice(bytes);
1684 out
1685 };
1686
1687 let server_tx = Arc::new(FileTransport::new(td.path()));
1688 let (addr_tx, addr_rx) = mpsc::channel();
1689 let exec_for_server = exec.clone();
1690 let _server_handle = thread::spawn(move || {
1691 let serve_fn =
1692 move |sess: mkit_transport_enc::EncSession<
1693 mkit_transport_enc::tokio_io::TokioStream,
1694 mkit_transport_enc::tokio_io::TokioSink,
1695 >,
1696 _peer: commonware_cryptography::ed25519::PublicKey| {
1697 let tx = server_tx.clone();
1698 async move {
1701 serve_enc_session(tx, sess, Some(std::time::Duration::from_secs(30))).await;
1702 }
1703 };
1704 let _ = serve_tcp_with_addr(
1705 "127.0.0.1:0",
1706 server_key,
1707 exec_for_server,
1708 move |addr| {
1709 let _ = addr_tx.send(addr);
1710 },
1711 serve_fn,
1712 );
1713 });
1714
1715 let addr = addr_rx
1716 .recv_timeout(Duration::from_secs(10))
1717 .expect("encrypted listener address");
1718 let client_key = PrivateKey::from_seed(2002);
1719 let client = connect_tcp_with_executor(
1720 &addr.ip().to_string(),
1721 addr.port(),
1722 &server_pubkey,
1723 client_key,
1724 exec,
1725 )
1726 .expect("connect encrypted client");
1727
1728 assert!(client.upload_pack(b"wrong", &key).is_err());
1729 assert_eq!(tx.download_pack(&key).unwrap(), bytes);
1730 }
1731
1732 #[cfg(feature = "enc-transport")]
1738 fn enc_repo() -> tempfile::TempDir {
1739 let td = tempfile::tempdir().unwrap();
1740 fs::create_dir_all(td.path().join(".mkit")).unwrap();
1741 td
1742 }
1743
1744 #[cfg(feature = "enc-transport")]
1747 #[test]
1748 fn listen_enc_fails_closed_without_peer_auth() {
1749 let td = enc_repo();
1750 let args = vec![
1751 td.path().to_str().unwrap().to_string(),
1752 "--listen-enc".to_string(),
1753 "127.0.0.1:0".to_string(),
1754 ];
1755 assert_eq!(run(&args), exit::CONFIG_ERROR);
1756 }
1757
1758 #[cfg(feature = "enc-transport")]
1761 #[test]
1762 fn listen_enc_rejects_empty_authorized_peers() {
1763 let td = enc_repo();
1764 let peers = td.path().join("peers.txt");
1765 fs::write(&peers, "# only comments\n\n").unwrap();
1766 let args = vec![
1767 td.path().to_str().unwrap().to_string(),
1768 "--listen-enc".to_string(),
1769 "127.0.0.1:0".to_string(),
1770 "--enc-authorized-peers".to_string(),
1771 peers.to_str().unwrap().to_string(),
1772 ];
1773 assert_eq!(run(&args), exit::CONFIG_ERROR);
1774 }
1775
1776 #[cfg(feature = "enc-transport")]
1779 #[test]
1780 fn listen_enc_rejects_conflicting_flags() {
1781 let td = enc_repo();
1782 let peers = td.path().join("peers.txt");
1783 fs::write(&peers, format!("{}\n", "aa".repeat(32))).unwrap();
1784 let args = vec![
1785 td.path().to_str().unwrap().to_string(),
1786 "--listen-enc".to_string(),
1787 "127.0.0.1:0".to_string(),
1788 "--enc-authorized-peers".to_string(),
1789 peers.to_str().unwrap().to_string(),
1790 "--unsafe-allow-any-enc-peer".to_string(),
1791 ];
1792 assert_eq!(run(&args), exit::USAGE);
1793 }
1794
1795 #[cfg(feature = "enc-transport")]
1796 #[test]
1797 fn authorized_peers_parses_hex_and_skips_comments() {
1798 let td = tempfile::tempdir().unwrap();
1799 let peers = td.path().join("peers.txt");
1800 let k1 = "aa".repeat(32);
1801 let k2 = "bb".repeat(32);
1802 fs::write(&peers, format!("# header\n{k1}\n\n {k2} \n")).unwrap();
1803 let set = load_authorized_peers(peers.to_str().unwrap()).unwrap();
1804 assert_eq!(set.len(), 2);
1805 assert!(set.contains(&[0xAA; 32]));
1806 assert!(set.contains(&[0xBB; 32]));
1807 }
1808
1809 #[cfg(feature = "enc-transport")]
1815 #[test]
1816 #[allow(clippy::cast_possible_truncation)] fn authorized_peers_parses_base64_matching_hex() {
1818 use commonware_codec::Encode as _;
1823 use commonware_cryptography::Signer as _;
1824 use commonware_cryptography::ed25519::PrivateKey;
1825
1826 let pk = PrivateKey::from_seed(4242).public_key();
1827 let raw: [u8; 32] = {
1828 let enc = pk.encode();
1829 let mut out = [0u8; 32];
1830 out.copy_from_slice(enc.as_ref());
1831 out
1832 };
1833 let hex = mkit_core::hash::to_hex(&raw);
1834 let b64 = {
1837 const A: &[u8; 64] =
1838 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
1839 let mut s = String::new();
1840 for chunk in raw.chunks(3) {
1841 let b0 = u32::from(chunk[0]);
1842 let b1 = chunk.get(1).copied().map_or(0, u32::from);
1843 let b2 = chunk.get(2).copied().map_or(0, u32::from);
1844 let n = (b0 << 16) | (b1 << 8) | b2;
1845 let chars = match chunk.len() {
1846 1 => 2,
1847 2 => 3,
1848 _ => 4,
1849 };
1850 for i in 0..chars {
1851 let idx = ((n >> (18 - 6 * i)) & 0x3F) as usize;
1852 s.push(A[idx] as char);
1853 }
1854 }
1855 s
1856 };
1857 assert_eq!(b64.len(), 43, "ed25519 key encodes to 43 b64 chars");
1858
1859 let td = tempfile::tempdir().unwrap();
1860 let peers = td.path().join("peers.txt");
1861 fs::write(&peers, format!("{hex}\n{b64}\n")).unwrap();
1862 let set = load_authorized_peers(peers.to_str().unwrap()).unwrap();
1863 assert_eq!(set.len(), 1, "hex and base64 forms must coincide");
1865 assert!(set.contains(&raw));
1866 }
1867
1868 #[cfg(feature = "enc-transport")]
1869 #[test]
1870 fn authorized_peers_rejects_malformed_key() {
1871 let td = tempfile::tempdir().unwrap();
1872 let peers = td.path().join("peers.txt");
1873 fs::write(&peers, "not-a-valid-key\n").unwrap();
1874 assert!(load_authorized_peers(peers.to_str().unwrap()).is_err());
1875 }
1876}