1use crate::did::{decode_did_key, ed25519_priv_to_x25519, ed25519_pub_to_x25519};
4use crate::envelope::{decode as decode_env, encode as encode_env, Envelope};
5use crate::error::{Error, Result};
6use crate::frame::FrameCipher;
7use crate::noise::{build_prologue, ResponderHandshake};
8use crate::session::{Role, Session, SessionTransport};
9use futures_util::{SinkExt, StreamExt};
10use std::collections::HashMap;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tokio::net::TcpListener;
14use tokio::sync::{mpsc, Mutex};
15use tokio_tungstenite::tungstenite::handshake::server::{Request, Response};
16use tokio_tungstenite::tungstenite::protocol::Message;
17
18pub use crate::session::{Handler, HandlerOutput};
19
20pub struct ServerOptions {
21 pub did: String,
22 pub private_key: [u8; 32],
23 pub handlers: HashMap<String, Handler>,
24}
25
26pub struct Server {
27 opts: Arc<ServerOptions>,
28 listener: Option<TcpListener>,
29 bound: Option<SocketAddr>,
30 accept_handle: Option<tokio::task::JoinHandle<()>>,
31 shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
32}
33
34impl Server {
35 pub fn new(opts: ServerOptions) -> Self {
36 Self {
37 opts: Arc::new(opts),
38 listener: None,
39 bound: None,
40 accept_handle: None,
41 shutdown_tx: None,
42 }
43 }
44
45 pub async fn listen(&mut self, port: u16, hostname: &str) -> Result<()> {
46 let addr: SocketAddr = format!("{hostname}:{port}")
47 .parse()
48 .map_err(|e: std::net::AddrParseError| Error::Ws(format!("bad addr: {e}")))?;
49 let listener = TcpListener::bind(addr).await?;
50 let bound = listener.local_addr()?;
51 let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
52 let opts = self.opts.clone();
53 let accept_handle = tokio::spawn(async move {
54 loop {
55 tokio::select! {
56 _ = shutdown_rx.changed() => {
57 if *shutdown_rx.borrow() { return; }
58 }
59 accept = listener.accept() => {
60 let (stream, _peer) = match accept { Ok(t) => t, Err(_) => continue };
61 let opts = opts.clone();
62 tokio::spawn(async move {
63 let _ = handle_connection(opts, stream).await;
64 });
65 }
66 }
67 }
68 });
69 self.bound = Some(bound);
70 self.accept_handle = Some(accept_handle);
71 self.shutdown_tx = Some(shutdown_tx);
72 Ok(())
73 }
74
75 pub fn address(&self) -> Option<SocketAddr> {
76 self.bound
77 }
78
79 pub async fn close(&mut self) {
80 if let Some(tx) = self.shutdown_tx.take() {
81 let _ = tx.send(true);
82 }
83 if let Some(handle) = self.accept_handle.take() {
84 handle.abort();
85 let _ = handle.await;
86 }
87 let _ = self.listener.take();
89 }
90}
91
92pub fn create_server(opts: ServerOptions) -> Server {
93 Server::new(opts)
94}
95
96async fn handle_connection(opts: Arc<ServerOptions>, tcp: tokio::net::TcpStream) -> Result<()> {
97 let caller_did_slot: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
98 let cap = caller_did_slot.clone();
99 let callback = move |req: &Request, mut resp: Response| {
100 if let Some(query) = req.uri().query() {
102 for kv in query.split('&') {
103 if let Some(stripped) = kv.strip_prefix("caller=") {
104 let _ = urlencoding::decode(stripped).map(|s| {
105 if let Ok(mut g) = cap.try_lock() {
106 *g = Some(s.into_owned());
107 }
108 });
109 }
110 }
111 }
112 resp.headers_mut()
113 .insert("Sec-WebSocket-Protocol", "agent-phone.v1".parse().unwrap());
114 Ok(resp)
115 };
116 let ws_stream = tokio_tungstenite::accept_hdr_async(tcp, callback)
117 .await
118 .map_err(|e| Error::Ws(format!("accept: {e}")))?;
119 let caller_did = caller_did_slot.lock().await.clone();
120 let Some(caller_did) = caller_did else {
121 return Err(Error::Handshake("missing ?caller=<did>".into()));
122 };
123
124 let static_priv = ed25519_priv_to_x25519(&opts.private_key);
125 let static_pub = ed25519_pub_to_x25519(&decode_did_key(&opts.did)?);
126
127 let (mut ws_write, mut ws_read) = ws_stream.split();
128
129 let mut hs = ResponderHandshake::new(
130 &build_prologue(&caller_did, &opts.did),
131 static_priv,
132 static_pub,
133 );
134
135 let m1 = match ws_read.next().await {
137 Some(Ok(Message::Binary(b))) => b,
138 _ => return Err(Error::Handshake("expected binary m1".into())),
139 };
140 hs.read_message_1(&m1)?;
141 let m2 = hs.write_message_2()?;
143 ws_write
144 .send(Message::Binary(m2))
145 .await
146 .map_err(|e| Error::Ws(format!("send m2: {e}")))?;
147 let m3 = match ws_read.next().await {
149 Some(Ok(Message::Binary(b))) => b,
150 _ => return Err(Error::Handshake("expected binary m3".into())),
151 };
152 hs.read_message_3(&m3)?;
153
154 let transport = hs.finish();
155 let cipher = Arc::new(Mutex::new(FrameCipher::new(transport)));
156
157 let (env_tx, mut env_rx) = mpsc::unbounded_channel::<Envelope>();
158 let session = Session::new(SessionTransport { tx: env_tx }, Role::Responder);
159 for (m, h) in opts.handlers.iter() {
160 session.handle(m.clone(), h.clone()).await;
161 }
162
163 let writer_cipher = cipher.clone();
164 let writer = tokio::spawn(async move {
165 while let Some(env) = env_rx.recv().await {
166 let pt = match encode_env(&env) {
167 Ok(b) => b,
168 Err(_) => continue,
169 };
170 let ct = {
171 let mut c = writer_cipher.lock().await;
172 match c.seal(&pt) {
173 Ok(b) => b,
174 Err(_) => continue,
175 }
176 };
177 if ws_write.send(Message::Binary(ct)).await.is_err() {
178 return;
179 }
180 }
181 });
182
183 let reader_cipher = cipher.clone();
184 let session_for_reader = session.clone();
185 while let Some(msg) = ws_read.next().await {
186 let Ok(msg) = msg else { break };
187 let bytes = match msg {
188 Message::Binary(b) => b,
189 Message::Close(_) => break,
190 _ => continue,
191 };
192 let pt = {
193 let mut c = reader_cipher.lock().await;
194 match c.open(&bytes) {
195 Ok(b) => b,
196 Err(_) => break,
197 }
198 };
199 let env = match decode_env(&pt) {
200 Ok(e) => e,
201 Err(_) => continue,
202 };
203 let _ = session_for_reader.dispatch(env).await;
204 }
205 writer.abort();
206 Ok(())
207}