#[cfg(feature = "automerge-backend")]
use crate::security::{
FormationAuthResult, FormationChallenge, FormationChallengeResponse, FormationKey,
FORMATION_CHALLENGE_SIZE, FORMATION_RESPONSE_SIZE,
};
#[cfg(feature = "automerge-backend")]
use anyhow::{Context, Result};
#[cfg(feature = "automerge-backend")]
use iroh::endpoint::Connection;
#[cfg(feature = "automerge-backend")]
pub const FORMATION_HANDSHAKE_ALPN: &[u8] = b"peat/formation-auth/1";
#[cfg(feature = "automerge-backend")]
const HANDSHAKE_TIMEOUT_SECS: u64 = 30;
#[cfg(feature = "automerge-backend")]
pub async fn perform_initiator_handshake(
connection: &Connection,
formation_key: &FormationKey,
) -> Result<()> {
use std::time::Duration;
use tokio::io::AsyncWriteExt;
let (mut send, mut recv) = tokio::time::timeout(
Duration::from_secs(HANDSHAKE_TIMEOUT_SECS),
connection.open_bi(),
)
.await
.map_err(|_| anyhow::anyhow!("Handshake stream open timeout"))?
.context("Failed to open handshake stream")?;
let formation_id_bytes = formation_key.formation_id().as_bytes();
let len = formation_id_bytes.len() as u16;
send.write_all(&len.to_le_bytes()).await?;
send.write_all(formation_id_bytes).await?;
send.flush().await?;
let mut id_len_buf = [0u8; 2];
tokio::time::timeout(
Duration::from_secs(HANDSHAKE_TIMEOUT_SECS),
recv.read_exact(&mut id_len_buf),
)
.await
.context("Challenge length receive timeout")?
.context("Failed to read challenge length")?;
let id_len = u16::from_le_bytes(id_len_buf) as usize;
let mut body_buf = vec![0u8; id_len + FORMATION_CHALLENGE_SIZE];
tokio::time::timeout(
Duration::from_secs(HANDSHAKE_TIMEOUT_SECS),
recv.read_exact(&mut body_buf),
)
.await
.context("Challenge body receive timeout")?
.context("Failed to read challenge body")?;
let mut challenge_buf = Vec::with_capacity(2 + body_buf.len());
challenge_buf.extend_from_slice(&id_len_buf);
challenge_buf.extend_from_slice(&body_buf);
let challenge = FormationChallenge::from_bytes(&challenge_buf)
.map_err(|e| anyhow::anyhow!("Invalid challenge: {}", e))?;
if challenge.formation_id != formation_key.formation_id() {
anyhow::bail!(
"Formation ID mismatch: expected '{}', got '{}'",
formation_key.formation_id(),
challenge.formation_id
);
}
let response_bytes = formation_key.respond_to_challenge(&challenge.nonce);
let response = FormationChallengeResponse {
response: response_bytes,
};
send.write_all(&response.to_bytes()).await?;
send.flush().await?;
let mut result_buf = [0u8; 1];
tokio::time::timeout(
Duration::from_secs(HANDSHAKE_TIMEOUT_SECS),
recv.read_exact(&mut result_buf),
)
.await
.context("Result receive timeout")?
.context("Failed to read result")?;
let result = FormationAuthResult::from_byte(result_buf[0]);
match result {
FormationAuthResult::Accepted => {
tracing::debug!(
"Formation handshake succeeded with {}",
formation_key.formation_id()
);
Ok(())
}
FormationAuthResult::Rejected => {
anyhow::bail!("Formation handshake rejected by peer")
}
}
}
#[cfg(feature = "automerge-backend")]
pub async fn perform_responder_handshake(
connection: &Connection,
formation_key: &FormationKey,
) -> Result<()> {
use std::time::Duration;
use tokio::io::AsyncWriteExt;
let (mut send, mut recv) = tokio::time::timeout(
Duration::from_secs(HANDSHAKE_TIMEOUT_SECS),
connection.accept_bi(),
)
.await
.map_err(|_| anyhow::anyhow!("Handshake stream accept timeout"))?
.context("Failed to accept handshake stream")?;
let mut len_buf = [0u8; 2];
recv.read_exact(&mut len_buf).await?;
let len = u16::from_le_bytes(len_buf) as usize;
let mut formation_id_buf = vec![0u8; len];
recv.read_exact(&mut formation_id_buf).await?;
let peer_formation_id = String::from_utf8(formation_id_buf)
.map_err(|e| anyhow::anyhow!("Invalid formation ID from peer: {}", e))?;
if peer_formation_id != formation_key.formation_id() {
tracing::warn!(
"Peer formation ID '{}' doesn't match ours '{}'",
peer_formation_id,
formation_key.formation_id()
);
}
let (nonce, _expected_response) = formation_key.create_challenge();
let challenge = FormationChallenge {
formation_id: formation_key.formation_id().to_string(),
nonce,
};
send.write_all(&challenge.to_bytes()).await?;
send.flush().await?;
let mut response_buf = [0u8; FORMATION_RESPONSE_SIZE];
tokio::time::timeout(
Duration::from_secs(HANDSHAKE_TIMEOUT_SECS),
recv.read_exact(&mut response_buf),
)
.await
.context("Response receive timeout")?
.context("Failed to read response")?;
let response = FormationChallengeResponse::from_bytes(&response_buf)
.map_err(|e| anyhow::anyhow!("Invalid response: {}", e))?;
let verified = formation_key.verify_response(&nonce, &response.response);
let result = if verified {
FormationAuthResult::Accepted
} else {
FormationAuthResult::Rejected
};
send.write_all(&[result.to_byte()]).await?;
send.flush().await?;
if verified {
tracing::debug!(
"Formation handshake verified for {}",
formation_key.formation_id()
);
Ok(())
} else {
anyhow::bail!("Formation handshake verification failed - peer has wrong key")
}
}
#[cfg(all(test, feature = "automerge-backend"))]
mod tests {
use super::*;
use crate::network::iroh_transport::IrohTransport;
use serial_test::serial;
use std::sync::Arc;
use tokio::sync::oneshot;
async fn run_handshake_test(
key1: FormationKey,
key2: FormationKey,
) -> (Result<()>, Result<()>) {
let transport1 = Arc::new(IrohTransport::new_local().await.unwrap());
let transport2 = Arc::new(IrohTransport::new_local().await.unwrap());
let t1_is_lower = transport1.endpoint_id().as_bytes() < transport2.endpoint_id().as_bytes();
let (initiator, responder, initiator_key, responder_key) = if t1_is_lower {
(transport1, transport2, key1, key2)
} else {
(transport2, transport1, key2, key1)
};
let responder_addr = responder.endpoint_addr();
let (ready_tx, ready_rx) = oneshot::channel::<()>();
let responder_clone = Arc::clone(&responder);
let responder_task = tokio::spawn(async move {
let _ = ready_tx.send(());
let conn = responder_clone
.accept()
.await
.unwrap()
.expect("Expected new connection, not duplicate");
perform_responder_handshake(&conn, &responder_key).await
});
ready_rx.await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let conn = initiator
.connect(responder_addr)
.await
.expect("Connection should succeed")
.expect("Should get new connection (not handled by accept)");
let initiator_result = perform_initiator_handshake(&conn, &initiator_key).await;
let responder_result = responder_task.await.unwrap();
let _ = initiator.close().await;
let _ = responder.close().await;
(initiator_result, responder_result)
}
#[tokio::test]
#[serial]
async fn test_formation_handshake_success() {
let secret = [0x42u8; 32];
let key1 = FormationKey::new("test-formation", &secret);
let key2 = FormationKey::new("test-formation", &secret);
let (initiator_result, responder_result) = run_handshake_test(key1, key2).await;
assert!(
initiator_result.is_ok(),
"Initiator failed: {:?}",
initiator_result
);
assert!(
responder_result.is_ok(),
"Responder failed: {:?}",
responder_result
);
}
#[tokio::test]
#[serial]
async fn test_formation_handshake_wrong_key() {
let key1 = FormationKey::new("test-formation", &[0x42u8; 32]);
let key2 = FormationKey::new("test-formation", &[0x43u8; 32]);
let (initiator_result, responder_result) = run_handshake_test(key1, key2).await;
assert!(responder_result.is_err());
assert!(initiator_result.is_err());
}
#[tokio::test]
#[serial]
async fn test_formation_handshake_wrong_formation_id() {
let secret = [0x42u8; 32];
let key1 = FormationKey::new("formation-alpha", &secret);
let key2 = FormationKey::new("formation-bravo", &secret);
let (initiator_result, _responder_result) = run_handshake_test(key1, key2).await;
assert!(initiator_result.is_err());
let err_msg = initiator_result.unwrap_err().to_string();
assert!(
err_msg.contains("Formation ID mismatch"),
"Expected 'Formation ID mismatch' but got: {}",
err_msg
);
}
}