use crate::did::{decode_did_key, ed25519_priv_to_x25519, ed25519_pub_to_x25519};
use crate::envelope::{decode as decode_env, encode as encode_env, Envelope};
use crate::error::{Error, Result};
use crate::frame::FrameCipher;
use crate::noise::{build_prologue, InitiatorHandshake};
use crate::session::{ClientStream, Role, Session, SessionTransport};
use futures_util::{SinkExt, StreamExt};
use serde_json::Value;
use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::tungstenite::protocol::Message;
use url::Url;
#[derive(Clone)]
pub struct ClientOptions {
pub url: String,
pub did: String,
pub private_key: [u8; 32],
pub responder_did: String,
pub responder_public_key: Option<[u8; 32]>,
}
pub struct Client {
pub session: Session,
pub _writer_task: tokio::task::JoinHandle<()>,
pub _reader_task: tokio::task::JoinHandle<()>,
pub _close_tx: tokio::sync::oneshot::Sender<()>,
}
impl Client {
pub async fn call(&self, method: &str, params: Option<Value>) -> Result<Value> {
self.session.call(method, params).await
}
pub async fn stream(
&self,
method: &str,
params: Option<Value>,
credits: u64,
) -> Result<ClientStream> {
self.session.stream(method, params, credits).await
}
pub async fn close(self) {
let _ = self._close_tx.send(());
let _ = self._writer_task.await;
let _ = self._reader_task.await;
}
}
pub async fn connect(opts: ClientOptions) -> Result<Client> {
let responder_static_pub = match opts.responder_public_key {
Some(p) => p,
None => ed25519_pub_to_x25519(&decode_did_key(&opts.responder_did)?),
};
let static_priv = ed25519_priv_to_x25519(&opts.private_key);
let static_pub = ed25519_pub_to_x25519(&decode_did_key(&opts.did)?);
let mut url = Url::parse(&opts.url).map_err(|e| Error::Ws(format!("bad url: {e}")))?;
url.query_pairs_mut().append_pair("caller", &opts.did);
let req = tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(
url.as_str(),
)
.map_err(|e| Error::Ws(format!("bad request: {e}")))?;
let mut req = req;
req.headers_mut()
.insert("Sec-WebSocket-Protocol", "agent-phone.v1".parse().unwrap());
let (ws, _resp) = tokio_tungstenite::connect_async(req)
.await
.map_err(|e| Error::Ws(format!("connect: {e}")))?;
let (mut ws_write, mut ws_read) = ws.split();
let mut hs = InitiatorHandshake::new(
&build_prologue(&opts.did, &opts.responder_did),
static_priv,
static_pub,
responder_static_pub,
);
let m1 = hs.write_message_1()?;
ws_write
.send(Message::Binary(m1))
.await
.map_err(|e| Error::Ws(format!("send m1: {e}")))?;
let m2_msg = tokio::time::timeout(std::time::Duration::from_secs(1), ws_read.next())
.await
.map_err(|_| {
Error::Handshake(format!(
"handshake failed before message 2 (timeout). \
Most likely cause: the server at {} does not hold the static key \
pinned by {}.",
opts.url, opts.responder_did
))
})?
.ok_or_else(|| Error::Handshake("ws closed before message 2".into()))?
.map_err(|e| Error::Ws(format!("recv m2: {e}")))?;
let m2 = match m2_msg {
Message::Binary(b) => b,
_ => return Err(Error::Handshake("non-binary frame during handshake".into())),
};
hs.read_message_2(&m2).map_err(|_| {
Error::Handshake(format!(
"message 2 AEAD failed. Responder's advertised static does not match {}.",
opts.responder_did
))
})?;
let m3 = hs.write_message_3()?;
ws_write
.send(Message::Binary(m3))
.await
.map_err(|e| Error::Ws(format!("send m3: {e}")))?;
let transport = hs.finish();
let cipher = std::sync::Arc::new(Mutex::new(FrameCipher::new(transport)));
let (env_tx, mut env_rx) = mpsc::unbounded_channel::<Envelope>();
let session = Session::new(SessionTransport { tx: env_tx }, Role::Initiator);
let writer_cipher = cipher.clone();
let (close_tx, mut close_rx) = tokio::sync::oneshot::channel::<()>();
let writer_task = tokio::spawn(async move {
loop {
tokio::select! {
_ = &mut close_rx => {
let _ = ws_write.send(Message::Close(None)).await;
return;
}
msg = env_rx.recv() => {
let Some(env) = msg else { return };
let pt = match encode_env(&env) { Ok(b) => b, Err(_) => continue };
let ct = {
let mut c = writer_cipher.lock().await;
match c.seal(&pt) { Ok(b) => b, Err(_) => continue }
};
if ws_write.send(Message::Binary(ct)).await.is_err() { return; }
}
}
}
});
let reader_cipher = cipher.clone();
let session_for_reader = session.clone();
let reader_task = tokio::spawn(async move {
while let Some(msg) = ws_read.next().await {
let Ok(msg) = msg else { return };
let bytes = match msg {
Message::Binary(b) => b,
Message::Close(_) => return,
_ => continue,
};
let pt = {
let mut c = reader_cipher.lock().await;
match c.open(&bytes) {
Ok(b) => b,
Err(_) => return,
}
};
let env = match decode_env(&pt) {
Ok(e) => e,
Err(_) => continue,
};
if session_for_reader.dispatch(env).await.is_err() {
return;
}
}
});
Ok(Client {
session,
_writer_task: writer_task,
_reader_task: reader_task,
_close_tx: close_tx,
})
}