1use std::collections::HashMap;
10
11use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
12use sha2::{Sha256, Digest};
13use serde::{Deserialize, Serialize};
14use rand::rngs::OsRng;
15use rand::RngCore;
16use num_bigint::BigUint;
17use num_traits::{One, Zero};
18
19use crate::{OpenADPError, Result};
20use crate::client::{
21 EncryptedOpenADPClient, ServerInfo, parse_server_public_key,
22 RegisterSecretRequest, RecoverSecretRequest, ListBackupsRequest,
23};
24use crate::crypto::{
25 H, point_compress, ShamirSecretSharing, point_decompress, unexpand, point_mul,
26 derive_enc_key, recover_point_secret, PointShare, mod_inverse
27};
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Identity {
32 pub uid: String, pub did: String, pub bid: String, }
36
37impl Identity {
38 pub fn new(uid: String, did: String, bid: String) -> Self {
39 Self { uid, did, bid }
40 }
41}
42
43impl std::fmt::Display for Identity {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(f, "UID={}, DID={}, BID={}", self.uid, self.did, self.bid)
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct AuthCodes {
52 pub base_auth_code: String,
53 pub server_auth_codes: HashMap<String, String>,
54}
55
56impl AuthCodes {
57 pub fn get_server_code(&self, server_url: &str) -> Option<&String> {
58 self.server_auth_codes.get(server_url)
59 .or_else(|| Some(&self.base_auth_code))
60 }
61}
62
63#[derive(Debug)]
65pub struct GenerateEncryptionKeyResult {
66 pub encryption_key: Option<Vec<u8>>,
67 pub error: Option<String>,
68 pub server_infos: Option<Vec<ServerInfo>>,
69 pub threshold: Option<usize>,
70 pub auth_codes: Option<AuthCodes>,
71}
72
73impl GenerateEncryptionKeyResult {
74 pub fn success(
75 encryption_key: Vec<u8>,
76 server_infos: Vec<ServerInfo>,
77 threshold: usize,
78 auth_codes: AuthCodes,
79 ) -> Self {
80 Self {
81 encryption_key: Some(encryption_key),
82 error: None,
83 server_infos: Some(server_infos),
84 threshold: Some(threshold),
85 auth_codes: Some(auth_codes),
86 }
87 }
88
89 pub fn error(error: String) -> Self {
90 Self {
91 encryption_key: None,
92 error: Some(error),
93 server_infos: None,
94 threshold: None,
95 auth_codes: None,
96 }
97 }
98}
99
100#[derive(Debug)]
102pub struct RecoverEncryptionKeyResult {
103 pub encryption_key: Option<Vec<u8>>,
104 pub error: Option<String>,
105}
106
107impl RecoverEncryptionKeyResult {
108 pub fn success(encryption_key: Vec<u8>) -> Self {
109 Self {
110 encryption_key: Some(encryption_key),
111 error: None,
112 }
113 }
114
115 pub fn error(error: String) -> Self {
116 Self {
117 encryption_key: None,
118 error: Some(error),
119 }
120 }
121}
122
123pub fn generate_auth_codes(server_urls: &[String]) -> AuthCodes {
125 let mut rng = OsRng;
126
127 let mut base_auth_bytes = [0u8; 32];
129 rng.fill_bytes(&mut base_auth_bytes);
130 let base_auth_code = hex::encode(base_auth_bytes);
131
132 let mut server_auth_codes = HashMap::new();
134 for server_url in server_urls {
135 let mut hasher = Sha256::new();
136 hasher.update(base_auth_code.as_bytes());
137 hasher.update(b":");
138 hasher.update(server_url.as_bytes());
139 let hash = hasher.finalize();
140 let server_code = hex::encode(&hash); server_auth_codes.insert(server_url.clone(), server_code);
142 }
143
144 AuthCodes {
145 base_auth_code,
146 server_auth_codes,
147 }
148}
149
150pub async fn generate_encryption_key(
152 identity: &Identity,
153 password: &str,
154 max_guesses: i32,
155 expiration: i64,
156 server_infos: Vec<ServerInfo>,
157) -> Result<GenerateEncryptionKeyResult> {
158 if server_infos.is_empty() {
159 return Ok(GenerateEncryptionKeyResult::error("No servers available".to_string()));
160 }
161
162 let server_urls: Vec<String> = server_infos.iter().map(|s| s.url.clone()).collect();
164 let auth_codes = generate_auth_codes(&server_urls);
165
166 let mut clients = Vec::new();
168 let mut live_server_infos = Vec::new();
169
170 for server_info in server_infos {
171 let public_key = if !server_info.public_key.is_empty() {
173 match parse_server_public_key(&server_info.public_key) {
174 Ok(key) => Some(key),
175 Err(_) => {
176 None
177 }
178 }
179 } else {
180 None
181 };
182
183 let client = EncryptedOpenADPClient::new(server_info.url.clone(), public_key, 30);
184
185 match client.test_connection().await {
187 Ok(_) => {
188 clients.push(client);
189 live_server_infos.push(server_info);
190 }
191 Err(_) => {
192 }
193 }
194 }
195
196 if live_server_infos.is_empty() {
197 return Ok(GenerateEncryptionKeyResult::error("No live servers available".to_string()));
198 }
199
200 let pin = password.as_bytes().to_vec();
202
203 let mut random_bytes = [0u8; 32];
205 let mut rng = OsRng;
206 rng.fill_bytes(&mut random_bytes);
207
208 let secret_int = rug::Integer::from_digits(&random_bytes, rug::integer::Order::MsfBe);
210 let q = ShamirSecretSharing::get_q();
211 let secret = secret_int % &q;
212
213 let threshold = live_server_infos.len() / 2 + 1;
215 let num_shares = live_server_infos.len(); if num_shares < threshold {
218 return Ok(GenerateEncryptionKeyResult::error(format!(
219 "Need at least {} servers, only {} available", threshold, num_shares
220 )));
221 }
222
223 let shares = ShamirSecretSharing::split_secret(&secret, threshold, num_shares)?;
225
226 let u = H(identity.uid.as_bytes(), identity.did.as_bytes(), identity.bid.as_bytes(), &pin)?;
228 let _u_2d = unexpand(&u)?;
229
230 let mut registration_errors = Vec::new();
233 let mut successful_registrations = 0;
234
235 for (i, mut client) in clients.into_iter().enumerate() {
236 let (share_id, share_data) = &shares[i];
237 let server_url = &live_server_infos[i].url;
238 let server_auth_code = auth_codes.get_server_code(server_url)
239 .ok_or_else(|| OpenADPError::Authentication("Missing server auth code".to_string()))?;
240
241 let y_big_int = rug::Integer::from(share_data);
243
244 let mut y_bytes = vec![0u8; 32];
246 let y_digits = y_big_int.to_digits::<u8>(rug::integer::Order::LsfLe);
247 let copy_len = std::cmp::min(y_digits.len(), 32);
248 y_bytes[..copy_len].copy_from_slice(&y_digits[..copy_len]);
249
250 let y_string = BASE64.encode(&y_bytes);
252
253 let request = RegisterSecretRequest {
254 auth_code: server_auth_code.clone(),
255 uid: identity.uid.clone(),
256 did: identity.did.clone(),
257 bid: identity.bid.clone(),
258 version: 1,
259 x: *share_id as i32,
260 y: y_string,
261 max_guesses,
262 expiration,
263 encrypted: client.has_public_key(),
264 auth_data: None,
265 };
266
267 match client.register_secret_standardized(request).await {
268 Ok(response) => {
269 if response.success {
270 let _enc_status = if client.has_public_key() { "encrypted" } else { "unencrypted" };
271 successful_registrations += 1;
272 } else {
273 let error_msg = format!("Server {} ({}): Registration returned false: {}",
274 i + 1, server_url, response.message);
275 registration_errors.push(error_msg);
276 }
277 }
278 Err(err) => {
279 let error_msg = format!("Server {} ({}): {}", i + 1, server_url, err);
280 registration_errors.push(error_msg);
281 }
282 }
283 }
284
285 if successful_registrations == 0 {
286 return Ok(GenerateEncryptionKeyResult::error(format!(
287 "Failed to register any shares: {:?}", registration_errors
288 )));
289 }
290
291 let secret_biguint = {
294 let bytes = secret.to_digits::<u8>(rug::integer::Order::MsfBe);
295 BigUint::from_bytes_be(&bytes)
296 };
297
298 let secret_point = point_mul(&secret_biguint, &u);
300 let encryption_key = derive_enc_key(&secret_point)?;
301
302
303 Ok(GenerateEncryptionKeyResult::success(
304 encryption_key,
305 live_server_infos,
306 threshold,
307 auth_codes,
308 ))
309}
310
311pub async fn recover_encryption_key(
313 identity: &Identity,
314 password: &str,
315 server_infos: Vec<ServerInfo>,
316 threshold: usize,
317 auth_codes: AuthCodes,
318) -> Result<RecoverEncryptionKeyResult> {
319
320 let pin = password.as_bytes().to_vec();
322 let u = H(identity.uid.as_bytes(), identity.did.as_bytes(), identity.bid.as_bytes(), &pin)?;
323 let _u_2d = unexpand(&u)?;
324
325 let r_scalar = {
327 use rand::RngCore;
328 let mut rng = OsRng;
329
330 let mut r_bytes = [0u8; 32];
332 rng.fill_bytes(&mut r_bytes);
333 let mut r = BigUint::from_bytes_be(&r_bytes);
334
335 let q = crate::crypto::Q.clone();
337 r = r % &q;
338
339 if r.is_zero() {
341 r = BigUint::one();
342 }
343
344 r
345 };
346
347
348 let updated_server_infos = fetch_remaining_guesses_for_servers(identity, &server_infos).await;
350 let selected_server_infos = select_servers_by_remaining_guesses(&updated_server_infos, threshold);
351
352 if selected_server_infos.is_empty() {
353 return Ok(RecoverEncryptionKeyResult::error("No servers available".to_string()));
354 }
355
356 let registry_servers = match crate::client::get_servers("").await {
358 Ok(servers) => servers,
359 Err(_) => {
360 crate::client::get_fallback_server_info()
362 }
363 };
364
365 let mut public_key_map = std::collections::HashMap::new();
367 for registry_server in ®istry_servers {
368 if !registry_server.public_key.is_empty() {
369 public_key_map.insert(registry_server.url.clone(), registry_server.public_key.clone());
370 }
371 }
372
373 let mut selected_server_infos_with_keys = Vec::new();
375 for mut server_info in selected_server_infos {
376 if let Some(public_key_str) = public_key_map.get(&server_info.url) {
377 server_info.public_key = public_key_str.clone();
378 }
379 selected_server_infos_with_keys.push(server_info);
380 }
381 let selected_server_infos = selected_server_infos_with_keys;
382
383 let b = point_mul(&r_scalar, &u);
385 let _b_2d = unexpand(&b)?;
386
387 let b_compressed = point_compress(&b)?;
389 let b_base64 = BASE64.encode(&b_compressed);
390
391 let mut recovered_point_shares = Vec::new();
393
394 for server_info in &selected_server_infos {
395 let public_key = if !server_info.public_key.is_empty() {
397 match parse_server_public_key(&server_info.public_key) {
398 Ok(key) => Some(key),
399 Err(_) => {
400 None
401 }
402 }
403 } else {
404 None
405 };
406
407 let mut client = EncryptedOpenADPClient::new(server_info.url.clone(), public_key, 30);
408
409 let server_auth_code = auth_codes.get_server_code(&server_info.url)
411 .ok_or_else(|| OpenADPError::Authentication("No auth code for server".to_string()))?;
412
413 let mut guess_num = 0; let list_request = ListBackupsRequest {
416 uid: identity.uid.clone(),
417 auth_code: String::new(),
418 encrypted: client.has_public_key(), auth_data: None,
420 };
421
422 match client.list_backups_standardized(list_request).await {
423 Ok(response) => {
424 for backup in &response.backups {
426 if backup.uid == identity.uid &&
427 backup.did == identity.did &&
428 backup.bid == identity.bid {
429 guess_num = backup.num_guesses;
430 break;
431 }
432 }
433 }
434 Err(err) => {
435 return Err(OpenADPError::Server(format!("Cannot get current guess number for idempotency: {}", err)));
436 }
437 }
438
439 let public_key_fresh = if !server_info.public_key.is_empty() {
441 match parse_server_public_key(&server_info.public_key) {
442 Ok(key) => Some(key),
443 Err(_) => None,
444 }
445 } else {
446 None
447 };
448 let mut fresh_client = EncryptedOpenADPClient::new(server_info.url.clone(), public_key_fresh, 30);
449
450 let request = RecoverSecretRequest {
451 auth_code: server_auth_code.clone(),
452 uid: identity.uid.clone(),
453 did: identity.did.clone(),
454 bid: identity.bid.clone(),
455 guess_num: guess_num, b: b_base64.clone(),
457 encrypted: fresh_client.has_public_key(),
458 auth_data: None,
459 };
460
461
462
463 match fresh_client.recover_secret_standardized(request).await {
464 Ok(response) => {
465 if response.success {
466
467 if let Some(si_b) = response.si_b {
469
470 match BASE64.decode(&si_b) {
472 Ok(si_b_bytes) => {
473
474 let si_b_point = point_decompress(&si_b_bytes)?;
476 let _si_b_2d = unexpand(&si_b_point)?;
477
478 recovered_point_shares.push(PointShare::new(response.x as usize, si_b_point));
480 }
481 Err(e) => {
482 return Err(OpenADPError::Crypto(format!("Failed to decompress point: {}", e)));
483 }
484 }
485 } else {
486 return Err(OpenADPError::Server("Server returned success but no si_b".to_string()));
487 }
488 } else {
489 return Err(OpenADPError::Server(format!("Server error: {}", response.message)));
490 }
491 }
492 Err(err) => {
493 return Err(OpenADPError::Server(format!("Cannot recover secret: {}", err)));
494 }
495 }
496 }
497
498 if recovered_point_shares.len() < threshold {
500 return Ok(RecoverEncryptionKeyResult::error(format!(
501 "Not enough shares recovered: got {}, need {}",
502 recovered_point_shares.len(), threshold
503 )));
504 }
505
506
507 let recovered_sb_4d = recover_point_secret(recovered_point_shares)?;
509 let _recovered_sb_2d = unexpand(&recovered_sb_4d)?;
510
511 let r_inv = mod_inverse(&r_scalar, &crate::crypto::Q.clone());
513
514 let original_su = point_mul(&r_inv, &recovered_sb_4d);
516 let _original_su_2d = unexpand(&original_su)?;
517
518 let encryption_key = derive_enc_key(&original_su)?;
520
521
522 Ok(RecoverEncryptionKeyResult::success(encryption_key))
523}
524
525pub async fn fetch_remaining_guesses_for_servers(
527 identity: &Identity,
528 server_infos: &[ServerInfo],
529) -> Vec<ServerInfo> {
530 let mut updated_infos = Vec::new();
531
532 for server_info in server_infos {
533 let mut updated_info = server_info.clone();
534
535 let public_key = if !server_info.public_key.is_empty() {
537 match parse_server_public_key(&server_info.public_key) {
538 Ok(key) => Some(key),
539 Err(_) => {
540 None
541 }
542 }
543 } else {
544 None
545 };
546
547 let mut client = EncryptedOpenADPClient::new(server_info.url.clone(), public_key, 30);
548
549 let request = ListBackupsRequest {
550 uid: identity.uid.clone(),
551 auth_code: String::new(),
552 encrypted: client.has_public_key(), auth_data: None,
554 };
555
556 match client.list_backups_standardized(request).await {
557 Ok(response) => {
558 for backup in &response.backups {
560 if backup.uid == identity.uid &&
561 backup.did == identity.did &&
562 backup.bid == identity.bid {
563 updated_info.remaining_guesses = Some(backup.max_guesses - backup.num_guesses);
564 break;
565 }
566 }
567
568 if updated_info.remaining_guesses.is_none() {
569 updated_info.remaining_guesses = Some(0);
570 }
571 }
572 Err(_) => {
573 updated_info.remaining_guesses = Some(0);
574 }
575 }
576
577 updated_infos.push(updated_info);
578 }
579
580 updated_infos
581}
582
583pub fn select_servers_by_remaining_guesses(
585 server_infos: &[ServerInfo],
586 threshold: usize,
587) -> Vec<ServerInfo> {
588 let mut available_servers: Vec<ServerInfo> = server_infos.iter()
590 .filter(|info| info.remaining_guesses.unwrap_or(-1) != 0)
591 .cloned()
592 .collect();
593
594 if available_servers.is_empty() {
595 return server_infos.to_vec(); }
597
598 available_servers.sort_by(|a, b| {
601 let a_guesses = if a.remaining_guesses.unwrap_or(-1) == -1 {
602 i32::MAX
603 } else {
604 a.remaining_guesses.unwrap_or(0)
605 };
606 let b_guesses = if b.remaining_guesses.unwrap_or(-1) == -1 {
607 i32::MAX
608 } else {
609 b.remaining_guesses.unwrap_or(0)
610 };
611 b_guesses.cmp(&a_guesses)
612 });
613
614 let num_to_select = std::cmp::min(available_servers.len(), threshold + 2);
616 let selected_servers = available_servers.into_iter().take(num_to_select).collect::<Vec<_>>();
617
618 for (_, server) in selected_servers.iter().enumerate() {
619 let _guesses_str = if server.remaining_guesses.unwrap_or(-1) == -1 {
620 "unknown".to_string()
621 } else {
622 server.remaining_guesses.unwrap_or(0).to_string()
623 };
624 }
625
626 selected_servers
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn test_identity() {
635 let identity = Identity::new(
636 "user123".to_string(),
637 "device456".to_string(),
638 "backup789".to_string()
639 );
640
641 assert_eq!(identity.uid, "user123");
642 assert_eq!(identity.did, "device456");
643 assert_eq!(identity.bid, "backup789");
644
645 let display = format!("{}", identity);
646 assert!(display.contains("user123"));
647 assert!(display.contains("device456"));
648 assert!(display.contains("backup789"));
649 }
650
651 #[test]
652 fn test_generate_auth_codes() {
653 let servers = vec![
654 "https://server1.example.com".to_string(),
655 "https://server2.example.com".to_string(),
656 ];
657
658 let auth_codes = generate_auth_codes(&servers);
659
660 assert!(!auth_codes.base_auth_code.is_empty());
661 assert_eq!(auth_codes.server_auth_codes.len(), 2);
662
663 let code1 = auth_codes.get_server_code("https://server1.example.com");
664 let code2 = auth_codes.get_server_code("https://server2.example.com");
665 let code_unknown = auth_codes.get_server_code("https://unknown.example.com");
666
667 assert!(code1.is_some());
668 assert!(code2.is_some());
669 assert_eq!(code_unknown, Some(&auth_codes.base_auth_code));
670 assert_ne!(code1, code2); }
672
673 #[test]
674 fn test_auth_codes_structure() {
675 let servers = vec!["https://test.com".to_string()];
676 let auth_codes = generate_auth_codes(&servers);
677
678 assert_eq!(auth_codes.base_auth_code.len(), 64);
680 assert!(auth_codes.base_auth_code.chars().all(|c| c.is_ascii_hexdigit()));
681
682 for (_, code) in &auth_codes.server_auth_codes {
684 assert_eq!(code.len(), 64);
685 assert!(code.chars().all(|c| c.is_ascii_hexdigit()));
686 }
687 }
688
689 #[test]
690 fn test_result_structures() {
691 let key = vec![1, 2, 3, 4];
693 let servers = vec![];
694 let auth_codes = AuthCodes {
695 base_auth_code: "test".to_string(),
696 server_auth_codes: HashMap::new(),
697 };
698
699 let success = GenerateEncryptionKeyResult::success(key.clone(), servers, 2, auth_codes);
700 assert!(success.encryption_key.is_some());
701 assert!(success.error.is_none());
702 assert_eq!(success.encryption_key.unwrap(), key);
703
704 let error = GenerateEncryptionKeyResult::error("test error".to_string());
706 assert!(error.encryption_key.is_none());
707 assert!(error.error.is_some());
708 assert_eq!(error.error.unwrap(), "test error");
709 }
710
711 #[test]
712 fn test_identity_validation() {
713 let identity = Identity::new("user".to_string(), "device".to_string(), "backup".to_string());
715 assert_eq!(identity.uid, "user");
716
717 let special_identity = Identity::new(
719 "user@domain.com".to_string(),
720 "device-123".to_string(),
721 "file://path/to/backup".to_string()
722 );
723 assert!(special_identity.uid.contains("@"));
724 assert!(special_identity.did.contains("-"));
725 assert!(special_identity.bid.contains("://"));
726
727 let display = format!("{}", special_identity);
729 assert!(display.contains("UID="));
730 assert!(display.contains("DID="));
731 assert!(display.contains("BID="));
732 }
733
734 #[test]
735 fn test_auth_codes_comprehensive() {
736 let servers = vec![
737 "https://server1.com".to_string(),
738 "https://server2.com".to_string(),
739 "https://server3.com".to_string(),
740 ];
741
742 let auth_codes1 = generate_auth_codes(&servers);
743 let auth_codes2 = generate_auth_codes(&servers);
744
745 assert_ne!(auth_codes1.base_auth_code, auth_codes2.base_auth_code);
747
748 assert_eq!(auth_codes1.server_auth_codes.len(), auth_codes2.server_auth_codes.len());
750
751 for server in &servers {
753 let code1 = auth_codes1.get_server_code(server);
754 let code2 = auth_codes1.get_server_code(server);
755 assert_eq!(code1, code2); assert!(code1.is_some());
757 }
758
759 let unknown_code = auth_codes1.get_server_code("https://unknown.com");
761 assert_eq!(unknown_code, Some(&auth_codes1.base_auth_code));
762 }
763
764 #[test]
765 fn test_encryption_key_derivation() {
766 let identity = Identity::new(
768 "test-user".to_string(),
769 "test-device".to_string(),
770 "test-backup".to_string()
771 );
772
773 let password = "test-password";
774 let pin = password.as_bytes().to_vec();
775
776 let result = H(
778 identity.uid.as_bytes(),
779 identity.did.as_bytes(),
780 identity.bid.as_bytes(),
781 &pin
782 );
783
784 assert!(result.is_ok(), "H function should succeed with valid inputs");
785
786 let point = result.unwrap();
787
788 let key_result = derive_enc_key(&point);
790 assert!(key_result.is_ok(), "Key derivation should succeed");
791
792 let key = key_result.unwrap();
793 assert_eq!(key.len(), 32, "Encryption key should be 32 bytes");
794
795 let point2 = H(
797 identity.uid.as_bytes(),
798 identity.did.as_bytes(),
799 identity.bid.as_bytes(),
800 &pin
801 ).unwrap();
802 let key2 = derive_enc_key(&point2).unwrap();
803 assert_eq!(key, key2, "Same inputs should produce same key");
804 }
805
806 #[test]
807 fn test_input_validation_edge_cases() {
808 let empty_identity = Identity::new("".to_string(), "".to_string(), "".to_string());
810 assert_eq!(empty_identity.uid, "");
811
812 let long_string = "x".repeat(1000);
814 let long_identity = Identity::new(long_string.clone(), long_string.clone(), long_string.clone());
815 assert_eq!(long_identity.uid.len(), 1000);
816
817 let display = format!("{}", empty_identity);
819 assert!(display.contains("UID="));
820 assert!(display.contains("DID="));
821 assert!(display.contains("BID="));
822 }
823}