Skip to main content

agent_phone/
client.rs

1//! WebSocket client: dial, Noise XK handshake, expose a Session.
2
3use crate::did::{decode_did_key, ed25519_priv_to_x25519, ed25519_pub_to_x25519};
4use crate::envelope::{decode as decode_env, encode as encode_env, Envelope};
5use crate::error::{Error, Result};
6use crate::frame::FrameCipher;
7use crate::noise::{build_prologue, InitiatorHandshake};
8use crate::session::{ClientStream, Role, Session, SessionTransport};
9use futures_util::{SinkExt, StreamExt};
10use serde_json::Value;
11use tokio::sync::{mpsc, Mutex};
12use tokio_tungstenite::tungstenite::protocol::Message;
13use url::Url;
14
15#[derive(Clone)]
16pub struct ClientOptions {
17    pub url: String,
18    pub did: String,
19    pub private_key: [u8; 32],
20    pub responder_did: String,
21    pub responder_public_key: Option<[u8; 32]>,
22}
23
24pub struct Client {
25    pub session: Session,
26    pub _writer_task: tokio::task::JoinHandle<()>,
27    pub _reader_task: tokio::task::JoinHandle<()>,
28    pub _close_tx: tokio::sync::oneshot::Sender<()>,
29}
30
31impl Client {
32    pub async fn call(&self, method: &str, params: Option<Value>) -> Result<Value> {
33        self.session.call(method, params).await
34    }
35    pub async fn stream(
36        &self,
37        method: &str,
38        params: Option<Value>,
39        credits: u64,
40    ) -> Result<ClientStream> {
41        self.session.stream(method, params, credits).await
42    }
43    pub async fn close(self) {
44        let _ = self._close_tx.send(());
45        let _ = self._writer_task.await;
46        let _ = self._reader_task.await;
47    }
48}
49
50pub async fn connect(opts: ClientOptions) -> Result<Client> {
51    let responder_static_pub = match opts.responder_public_key {
52        Some(p) => p,
53        None => ed25519_pub_to_x25519(&decode_did_key(&opts.responder_did)?),
54    };
55    let static_priv = ed25519_priv_to_x25519(&opts.private_key);
56    let static_pub = ed25519_pub_to_x25519(&decode_did_key(&opts.did)?);
57
58    let mut url = Url::parse(&opts.url).map_err(|e| Error::Ws(format!("bad url: {e}")))?;
59    url.query_pairs_mut().append_pair("caller", &opts.did);
60
61    let req = tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(
62        url.as_str(),
63    )
64    .map_err(|e| Error::Ws(format!("bad request: {e}")))?;
65    // Add subprotocol header.
66    let mut req = req;
67    req.headers_mut()
68        .insert("Sec-WebSocket-Protocol", "agent-phone.v1".parse().unwrap());
69
70    let (ws, _resp) = tokio_tungstenite::connect_async(req)
71        .await
72        .map_err(|e| Error::Ws(format!("connect: {e}")))?;
73    let (mut ws_write, mut ws_read) = ws.split();
74
75    let mut hs = InitiatorHandshake::new(
76        &build_prologue(&opts.did, &opts.responder_did),
77        static_priv,
78        static_pub,
79        responder_static_pub,
80    );
81
82    // -> m1
83    let m1 = hs.write_message_1()?;
84    ws_write
85        .send(Message::Binary(m1))
86        .await
87        .map_err(|e| Error::Ws(format!("send m1: {e}")))?;
88
89    // <- m2 (with timeout)
90    let m2_msg = tokio::time::timeout(std::time::Duration::from_secs(1), ws_read.next())
91        .await
92        .map_err(|_| {
93            Error::Handshake(format!(
94                "handshake failed before message 2 (timeout). \
95                 Most likely cause: the server at {} does not hold the static key \
96                 pinned by {}.",
97                opts.url, opts.responder_did
98            ))
99        })?
100        .ok_or_else(|| Error::Handshake("ws closed before message 2".into()))?
101        .map_err(|e| Error::Ws(format!("recv m2: {e}")))?;
102    let m2 = match m2_msg {
103        Message::Binary(b) => b,
104        _ => return Err(Error::Handshake("non-binary frame during handshake".into())),
105    };
106    hs.read_message_2(&m2).map_err(|_| {
107        Error::Handshake(format!(
108            "message 2 AEAD failed. Responder's advertised static does not match {}.",
109            opts.responder_did
110        ))
111    })?;
112
113    // -> m3
114    let m3 = hs.write_message_3()?;
115    ws_write
116        .send(Message::Binary(m3))
117        .await
118        .map_err(|e| Error::Ws(format!("send m3: {e}")))?;
119
120    let transport = hs.finish();
121    let cipher = std::sync::Arc::new(Mutex::new(FrameCipher::new(transport)));
122
123    // Build the session transport: a channel for envelopes the session wants
124    // to send → the writer task seals + writes them.
125    let (env_tx, mut env_rx) = mpsc::unbounded_channel::<Envelope>();
126    let session = Session::new(SessionTransport { tx: env_tx }, Role::Initiator);
127
128    let writer_cipher = cipher.clone();
129    let (close_tx, mut close_rx) = tokio::sync::oneshot::channel::<()>();
130
131    let writer_task = tokio::spawn(async move {
132        loop {
133            tokio::select! {
134                _ = &mut close_rx => {
135                    let _ = ws_write.send(Message::Close(None)).await;
136                    return;
137                }
138                msg = env_rx.recv() => {
139                    let Some(env) = msg else { return };
140                    let pt = match encode_env(&env) { Ok(b) => b, Err(_) => continue };
141                    let ct = {
142                        let mut c = writer_cipher.lock().await;
143                        match c.seal(&pt) { Ok(b) => b, Err(_) => continue }
144                    };
145                    if ws_write.send(Message::Binary(ct)).await.is_err() { return; }
146                }
147            }
148        }
149    });
150
151    let reader_cipher = cipher.clone();
152    let session_for_reader = session.clone();
153    let reader_task = tokio::spawn(async move {
154        while let Some(msg) = ws_read.next().await {
155            let Ok(msg) = msg else { return };
156            let bytes = match msg {
157                Message::Binary(b) => b,
158                Message::Close(_) => return,
159                _ => continue,
160            };
161            let pt = {
162                let mut c = reader_cipher.lock().await;
163                match c.open(&bytes) {
164                    Ok(b) => b,
165                    Err(_) => return,
166                }
167            };
168            let env = match decode_env(&pt) {
169                Ok(e) => e,
170                Err(_) => continue,
171            };
172            if session_for_reader.dispatch(env).await.is_err() {
173                return;
174            }
175        }
176    });
177
178    Ok(Client {
179        session,
180        _writer_task: writer_task,
181        _reader_task: reader_task,
182        _close_tx: close_tx,
183    })
184}