1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, SystemTime, UNIX_EPOCH};
4
5use ap_noise::{InitiatorHandshake, MultiDeviceTransport, Psk};
6use ap_proxy_client::IncomingMessage;
7use ap_proxy_protocol::{IdentityFingerprint, RendezvousCode};
8use base64::{Engine, engine::general_purpose::STANDARD};
9use rand::RngCore;
10
11use crate::proxy::ProxyClient;
12use tokio::{
13 sync::{Mutex, mpsc, oneshot},
14 time::timeout,
15};
16use tracing::{debug, warn};
17
18use crate::traits::{IdentityProvider, SessionStore};
19use crate::{
20 error::RemoteClientError,
21 types::{
22 CredentialData, CredentialQuery, CredentialRequestPayload, CredentialResponsePayload,
23 ProtocolMessage, RemoteClientEvent, RemoteClientResponse,
24 },
25};
26
27const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
28
29type PendingRequestMap =
31 HashMap<String, oneshot::Sender<Result<CredentialData, RemoteClientError>>>;
32
33pub struct RemoteClient {
35 session_store: Box<dyn SessionStore>,
36 proxy_client: Box<dyn ProxyClient>,
37 incoming_rx: Option<mpsc::UnboundedReceiver<IncomingMessage>>,
38 transport: Option<Arc<Mutex<MultiDeviceTransport>>>,
39 remote_fingerprint: Option<IdentityFingerprint>,
40 pending_requests: Arc<Mutex<PendingRequestMap>>,
41 event_tx: mpsc::Sender<RemoteClientEvent>,
42 response_rx: Option<mpsc::Receiver<RemoteClientResponse>>,
43}
44
45impl RemoteClient {
46 pub async fn new(
61 identity_provider: Box<dyn IdentityProvider>,
62 session_store: Box<dyn SessionStore>,
63 event_tx: mpsc::Sender<RemoteClientEvent>,
64 response_rx: mpsc::Receiver<RemoteClientResponse>,
65 mut proxy_client: Box<dyn ProxyClient>,
66 ) -> Result<Self, RemoteClientError> {
67 let identity = identity_provider.identity().await;
68
69 debug!(
70 "Connecting to proxy with identity {:?}",
71 identity.identity().fingerprint()
72 );
73
74 event_tx
75 .send(RemoteClientEvent::Connecting {
76 proxy_url: String::new(),
77 })
78 .await
79 .ok();
80
81 let incoming_rx = proxy_client.connect().await?;
82
83 event_tx
84 .send(RemoteClientEvent::Connected {
85 fingerprint: identity.identity().fingerprint(),
86 })
87 .await
88 .ok();
89
90 debug!("Connected to proxy successfully");
91
92 Ok(Self {
93 session_store,
94 proxy_client,
95 incoming_rx: Some(incoming_rx),
96 transport: None,
97 remote_fingerprint: None,
98 pending_requests: Arc::new(Mutex::new(HashMap::new())),
99 event_tx,
100 response_rx: Some(response_rx),
101 })
102 }
103
104 pub async fn pair_with_handshake(
114 &mut self,
115 rendezvous_code: &str,
116 verify_fingerprint: bool,
117 ) -> Result<IdentityFingerprint, RemoteClientError> {
118 let incoming_rx =
119 self.incoming_rx
120 .as_mut()
121 .ok_or_else(|| RemoteClientError::InvalidState {
122 expected: "proxy connected".to_string(),
123 current: "not connected".to_string(),
124 })?;
125
126 let event_tx = self.event_tx.clone();
127
128 let response_rx = self
129 .response_rx
130 .as_mut()
131 .ok_or(RemoteClientError::NotInitialized)?;
132
133 event_tx
135 .send(RemoteClientEvent::RendezvousResolving {
136 code: rendezvous_code.to_string(),
137 })
138 .await
139 .ok();
140
141 let remote_fingerprint =
142 Self::resolve_rendezvous(self.proxy_client.as_ref(), incoming_rx, rendezvous_code)
143 .await?;
144
145 event_tx
146 .send(RemoteClientEvent::RendezvousResolved {
147 fingerprint: remote_fingerprint,
148 })
149 .await
150 .ok();
151
152 event_tx.send(RemoteClientEvent::HandshakeStart).await.ok();
154
155 let (transport, fingerprint_str) = Self::perform_handshake(
156 self.proxy_client.as_ref(),
157 incoming_rx,
158 remote_fingerprint,
159 None,
160 )
161 .await?;
162
163 event_tx
164 .send(RemoteClientEvent::HandshakeComplete)
165 .await
166 .ok();
167
168 event_tx
170 .send(RemoteClientEvent::HandshakeFingerprint {
171 fingerprint: fingerprint_str,
172 })
173 .await
174 .ok();
175
176 if verify_fingerprint {
177 match timeout(Duration::from_secs(60), response_rx.recv()).await {
179 Ok(Some(RemoteClientResponse::VerifyFingerprint { approved: true })) => {
180 event_tx
181 .send(RemoteClientEvent::FingerprintVerified)
182 .await
183 .ok();
184 }
185 Ok(Some(RemoteClientResponse::VerifyFingerprint { approved: false })) => {
186 self.proxy_client.disconnect().await.ok();
187 event_tx
188 .send(RemoteClientEvent::FingerprintRejected {
189 reason: "User rejected fingerprint verification".to_string(),
190 })
191 .await
192 .ok();
193 return Err(RemoteClientError::FingerprintRejected);
194 }
195 Ok(None) => {
196 return Err(RemoteClientError::ChannelClosed);
197 }
198 Err(_) => {
199 self.proxy_client.disconnect().await.ok();
200 return Err(RemoteClientError::Timeout(
201 "Fingerprint verification timeout".to_string(),
202 ));
203 }
204 }
205 }
206
207 self.session_store.cache_session(remote_fingerprint).await?;
209
210 self.finalize_connection(transport, remote_fingerprint, event_tx)
212 .await?;
213
214 Ok(remote_fingerprint)
215 }
216
217 pub async fn pair_with_psk(
222 &mut self,
223 psk: Psk,
224 remote_fingerprint: IdentityFingerprint,
225 ) -> Result<(), RemoteClientError> {
226 let incoming_rx =
227 self.incoming_rx
228 .as_mut()
229 .ok_or_else(|| RemoteClientError::InvalidState {
230 expected: "proxy connected".to_string(),
231 current: "not connected".to_string(),
232 })?;
233
234 let event_tx = self.event_tx.clone();
235
236 event_tx
238 .send(RemoteClientEvent::PskMode {
239 fingerprint: remote_fingerprint,
240 })
241 .await
242 .ok();
243
244 event_tx.send(RemoteClientEvent::HandshakeStart).await.ok();
246
247 let (transport, _fingerprint_str) = Self::perform_handshake(
248 self.proxy_client.as_ref(),
249 incoming_rx,
250 remote_fingerprint,
251 Some(psk),
252 )
253 .await?;
254
255 event_tx
256 .send(RemoteClientEvent::HandshakeComplete)
257 .await
258 .ok();
259
260 event_tx
262 .send(RemoteClientEvent::FingerprintVerified)
263 .await
264 .ok();
265
266 self.session_store.cache_session(remote_fingerprint).await?;
268
269 self.finalize_connection(transport, remote_fingerprint, event_tx)
271 .await?;
272
273 Ok(())
274 }
275
276 pub async fn load_cached_session(
281 &mut self,
282 remote_fingerprint: IdentityFingerprint,
283 ) -> Result<(), RemoteClientError> {
284 let event_tx = self.event_tx.clone();
285
286 if !self.session_store.has_session(&remote_fingerprint).await {
288 return Err(RemoteClientError::SessionNotFound);
289 }
290
291 event_tx
293 .send(RemoteClientEvent::ReconnectingToSession {
294 fingerprint: remote_fingerprint,
295 })
296 .await
297 .ok();
298
299 let transport_state = self
300 .session_store
301 .load_transport_state(&remote_fingerprint)
302 .await?
303 .expect("Transport state should exist for cached session");
304
305 event_tx
306 .send(RemoteClientEvent::HandshakeComplete)
307 .await
308 .ok();
309
310 event_tx
312 .send(RemoteClientEvent::FingerprintVerified)
313 .await
314 .ok();
315
316 self.session_store
318 .update_last_connected(&remote_fingerprint)
319 .await?;
320
321 self.finalize_connection(transport_state, remote_fingerprint, event_tx)
323 .await?;
324
325 Ok(())
326 }
327
328 async fn finalize_connection(
333 &mut self,
334 transport: MultiDeviceTransport,
335 remote_fingerprint: IdentityFingerprint,
336 event_tx: mpsc::Sender<RemoteClientEvent>,
337 ) -> Result<(), RemoteClientError> {
338 self.session_store
340 .save_transport_state(&remote_fingerprint, transport.clone())
341 .await?;
342
343 let transport = Arc::new(Mutex::new(transport));
345 self.transport = Some(Arc::clone(&transport));
346 self.remote_fingerprint = Some(remote_fingerprint);
347
348 event_tx
350 .send(RemoteClientEvent::Ready {
351 can_request_credentials: true,
352 })
353 .await
354 .ok();
355
356 let incoming_rx = self
358 .incoming_rx
359 .take()
360 .ok_or(RemoteClientError::NotInitialized)?;
361
362 let pending_requests_clone = Arc::clone(&self.pending_requests);
364 tokio::spawn(async move {
365 Self::message_loop(incoming_rx, event_tx, transport, pending_requests_clone).await;
366 });
367
368 debug!("Connection established successfully");
369 Ok(())
370 }
371
372 pub async fn request_credential(
374 &mut self,
375 query: &CredentialQuery,
376 ) -> Result<CredentialData, RemoteClientError> {
377 let transport = self
378 .transport
379 .as_ref()
380 .ok_or(RemoteClientError::SecureChannelNotEstablished)?;
381
382 let remote_fingerprint = self
383 .remote_fingerprint
384 .ok_or(RemoteClientError::NotInitialized)?;
385
386 #[allow(clippy::string_slice)]
388 let request_id = format!("req-{}-{}", now_millis(), &uuid_v4()[..8]);
389
390 debug!("Requesting credential for query: {:?}", query);
391
392 let request = CredentialRequestPayload {
394 request_type: "credential_request".to_string(),
395 query: query.clone(),
396 timestamp: now_millis(),
397 request_id: request_id.clone(),
398 };
399
400 let request_json = serde_json::to_string(&request)?;
401 let mut transport_guard = transport.lock().await;
402 let encrypted_packet = transport_guard
403 .encrypt(request_json.as_bytes())
404 .map_err(|e| RemoteClientError::NoiseProtocol(e.to_string()))?;
405 drop(transport_guard);
406
407 let msg = ProtocolMessage::CredentialRequest {
408 encrypted: STANDARD.encode(encrypted_packet.encode()),
409 };
410
411 let msg_json = serde_json::to_string(&msg)?;
413 self.proxy_client
414 .send_to(remote_fingerprint, msg_json.into_bytes())
415 .await?;
416
417 self.event_tx
419 .send(RemoteClientEvent::CredentialRequestSent {
420 query: query.clone(),
421 })
422 .await
423 .ok();
424
425 let (response_tx, response_rx) = oneshot::channel();
427
428 self.pending_requests
430 .lock()
431 .await
432 .insert(request_id.clone(), response_tx);
433
434 match timeout(DEFAULT_TIMEOUT, response_rx).await {
436 Ok(Ok(Ok(credential))) => {
437 debug!("Received credential for query: {:?}", query);
439 Ok(credential)
440 }
441 Ok(Ok(Err(e))) => {
442 Err(e)
444 }
445 Ok(Err(_)) => {
446 self.pending_requests.lock().await.remove(&request_id);
448 Err(RemoteClientError::ChannelClosed)
449 }
450 Err(_) => {
451 self.pending_requests.lock().await.remove(&request_id);
453 Err(RemoteClientError::Timeout(format!(
454 "Timeout waiting for credential response for query: {query:?}"
455 )))
456 }
457 }
458 }
459
460 pub fn is_ready(&self) -> bool {
462 self.transport.is_some()
463 }
464
465 pub async fn close(&mut self) {
467 let mut pending = self.pending_requests.lock().await;
469 pending.clear(); drop(pending);
471
472 self.proxy_client.disconnect().await.ok();
473 self.transport = None;
474 self.remote_fingerprint = None;
475 self.incoming_rx = None;
476 self.response_rx = None;
477 debug!("Connection closed");
478 }
479
480 pub fn session_store(&self) -> &dyn SessionStore {
482 self.session_store.as_ref()
483 }
484
485 pub fn session_store_mut(&mut self) -> &mut dyn SessionStore {
487 self.session_store.as_mut()
488 }
489
490 async fn resolve_rendezvous(
492 proxy_client: &dyn ProxyClient,
493 incoming_rx: &mut mpsc::UnboundedReceiver<IncomingMessage>,
494 rendezvous_code: &str,
495 ) -> Result<IdentityFingerprint, RemoteClientError> {
496 proxy_client
498 .request_identity(RendezvousCode::from_string(rendezvous_code.to_string()))
499 .await
500 .map_err(|e| RemoteClientError::RendezvousResolutionFailed(e.to_string()))?;
501
502 let timeout_duration = tokio::time::Duration::from_secs(10);
504 match tokio::time::timeout(timeout_duration, async {
505 while let Some(msg) = incoming_rx.recv().await {
506 if let IncomingMessage::IdentityInfo { fingerprint, .. } = msg {
507 return Some(fingerprint);
508 }
509 }
510 None
511 })
512 .await
513 {
514 Ok(Some(fingerprint)) => Ok(fingerprint),
515 Ok(None) => Err(RemoteClientError::RendezvousResolutionFailed(
516 "Connection closed while waiting for identity response".to_string(),
517 )),
518 Err(_) => Err(RemoteClientError::RendezvousResolutionFailed(
519 "Timeout waiting for identity response. The rendezvous code may be invalid, expired, or the target client may be disconnected.".to_string(),
520 )),
521 }
522 }
523
524 async fn perform_handshake(
526 proxy_client: &dyn ProxyClient,
527 incoming_rx: &mut mpsc::UnboundedReceiver<IncomingMessage>,
528 remote_fingerprint: IdentityFingerprint,
529 psk: Option<Psk>,
530 ) -> Result<(MultiDeviceTransport, String), RemoteClientError> {
531 let mut handshake = if let Some(psk) = psk {
533 InitiatorHandshake::with_psk(psk)
534 } else {
535 InitiatorHandshake::new()
536 };
537
538 let init_packet = handshake.send_start()?;
540
541 let msg = ProtocolMessage::HandshakeInit {
543 data: STANDARD.encode(init_packet.encode()?),
544 ciphersuite: format!("{:?}", handshake.ciphersuite()),
545 };
546
547 let msg_json = serde_json::to_string(&msg)?;
548 proxy_client
549 .send_to(remote_fingerprint, msg_json.into_bytes())
550 .await?;
551
552 debug!("Sent handshake init");
553
554 let response_timeout = Duration::from_secs(10);
556 let response: String = timeout(response_timeout, async {
557 loop {
558 if let Some(incoming) = incoming_rx.recv().await {
559 match incoming {
560 IncomingMessage::Send { payload, .. } => {
561 if let Ok(text) = String::from_utf8(payload)
563 && let Ok(ProtocolMessage::HandshakeResponse { data, .. }) =
564 serde_json::from_str::<ProtocolMessage>(&text)
565 {
566 return Ok::<String, RemoteClientError>(data);
567 }
568 }
569 _ => continue,
570 }
571 }
572 }
573 })
574 .await
575 .map_err(|_| RemoteClientError::Timeout("Waiting for handshake response".to_string()))??;
576
577 let response_bytes = STANDARD
579 .decode(&response)
580 .map_err(|e| RemoteClientError::Serialization(format!("Invalid base64: {e}")))?;
581
582 let response_packet = ap_noise::HandshakePacket::decode(&response_bytes)
583 .map_err(|e| RemoteClientError::NoiseProtocol(format!("Invalid packet: {e}")))?;
584
585 handshake.receive_finish(&response_packet)?;
587 let (transport, fingerprint) = handshake.finalize()?;
588
589 debug!("Handshake complete");
590 Ok((transport, fingerprint.to_string()))
591 }
592
593 async fn message_loop(
595 mut incoming_rx: mpsc::UnboundedReceiver<IncomingMessage>,
596 event_tx: mpsc::Sender<RemoteClientEvent>,
597 transport: Arc<Mutex<MultiDeviceTransport>>,
598 pending_requests: Arc<Mutex<PendingRequestMap>>,
599 ) {
600 while let Some(msg) = incoming_rx.recv().await {
601 match msg {
602 IncomingMessage::Send { payload, .. } => {
603 if let Ok(text) = String::from_utf8(payload) {
605 if let Ok(protocol_msg) = serde_json::from_str::<ProtocolMessage>(&text) {
606 match protocol_msg {
607 ProtocolMessage::CredentialResponse { encrypted } => {
608 if let Err(e) = Self::handle_credential_response(
609 encrypted,
610 &transport,
611 &pending_requests,
612 &event_tx,
613 )
614 .await
615 {
616 warn!("Error handling credential response: {:?}", e);
617 event_tx
618 .send(RemoteClientEvent::Error {
619 message: e.to_string(),
620 context: Some("credential_response".to_string()),
621 })
622 .await
623 .ok();
624 }
625 }
626 _ => {
627 debug!("Received other message type");
628 }
629 }
630 }
631 }
632 }
633 IncomingMessage::RendezvousInfo(_) => {
634 }
636 IncomingMessage::IdentityInfo { .. } => {
637 debug!("Received IdentityInfo message");
639 }
640 }
641 }
642 }
643
644 async fn handle_credential_response(
646 encrypted: String,
647 transport: &Arc<Mutex<MultiDeviceTransport>>,
648 pending_requests: &Arc<Mutex<PendingRequestMap>>,
649 event_tx: &mpsc::Sender<RemoteClientEvent>,
650 ) -> Result<(), RemoteClientError> {
651 let encrypted_bytes = STANDARD
653 .decode(&encrypted)
654 .map_err(|e| RemoteClientError::Serialization(format!("Invalid base64: {e}")))?;
655
656 let packet = ap_noise::TransportPacket::decode(&encrypted_bytes)
657 .map_err(|e| RemoteClientError::NoiseProtocol(format!("Invalid packet: {e}")))?;
658
659 let mut transport_guard = transport.lock().await;
660 let decrypted = transport_guard
661 .decrypt(&packet)
662 .map_err(|e| RemoteClientError::NoiseProtocol(e.to_string()))?;
663 drop(transport_guard);
664
665 let response: CredentialResponsePayload = serde_json::from_slice(&decrypted)?;
667
668 let mut pending = pending_requests.lock().await;
670 let sender = match response.request_id {
671 Some(ref req_id) => pending.remove(req_id),
672 None => {
673 warn!("Received credential response without request_id");
674 return Ok(()); }
676 };
677 drop(pending);
678
679 if let Some(sender) = sender {
681 let result = if let Some(error) = response.error {
682 Err(RemoteClientError::CredentialRequestFailed(error))
683 } else if let Some(credential) = response.credential {
684 event_tx
686 .send(RemoteClientEvent::CredentialReceived {
687 credential: credential.clone(),
688 })
689 .await
690 .ok();
691
692 Ok(credential)
693 } else {
694 Err(RemoteClientError::CredentialRequestFailed(
695 "Response contains neither credential nor error".to_string(),
696 ))
697 };
698
699 sender.send(result).ok(); } else {
701 debug!(
703 "Received response for unknown request_id: {:?}",
704 response.request_id
705 );
706 }
707
708 Ok(())
709 }
710}
711
712fn now_millis() -> u64 {
713 SystemTime::now()
714 .duration_since(UNIX_EPOCH)
715 .map(|d| d.as_millis() as u64)
716 .unwrap_or(0)
717}
718
719fn uuid_v4() -> String {
720 let mut bytes = [0u8; 16];
722 let mut rng = rand::thread_rng();
723 rng.fill_bytes(&mut bytes);
724
725 bytes[6] = (bytes[6] & 0x0f) | 0x40;
727 bytes[8] = (bytes[8] & 0x3f) | 0x80;
728
729 format!(
730 "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
731 bytes[0],
732 bytes[1],
733 bytes[2],
734 bytes[3],
735 bytes[4],
736 bytes[5],
737 bytes[6],
738 bytes[7],
739 bytes[8],
740 bytes[9],
741 bytes[10],
742 bytes[11],
743 bytes[12],
744 bytes[13],
745 bytes[14],
746 bytes[15]
747 )
748}