1use std::collections::HashMap;
2use std::time::Duration;
3
4use ap_noise::{Ciphersuite, MultiDeviceTransport, Psk, ResponderHandshake};
5use ap_proxy_client::IncomingMessage;
6use ap_proxy_protocol::{IdentityFingerprint, RendevouzCode};
7use base64::{Engine, engine::general_purpose::STANDARD};
8
9use crate::proxy::ProxyClient;
10use crate::types::CredentialData;
11use tokio::sync::mpsc;
12use tracing::{debug, warn};
13
14const RECONNECT_BASE_DELAY: Duration = Duration::from_secs(2);
16const RECONNECT_MAX_DELAY: Duration = Duration::from_secs(15 * 60);
18
19struct PendingHandshakeVerification {
21 source: IdentityFingerprint,
23 transport: MultiDeviceTransport,
25}
26
27use crate::{
28 error::RemoteClientError,
29 traits::{
30 AuditConnectionType, AuditEvent, AuditLog, CredentialFieldSet, IdentityProvider,
31 NoOpAuditLog, SessionStore,
32 },
33 types::{CredentialRequestPayload, CredentialResponsePayload, ProtocolMessage},
34};
35
36#[derive(Debug, Clone)]
38pub enum UserClientEvent {
39 Listening {},
41 RendevouzCodeGenerated {
43 code: String,
45 },
46 PskTokenGenerated {
48 token: String,
50 },
51 HandshakeStart {},
53 HandshakeProgress {
55 message: String,
57 },
58 HandshakeComplete {},
60 HandshakeFingerprint {
62 fingerprint: String,
64 },
65 FingerprintVerified {},
67 FingerprintRejected {
69 reason: String,
71 },
72 CredentialRequest {
74 query: crate::types::CredentialQuery,
76 request_id: String,
78 session_id: String,
80 },
81 CredentialApproved {
83 domain: Option<String>,
85 credential_id: Option<String>,
87 },
88 CredentialDenied {
90 domain: Option<String>,
92 credential_id: Option<String>,
94 },
95 SessionRefreshed {
97 fingerprint: IdentityFingerprint,
99 },
100 ClientDisconnected {},
102 Reconnecting {
104 attempt: u32,
106 },
107 Reconnected {},
109 Error {
111 message: String,
113 context: Option<String>,
115 },
116}
117
118#[derive(Debug, Clone)]
120#[allow(clippy::large_enum_variant)]
121pub enum UserClientResponse {
122 VerifyFingerprint {
124 approved: bool,
126 name: Option<String>,
128 },
129 RespondCredential {
131 request_id: String,
133 session_id: String,
135 query: crate::types::CredentialQuery,
137 approved: bool,
139 credential: Option<CredentialData>,
141 credential_id: Option<String>,
143 },
144}
145
146pub struct UserClient {
148 identity_provider: Box<dyn IdentityProvider>,
149 session_store: Box<dyn SessionStore>,
150 proxy_client: Option<Box<dyn ProxyClient>>,
151 transports: HashMap<IdentityFingerprint, MultiDeviceTransport>,
153 rendezvous_code: Option<RendevouzCode>,
155 psk: Option<Psk>,
157 incoming_rx: Option<mpsc::UnboundedReceiver<IncomingMessage>>,
159 pending_verification: Option<PendingHandshakeVerification>,
161 pending_session_name: Option<String>,
163 audit_log: Box<dyn AuditLog>,
165}
166
167impl UserClient {
168 pub async fn listen(
175 identity_provider: Box<dyn IdentityProvider>,
176 session_store: Box<dyn SessionStore>,
177 mut proxy_client: Box<dyn ProxyClient>,
178 ) -> Result<Self, RemoteClientError> {
179 let incoming_rx = proxy_client.connect().await?;
180
181 Ok(Self {
182 identity_provider,
183 session_store,
184 proxy_client: Some(proxy_client),
185 transports: HashMap::new(),
186 rendezvous_code: None,
187 psk: None,
188 incoming_rx: Some(incoming_rx),
189 pending_verification: None,
190 pending_session_name: None,
191 audit_log: Box::new(NoOpAuditLog),
192 })
193 }
194
195 pub fn with_audit_log(mut self, audit_log: Box<dyn AuditLog>) -> Self {
197 self.audit_log = audit_log;
198 self
199 }
200
201 pub async fn listen_cached_only(
206 &mut self,
207 event_tx: mpsc::Sender<UserClientEvent>,
208 response_rx: mpsc::Receiver<UserClientResponse>,
209 ) -> Result<(), RemoteClientError> {
210 debug!("User client listening for cached sessions only (no new pairing code)");
211
212 event_tx.send(UserClientEvent::Listening {}).await.ok();
214
215 self.run_event_loop(event_tx, response_rx).await
217 }
218
219 pub async fn enable_psk(
223 &mut self,
224 event_tx: mpsc::Sender<UserClientEvent>,
225 response_rx: mpsc::Receiver<UserClientResponse>,
226 ) -> Result<(), RemoteClientError> {
227 let psk = Psk::generate();
229 let fingerprint = self.identity_provider.fingerprint();
230 let token = format!("{}_{}", psk.to_hex(), hex::encode(fingerprint.0));
231
232 self.psk = Some(psk);
233
234 event_tx
235 .send(UserClientEvent::PskTokenGenerated { token })
236 .await
237 .ok();
238
239 debug!("User client listening in PSK mode");
240
241 event_tx.send(UserClientEvent::Listening {}).await.ok();
243
244 self.run_event_loop(event_tx, response_rx).await
246 }
247
248 pub async fn enable_rendezvous(
252 &mut self,
253 event_tx: mpsc::Sender<UserClientEvent>,
254 response_rx: mpsc::Receiver<UserClientResponse>,
255 ) -> Result<(), RemoteClientError> {
256 let proxy_client = self
257 .proxy_client
258 .as_ref()
259 .ok_or(RemoteClientError::NotInitialized)?;
260
261 proxy_client.request_rendezvous().await?;
263
264 let incoming_rx = self
266 .incoming_rx
267 .as_mut()
268 .ok_or(RemoteClientError::NotInitialized)?;
269
270 let code = loop {
271 if let Some(IncomingMessage::RendevouzInfo(c)) = incoming_rx.recv().await {
272 break c;
273 }
274 };
275
276 self.rendezvous_code = Some(code.clone());
277
278 event_tx
279 .send(UserClientEvent::RendevouzCodeGenerated {
280 code: code.as_str().to_string(),
281 })
282 .await
283 .ok();
284
285 debug!("User client listening with rendezvous code: {}", code);
286
287 event_tx.send(UserClientEvent::Listening {}).await.ok();
289
290 self.run_event_loop(event_tx, response_rx).await
292 }
293
294 async fn run_event_loop(
296 &mut self,
297 event_tx: mpsc::Sender<UserClientEvent>,
298 mut response_rx: mpsc::Receiver<UserClientResponse>,
299 ) -> Result<(), RemoteClientError> {
300 let mut incoming_rx = self
302 .incoming_rx
303 .take()
304 .ok_or(RemoteClientError::NotInitialized)?;
305
306 loop {
307 tokio::select! {
308 msg = incoming_rx.recv() => {
309 match msg {
310 Some(msg) => {
311 if let Err(e) = self.handle_incoming(msg, &event_tx).await {
312 warn!("Error handling incoming message: {}", e);
313 event_tx.send(UserClientEvent::Error {
314 message: e.to_string(),
315 context: Some("handle_incoming".to_string()),
316 }).await.ok();
317 }
318 }
319 None => {
320 event_tx.send(UserClientEvent::ClientDisconnected {}).await.ok();
322 match self.attempt_reconnection(&event_tx).await {
323 Ok(new_rx) => {
324 incoming_rx = new_rx;
325 event_tx.send(UserClientEvent::Reconnected {}).await.ok();
326 }
327 Err(e) => {
328 warn!("Reconnection failed permanently: {}", e);
329 return Err(e);
330 }
331 }
332 }
333 }
334 }
335 Some(response) = response_rx.recv() => {
336 if let Err(e) = self.handle_response(response, &event_tx).await {
337 warn!("Error handling response: {}", e);
338 event_tx.send(UserClientEvent::Error {
339 message: e.to_string(),
340 context: Some("handle_response".to_string()),
341 }).await.ok();
342 }
343 }
344 }
345 }
346 }
347
348 async fn attempt_reconnection(
353 &mut self,
354 event_tx: &mpsc::Sender<UserClientEvent>,
355 ) -> Result<mpsc::UnboundedReceiver<IncomingMessage>, RemoteClientError> {
356 use rand::Rng;
357
358 let mut rng = rand::thread_rng();
359 let mut attempt: u32 = 0;
360
361 loop {
362 attempt = attempt.saturating_add(1);
363
364 let proxy_client = self
365 .proxy_client
366 .as_mut()
367 .ok_or(RemoteClientError::NotInitialized)?;
368
369 let _ = proxy_client.disconnect().await;
371
372 match proxy_client.connect().await {
373 Ok(new_rx) => {
374 debug!("Reconnected to proxy on attempt {}", attempt);
375 return Ok(new_rx);
376 }
377 Err(e) => {
378 debug!("Reconnection attempt {} failed: {}", attempt, e);
379 event_tx
380 .send(UserClientEvent::Reconnecting { attempt })
381 .await
382 .ok();
383
384 let exp_delay = RECONNECT_BASE_DELAY
386 .saturating_mul(2u32.saturating_pow(attempt.saturating_sub(1)));
387 let delay = exp_delay.min(RECONNECT_MAX_DELAY);
388 let jitter_max = (delay.as_millis() as u64) / 4;
389 let jitter = if jitter_max > 0 {
390 rng.gen_range(0..=jitter_max)
391 } else {
392 0
393 };
394 let total_delay = delay + Duration::from_millis(jitter);
395
396 tokio::time::sleep(total_delay).await;
397 }
398 }
399 }
400 }
401
402 async fn handle_incoming(
404 &mut self,
405 msg: IncomingMessage,
406 event_tx: &mpsc::Sender<UserClientEvent>,
407 ) -> Result<(), RemoteClientError> {
408 match msg {
409 IncomingMessage::Send {
410 source, payload, ..
411 } => {
412 let text = String::from_utf8(payload)
414 .map_err(|e| RemoteClientError::Serialization(format!("Invalid UTF-8: {e}")))?;
415
416 let protocol_msg: ProtocolMessage = serde_json::from_str(&text)?;
417
418 match protocol_msg {
419 ProtocolMessage::HandshakeInit { data, ciphersuite } => {
420 self.handle_handshake_init(source, data, ciphersuite, event_tx)
421 .await?;
422 }
423 ProtocolMessage::CredentialRequest { encrypted } => {
424 self.handle_credential_request(source, encrypted, event_tx)
425 .await?;
426 }
427 _ => {
428 debug!("Received unexpected message type from {:?}", source);
429 }
430 }
431 }
432 IncomingMessage::RendevouzInfo(_) => {
433 }
435 IncomingMessage::IdentityInfo { .. } => {
436 debug!("Received unexpected IdentityInfo message");
438 }
439 }
440 Ok(())
441 }
442
443 async fn handle_handshake_init(
445 &mut self,
446 source: IdentityFingerprint,
447 data: String,
448 ciphersuite: String,
449 event_tx: &mpsc::Sender<UserClientEvent>,
450 ) -> Result<(), RemoteClientError> {
451 debug!("Received handshake init from source: {:?}", source);
452 event_tx.send(UserClientEvent::HandshakeStart {}).await.ok();
453
454 let (transport, fingerprint_str) =
455 self.complete_handshake(source, &data, &ciphersuite).await?;
456
457 event_tx
458 .send(UserClientEvent::HandshakeComplete {})
459 .await
460 .ok();
461
462 let is_new_connection = !self.session_store.has_session(&source);
464 let is_psk_connection = self.psk.is_some();
466
467 if is_new_connection && !is_psk_connection {
468 self.pending_verification = Some(PendingHandshakeVerification { source, transport });
470
471 event_tx
472 .send(UserClientEvent::HandshakeFingerprint {
473 fingerprint: fingerprint_str,
474 })
475 .await
476 .ok();
477 } else if !is_new_connection {
478 self.transports.insert(source, transport.clone());
482 self.session_store.cache_session(source)?;
483 if let Some(name) = self.pending_session_name.take() {
486 self.session_store.set_session_name(&source, name)?;
487 }
488 self.session_store
489 .save_transport_state(&source, transport)?;
490
491 self.audit_log
492 .write(AuditEvent::SessionRefreshed {
493 remote_identity: &source,
494 })
495 .await;
496
497 event_tx
498 .send(UserClientEvent::SessionRefreshed {
499 fingerprint: source,
500 })
501 .await
502 .ok();
503 } else if is_psk_connection {
504 let session_name = self.pending_session_name.take();
506 self.accept_new_connection(
507 source,
508 transport,
509 session_name.as_deref(),
510 AuditConnectionType::Psk,
511 )
512 .await?;
513
514 event_tx
515 .send(UserClientEvent::HandshakeFingerprint {
516 fingerprint: fingerprint_str,
517 })
518 .await
519 .ok();
520 }
521
522 Ok(())
523 }
524
525 async fn accept_new_connection(
527 &mut self,
528 fingerprint: IdentityFingerprint,
529 transport: MultiDeviceTransport,
530 session_name: Option<&str>,
531 connection_type: AuditConnectionType,
532 ) -> Result<(), RemoteClientError> {
533 self.transports.insert(fingerprint, transport.clone());
534 self.session_store.cache_session(fingerprint)?;
535 if let Some(name) = session_name {
536 self.session_store
537 .set_session_name(&fingerprint, name.to_owned())?;
538 }
539 self.session_store
540 .save_transport_state(&fingerprint, transport)?;
541
542 self.audit_log
543 .write(AuditEvent::ConnectionEstablished {
544 remote_identity: &fingerprint,
545 remote_name: session_name,
546 connection_type,
547 })
548 .await;
549
550 Ok(())
551 }
552
553 async fn handle_fingerprint_verification(
555 &mut self,
556 approved: bool,
557 name: Option<String>,
558 event_tx: &mpsc::Sender<UserClientEvent>,
559 ) -> Result<(), RemoteClientError> {
560 let pending = self
561 .pending_verification
562 .take()
563 .ok_or(RemoteClientError::InvalidState {
564 expected: "pending verification".to_string(),
565 current: "no pending verification".to_string(),
566 })?;
567
568 if approved {
569 let session_name = name.or(self.pending_session_name.take());
570 self.accept_new_connection(
571 pending.source,
572 pending.transport,
573 session_name.as_deref(),
574 AuditConnectionType::Rendezvous,
575 )
576 .await?;
577
578 event_tx
579 .send(UserClientEvent::FingerprintVerified {})
580 .await
581 .ok();
582 } else {
583 self.audit_log
584 .write(AuditEvent::ConnectionRejected {
585 remote_identity: &pending.source,
586 })
587 .await;
588
589 event_tx
590 .send(UserClientEvent::FingerprintRejected {
591 reason: "User rejected fingerprint verification".to_string(),
592 })
593 .await
594 .ok();
595 }
596
597 Ok(())
598 }
599
600 async fn handle_credential_request(
602 &mut self,
603 source: IdentityFingerprint,
604 encrypted: String,
605 event_tx: &mpsc::Sender<UserClientEvent>,
606 ) -> Result<(), RemoteClientError> {
607 if !self.transports.contains_key(&source) {
608 debug!("Loading transport state for source: {:?}", source);
609 let session = self
610 .session_store
611 .load_transport_state(&source)?
612 .expect("Transport state should exist for cached session");
613 self.transports.insert(source, session);
614 }
615
616 let transport = self
618 .transports
619 .get_mut(&source)
620 .ok_or(RemoteClientError::SecureChannelNotEstablished)?;
621
622 let encrypted_bytes = STANDARD
624 .decode(&encrypted)
625 .map_err(|e| RemoteClientError::Serialization(format!("Invalid base64: {e}")))?;
626
627 let packet = ap_noise::TransportPacket::decode(&encrypted_bytes)
628 .map_err(|e| RemoteClientError::NoiseProtocol(format!("Invalid packet: {e}")))?;
629
630 let decrypted = transport
631 .decrypt(&packet)
632 .map_err(|e| RemoteClientError::NoiseProtocol(e.to_string()))?;
633
634 let request: CredentialRequestPayload = serde_json::from_slice(&decrypted)?;
635
636 self.audit_log
637 .write(AuditEvent::CredentialRequested {
638 query: &request.query,
639 remote_identity: &source,
640 request_id: &request.request_id,
641 })
642 .await;
643
644 event_tx
646 .send(UserClientEvent::CredentialRequest {
647 query: request.query,
648 request_id: request.request_id.clone(),
649 session_id: format!("{source:?}"),
650 })
651 .await
652 .ok();
653
654 Ok(())
655 }
656
657 async fn handle_response(
659 &mut self,
660 response: UserClientResponse,
661 event_tx: &mpsc::Sender<UserClientEvent>,
662 ) -> Result<(), RemoteClientError> {
663 match response {
664 UserClientResponse::VerifyFingerprint { approved, name } => {
665 self.handle_fingerprint_verification(approved, name, event_tx)
666 .await?;
667 }
668 UserClientResponse::RespondCredential {
669 request_id,
670 session_id,
671 query,
672 approved,
673 credential,
674 credential_id,
675 } => {
676 self.handle_credential_response(
677 request_id,
678 session_id,
679 query,
680 approved,
681 credential,
682 credential_id,
683 event_tx,
684 )
685 .await?;
686 }
687 }
688 Ok(())
689 }
690
691 #[allow(clippy::too_many_arguments)]
693 async fn handle_credential_response(
694 &mut self,
695 request_id: String,
696 session_id: String,
697 query: crate::types::CredentialQuery,
698 approved: bool,
699 credential: Option<CredentialData>,
700 credential_id: Option<String>,
701 event_tx: &mpsc::Sender<UserClientEvent>,
702 ) -> Result<(), RemoteClientError> {
703 let fingerprint = self
705 .transports
706 .keys()
707 .find(|fp| format!("{fp:?}") == session_id)
708 .copied()
709 .ok_or(RemoteClientError::NotInitialized)?;
710
711 let transport = self
712 .transports
713 .get_mut(&fingerprint)
714 .ok_or(RemoteClientError::SecureChannelNotEstablished)?;
715
716 let domain = credential.as_ref().and_then(|c| c.domain.clone());
718 let fields = credential
719 .as_ref()
720 .map_or_else(CredentialFieldSet::default, |c| CredentialFieldSet {
721 has_username: c.username.is_some(),
722 has_password: c.password.is_some(),
723 has_totp: c.totp.is_some(),
724 has_uri: c.uri.is_some(),
725 has_notes: c.notes.is_some(),
726 });
727
728 let response_payload = CredentialResponsePayload {
730 credential: if approved { credential } else { None },
731 error: if !approved {
732 Some("Request denied".to_string())
733 } else {
734 None
735 },
736 request_id: Some(request_id.clone()),
737 };
738
739 let response_json = serde_json::to_string(&response_payload)?;
741 let encrypted = transport
742 .encrypt(response_json.as_bytes())
743 .map_err(|e| RemoteClientError::NoiseProtocol(e.to_string()))?;
744
745 let msg = ProtocolMessage::CredentialResponse {
746 encrypted: STANDARD.encode(encrypted.encode()),
747 };
748
749 let msg_json = serde_json::to_string(&msg)?;
750
751 let proxy_client = self
752 .proxy_client
753 .as_ref()
754 .ok_or(RemoteClientError::NotInitialized)?;
755
756 proxy_client
757 .send_to(fingerprint, msg_json.into_bytes())
758 .await?;
759
760 if approved {
762 self.audit_log
763 .write(AuditEvent::CredentialApproved {
764 query: &query,
765 domain: domain.as_deref(),
766 remote_identity: &fingerprint,
767 request_id: &request_id,
768 credential_id: credential_id.as_deref(),
769 fields,
770 })
771 .await;
772
773 event_tx
774 .send(UserClientEvent::CredentialApproved {
775 domain,
776 credential_id,
777 })
778 .await
779 .ok();
780 } else {
781 self.audit_log
782 .write(AuditEvent::CredentialDenied {
783 query: &query,
784 domain: domain.as_deref(),
785 remote_identity: &fingerprint,
786 request_id: &request_id,
787 credential_id: credential_id.as_deref(),
788 })
789 .await;
790
791 event_tx
792 .send(UserClientEvent::CredentialDenied {
793 domain,
794 credential_id,
795 })
796 .await
797 .ok();
798 }
799
800 Ok(())
801 }
802
803 async fn complete_handshake(
805 &self,
806 remote_fingerprint: IdentityFingerprint,
807 handshake_data: &str,
808 ciphersuite_str: &str,
809 ) -> Result<(MultiDeviceTransport, String), RemoteClientError> {
810 let ciphersuite = match ciphersuite_str {
812 s if s.contains("Kyber768") => Ciphersuite::PQNNpsk2_Kyber768_XChaCha20Poly1305,
813 _ => Ciphersuite::ClassicalNNpsk2_25519_XChaCha20Poly1035,
814 };
815
816 let init_bytes = STANDARD
818 .decode(handshake_data)
819 .map_err(|e| RemoteClientError::Serialization(format!("Invalid base64: {e}")))?;
820
821 let init_packet = ap_noise::HandshakePacket::decode(&init_bytes)
822 .map_err(|e| RemoteClientError::NoiseProtocol(format!("Invalid packet: {e}")))?;
823
824 let mut handshake = if let Some(ref psk) = self.psk {
826 ResponderHandshake::with_psk(psk.clone())
827 } else {
828 ResponderHandshake::new()
829 };
830
831 handshake.receive_start(&init_packet)?;
833 let response_packet = handshake.send_finish()?;
834 let (transport, fingerprint) = handshake.finalize()?;
835
836 let msg = ProtocolMessage::HandshakeResponse {
838 data: STANDARD.encode(response_packet.encode()?),
839 ciphersuite: format!("{ciphersuite:?}"),
840 };
841
842 let msg_json = serde_json::to_string(&msg)?;
843
844 let proxy_client = self
845 .proxy_client
846 .as_ref()
847 .ok_or(RemoteClientError::NotInitialized)?;
848
849 proxy_client
850 .send_to(remote_fingerprint, msg_json.into_bytes())
851 .await?;
852
853 debug!("Sent handshake response to {:?}", remote_fingerprint);
854
855 Ok((transport, fingerprint.to_string()))
856 }
857
858 pub fn rendezvous_code(&self) -> Option<&RendevouzCode> {
860 self.rendezvous_code.as_ref()
861 }
862
863 pub fn set_pending_session_name(&mut self, name: String) {
865 self.pending_session_name = Some(name);
866 }
867}