1use std::collections::HashMap;
10
11use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
12use sha2::{Sha256, Digest};
13use serde::{Deserialize, Serialize};
14use rand::rngs::OsRng;
15use rand::RngCore;
16use num_bigint::BigUint;
17use num_traits::{One, Zero};
18
19use crate::{OpenADPError, Result};
20use crate::client::{
21 EncryptedOpenADPClient, ServerInfo, parse_server_public_key,
22 RegisterSecretRequest, RecoverSecretRequest, ListBackupsRequest,
23};
24use crate::crypto::{
25 H, point_compress, ShamirSecretSharing, point_decompress, unexpand, point_mul,
26 derive_enc_key, recover_point_secret, PointShare, mod_inverse
27};
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Identity {
32 pub uid: String, pub did: String, pub bid: String, }
36
37impl Identity {
38 pub fn new(uid: String, did: String, bid: String) -> Self {
39 Self { uid, did, bid }
40 }
41}
42
43impl std::fmt::Display for Identity {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(f, "UID={}, DID={}, BID={}", self.uid, self.did, self.bid)
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct AuthCodes {
52 pub base_auth_code: String,
53 pub server_auth_codes: HashMap<String, String>,
54}
55
56impl AuthCodes {
57 pub fn get_server_code(&self, server_url: &str) -> Option<&String> {
58 self.server_auth_codes.get(server_url)
59 .or_else(|| Some(&self.base_auth_code))
60 }
61}
62
63#[derive(Debug)]
65pub struct GenerateEncryptionKeyResult {
66 pub encryption_key: Option<Vec<u8>>,
67 pub error: Option<String>,
68 pub server_infos: Option<Vec<ServerInfo>>,
69 pub threshold: Option<usize>,
70 pub auth_codes: Option<AuthCodes>,
71}
72
73impl GenerateEncryptionKeyResult {
74 pub fn success(
75 encryption_key: Vec<u8>,
76 server_infos: Vec<ServerInfo>,
77 threshold: usize,
78 auth_codes: AuthCodes,
79 ) -> Self {
80 Self {
81 encryption_key: Some(encryption_key),
82 error: None,
83 server_infos: Some(server_infos),
84 threshold: Some(threshold),
85 auth_codes: Some(auth_codes),
86 }
87 }
88
89 pub fn error(error: String) -> Self {
90 Self {
91 encryption_key: None,
92 error: Some(error),
93 server_infos: None,
94 threshold: None,
95 auth_codes: None,
96 }
97 }
98}
99
100#[derive(Debug)]
102pub struct RecoverEncryptionKeyResult {
103 pub encryption_key: Option<Vec<u8>>,
104 pub error: Option<String>,
105 pub num_guesses: i32, pub max_guesses: i32, }
108
109impl RecoverEncryptionKeyResult {
110 pub fn success(encryption_key: Vec<u8>, num_guesses: i32, max_guesses: i32) -> Self {
111 Self {
112 encryption_key: Some(encryption_key),
113 error: None,
114 num_guesses,
115 max_guesses,
116 }
117 }
118
119 pub fn error(error: String) -> Self {
120 Self {
121 encryption_key: None,
122 error: Some(error),
123 num_guesses: 0,
124 max_guesses: 0,
125 }
126 }
127}
128
129pub fn generate_auth_codes(server_urls: &[String]) -> AuthCodes {
131 let mut rng = OsRng;
132
133 let mut base_auth_bytes = [0u8; 32];
135 rng.fill_bytes(&mut base_auth_bytes);
136 let base_auth_code = hex::encode(base_auth_bytes);
137
138 let mut server_auth_codes = HashMap::new();
140 for server_url in server_urls {
141 let mut hasher = Sha256::new();
142 hasher.update(base_auth_code.as_bytes());
143 hasher.update(b":");
144 hasher.update(server_url.as_bytes());
145 let hash = hasher.finalize();
146 let server_code = hex::encode(&hash); server_auth_codes.insert(server_url.clone(), server_code);
148 }
149
150 AuthCodes {
151 base_auth_code,
152 server_auth_codes,
153 }
154}
155
156pub async fn generate_encryption_key(
158 identity: &Identity,
159 password: &str,
160 max_guesses: i32,
161 expiration: i64,
162 server_infos: Vec<ServerInfo>,
163) -> Result<GenerateEncryptionKeyResult> {
164 if server_infos.is_empty() {
165 return Ok(GenerateEncryptionKeyResult::error("No servers available".to_string()));
166 }
167
168 let server_urls: Vec<String> = server_infos.iter().map(|s| s.url.clone()).collect();
170 let auth_codes = generate_auth_codes(&server_urls);
171
172 let mut clients = Vec::new();
174 let mut live_server_infos = Vec::new();
175
176 for server_info in server_infos {
177 let public_key = if !server_info.public_key.is_empty() {
179 match parse_server_public_key(&server_info.public_key) {
180 Ok(key) => Some(key),
181 Err(_) => {
182 None
183 }
184 }
185 } else {
186 None
187 };
188
189 let client = EncryptedOpenADPClient::new(server_info.url.clone(), public_key, 30);
190
191 match client.test_connection().await {
193 Ok(_) => {
194 clients.push(client);
195 live_server_infos.push(server_info);
196 }
197 Err(_) => {
198 }
199 }
200 }
201
202 if live_server_infos.is_empty() {
203 return Ok(GenerateEncryptionKeyResult::error("No live servers available".to_string()));
204 }
205
206 let pin = password.as_bytes().to_vec();
208
209 let mut random_bytes = [0u8; 32];
211 let mut rng = OsRng;
212 rng.fill_bytes(&mut random_bytes);
213
214 let secret_int = rug::Integer::from_digits(&random_bytes, rug::integer::Order::MsfBe);
216 let q = ShamirSecretSharing::get_q();
217 let secret = secret_int % &q;
218
219 let threshold = live_server_infos.len() / 2 + 1;
221 let num_shares = live_server_infos.len(); if num_shares < threshold {
224 return Ok(GenerateEncryptionKeyResult::error(format!(
225 "Need at least {} servers, only {} available", threshold, num_shares
226 )));
227 }
228
229 let shares = ShamirSecretSharing::split_secret(&secret, threshold, num_shares)?;
231
232 let u = H(identity.uid.as_bytes(), identity.did.as_bytes(), identity.bid.as_bytes(), &pin)?;
234 let _u_2d = unexpand(&u)?;
235
236 let mut registration_errors = Vec::new();
239 let mut successful_registrations = 0;
240
241 for (i, mut client) in clients.into_iter().enumerate() {
242 let (share_id, share_data) = &shares[i];
243 let server_url = &live_server_infos[i].url;
244 let server_auth_code = auth_codes.get_server_code(server_url)
245 .ok_or_else(|| OpenADPError::Authentication("Missing server auth code".to_string()))?;
246
247 let y_big_int = rug::Integer::from(share_data);
249
250 let mut y_bytes = vec![0u8; 32];
252 let y_digits = y_big_int.to_digits::<u8>(rug::integer::Order::LsfLe);
253 let copy_len = std::cmp::min(y_digits.len(), 32);
254 y_bytes[..copy_len].copy_from_slice(&y_digits[..copy_len]);
255
256 let y_string = BASE64.encode(&y_bytes);
258
259 let request = RegisterSecretRequest {
260 auth_code: server_auth_code.clone(),
261 uid: identity.uid.clone(),
262 did: identity.did.clone(),
263 bid: identity.bid.clone(),
264 version: 1,
265 x: *share_id as i32,
266 y: y_string,
267 max_guesses,
268 expiration,
269 encrypted: client.has_public_key(),
270 auth_data: None,
271 };
272
273 match client.register_secret_standardized(request).await {
274 Ok(response) => {
275 if response.success {
276 let _enc_status = if client.has_public_key() { "encrypted" } else { "unencrypted" };
277 successful_registrations += 1;
278 } else {
279 let error_msg = format!("Server {} ({}): Registration returned false: {}",
280 i + 1, server_url, response.message);
281 registration_errors.push(error_msg);
282 }
283 }
284 Err(err) => {
285 let error_msg = format!("Server {} ({}): {}", i + 1, server_url, err);
286 registration_errors.push(error_msg);
287 }
288 }
289 }
290
291 if successful_registrations == 0 {
292 return Ok(GenerateEncryptionKeyResult::error(format!(
293 "Failed to register any shares: {:?}", registration_errors
294 )));
295 }
296
297 let secret_biguint = {
300 let bytes = secret.to_digits::<u8>(rug::integer::Order::MsfBe);
301 BigUint::from_bytes_be(&bytes)
302 };
303
304 let secret_point = point_mul(&secret_biguint, &u);
306 let encryption_key = derive_enc_key(&secret_point)?;
307
308
309 Ok(GenerateEncryptionKeyResult::success(
310 encryption_key,
311 live_server_infos,
312 threshold,
313 auth_codes,
314 ))
315}
316
317pub async fn recover_encryption_key(
319 identity: &Identity,
320 password: &str,
321 server_infos: Vec<ServerInfo>,
322 threshold: usize,
323 auth_codes: AuthCodes,
324) -> Result<RecoverEncryptionKeyResult> {
325
326 let pin = password.as_bytes().to_vec();
328 let u = H(identity.uid.as_bytes(), identity.did.as_bytes(), identity.bid.as_bytes(), &pin)?;
329 let _u_2d = unexpand(&u)?;
330
331 let r_scalar = {
333 use rand::RngCore;
334 let mut rng = OsRng;
335
336 let mut r_bytes = [0u8; 32];
338 rng.fill_bytes(&mut r_bytes);
339 let mut r = BigUint::from_bytes_be(&r_bytes);
340
341 let q = crate::crypto::Q.clone();
343 r = r % &q;
344
345 if r.is_zero() {
347 r = BigUint::one();
348 }
349
350 r
351 };
352
353
354 let updated_server_infos = fetch_remaining_guesses_for_servers(identity, &server_infos).await;
356 let selected_server_infos = select_servers_by_remaining_guesses(&updated_server_infos, threshold);
357
358 if selected_server_infos.is_empty() {
359 return Ok(RecoverEncryptionKeyResult::error("No servers available".to_string()));
360 }
361
362 let registry_servers = match crate::client::get_servers("").await {
364 Ok(servers) => servers,
365 Err(_) => {
366 crate::client::get_fallback_server_info()
368 }
369 };
370
371 let mut public_key_map = std::collections::HashMap::new();
373 for registry_server in ®istry_servers {
374 if !registry_server.public_key.is_empty() {
375 public_key_map.insert(registry_server.url.clone(), registry_server.public_key.clone());
376 }
377 }
378
379 let mut selected_server_infos_with_keys = Vec::new();
381 for mut server_info in selected_server_infos {
382 if let Some(public_key_str) = public_key_map.get(&server_info.url) {
383 server_info.public_key = public_key_str.clone();
384 }
385 selected_server_infos_with_keys.push(server_info);
386 }
387 let selected_server_infos = selected_server_infos_with_keys;
388
389 let b = point_mul(&r_scalar, &u);
391 let _b_2d = unexpand(&b)?;
392
393 let b_compressed = point_compress(&b)?;
395 let b_base64 = BASE64.encode(&b_compressed);
396
397 let mut recovered_point_shares = Vec::new();
399 let mut actual_num_guesses = 0i32;
400 let mut actual_max_guesses = 0i32;
401
402 for server_info in &selected_server_infos {
403 let public_key = if !server_info.public_key.is_empty() {
405 match parse_server_public_key(&server_info.public_key) {
406 Ok(key) => Some(key),
407 Err(_) => {
408 None
409 }
410 }
411 } else {
412 None
413 };
414
415 let mut client = EncryptedOpenADPClient::new(server_info.url.clone(), public_key, 30);
416
417 let server_auth_code = auth_codes.get_server_code(&server_info.url)
419 .ok_or_else(|| OpenADPError::Authentication("No auth code for server".to_string()))?;
420
421 let mut guess_num = 0; let list_request = ListBackupsRequest {
424 uid: identity.uid.clone(),
425 auth_code: String::new(),
426 encrypted: client.has_public_key(), auth_data: None,
428 };
429
430 match client.list_backups_standardized(list_request).await {
431 Ok(response) => {
432 for backup in &response.backups {
434 if backup.uid == identity.uid &&
435 backup.did == identity.did &&
436 backup.bid == identity.bid {
437 guess_num = backup.num_guesses;
438 break;
439 }
440 }
441 }
442 Err(err) => {
443 return Err(OpenADPError::Server(format!("Cannot get current guess number for idempotency: {}", err)));
444 }
445 }
446
447 let public_key_fresh = if !server_info.public_key.is_empty() {
449 match parse_server_public_key(&server_info.public_key) {
450 Ok(key) => Some(key),
451 Err(_) => None,
452 }
453 } else {
454 None
455 };
456 let mut fresh_client = EncryptedOpenADPClient::new(server_info.url.clone(), public_key_fresh, 30);
457
458 let request = RecoverSecretRequest {
459 auth_code: server_auth_code.clone(),
460 uid: identity.uid.clone(),
461 did: identity.did.clone(),
462 bid: identity.bid.clone(),
463 guess_num: guess_num, b: b_base64.clone(),
465 encrypted: fresh_client.has_public_key(),
466 auth_data: None,
467 };
468
469
470
471 match fresh_client.recover_secret_standardized(request).await {
472 Ok(response) => {
473 if response.success {
474
475 if actual_num_guesses == 0 && actual_max_guesses == 0 {
477 actual_num_guesses = response.num_guesses;
478 actual_max_guesses = response.max_guesses;
479 }
480
481 if let Some(si_b) = response.si_b {
483
484 match BASE64.decode(&si_b) {
486 Ok(si_b_bytes) => {
487
488 let si_b_point = point_decompress(&si_b_bytes)?;
490 let _si_b_2d = unexpand(&si_b_point)?;
491
492 recovered_point_shares.push(PointShare::new(response.x as usize, si_b_point));
494 }
495 Err(e) => {
496 return Err(OpenADPError::Crypto(format!("Failed to decompress point: {}", e)));
497 }
498 }
499 } else {
500 return Err(OpenADPError::Server("Server returned success but no si_b".to_string()));
501 }
502 } else {
503 return Err(OpenADPError::Server(format!("Server error: {}", response.message)));
504 }
505 }
506 Err(err) => {
507 return Err(OpenADPError::Server(format!("Cannot recover secret: {}", err)));
508 }
509 }
510 }
511
512 if recovered_point_shares.len() < threshold {
514 return Ok(RecoverEncryptionKeyResult::error(format!(
515 "Not enough shares recovered: got {}, need {}",
516 recovered_point_shares.len(), threshold
517 )));
518 }
519
520
521 let recovered_sb_4d = recover_point_secret(recovered_point_shares)?;
523 let _recovered_sb_2d = unexpand(&recovered_sb_4d)?;
524
525 let r_inv = mod_inverse(&r_scalar, &crate::crypto::Q.clone());
527
528 let original_su = point_mul(&r_inv, &recovered_sb_4d);
530 let _original_su_2d = unexpand(&original_su)?;
531
532 let encryption_key = derive_enc_key(&original_su)?;
534
535
536 Ok(RecoverEncryptionKeyResult::success(encryption_key, actual_num_guesses, actual_max_guesses))
537}
538
539pub async fn fetch_remaining_guesses_for_servers(
541 identity: &Identity,
542 server_infos: &[ServerInfo],
543) -> Vec<ServerInfo> {
544 let mut updated_infos = Vec::new();
545
546 for server_info in server_infos {
547 let mut updated_info = server_info.clone();
548
549 let public_key = if !server_info.public_key.is_empty() {
551 match parse_server_public_key(&server_info.public_key) {
552 Ok(key) => Some(key),
553 Err(_) => {
554 None
555 }
556 }
557 } else {
558 None
559 };
560
561 let mut client = EncryptedOpenADPClient::new(server_info.url.clone(), public_key, 30);
562
563 let request = ListBackupsRequest {
564 uid: identity.uid.clone(),
565 auth_code: String::new(),
566 encrypted: client.has_public_key(), auth_data: None,
568 };
569
570 match client.list_backups_standardized(request).await {
571 Ok(response) => {
572 for backup in &response.backups {
574 if backup.uid == identity.uid &&
575 backup.did == identity.did &&
576 backup.bid == identity.bid {
577 updated_info.remaining_guesses = Some(backup.max_guesses - backup.num_guesses);
578 break;
579 }
580 }
581
582 if updated_info.remaining_guesses.is_none() {
583 updated_info.remaining_guesses = Some(0);
584 }
585 }
586 Err(_) => {
587 updated_info.remaining_guesses = Some(0);
588 }
589 }
590
591 updated_infos.push(updated_info);
592 }
593
594 updated_infos
595}
596
597pub fn select_servers_by_remaining_guesses(
599 server_infos: &[ServerInfo],
600 threshold: usize,
601) -> Vec<ServerInfo> {
602 let mut available_servers: Vec<ServerInfo> = server_infos.iter()
604 .filter(|info| info.remaining_guesses.unwrap_or(-1) != 0)
605 .cloned()
606 .collect();
607
608 if available_servers.is_empty() {
609 return server_infos.to_vec(); }
611
612 available_servers.sort_by(|a, b| {
615 let a_guesses = if a.remaining_guesses.unwrap_or(-1) == -1 {
616 i32::MAX
617 } else {
618 a.remaining_guesses.unwrap_or(0)
619 };
620 let b_guesses = if b.remaining_guesses.unwrap_or(-1) == -1 {
621 i32::MAX
622 } else {
623 b.remaining_guesses.unwrap_or(0)
624 };
625 b_guesses.cmp(&a_guesses)
626 });
627
628 let num_to_select = std::cmp::min(available_servers.len(), threshold + 2);
630 let selected_servers = available_servers.into_iter().take(num_to_select).collect::<Vec<_>>();
631
632 for (_, server) in selected_servers.iter().enumerate() {
633 let _guesses_str = if server.remaining_guesses.unwrap_or(-1) == -1 {
634 "unknown".to_string()
635 } else {
636 server.remaining_guesses.unwrap_or(0).to_string()
637 };
638 }
639
640 selected_servers
641}
642
643#[cfg(test)]
644mod tests {
645 use super::*;
646
647 #[test]
648 fn test_identity() {
649 let identity = Identity::new(
650 "user123".to_string(),
651 "device456".to_string(),
652 "backup789".to_string()
653 );
654
655 assert_eq!(identity.uid, "user123");
656 assert_eq!(identity.did, "device456");
657 assert_eq!(identity.bid, "backup789");
658
659 let display = format!("{}", identity);
660 assert!(display.contains("user123"));
661 assert!(display.contains("device456"));
662 assert!(display.contains("backup789"));
663 }
664
665 #[test]
666 fn test_generate_auth_codes() {
667 let servers = vec![
668 "https://server1.example.com".to_string(),
669 "https://server2.example.com".to_string(),
670 ];
671
672 let auth_codes = generate_auth_codes(&servers);
673
674 assert!(!auth_codes.base_auth_code.is_empty());
675 assert_eq!(auth_codes.server_auth_codes.len(), 2);
676
677 let code1 = auth_codes.get_server_code("https://server1.example.com");
678 let code2 = auth_codes.get_server_code("https://server2.example.com");
679 let code_unknown = auth_codes.get_server_code("https://unknown.example.com");
680
681 assert!(code1.is_some());
682 assert!(code2.is_some());
683 assert_eq!(code_unknown, Some(&auth_codes.base_auth_code));
684 assert_ne!(code1, code2); }
686
687 #[test]
688 fn test_auth_codes_structure() {
689 let servers = vec!["https://test.com".to_string()];
690 let auth_codes = generate_auth_codes(&servers);
691
692 assert_eq!(auth_codes.base_auth_code.len(), 64);
694 assert!(auth_codes.base_auth_code.chars().all(|c| c.is_ascii_hexdigit()));
695
696 for (_, code) in &auth_codes.server_auth_codes {
698 assert_eq!(code.len(), 64);
699 assert!(code.chars().all(|c| c.is_ascii_hexdigit()));
700 }
701 }
702
703 #[test]
704 fn test_result_structures() {
705 let key = vec![1, 2, 3, 4];
707 let servers = vec![];
708 let auth_codes = AuthCodes {
709 base_auth_code: "test".to_string(),
710 server_auth_codes: HashMap::new(),
711 };
712
713 let success = GenerateEncryptionKeyResult::success(key.clone(), servers, 2, auth_codes);
714 assert!(success.encryption_key.is_some());
715 assert!(success.error.is_none());
716 assert_eq!(success.encryption_key.unwrap(), key);
717
718 let error = GenerateEncryptionKeyResult::error("test error".to_string());
720 assert!(error.encryption_key.is_none());
721 assert!(error.error.is_some());
722 assert_eq!(error.error.unwrap(), "test error");
723 }
724
725 #[test]
726 fn test_identity_validation() {
727 let identity = Identity::new("user".to_string(), "device".to_string(), "backup".to_string());
729 assert_eq!(identity.uid, "user");
730
731 let special_identity = Identity::new(
733 "user@domain.com".to_string(),
734 "device-123".to_string(),
735 "file://path/to/backup".to_string()
736 );
737 assert!(special_identity.uid.contains("@"));
738 assert!(special_identity.did.contains("-"));
739 assert!(special_identity.bid.contains("://"));
740
741 let display = format!("{}", special_identity);
743 assert!(display.contains("UID="));
744 assert!(display.contains("DID="));
745 assert!(display.contains("BID="));
746 }
747
748 #[test]
749 fn test_auth_codes_comprehensive() {
750 let servers = vec![
751 "https://server1.com".to_string(),
752 "https://server2.com".to_string(),
753 "https://server3.com".to_string(),
754 ];
755
756 let auth_codes1 = generate_auth_codes(&servers);
757 let auth_codes2 = generate_auth_codes(&servers);
758
759 assert_ne!(auth_codes1.base_auth_code, auth_codes2.base_auth_code);
761
762 assert_eq!(auth_codes1.server_auth_codes.len(), auth_codes2.server_auth_codes.len());
764
765 for server in &servers {
767 let code1 = auth_codes1.get_server_code(server);
768 let code2 = auth_codes1.get_server_code(server);
769 assert_eq!(code1, code2); assert!(code1.is_some());
771 }
772
773 let unknown_code = auth_codes1.get_server_code("https://unknown.com");
775 assert_eq!(unknown_code, Some(&auth_codes1.base_auth_code));
776 }
777
778 #[test]
779 fn test_encryption_key_derivation() {
780 let identity = Identity::new(
782 "test-user".to_string(),
783 "test-device".to_string(),
784 "test-backup".to_string()
785 );
786
787 let password = "test-password";
788 let pin = password.as_bytes().to_vec();
789
790 let result = H(
792 identity.uid.as_bytes(),
793 identity.did.as_bytes(),
794 identity.bid.as_bytes(),
795 &pin
796 );
797
798 assert!(result.is_ok(), "H function should succeed with valid inputs");
799
800 let point = result.unwrap();
801
802 let key_result = derive_enc_key(&point);
804 assert!(key_result.is_ok(), "Key derivation should succeed");
805
806 let key = key_result.unwrap();
807 assert_eq!(key.len(), 32, "Encryption key should be 32 bytes");
808
809 let point2 = H(
811 identity.uid.as_bytes(),
812 identity.did.as_bytes(),
813 identity.bid.as_bytes(),
814 &pin
815 ).unwrap();
816 let key2 = derive_enc_key(&point2).unwrap();
817 assert_eq!(key, key2, "Same inputs should produce same key");
818 }
819
820 #[test]
821 fn test_input_validation_edge_cases() {
822 let empty_identity = Identity::new("".to_string(), "".to_string(), "".to_string());
824 assert_eq!(empty_identity.uid, "");
825
826 let long_string = "x".repeat(1000);
828 let long_identity = Identity::new(long_string.clone(), long_string.clone(), long_string.clone());
829 assert_eq!(long_identity.uid.len(), 1000);
830
831 let display = format!("{}", empty_identity);
833 assert!(display.contains("UID="));
834 assert!(display.contains("DID="));
835 assert!(display.contains("BID="));
836 }
837}