1use std::collections::HashMap;
2use std::time::Duration;
3
4use ap_noise::{Ciphersuite, MultiDeviceTransport, Psk, ResponderHandshake};
5use ap_proxy_client::IncomingMessage;
6use ap_proxy_protocol::{IdentityFingerprint, RendezvousCode};
7use base64::{Engine, engine::general_purpose::STANDARD};
8
9use crate::proxy::ProxyClient;
10use crate::types::CredentialData;
11use tokio::sync::mpsc;
12use tracing::{debug, warn};
13
14const RECONNECT_BASE_DELAY: Duration = Duration::from_secs(2);
16const RECONNECT_MAX_DELAY: Duration = Duration::from_secs(15 * 60);
18
19struct PendingHandshakeVerification {
21 source: IdentityFingerprint,
23 transport: MultiDeviceTransport,
25}
26
27use crate::{
28 error::RemoteClientError,
29 traits::{
30 AuditConnectionType, AuditEvent, AuditLog, CredentialFieldSet, IdentityProvider,
31 NoOpAuditLog, SessionStore,
32 },
33 types::{CredentialRequestPayload, CredentialResponsePayload, ProtocolMessage},
34};
35
36#[derive(Debug, Clone)]
38pub enum UserClientEvent {
39 Listening {},
41 RendezvousCodeGenerated {
43 code: String,
45 },
46 PskTokenGenerated {
48 token: String,
50 },
51 HandshakeStart {},
53 HandshakeProgress {
55 message: String,
57 },
58 HandshakeComplete {},
60 HandshakeFingerprint {
62 fingerprint: String,
64 identity: IdentityFingerprint,
66 },
67 FingerprintVerified {},
69 FingerprintRejected {
71 reason: String,
73 },
74 CredentialRequest {
76 query: crate::types::CredentialQuery,
78 request_id: String,
80 session_id: String,
82 },
83 CredentialApproved {
85 domain: Option<String>,
87 credential_id: Option<String>,
89 },
90 CredentialDenied {
92 domain: Option<String>,
94 credential_id: Option<String>,
96 },
97 SessionRefreshed {
99 fingerprint: IdentityFingerprint,
101 },
102 ClientDisconnected {},
104 Reconnecting {
106 attempt: u32,
108 },
109 Reconnected {},
111 Error {
113 message: String,
115 context: Option<String>,
117 },
118}
119
120#[derive(Debug, Clone)]
122#[allow(clippy::large_enum_variant)]
123pub enum UserClientResponse {
124 VerifyFingerprint {
126 approved: bool,
128 name: Option<String>,
130 },
131 RespondCredential {
133 request_id: String,
135 session_id: String,
137 query: crate::types::CredentialQuery,
139 approved: bool,
141 credential: Option<CredentialData>,
143 credential_id: Option<String>,
145 },
146}
147
148pub struct UserClient {
150 identity_provider: Box<dyn IdentityProvider>,
151 session_store: Box<dyn SessionStore>,
152 proxy_client: Option<Box<dyn ProxyClient>>,
153 transports: HashMap<IdentityFingerprint, MultiDeviceTransport>,
155 rendezvous_code: Option<RendezvousCode>,
157 psk: Option<Psk>,
159 incoming_rx: Option<mpsc::UnboundedReceiver<IncomingMessage>>,
161 pending_verification: Option<PendingHandshakeVerification>,
163 pending_session_name: Option<String>,
165 audit_log: Box<dyn AuditLog>,
167}
168
169impl UserClient {
170 pub async fn listen(
177 identity_provider: Box<dyn IdentityProvider>,
178 session_store: Box<dyn SessionStore>,
179 mut proxy_client: Box<dyn ProxyClient>,
180 ) -> Result<Self, RemoteClientError> {
181 let incoming_rx = proxy_client.connect().await?;
182
183 Ok(Self {
184 identity_provider,
185 session_store,
186 proxy_client: Some(proxy_client),
187 transports: HashMap::new(),
188 rendezvous_code: None,
189 psk: None,
190 incoming_rx: Some(incoming_rx),
191 pending_verification: None,
192 pending_session_name: None,
193 audit_log: Box::new(NoOpAuditLog),
194 })
195 }
196
197 pub fn with_audit_log(mut self, audit_log: Box<dyn AuditLog>) -> Self {
199 self.audit_log = audit_log;
200 self
201 }
202
203 pub async fn listen_cached_only(
208 &mut self,
209 event_tx: mpsc::Sender<UserClientEvent>,
210 response_rx: mpsc::Receiver<UserClientResponse>,
211 ) -> Result<(), RemoteClientError> {
212 debug!("User client listening for cached sessions only (no new pairing code)");
213
214 event_tx.send(UserClientEvent::Listening {}).await.ok();
216
217 self.run_event_loop(event_tx, response_rx).await
219 }
220
221 pub async fn enable_psk(
225 &mut self,
226 event_tx: mpsc::Sender<UserClientEvent>,
227 response_rx: mpsc::Receiver<UserClientResponse>,
228 ) -> Result<(), RemoteClientError> {
229 let psk = Psk::generate();
231 let fingerprint = self.identity_provider.fingerprint().await;
232 let token = format!("{}_{}", psk.to_hex(), hex::encode(fingerprint.0));
233
234 self.psk = Some(psk);
235
236 event_tx
237 .send(UserClientEvent::PskTokenGenerated { token })
238 .await
239 .ok();
240
241 debug!("User client listening in PSK mode");
242
243 event_tx.send(UserClientEvent::Listening {}).await.ok();
245
246 self.run_event_loop(event_tx, response_rx).await
248 }
249
250 pub async fn enable_rendezvous(
254 &mut self,
255 event_tx: mpsc::Sender<UserClientEvent>,
256 response_rx: mpsc::Receiver<UserClientResponse>,
257 ) -> Result<(), RemoteClientError> {
258 let proxy_client = self
259 .proxy_client
260 .as_ref()
261 .ok_or(RemoteClientError::NotInitialized)?;
262
263 proxy_client.request_rendezvous().await?;
265
266 let incoming_rx = self
268 .incoming_rx
269 .as_mut()
270 .ok_or(RemoteClientError::NotInitialized)?;
271
272 let code = loop {
273 if let Some(IncomingMessage::RendezvousInfo(c)) = incoming_rx.recv().await {
274 break c;
275 }
276 };
277
278 self.rendezvous_code = Some(code.clone());
279
280 event_tx
281 .send(UserClientEvent::RendezvousCodeGenerated {
282 code: code.as_str().to_string(),
283 })
284 .await
285 .ok();
286
287 debug!("User client listening with rendezvous code: {}", code);
288
289 event_tx.send(UserClientEvent::Listening {}).await.ok();
291
292 self.run_event_loop(event_tx, response_rx).await
294 }
295
296 async fn run_event_loop(
298 &mut self,
299 event_tx: mpsc::Sender<UserClientEvent>,
300 mut response_rx: mpsc::Receiver<UserClientResponse>,
301 ) -> Result<(), RemoteClientError> {
302 let mut incoming_rx = self
304 .incoming_rx
305 .take()
306 .ok_or(RemoteClientError::NotInitialized)?;
307
308 loop {
309 tokio::select! {
310 msg = incoming_rx.recv() => {
311 match msg {
312 Some(msg) => {
313 if let Err(e) = self.handle_incoming(msg, &event_tx).await {
314 warn!("Error handling incoming message: {}", e);
315 event_tx.send(UserClientEvent::Error {
316 message: e.to_string(),
317 context: Some("handle_incoming".to_string()),
318 }).await.ok();
319 }
320 }
321 None => {
322 event_tx.send(UserClientEvent::ClientDisconnected {}).await.ok();
324 match self.attempt_reconnection(&event_tx).await {
325 Ok(new_rx) => {
326 incoming_rx = new_rx;
327 event_tx.send(UserClientEvent::Reconnected {}).await.ok();
328 }
329 Err(e) => {
330 warn!("Reconnection failed permanently: {}", e);
331 return Err(e);
332 }
333 }
334 }
335 }
336 }
337 Some(response) = response_rx.recv() => {
338 if let Err(e) = self.handle_response(response, &event_tx).await {
339 warn!("Error handling response: {}", e);
340 event_tx.send(UserClientEvent::Error {
341 message: e.to_string(),
342 context: Some("handle_response".to_string()),
343 }).await.ok();
344 }
345 }
346 }
347 }
348 }
349
350 async fn attempt_reconnection(
355 &mut self,
356 event_tx: &mpsc::Sender<UserClientEvent>,
357 ) -> Result<mpsc::UnboundedReceiver<IncomingMessage>, RemoteClientError> {
358 use rand::Rng;
359
360 let mut rng = rand::thread_rng();
361 let mut attempt: u32 = 0;
362
363 loop {
364 attempt = attempt.saturating_add(1);
365
366 let proxy_client = self
367 .proxy_client
368 .as_mut()
369 .ok_or(RemoteClientError::NotInitialized)?;
370
371 let _ = proxy_client.disconnect().await;
373
374 match proxy_client.connect().await {
375 Ok(new_rx) => {
376 debug!("Reconnected to proxy on attempt {}", attempt);
377 return Ok(new_rx);
378 }
379 Err(e) => {
380 debug!("Reconnection attempt {} failed: {}", attempt, e);
381 event_tx
382 .send(UserClientEvent::Reconnecting { attempt })
383 .await
384 .ok();
385
386 let exp_delay = RECONNECT_BASE_DELAY
388 .saturating_mul(2u32.saturating_pow(attempt.saturating_sub(1)));
389 let delay = exp_delay.min(RECONNECT_MAX_DELAY);
390 let jitter_max = (delay.as_millis() as u64) / 4;
391 let jitter = if jitter_max > 0 {
392 rng.gen_range(0..=jitter_max)
393 } else {
394 0
395 };
396 let total_delay = delay + Duration::from_millis(jitter);
397
398 tokio::time::sleep(total_delay).await;
399 }
400 }
401 }
402 }
403
404 async fn handle_incoming(
406 &mut self,
407 msg: IncomingMessage,
408 event_tx: &mpsc::Sender<UserClientEvent>,
409 ) -> Result<(), RemoteClientError> {
410 match msg {
411 IncomingMessage::Send {
412 source, payload, ..
413 } => {
414 let text = String::from_utf8(payload)
416 .map_err(|e| RemoteClientError::Serialization(format!("Invalid UTF-8: {e}")))?;
417
418 let protocol_msg: ProtocolMessage = serde_json::from_str(&text)?;
419
420 match protocol_msg {
421 ProtocolMessage::HandshakeInit { data, ciphersuite } => {
422 self.handle_handshake_init(source, data, ciphersuite, event_tx)
423 .await?;
424 }
425 ProtocolMessage::CredentialRequest { encrypted } => {
426 self.handle_credential_request(source, encrypted, event_tx)
427 .await?;
428 }
429 _ => {
430 debug!("Received unexpected message type from {:?}", source);
431 }
432 }
433 }
434 IncomingMessage::RendezvousInfo(_) => {
435 }
437 IncomingMessage::IdentityInfo { .. } => {
438 debug!("Received unexpected IdentityInfo message");
440 }
441 }
442 Ok(())
443 }
444
445 async fn handle_handshake_init(
447 &mut self,
448 source: IdentityFingerprint,
449 data: String,
450 ciphersuite: String,
451 event_tx: &mpsc::Sender<UserClientEvent>,
452 ) -> Result<(), RemoteClientError> {
453 debug!("Received handshake init from source: {:?}", source);
454 event_tx.send(UserClientEvent::HandshakeStart {}).await.ok();
455
456 let (transport, fingerprint_str) =
457 self.complete_handshake(source, &data, &ciphersuite).await?;
458
459 event_tx
460 .send(UserClientEvent::HandshakeComplete {})
461 .await
462 .ok();
463
464 let is_new_connection = !self.session_store.has_session(&source).await;
466 let is_psk_connection = self.psk.is_some();
468
469 if is_new_connection && !is_psk_connection {
470 self.pending_verification = Some(PendingHandshakeVerification { source, transport });
472
473 event_tx
474 .send(UserClientEvent::HandshakeFingerprint {
475 fingerprint: fingerprint_str,
476 identity: source,
477 })
478 .await
479 .ok();
480 } else if !is_new_connection {
481 self.transports.insert(source, transport.clone());
485 self.session_store.cache_session(source).await?;
486 if let Some(name) = self.pending_session_name.take() {
489 self.session_store.set_session_name(&source, name).await?;
490 }
491 self.session_store
492 .save_transport_state(&source, transport)
493 .await?;
494
495 self.audit_log
496 .write(AuditEvent::SessionRefreshed {
497 remote_identity: &source,
498 })
499 .await;
500
501 event_tx
502 .send(UserClientEvent::SessionRefreshed {
503 fingerprint: source,
504 })
505 .await
506 .ok();
507 } else if is_psk_connection {
508 let session_name = self.pending_session_name.take();
510 self.accept_new_connection(
511 source,
512 transport,
513 session_name.as_deref(),
514 AuditConnectionType::Psk,
515 )
516 .await?;
517
518 event_tx
519 .send(UserClientEvent::HandshakeFingerprint {
520 fingerprint: fingerprint_str,
521 identity: source,
522 })
523 .await
524 .ok();
525 }
526
527 Ok(())
528 }
529
530 async fn accept_new_connection(
532 &mut self,
533 fingerprint: IdentityFingerprint,
534 transport: MultiDeviceTransport,
535 session_name: Option<&str>,
536 connection_type: AuditConnectionType,
537 ) -> Result<(), RemoteClientError> {
538 self.transports.insert(fingerprint, transport.clone());
539 self.session_store.cache_session(fingerprint).await?;
540 if let Some(name) = session_name {
541 self.session_store
542 .set_session_name(&fingerprint, name.to_owned())
543 .await?;
544 }
545 self.session_store
546 .save_transport_state(&fingerprint, transport)
547 .await?;
548
549 self.audit_log
550 .write(AuditEvent::ConnectionEstablished {
551 remote_identity: &fingerprint,
552 remote_name: session_name,
553 connection_type,
554 })
555 .await;
556
557 Ok(())
558 }
559
560 async fn handle_fingerprint_verification(
562 &mut self,
563 approved: bool,
564 name: Option<String>,
565 event_tx: &mpsc::Sender<UserClientEvent>,
566 ) -> Result<(), RemoteClientError> {
567 let pending = self
568 .pending_verification
569 .take()
570 .ok_or(RemoteClientError::InvalidState {
571 expected: "pending verification".to_string(),
572 current: "no pending verification".to_string(),
573 })?;
574
575 if approved {
576 let session_name = name.or(self.pending_session_name.take());
577 self.accept_new_connection(
578 pending.source,
579 pending.transport,
580 session_name.as_deref(),
581 AuditConnectionType::Rendezvous,
582 )
583 .await?;
584
585 event_tx
586 .send(UserClientEvent::FingerprintVerified {})
587 .await
588 .ok();
589 } else {
590 self.audit_log
591 .write(AuditEvent::ConnectionRejected {
592 remote_identity: &pending.source,
593 })
594 .await;
595
596 event_tx
597 .send(UserClientEvent::FingerprintRejected {
598 reason: "User rejected fingerprint verification".to_string(),
599 })
600 .await
601 .ok();
602 }
603
604 Ok(())
605 }
606
607 async fn handle_credential_request(
609 &mut self,
610 source: IdentityFingerprint,
611 encrypted: String,
612 event_tx: &mpsc::Sender<UserClientEvent>,
613 ) -> Result<(), RemoteClientError> {
614 if !self.transports.contains_key(&source) {
615 debug!("Loading transport state for source: {:?}", source);
616 let session = self
617 .session_store
618 .load_transport_state(&source)
619 .await?
620 .expect("Transport state should exist for cached session");
621 self.transports.insert(source, session);
622 }
623
624 let transport = self
626 .transports
627 .get_mut(&source)
628 .ok_or(RemoteClientError::SecureChannelNotEstablished)?;
629
630 let encrypted_bytes = STANDARD
632 .decode(&encrypted)
633 .map_err(|e| RemoteClientError::Serialization(format!("Invalid base64: {e}")))?;
634
635 let packet = ap_noise::TransportPacket::decode(&encrypted_bytes)
636 .map_err(|e| RemoteClientError::NoiseProtocol(format!("Invalid packet: {e}")))?;
637
638 let decrypted = transport
639 .decrypt(&packet)
640 .map_err(|e| RemoteClientError::NoiseProtocol(e.to_string()))?;
641
642 let request: CredentialRequestPayload = serde_json::from_slice(&decrypted)?;
643
644 self.audit_log
645 .write(AuditEvent::CredentialRequested {
646 query: &request.query,
647 remote_identity: &source,
648 request_id: &request.request_id,
649 })
650 .await;
651
652 event_tx
654 .send(UserClientEvent::CredentialRequest {
655 query: request.query,
656 request_id: request.request_id.clone(),
657 session_id: format!("{source:?}"),
658 })
659 .await
660 .ok();
661
662 Ok(())
663 }
664
665 async fn handle_response(
667 &mut self,
668 response: UserClientResponse,
669 event_tx: &mpsc::Sender<UserClientEvent>,
670 ) -> Result<(), RemoteClientError> {
671 match response {
672 UserClientResponse::VerifyFingerprint { approved, name } => {
673 self.handle_fingerprint_verification(approved, name, event_tx)
674 .await?;
675 }
676 UserClientResponse::RespondCredential {
677 request_id,
678 session_id,
679 query,
680 approved,
681 credential,
682 credential_id,
683 } => {
684 self.handle_credential_response(
685 request_id,
686 session_id,
687 query,
688 approved,
689 credential,
690 credential_id,
691 event_tx,
692 )
693 .await?;
694 }
695 }
696 Ok(())
697 }
698
699 #[allow(clippy::too_many_arguments)]
701 async fn handle_credential_response(
702 &mut self,
703 request_id: String,
704 session_id: String,
705 query: crate::types::CredentialQuery,
706 approved: bool,
707 credential: Option<CredentialData>,
708 credential_id: Option<String>,
709 event_tx: &mpsc::Sender<UserClientEvent>,
710 ) -> Result<(), RemoteClientError> {
711 let fingerprint = self
713 .transports
714 .keys()
715 .find(|fp| format!("{fp:?}") == session_id)
716 .copied()
717 .ok_or(RemoteClientError::NotInitialized)?;
718
719 let transport = self
720 .transports
721 .get_mut(&fingerprint)
722 .ok_or(RemoteClientError::SecureChannelNotEstablished)?;
723
724 let domain = credential.as_ref().and_then(|c| c.domain.clone());
726 let fields = credential
727 .as_ref()
728 .map_or_else(CredentialFieldSet::default, |c| CredentialFieldSet {
729 has_username: c.username.is_some(),
730 has_password: c.password.is_some(),
731 has_totp: c.totp.is_some(),
732 has_uri: c.uri.is_some(),
733 has_notes: c.notes.is_some(),
734 });
735
736 let response_payload = CredentialResponsePayload {
738 credential: if approved { credential } else { None },
739 error: if !approved {
740 Some("Request denied".to_string())
741 } else {
742 None
743 },
744 request_id: Some(request_id.clone()),
745 };
746
747 let response_json = serde_json::to_string(&response_payload)?;
749 let encrypted = transport
750 .encrypt(response_json.as_bytes())
751 .map_err(|e| RemoteClientError::NoiseProtocol(e.to_string()))?;
752
753 let msg = ProtocolMessage::CredentialResponse {
754 encrypted: STANDARD.encode(encrypted.encode()),
755 };
756
757 let msg_json = serde_json::to_string(&msg)?;
758
759 let proxy_client = self
760 .proxy_client
761 .as_ref()
762 .ok_or(RemoteClientError::NotInitialized)?;
763
764 proxy_client
765 .send_to(fingerprint, msg_json.into_bytes())
766 .await?;
767
768 if approved {
770 self.audit_log
771 .write(AuditEvent::CredentialApproved {
772 query: &query,
773 domain: domain.as_deref(),
774 remote_identity: &fingerprint,
775 request_id: &request_id,
776 credential_id: credential_id.as_deref(),
777 fields,
778 })
779 .await;
780
781 event_tx
782 .send(UserClientEvent::CredentialApproved {
783 domain,
784 credential_id,
785 })
786 .await
787 .ok();
788 } else {
789 self.audit_log
790 .write(AuditEvent::CredentialDenied {
791 query: &query,
792 domain: domain.as_deref(),
793 remote_identity: &fingerprint,
794 request_id: &request_id,
795 credential_id: credential_id.as_deref(),
796 })
797 .await;
798
799 event_tx
800 .send(UserClientEvent::CredentialDenied {
801 domain,
802 credential_id,
803 })
804 .await
805 .ok();
806 }
807
808 Ok(())
809 }
810
811 async fn complete_handshake(
813 &self,
814 remote_fingerprint: IdentityFingerprint,
815 handshake_data: &str,
816 ciphersuite_str: &str,
817 ) -> Result<(MultiDeviceTransport, String), RemoteClientError> {
818 let ciphersuite = match ciphersuite_str {
820 s if s.contains("Kyber768") => Ciphersuite::PQNNpsk2_Kyber768_XChaCha20Poly1305,
821 _ => Ciphersuite::ClassicalNNpsk2_25519_XChaCha20Poly1035,
822 };
823
824 let init_bytes = STANDARD
826 .decode(handshake_data)
827 .map_err(|e| RemoteClientError::Serialization(format!("Invalid base64: {e}")))?;
828
829 let init_packet = ap_noise::HandshakePacket::decode(&init_bytes)
830 .map_err(|e| RemoteClientError::NoiseProtocol(format!("Invalid packet: {e}")))?;
831
832 let mut handshake = if let Some(ref psk) = self.psk {
834 ResponderHandshake::with_psk(psk.clone())
835 } else {
836 ResponderHandshake::new()
837 };
838
839 handshake.receive_start(&init_packet)?;
841 let response_packet = handshake.send_finish()?;
842 let (transport, fingerprint) = handshake.finalize()?;
843
844 let msg = ProtocolMessage::HandshakeResponse {
846 data: STANDARD.encode(response_packet.encode()?),
847 ciphersuite: format!("{ciphersuite:?}"),
848 };
849
850 let msg_json = serde_json::to_string(&msg)?;
851
852 let proxy_client = self
853 .proxy_client
854 .as_ref()
855 .ok_or(RemoteClientError::NotInitialized)?;
856
857 proxy_client
858 .send_to(remote_fingerprint, msg_json.into_bytes())
859 .await?;
860
861 debug!("Sent handshake response to {:?}", remote_fingerprint);
862
863 Ok((transport, fingerprint.to_string()))
864 }
865
866 pub fn rendezvous_code(&self) -> Option<&RendezvousCode> {
868 self.rendezvous_code.as_ref()
869 }
870
871 pub fn set_pending_session_name(&mut self, name: String) {
873 self.pending_session_name = Some(name);
874 }
875}