1use base64;
2use base64::Engine;
3use base64::engine::general_purpose::STANDARD;
4
5use std::time::Duration;
6use std::cell::RefCell;
7use std::collections::HashMap;
8use std::io::{Cursor, Write};
9
10use tokio;
11use tokio::time::sleep;
12use urlencoding::encode;
13
14const METADATA_ENTRY_NAME: &str = "_metadata.json";
15const XML_FILE_EXTENSION: &str = ".xml";
16
17mod certificates;
18mod cryptography;
19pub mod invoice;
20mod models;
21mod utils;
22
23pub struct KsefClient {
24 base_url_parsed: url::Url,
25 base_url: String,
26 sleep_time: u64,
27 public_certificates: RefCell<Option<Vec<models::PemCertificateInfo>>>,
28}
29
30pub struct CompanyInfo {
31 pub ksef_token: String,
32 pub nip: String,
33}
34
35impl KsefClient {
36 pub fn new(base_url: String, sleep_time: u64) -> Result<Self, url::ParseError> {
37 let base_url_parsed = url::Url::parse(&base_url)?;
38 let base_url = base_url_parsed
39 .to_string()
40 .trim_end_matches('/')
41 .to_string();
42 Ok(Self {
43 base_url_parsed,
44 base_url,
45 sleep_time,
46 public_certificates: RefCell::new(None),
47 })
48 }
49
50 fn join_url(&self, path: &str) -> url::Url {
51 self.base_url_parsed.join(path).unwrap()
52 }
53
54 pub async fn get_access_tokens(
55 &self,
56 company_info: &CompanyInfo,
57 ) -> Result<models::TokenPair, &str> {
58 let ksef_token_cert = match certificates::public_certificate(
59 &self,
60 &models::PublicKeyCertificateUsage::KsefTokenEncryption,
61 )
62 .await
63 {
64 Ok(ksef_token_cert) => ksef_token_cert,
65 Err(e) => {
66 return Err(e);
67 }
68 };
69
70 let challenge = match self.get_auth_challenge().await {
71 Ok(challenge) => challenge,
72 Err(_) => {
73 return Err("challenge_error");
74 }
75 };
76
77 let timestamp_ms = challenge.timestamp.timestamp_millis();
78
79 let token_with_timestamp = format!("{}|{}", &company_info.ksef_token, timestamp_ms);
80 let token_bytes: Vec<u8> = token_with_timestamp.as_bytes().to_vec();
81 let encrypted: Vec<u8> = cryptography::encrypt_ksef_token_with_rsa_using_public_key(
82 &ksef_token_cert,
83 &token_bytes,
84 )
85 .unwrap();
86
87 let encrypted_token_b64 = STANDARD.encode(&encrypted);
88
89 let request = models::AuthenticationKsefTokenRequest {
90 challenge: challenge.challenge,
91 context_identifier: models::AuthenticationTokenContextIdentifier {
92 auth_type: models::AuthenticationTokenContextIdentifierType::Nip,
93 value: Some(company_info.nip.clone()),
94 },
95 encrypted_token: encrypted_token_b64,
96 };
97
98 let signature = match self.submit_ksef_token_auth_request(&request).await {
99 Ok(signature) => signature,
100 Err(_) => {
101 return Err("signature_error");
102 }
103 };
104
105 let poll_timeout = Duration::from_secs(2 * 60); let total_millis = poll_timeout.as_millis();
108
109 let status_attempts = std::cmp::max(1, (total_millis / self.sleep_time as u128) as i32);
110
111 for attempt in 1..=status_attempts {
112 match self
113 .get_auth_status(
114 &signature.reference_number,
115 &signature.authentication_token.token,
116 )
117 .await
118 {
119 Ok(auth_status) => {
120 if auth_status.status.code == 200 {
121 break;
122 }
123 }
124 Err(_) => {
125 return Err("auth_status_error");
126 }
127 }
128 if attempt == status_attempts {
129 return Err("Maximum number of attempts exceeded");
130 }
131
132 sleep(Duration::from_millis(self.sleep_time)).await;
133 }
134
135 let tokens = match self
136 .get_access_token_by_authentication_token(&signature.authentication_token.token)
137 .await
138 {
139 Ok(tokens) => tokens,
140 Err(_) => {
141 return Err("token_error");
142 }
143 };
144
145 Ok(tokens)
146 }
147
148 pub async fn refresh_access_token(
149 &self,
150 refresh_token: &String,
151 ) -> Result<models::TokenInfo, &str> {
152 let url = "/v2/auth/token/refresh";
153
154 let reqwest_client = reqwest::Client::new();
155 let resp = reqwest_client
156 .post(self.join_url(url))
157 .bearer_auth(&refresh_token)
158 .send()
159 .await
160 .map_err(|_| "network error")?;
161
162 if resp.status().is_success() {
163 let result = resp
164 .json::<models::RefreshTokenResponse>()
165 .await
166 .map_err(|_| "invalid success response")?;
167 return Ok(result.access_token);
168 }
169
170 Err("server returned error status")
171 }
172
173 pub async fn query_invoice_metadata(
174 &self,
175 request: &invoice::InvoiceQueryFilters,
176 access_token: &String,
177 page_offset: i32,
178 page_size: i32,
179 sort_order: invoice::SortOrder,
180 ) -> Result<invoice::PagedInvoiceResponse, models::ErrorResponse> {
181 let mut url = format!("/v2/invoices/query/metadata?sortOrder={}", sort_order);
182
183 if page_offset > 0 {
184 url = format!("{}&pageOffset={}", url, page_offset);
185 }
186
187 if page_size > 0 {
188 url = format!("{}&pageSize={}", url, page_size);
189 }
190
191
192 let reqwest_client = reqwest::Client::new();
193 let resp = reqwest_client
194 .post(self.join_url(url.as_str()))
195 .bearer_auth(access_token)
196 .header("Content-Type", "application/json")
197 .json(request)
198 .send()
199 .await
200 .map_err(|_| models::ErrorResponse {
201 code: "network_error".into(),
202 message: "Failed to send request".into(),
203 })?;
204
205 let status = resp.status();
206
207 if status.is_success() {
208 let ok = resp
209 .json::<invoice::PagedInvoiceResponse>()
210 .await
211 .map_err(|_| models::ErrorResponse {
212 code: "invalid_response".into(),
213 message: "Failed to parse success response".into(),
214 })?;
215 return Ok(ok);
216 }
217
218 let err = resp
219 .json::<models::ErrorResponse>()
220 .await
221 .unwrap_or_else(|_| models::ErrorResponse {
222 code: "unknown_error".into(),
223 message: format!("Server returned HTTP {}", status),
224 });
225
226 Err(err)
227 }
228
229 async fn get_auth_challenge(
230 &self,
231 ) -> Result<models::AuthenticationChallengeResponse, reqwest::Error> {
232 let url = "/v2/auth/challenge";
233
234 let reqwest_client = reqwest::Client::new();
235 let result = reqwest_client
236 .post(self.join_url(url))
237 .send()
238 .await?
239 .json::<models::AuthenticationChallengeResponse>()
240 .await?;
241 Ok(result)
242 }
243
244 async fn submit_ksef_token_auth_request(
245 &self,
246 request: &models::AuthenticationKsefTokenRequest,
247 ) -> Result<models::SignatureResponse, reqwest::Error> {
248 let url = "/v2/auth/ksef-token";
249
250 let reqwest_client = reqwest::Client::new();
251 let result = reqwest_client
252 .post(self.join_url(url))
253 .json(&request)
254 .send()
255 .await?
256 .json::<models::SignatureResponse>()
257 .await?;
258 Ok(result)
259 }
260
261 async fn get_auth_status(
262 &self,
263 auth_operation_reference_number: &String,
264 authentication_token: &String,
265 ) -> Result<models::AuthStatus, reqwest::Error> {
266 let escaped = encode(auth_operation_reference_number);
267 let url = format!("/v2/auth/{}", escaped);
268
269 let reqwest_client = reqwest::Client::new();
270 let result = reqwest_client
271 .get(self.join_url(url.as_str()))
272 .bearer_auth(&authentication_token)
273 .send()
274 .await?
275 .json::<models::AuthStatus>()
276 .await?;
277 Ok(result)
278 }
279
280 async fn get_access_token_by_authentication_token(
281 &self,
282 authentication_token: &String,
283 ) -> Result<models::TokenPair, reqwest::Error> {
284 let url = "/v2/auth/token/redeem";
285
286 let reqwest_client = reqwest::Client::new();
287 let result = reqwest_client
288 .post(self.join_url(url))
289 .bearer_auth(&authentication_token)
290 .send()
291 .await?
292 .json::<models::TokenPair>()
293 .await?;
294 Ok(result)
295 }
296
297 async fn start_invoices_export(
298 &self,
299 request: &invoice::InvoiceExportRequest,
300 access_token: &String,
301 ) -> Result<invoice::OperationResponse, reqwest::Error> {
302 let url = "/v2/invoices/exports";
303
304 let reqwest_client = reqwest::Client::new();
305 let result = reqwest_client
306 .post(self.join_url(url))
307 .json(&request)
308 .bearer_auth(&access_token)
309 .send()
310 .await?
311 .json::<invoice::OperationResponse>()
312 .await?;
313 Ok(result)
314 }
315
316 async fn get_invoice_export_status_try(
317 &self,
318 reference_number: &String,
319 access_token: &String,
320 ) -> Result<invoice::InvoiceExportStatusResponse, reqwest::Error> {
321 let url = format!("/v2/invoices/exports/{}", encode(reference_number));
322
323 let reqwest_client = reqwest::Client::new();
324 let result = reqwest_client
325 .get(self.join_url(url.as_str()))
326 .bearer_auth(&access_token)
327 .send()
328 .await?
329 .json::<invoice::InvoiceExportStatusResponse>()
330 .await?;
331 Ok(result)
332 }
333
334 async fn get_invoice_export_status(
335 &self,
336 reference_number: &String,
337 access_token: &String,
338 ) -> Result<invoice::InvoiceExportStatusResponse, &'static str> {
339 let poll_timeout = Duration::from_secs(2 * 60);
340
341 let total_millis = poll_timeout.as_millis();
342 let status_attempts = std::cmp::max(1, (total_millis / self.sleep_time as u128) as i32);
343
344 for attempt in 1..=status_attempts {
345 match self
346 .get_invoice_export_status_try(&reference_number, &access_token)
347 .await
348 {
349 Ok(try_status) => {
350 if try_status.status.code == 200 {
351 return Ok(try_status);
352 }
353 }
354 Err(_) => {
355 return Err("try_status_error");
356 }
357 }
358 if attempt == status_attempts {
359 return Err("Maximum number of attempts exceeded");
360 }
361
362 sleep(Duration::from_millis(self.sleep_time)).await;
363 }
364
365 Err("export_error")
366 }
367
368 pub async fn get_invoice_export(
369 &self,
370 filters: &invoice::InvoiceQueryFilters,
371 access_token: &String,
372 ) -> Result<invoice::InvoiceExportResult, models::ErrorResponse> {
373 let encryption = match cryptography::get_encryption_data(&self).await {
374 Ok(encryption) => encryption,
375 Err(e) => return Err(models::ErrorResponse {
376 code: "encryption_error".into(),
377 message: e.into(),
378 }),
379 };
380
381 let invoice_export_request = invoice::InvoiceExportRequest {
382 encryption: encryption.encryption_info.clone(),
383 filters: (*filters).clone(),
384 };
385
386 let start_invoices_export = match self
387 .start_invoices_export(&invoice_export_request, &access_token)
388 .await
389 {
390 Ok(start_invoices_export) => start_invoices_export,
391 Err(e) => return Err(models::ErrorResponse {
392 code: "start_invoices_export_error".into(),
393 message: format!("Status: {}", e),
394 }),
395 };
396
397 let invoice_export_status = match self
398 .get_invoice_export_status(&start_invoices_export.reference_number, &access_token)
399 .await
400 {
401 Ok(export_status) => export_status,
402 Err(e) => return Err(models::ErrorResponse {
403 code: "invoice_export_status_error".into(),
404 message: e.into(),
405 }),
406 };
407
408 let mut metadata_summaries: Vec<invoice::InvoiceSummary> = Vec::new();
409 let mut xml_files: HashMap<String, String> = HashMap::new();
410
411
412 if !invoice_export_status.package.parts.is_empty() {
413 let decrypted_archive_stream = match self
414 .download_package_parts(&invoice_export_status.package.parts, &encryption)
415 .await
416 {
417 Ok(decrypted_archive_stream) => decrypted_archive_stream,
418 Err(e) => return Err(models::ErrorResponse {
419 code: "download_package_parts_error".into(),
420 message: e.into(),
421 }),
422 };
423
424 let unzipped_files = utils::unzip(decrypted_archive_stream);
425
426 for (file_name, content) in unzipped_files {
427 if file_name.eq_ignore_ascii_case(METADATA_ENTRY_NAME) {
428 if let Ok(metadata) =
429 serde_json::from_str::<invoice::InvoicePackageMetadata>(&content)
430 {
431 if let Some(invoices) = metadata.invoices {
432 metadata_summaries.extend(invoices);
433 }
434 }
435 } else if file_name.to_lowercase().ends_with(XML_FILE_EXTENSION) {
436 xml_files.insert(file_name.to_lowercase(), content);
437 }
438 }
439
440 }
441
442 let result = invoice::InvoiceExportResult{
443 metadata_summaries: metadata_summaries,
444 xml_files: xml_files,
445 is_truncated: invoice_export_status.package.is_truncated,
446 last_permanent_storage_date: invoice_export_status.package.last_permanent_storage_date,
447 permanent_storage_hwm_date: invoice_export_status.package.permanent_storage_hwm_date,
448 };
449
450 Ok(result)
451 }
452
453 async fn download_package_parts(
454 &self,
455 parts: &Vec<invoice::InvoiceExportPackagePart>,
456 encryption: &models::EncryptionData,
457 ) -> Result<Cursor<Vec<u8>>, &str> {
458 let mut buffer = Cursor::new(Vec::new());
459
460 let mut parts_sorted: Vec<_> = parts.iter().collect();
461 parts_sorted.sort_by_key(|p| p.ordinal_number);
462
463 for part in parts_sorted {
464 let encrypted_bytes = match self.download_package_part(&part).await {
465 Ok(encrypted_bytes) => encrypted_bytes,
466 Err(e) => return Err(e),
467 };
468
469 let decrypted_bytes = match cryptography::decrypt_bytes_with_aes256(
470 &encrypted_bytes,
471 &encryption.cipher_key,
472 &encryption.cipher_iv,
473 ) {
474 Ok(decrypted_bytes) => decrypted_bytes,
475 Err(_) => return Err("decrypted_bytes_error"),
476 };
477
478 buffer.write_all(&decrypted_bytes).unwrap();
479 }
480
481 buffer.set_position(0);
482 Ok(buffer)
483 }
484
485 async fn download_package_part(
486 &self,
487 part: &invoice::InvoiceExportPackagePart,
488 ) -> Result<Vec<u8>, &str> {
489 let method_str = if part.method.is_empty() {
490 "GET"
491 } else {
492 part.method.as_str()
493 };
494
495 let method = method_str
496 .parse::<reqwest::Method>()
497 .map_err(|e| format!("Invalid HTTP method: {}", e))
498 .unwrap();
499
500 let reqwest_client = reqwest::Client::new();
501 let request = reqwest_client.request(method, &part.url);
502
503 let response = request
504 .send()
505 .await
506 .map_err(|e| format!("Response error: {}", e))
507 .unwrap();
508 let response = response
509 .error_for_status()
510 .map_err(|e| format!("EnsureSuccessStatusCode error: {}", e))
511 .unwrap();
512
513 let bytes = response
514 .bytes()
515 .await
516 .map_err(|e| format!("Get bytes error: {}", e))
517 .unwrap();
518 Ok(bytes.to_vec())
519 }
520}