use crate::did::{decode_did_key, ed25519_priv_to_x25519, ed25519_pub_to_x25519};
use crate::envelope::{decode as decode_env, encode as encode_env, Envelope};
use crate::error::{Error, Result};
use crate::frame::FrameCipher;
use crate::noise::{build_prologue, ResponderHandshake};
use crate::session::{Role, Session, SessionTransport};
use futures_util::{SinkExt, StreamExt};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::tungstenite::handshake::server::{Request, Response};
use tokio_tungstenite::tungstenite::protocol::Message;
pub use crate::session::{Handler, HandlerOutput};
pub struct ServerOptions {
pub did: String,
pub private_key: [u8; 32],
pub handlers: HashMap<String, Handler>,
}
pub struct Server {
opts: Arc<ServerOptions>,
listener: Option<TcpListener>,
bound: Option<SocketAddr>,
accept_handle: Option<tokio::task::JoinHandle<()>>,
shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
}
impl Server {
pub fn new(opts: ServerOptions) -> Self {
Self {
opts: Arc::new(opts),
listener: None,
bound: None,
accept_handle: None,
shutdown_tx: None,
}
}
pub async fn listen(&mut self, port: u16, hostname: &str) -> Result<()> {
let addr: SocketAddr = format!("{hostname}:{port}")
.parse()
.map_err(|e: std::net::AddrParseError| Error::Ws(format!("bad addr: {e}")))?;
let listener = TcpListener::bind(addr).await?;
let bound = listener.local_addr()?;
let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
let opts = self.opts.clone();
let accept_handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() { return; }
}
accept = listener.accept() => {
let (stream, _peer) = match accept { Ok(t) => t, Err(_) => continue };
let opts = opts.clone();
tokio::spawn(async move {
let _ = handle_connection(opts, stream).await;
});
}
}
}
});
self.bound = Some(bound);
self.accept_handle = Some(accept_handle);
self.shutdown_tx = Some(shutdown_tx);
Ok(())
}
pub fn address(&self) -> Option<SocketAddr> {
self.bound
}
pub async fn close(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(true);
}
if let Some(handle) = self.accept_handle.take() {
handle.abort();
let _ = handle.await;
}
let _ = self.listener.take();
}
}
pub fn create_server(opts: ServerOptions) -> Server {
Server::new(opts)
}
async fn handle_connection(opts: Arc<ServerOptions>, tcp: tokio::net::TcpStream) -> Result<()> {
let caller_did_slot: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let cap = caller_did_slot.clone();
let callback = move |req: &Request, mut resp: Response| {
if let Some(query) = req.uri().query() {
for kv in query.split('&') {
if let Some(stripped) = kv.strip_prefix("caller=") {
let _ = urlencoding::decode(stripped).map(|s| {
if let Ok(mut g) = cap.try_lock() {
*g = Some(s.into_owned());
}
});
}
}
}
resp.headers_mut()
.insert("Sec-WebSocket-Protocol", "agent-phone.v1".parse().unwrap());
Ok(resp)
};
let ws_stream = tokio_tungstenite::accept_hdr_async(tcp, callback)
.await
.map_err(|e| Error::Ws(format!("accept: {e}")))?;
let caller_did = caller_did_slot.lock().await.clone();
let Some(caller_did) = caller_did else {
return Err(Error::Handshake("missing ?caller=<did>".into()));
};
let static_priv = ed25519_priv_to_x25519(&opts.private_key);
let static_pub = ed25519_pub_to_x25519(&decode_did_key(&opts.did)?);
let (mut ws_write, mut ws_read) = ws_stream.split();
let mut hs = ResponderHandshake::new(
&build_prologue(&caller_did, &opts.did),
static_priv,
static_pub,
);
let m1 = match ws_read.next().await {
Some(Ok(Message::Binary(b))) => b,
_ => return Err(Error::Handshake("expected binary m1".into())),
};
hs.read_message_1(&m1)?;
let m2 = hs.write_message_2()?;
ws_write
.send(Message::Binary(m2))
.await
.map_err(|e| Error::Ws(format!("send m2: {e}")))?;
let m3 = match ws_read.next().await {
Some(Ok(Message::Binary(b))) => b,
_ => return Err(Error::Handshake("expected binary m3".into())),
};
hs.read_message_3(&m3)?;
let transport = hs.finish();
let cipher = Arc::new(Mutex::new(FrameCipher::new(transport)));
let (env_tx, mut env_rx) = mpsc::unbounded_channel::<Envelope>();
let session = Session::new(SessionTransport { tx: env_tx }, Role::Responder);
for (m, h) in opts.handlers.iter() {
session.handle(m.clone(), h.clone()).await;
}
let writer_cipher = cipher.clone();
let writer = tokio::spawn(async move {
while let Some(env) = env_rx.recv().await {
let pt = match encode_env(&env) {
Ok(b) => b,
Err(_) => continue,
};
let ct = {
let mut c = writer_cipher.lock().await;
match c.seal(&pt) {
Ok(b) => b,
Err(_) => continue,
}
};
if ws_write.send(Message::Binary(ct)).await.is_err() {
return;
}
}
});
let reader_cipher = cipher.clone();
let session_for_reader = session.clone();
while let Some(msg) = ws_read.next().await {
let Ok(msg) = msg else { break };
let bytes = match msg {
Message::Binary(b) => b,
Message::Close(_) => break,
_ => continue,
};
let pt = {
let mut c = reader_cipher.lock().await;
match c.open(&bytes) {
Ok(b) => b,
Err(_) => break,
}
};
let env = match decode_env(&pt) {
Ok(e) => e,
Err(_) => continue,
};
let _ = session_for_reader.dispatch(env).await;
}
writer.abort();
Ok(())
}