Skip to main content

agent_phone/
server.rs

1//! WebSocket server: accept, do Noise XK handshake, host a Session per peer.
2
3use crate::did::{decode_did_key, ed25519_priv_to_x25519, ed25519_pub_to_x25519};
4use crate::envelope::{decode as decode_env, encode as encode_env, Envelope};
5use crate::error::{Error, Result};
6use crate::frame::FrameCipher;
7use crate::noise::{build_prologue, ResponderHandshake};
8use crate::session::{Role, Session, SessionTransport};
9use futures_util::{SinkExt, StreamExt};
10use std::collections::HashMap;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tokio::net::TcpListener;
14use tokio::sync::{mpsc, Mutex};
15use tokio_tungstenite::tungstenite::handshake::server::{Request, Response};
16use tokio_tungstenite::tungstenite::protocol::Message;
17
18pub use crate::session::{Handler, HandlerOutput};
19
20pub struct ServerOptions {
21    pub did: String,
22    pub private_key: [u8; 32],
23    pub handlers: HashMap<String, Handler>,
24}
25
26pub struct Server {
27    opts: Arc<ServerOptions>,
28    listener: Option<TcpListener>,
29    bound: Option<SocketAddr>,
30    accept_handle: Option<tokio::task::JoinHandle<()>>,
31    shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
32}
33
34impl Server {
35    pub fn new(opts: ServerOptions) -> Self {
36        Self {
37            opts: Arc::new(opts),
38            listener: None,
39            bound: None,
40            accept_handle: None,
41            shutdown_tx: None,
42        }
43    }
44
45    pub async fn listen(&mut self, port: u16, hostname: &str) -> Result<()> {
46        let addr: SocketAddr = format!("{hostname}:{port}")
47            .parse()
48            .map_err(|e: std::net::AddrParseError| Error::Ws(format!("bad addr: {e}")))?;
49        let listener = TcpListener::bind(addr).await?;
50        let bound = listener.local_addr()?;
51        let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
52        let opts = self.opts.clone();
53        let accept_handle = tokio::spawn(async move {
54            loop {
55                tokio::select! {
56                    _ = shutdown_rx.changed() => {
57                        if *shutdown_rx.borrow() { return; }
58                    }
59                    accept = listener.accept() => {
60                        let (stream, _peer) = match accept { Ok(t) => t, Err(_) => continue };
61                        let opts = opts.clone();
62                        tokio::spawn(async move {
63                            let _ = handle_connection(opts, stream).await;
64                        });
65                    }
66                }
67            }
68        });
69        self.bound = Some(bound);
70        self.accept_handle = Some(accept_handle);
71        self.shutdown_tx = Some(shutdown_tx);
72        Ok(())
73    }
74
75    pub fn address(&self) -> Option<SocketAddr> {
76        self.bound
77    }
78
79    pub async fn close(&mut self) {
80        if let Some(tx) = self.shutdown_tx.take() {
81            let _ = tx.send(true);
82        }
83        if let Some(handle) = self.accept_handle.take() {
84            handle.abort();
85            let _ = handle.await;
86        }
87        // Drop listener
88        let _ = self.listener.take();
89    }
90}
91
92pub fn create_server(opts: ServerOptions) -> Server {
93    Server::new(opts)
94}
95
96async fn handle_connection(opts: Arc<ServerOptions>, tcp: tokio::net::TcpStream) -> Result<()> {
97    let caller_did_slot: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
98    let cap = caller_did_slot.clone();
99    let callback = move |req: &Request, mut resp: Response| {
100        // Extract caller from query.
101        if let Some(query) = req.uri().query() {
102            for kv in query.split('&') {
103                if let Some(stripped) = kv.strip_prefix("caller=") {
104                    let _ = urlencoding::decode(stripped).map(|s| {
105                        if let Ok(mut g) = cap.try_lock() {
106                            *g = Some(s.into_owned());
107                        }
108                    });
109                }
110            }
111        }
112        resp.headers_mut()
113            .insert("Sec-WebSocket-Protocol", "agent-phone.v1".parse().unwrap());
114        Ok(resp)
115    };
116    let ws_stream = tokio_tungstenite::accept_hdr_async(tcp, callback)
117        .await
118        .map_err(|e| Error::Ws(format!("accept: {e}")))?;
119    let caller_did = caller_did_slot.lock().await.clone();
120    let Some(caller_did) = caller_did else {
121        return Err(Error::Handshake("missing ?caller=<did>".into()));
122    };
123
124    let static_priv = ed25519_priv_to_x25519(&opts.private_key);
125    let static_pub = ed25519_pub_to_x25519(&decode_did_key(&opts.did)?);
126
127    let (mut ws_write, mut ws_read) = ws_stream.split();
128
129    let mut hs = ResponderHandshake::new(
130        &build_prologue(&caller_did, &opts.did),
131        static_priv,
132        static_pub,
133    );
134
135    // <- m1
136    let m1 = match ws_read.next().await {
137        Some(Ok(Message::Binary(b))) => b,
138        _ => return Err(Error::Handshake("expected binary m1".into())),
139    };
140    hs.read_message_1(&m1)?;
141    // -> m2
142    let m2 = hs.write_message_2()?;
143    ws_write
144        .send(Message::Binary(m2))
145        .await
146        .map_err(|e| Error::Ws(format!("send m2: {e}")))?;
147    // <- m3
148    let m3 = match ws_read.next().await {
149        Some(Ok(Message::Binary(b))) => b,
150        _ => return Err(Error::Handshake("expected binary m3".into())),
151    };
152    hs.read_message_3(&m3)?;
153
154    let transport = hs.finish();
155    let cipher = Arc::new(Mutex::new(FrameCipher::new(transport)));
156
157    let (env_tx, mut env_rx) = mpsc::unbounded_channel::<Envelope>();
158    let session = Session::new(SessionTransport { tx: env_tx }, Role::Responder);
159    for (m, h) in opts.handlers.iter() {
160        session.handle(m.clone(), h.clone()).await;
161    }
162
163    let writer_cipher = cipher.clone();
164    let writer = tokio::spawn(async move {
165        while let Some(env) = env_rx.recv().await {
166            let pt = match encode_env(&env) {
167                Ok(b) => b,
168                Err(_) => continue,
169            };
170            let ct = {
171                let mut c = writer_cipher.lock().await;
172                match c.seal(&pt) {
173                    Ok(b) => b,
174                    Err(_) => continue,
175                }
176            };
177            if ws_write.send(Message::Binary(ct)).await.is_err() {
178                return;
179            }
180        }
181    });
182
183    let reader_cipher = cipher.clone();
184    let session_for_reader = session.clone();
185    while let Some(msg) = ws_read.next().await {
186        let Ok(msg) = msg else { break };
187        let bytes = match msg {
188            Message::Binary(b) => b,
189            Message::Close(_) => break,
190            _ => continue,
191        };
192        let pt = {
193            let mut c = reader_cipher.lock().await;
194            match c.open(&bytes) {
195                Ok(b) => b,
196                Err(_) => break,
197            }
198        };
199        let env = match decode_env(&pt) {
200            Ok(e) => e,
201            Err(_) => continue,
202        };
203        let _ = session_for_reader.dispatch(env).await;
204    }
205    writer.abort();
206    Ok(())
207}