1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, SystemTime, UNIX_EPOCH};
4
5use ap_noise::{InitiatorHandshake, MultiDeviceTransport, Psk};
6use ap_proxy_client::IncomingMessage;
7use ap_proxy_protocol::{IdentityFingerprint, RendevouzCode};
8use base64::{Engine, engine::general_purpose::STANDARD};
9use rand::RngCore;
10
11use crate::proxy::ProxyClient;
12use tokio::{
13 sync::{Mutex, mpsc, oneshot},
14 time::timeout,
15};
16use tracing::{debug, warn};
17
18use crate::traits::{IdentityProvider, SessionStore};
19use crate::{
20 error::RemoteClientError,
21 types::{
22 CredentialData, CredentialQuery, CredentialRequestPayload, CredentialResponsePayload,
23 ProtocolMessage, RemoteClientEvent, RemoteClientResponse,
24 },
25};
26
27const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
28
29type PendingRequestMap =
31 HashMap<String, oneshot::Sender<Result<CredentialData, RemoteClientError>>>;
32
33pub struct RemoteClient {
35 session_store: Box<dyn SessionStore>,
36 proxy_client: Box<dyn ProxyClient>,
37 incoming_rx: Option<mpsc::UnboundedReceiver<IncomingMessage>>,
38 transport: Option<Arc<Mutex<MultiDeviceTransport>>>,
39 remote_fingerprint: Option<IdentityFingerprint>,
40 pending_requests: Arc<Mutex<PendingRequestMap>>,
41 event_tx: mpsc::Sender<RemoteClientEvent>,
42 response_rx: Option<mpsc::Receiver<RemoteClientResponse>>,
43}
44
45impl RemoteClient {
46 pub async fn new(
61 identity_provider: Box<dyn IdentityProvider>,
62 session_store: Box<dyn SessionStore>,
63 event_tx: mpsc::Sender<RemoteClientEvent>,
64 response_rx: mpsc::Receiver<RemoteClientResponse>,
65 mut proxy_client: Box<dyn ProxyClient>,
66 ) -> Result<Self, RemoteClientError> {
67 let identity = identity_provider.identity().to_owned();
68
69 debug!(
70 "Connecting to proxy with identity {:?}",
71 identity.identity().fingerprint()
72 );
73
74 event_tx
75 .send(RemoteClientEvent::Connecting {
76 proxy_url: String::new(),
77 })
78 .await
79 .ok();
80
81 let incoming_rx = proxy_client.connect().await?;
82
83 event_tx
84 .send(RemoteClientEvent::Connected {
85 fingerprint: identity.identity().fingerprint(),
86 })
87 .await
88 .ok();
89
90 debug!("Connected to proxy successfully");
91
92 Ok(Self {
93 session_store,
94 proxy_client,
95 incoming_rx: Some(incoming_rx),
96 transport: None,
97 remote_fingerprint: None,
98 pending_requests: Arc::new(Mutex::new(HashMap::new())),
99 event_tx,
100 response_rx: Some(response_rx),
101 })
102 }
103
104 pub async fn pair_with_handshake(
114 &mut self,
115 rendezvous_code: &str,
116 verify_fingerprint: bool,
117 ) -> Result<IdentityFingerprint, RemoteClientError> {
118 let incoming_rx =
119 self.incoming_rx
120 .as_mut()
121 .ok_or_else(|| RemoteClientError::InvalidState {
122 expected: "proxy connected".to_string(),
123 current: "not connected".to_string(),
124 })?;
125
126 let event_tx = self.event_tx.clone();
127
128 let response_rx = self
129 .response_rx
130 .as_mut()
131 .ok_or(RemoteClientError::NotInitialized)?;
132
133 event_tx
135 .send(RemoteClientEvent::RendevouzResolving {
136 code: rendezvous_code.to_string(),
137 })
138 .await
139 .ok();
140
141 let remote_fingerprint =
142 Self::resolve_rendezvous(self.proxy_client.as_ref(), incoming_rx, rendezvous_code)
143 .await?;
144
145 event_tx
146 .send(RemoteClientEvent::RendevouzResolved {
147 fingerprint: remote_fingerprint,
148 })
149 .await
150 .ok();
151
152 event_tx.send(RemoteClientEvent::HandshakeStart).await.ok();
154
155 let (transport, fingerprint_str) = Self::perform_handshake(
156 self.proxy_client.as_ref(),
157 incoming_rx,
158 remote_fingerprint,
159 None,
160 )
161 .await?;
162
163 event_tx
164 .send(RemoteClientEvent::HandshakeComplete)
165 .await
166 .ok();
167
168 event_tx
170 .send(RemoteClientEvent::HandshakeFingerprint {
171 fingerprint: fingerprint_str,
172 })
173 .await
174 .ok();
175
176 if verify_fingerprint {
177 match timeout(Duration::from_secs(60), response_rx.recv()).await {
179 Ok(Some(RemoteClientResponse::VerifyFingerprint { approved: true })) => {
180 event_tx
181 .send(RemoteClientEvent::FingerprintVerified)
182 .await
183 .ok();
184 }
185 Ok(Some(RemoteClientResponse::VerifyFingerprint { approved: false })) => {
186 self.proxy_client.disconnect().await.ok();
187 event_tx
188 .send(RemoteClientEvent::FingerprintRejected {
189 reason: "User rejected fingerprint verification".to_string(),
190 })
191 .await
192 .ok();
193 return Err(RemoteClientError::FingerprintRejected);
194 }
195 Ok(None) => {
196 return Err(RemoteClientError::ChannelClosed);
197 }
198 Err(_) => {
199 self.proxy_client.disconnect().await.ok();
200 return Err(RemoteClientError::Timeout(
201 "Fingerprint verification timeout".to_string(),
202 ));
203 }
204 }
205 }
206
207 self.session_store.cache_session(remote_fingerprint)?;
209
210 self.finalize_connection(transport, remote_fingerprint, event_tx)
212 .await?;
213
214 Ok(remote_fingerprint)
215 }
216
217 pub async fn pair_with_psk(
222 &mut self,
223 psk: Psk,
224 remote_fingerprint: IdentityFingerprint,
225 ) -> Result<(), RemoteClientError> {
226 let incoming_rx =
227 self.incoming_rx
228 .as_mut()
229 .ok_or_else(|| RemoteClientError::InvalidState {
230 expected: "proxy connected".to_string(),
231 current: "not connected".to_string(),
232 })?;
233
234 let event_tx = self.event_tx.clone();
235
236 event_tx
238 .send(RemoteClientEvent::PskMode {
239 fingerprint: remote_fingerprint,
240 })
241 .await
242 .ok();
243
244 event_tx.send(RemoteClientEvent::HandshakeStart).await.ok();
246
247 let (transport, _fingerprint_str) = Self::perform_handshake(
248 self.proxy_client.as_ref(),
249 incoming_rx,
250 remote_fingerprint,
251 Some(psk),
252 )
253 .await?;
254
255 event_tx
256 .send(RemoteClientEvent::HandshakeComplete)
257 .await
258 .ok();
259
260 event_tx
262 .send(RemoteClientEvent::FingerprintVerified)
263 .await
264 .ok();
265
266 self.session_store.cache_session(remote_fingerprint)?;
268
269 self.finalize_connection(transport, remote_fingerprint, event_tx)
271 .await?;
272
273 Ok(())
274 }
275
276 pub async fn load_cached_session(
281 &mut self,
282 remote_fingerprint: IdentityFingerprint,
283 ) -> Result<(), RemoteClientError> {
284 let event_tx = self.event_tx.clone();
285
286 if !self.session_store.has_session(&remote_fingerprint) {
288 return Err(RemoteClientError::SessionNotFound);
289 }
290
291 event_tx
293 .send(RemoteClientEvent::ReconnectingToSession {
294 fingerprint: remote_fingerprint,
295 })
296 .await
297 .ok();
298
299 let transport_state = self
300 .session_store
301 .load_transport_state(&remote_fingerprint)?
302 .expect("Transport state should exist for cached session");
303
304 event_tx
305 .send(RemoteClientEvent::HandshakeComplete)
306 .await
307 .ok();
308
309 event_tx
311 .send(RemoteClientEvent::FingerprintVerified)
312 .await
313 .ok();
314
315 self.session_store
317 .update_last_connected(&remote_fingerprint)?;
318
319 self.finalize_connection(transport_state, remote_fingerprint, event_tx)
321 .await?;
322
323 Ok(())
324 }
325
326 async fn finalize_connection(
331 &mut self,
332 transport: MultiDeviceTransport,
333 remote_fingerprint: IdentityFingerprint,
334 event_tx: mpsc::Sender<RemoteClientEvent>,
335 ) -> Result<(), RemoteClientError> {
336 self.session_store
338 .save_transport_state(&remote_fingerprint, transport.clone())?;
339
340 let transport = Arc::new(Mutex::new(transport));
342 self.transport = Some(Arc::clone(&transport));
343 self.remote_fingerprint = Some(remote_fingerprint);
344
345 event_tx
347 .send(RemoteClientEvent::Ready {
348 can_request_credentials: true,
349 })
350 .await
351 .ok();
352
353 let incoming_rx = self
355 .incoming_rx
356 .take()
357 .ok_or(RemoteClientError::NotInitialized)?;
358
359 let pending_requests_clone = Arc::clone(&self.pending_requests);
361 tokio::spawn(async move {
362 Self::message_loop(incoming_rx, event_tx, transport, pending_requests_clone).await;
363 });
364
365 debug!("Connection established successfully");
366 Ok(())
367 }
368
369 pub async fn request_credential(
371 &mut self,
372 query: &CredentialQuery,
373 ) -> Result<CredentialData, RemoteClientError> {
374 let transport = self
375 .transport
376 .as_ref()
377 .ok_or(RemoteClientError::SecureChannelNotEstablished)?;
378
379 let remote_fingerprint = self
380 .remote_fingerprint
381 .ok_or(RemoteClientError::NotInitialized)?;
382
383 #[allow(clippy::string_slice)]
385 let request_id = format!("req-{}-{}", now_millis(), &uuid_v4()[..8]);
386
387 debug!("Requesting credential for query: {:?}", query);
388
389 let request = CredentialRequestPayload {
391 request_type: "credential_request".to_string(),
392 query: query.clone(),
393 timestamp: now_millis(),
394 request_id: request_id.clone(),
395 };
396
397 let request_json = serde_json::to_string(&request)?;
398 let mut transport_guard = transport.lock().await;
399 let encrypted_packet = transport_guard
400 .encrypt(request_json.as_bytes())
401 .map_err(|e| RemoteClientError::NoiseProtocol(e.to_string()))?;
402 drop(transport_guard);
403
404 let msg = ProtocolMessage::CredentialRequest {
405 encrypted: STANDARD.encode(encrypted_packet.encode()),
406 };
407
408 let msg_json = serde_json::to_string(&msg)?;
410 self.proxy_client
411 .send_to(remote_fingerprint, msg_json.into_bytes())
412 .await?;
413
414 self.event_tx
416 .send(RemoteClientEvent::CredentialRequestSent {
417 query: query.clone(),
418 })
419 .await
420 .ok();
421
422 let (response_tx, response_rx) = oneshot::channel();
424
425 self.pending_requests
427 .lock()
428 .await
429 .insert(request_id.clone(), response_tx);
430
431 match timeout(DEFAULT_TIMEOUT, response_rx).await {
433 Ok(Ok(Ok(credential))) => {
434 debug!("Received credential for query: {:?}", query);
436 Ok(credential)
437 }
438 Ok(Ok(Err(e))) => {
439 Err(e)
441 }
442 Ok(Err(_)) => {
443 self.pending_requests.lock().await.remove(&request_id);
445 Err(RemoteClientError::ChannelClosed)
446 }
447 Err(_) => {
448 self.pending_requests.lock().await.remove(&request_id);
450 Err(RemoteClientError::Timeout(format!(
451 "Timeout waiting for credential response for query: {query:?}"
452 )))
453 }
454 }
455 }
456
457 pub fn is_ready(&self) -> bool {
459 self.transport.is_some()
460 }
461
462 pub async fn close(&mut self) {
464 let mut pending = self.pending_requests.lock().await;
466 pending.clear(); drop(pending);
468
469 self.proxy_client.disconnect().await.ok();
470 self.transport = None;
471 self.remote_fingerprint = None;
472 self.incoming_rx = None;
473 self.response_rx = None;
474 debug!("Connection closed");
475 }
476
477 pub fn session_store(&self) -> &dyn SessionStore {
479 self.session_store.as_ref()
480 }
481
482 pub fn session_store_mut(&mut self) -> &mut dyn SessionStore {
484 self.session_store.as_mut()
485 }
486
487 async fn resolve_rendezvous(
489 proxy_client: &dyn ProxyClient,
490 incoming_rx: &mut mpsc::UnboundedReceiver<IncomingMessage>,
491 rendezvous_code: &str,
492 ) -> Result<IdentityFingerprint, RemoteClientError> {
493 proxy_client
495 .request_identity(RendevouzCode::from_string(rendezvous_code.to_string()))
496 .await
497 .map_err(|e| RemoteClientError::RendevouzResolutionFailed(e.to_string()))?;
498
499 let timeout_duration = tokio::time::Duration::from_secs(10);
501 match tokio::time::timeout(timeout_duration, async {
502 while let Some(msg) = incoming_rx.recv().await {
503 if let IncomingMessage::IdentityInfo { fingerprint, .. } = msg {
504 return Some(fingerprint);
505 }
506 }
507 None
508 })
509 .await
510 {
511 Ok(Some(fingerprint)) => Ok(fingerprint),
512 Ok(None) => Err(RemoteClientError::RendevouzResolutionFailed(
513 "Connection closed while waiting for identity response".to_string(),
514 )),
515 Err(_) => Err(RemoteClientError::RendevouzResolutionFailed(
516 "Timeout waiting for identity response. The rendezvous code may be invalid, expired, or the target client may be disconnected.".to_string(),
517 )),
518 }
519 }
520
521 async fn perform_handshake(
523 proxy_client: &dyn ProxyClient,
524 incoming_rx: &mut mpsc::UnboundedReceiver<IncomingMessage>,
525 remote_fingerprint: IdentityFingerprint,
526 psk: Option<Psk>,
527 ) -> Result<(MultiDeviceTransport, String), RemoteClientError> {
528 let mut handshake = if let Some(psk) = psk {
530 InitiatorHandshake::with_psk(psk)
531 } else {
532 InitiatorHandshake::new()
533 };
534
535 let init_packet = handshake.send_start()?;
537
538 let msg = ProtocolMessage::HandshakeInit {
540 data: STANDARD.encode(init_packet.encode()?),
541 ciphersuite: format!("{:?}", handshake.ciphersuite()),
542 };
543
544 let msg_json = serde_json::to_string(&msg)?;
545 proxy_client
546 .send_to(remote_fingerprint, msg_json.into_bytes())
547 .await?;
548
549 debug!("Sent handshake init");
550
551 let response_timeout = Duration::from_secs(10);
553 let response: String = timeout(response_timeout, async {
554 loop {
555 if let Some(incoming) = incoming_rx.recv().await {
556 match incoming {
557 IncomingMessage::Send { payload, .. } => {
558 if let Ok(text) = String::from_utf8(payload)
560 && let Ok(ProtocolMessage::HandshakeResponse { data, .. }) =
561 serde_json::from_str::<ProtocolMessage>(&text)
562 {
563 return Ok::<String, RemoteClientError>(data);
564 }
565 }
566 _ => continue,
567 }
568 }
569 }
570 })
571 .await
572 .map_err(|_| RemoteClientError::Timeout("Waiting for handshake response".to_string()))??;
573
574 let response_bytes = STANDARD
576 .decode(&response)
577 .map_err(|e| RemoteClientError::Serialization(format!("Invalid base64: {e}")))?;
578
579 let response_packet = ap_noise::HandshakePacket::decode(&response_bytes)
580 .map_err(|e| RemoteClientError::NoiseProtocol(format!("Invalid packet: {e}")))?;
581
582 handshake.receive_finish(&response_packet)?;
584 let (transport, fingerprint) = handshake.finalize()?;
585
586 debug!("Handshake complete");
587 Ok((transport, fingerprint.to_string()))
588 }
589
590 async fn message_loop(
592 mut incoming_rx: mpsc::UnboundedReceiver<IncomingMessage>,
593 event_tx: mpsc::Sender<RemoteClientEvent>,
594 transport: Arc<Mutex<MultiDeviceTransport>>,
595 pending_requests: Arc<Mutex<PendingRequestMap>>,
596 ) {
597 while let Some(msg) = incoming_rx.recv().await {
598 match msg {
599 IncomingMessage::Send { payload, .. } => {
600 if let Ok(text) = String::from_utf8(payload) {
602 if let Ok(protocol_msg) = serde_json::from_str::<ProtocolMessage>(&text) {
603 match protocol_msg {
604 ProtocolMessage::CredentialResponse { encrypted } => {
605 if let Err(e) = Self::handle_credential_response(
606 encrypted,
607 &transport,
608 &pending_requests,
609 &event_tx,
610 )
611 .await
612 {
613 warn!("Error handling credential response: {:?}", e);
614 event_tx
615 .send(RemoteClientEvent::Error {
616 message: e.to_string(),
617 context: Some("credential_response".to_string()),
618 })
619 .await
620 .ok();
621 }
622 }
623 _ => {
624 debug!("Received other message type");
625 }
626 }
627 }
628 }
629 }
630 IncomingMessage::RendevouzInfo(_) => {
631 }
633 IncomingMessage::IdentityInfo { .. } => {
634 debug!("Received IdentityInfo message");
636 }
637 }
638 }
639 }
640
641 async fn handle_credential_response(
643 encrypted: String,
644 transport: &Arc<Mutex<MultiDeviceTransport>>,
645 pending_requests: &Arc<Mutex<PendingRequestMap>>,
646 event_tx: &mpsc::Sender<RemoteClientEvent>,
647 ) -> Result<(), RemoteClientError> {
648 let encrypted_bytes = STANDARD
650 .decode(&encrypted)
651 .map_err(|e| RemoteClientError::Serialization(format!("Invalid base64: {e}")))?;
652
653 let packet = ap_noise::TransportPacket::decode(&encrypted_bytes)
654 .map_err(|e| RemoteClientError::NoiseProtocol(format!("Invalid packet: {e}")))?;
655
656 let mut transport_guard = transport.lock().await;
657 let decrypted = transport_guard
658 .decrypt(&packet)
659 .map_err(|e| RemoteClientError::NoiseProtocol(e.to_string()))?;
660 drop(transport_guard);
661
662 let response: CredentialResponsePayload = serde_json::from_slice(&decrypted)?;
664
665 let mut pending = pending_requests.lock().await;
667 let sender = match response.request_id {
668 Some(ref req_id) => pending.remove(req_id),
669 None => {
670 warn!("Received credential response without request_id");
671 return Ok(()); }
673 };
674 drop(pending);
675
676 if let Some(sender) = sender {
678 let result = if let Some(error) = response.error {
679 Err(RemoteClientError::CredentialRequestFailed(error))
680 } else if let Some(credential) = response.credential {
681 event_tx
683 .send(RemoteClientEvent::CredentialReceived {
684 credential: credential.clone(),
685 })
686 .await
687 .ok();
688
689 Ok(credential)
690 } else {
691 Err(RemoteClientError::CredentialRequestFailed(
692 "Response contains neither credential nor error".to_string(),
693 ))
694 };
695
696 sender.send(result).ok(); } else {
698 debug!(
700 "Received response for unknown request_id: {:?}",
701 response.request_id
702 );
703 }
704
705 Ok(())
706 }
707}
708
709fn now_millis() -> u64 {
710 SystemTime::now()
711 .duration_since(UNIX_EPOCH)
712 .map(|d| d.as_millis() as u64)
713 .unwrap_or(0)
714}
715
716fn uuid_v4() -> String {
717 let mut bytes = [0u8; 16];
719 let mut rng = rand::thread_rng();
720 rng.fill_bytes(&mut bytes);
721
722 bytes[6] = (bytes[6] & 0x0f) | 0x40;
724 bytes[8] = (bytes[8] & 0x3f) | 0x80;
725
726 format!(
727 "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
728 bytes[0],
729 bytes[1],
730 bytes[2],
731 bytes[3],
732 bytes[4],
733 bytes[5],
734 bytes[6],
735 bytes[7],
736 bytes[8],
737 bytes[9],
738 bytes[10],
739 bytes[11],
740 bytes[12],
741 bytes[13],
742 bytes[14],
743 bytes[15]
744 )
745}