Skip to main content

ksef_client/
lib.rs

1use base64;
2use base64::Engine;
3use base64::engine::general_purpose::STANDARD;
4
5use std::time::Duration;
6use std::cell::RefCell;
7use std::collections::HashMap;
8use std::io::{Cursor, Write};
9
10use tokio;
11use tokio::time::sleep;
12use urlencoding::encode;
13
14const METADATA_ENTRY_NAME: &str = "_metadata.json";
15const XML_FILE_EXTENSION: &str = ".xml";
16
17mod certificates;
18mod cryptography;
19pub mod invoice;
20mod models;
21mod utils;
22
23pub struct KsefClient {
24    base_url_parsed: url::Url,
25    base_url: String,
26    sleep_time: u64,
27    public_certificates: RefCell<Option<Vec<models::PemCertificateInfo>>>,
28}
29
30pub struct CompanyInfo {
31    pub ksef_token: String,
32    pub nip: String,
33}
34
35impl KsefClient {
36    pub fn new(base_url: String, sleep_time: u64) -> Result<Self, url::ParseError> {
37        let base_url_parsed = url::Url::parse(&base_url)?;
38        let base_url = base_url_parsed
39            .to_string()
40            .trim_end_matches('/')
41            .to_string();
42        Ok(Self {
43            base_url_parsed,
44            base_url,
45            sleep_time,
46            public_certificates: RefCell::new(None),
47        })
48    }
49
50    fn join_url(&self, path: &str) -> url::Url {
51        self.base_url_parsed.join(path).unwrap()
52    }
53
54    pub async fn get_access_tokens(
55        &self,
56        company_info: &CompanyInfo,
57    ) -> Result<models::TokenPair, &str> {
58        let ksef_token_cert = match certificates::public_certificate(
59            &self,
60            &models::PublicKeyCertificateUsage::KsefTokenEncryption,
61        )
62        .await
63        {
64            Ok(ksef_token_cert) => ksef_token_cert,
65            Err(e) => {
66                return Err(e);
67            }
68        };
69
70        let challenge = match self.get_auth_challenge().await {
71            Ok(challenge) => challenge,
72            Err(_) => {
73                return Err("challenge_error");
74            }
75        };
76
77        let timestamp_ms = challenge.timestamp.timestamp_millis();
78
79        let token_with_timestamp = format!("{}|{}", &company_info.ksef_token, timestamp_ms);
80        let token_bytes: Vec<u8> = token_with_timestamp.as_bytes().to_vec();
81        let encrypted: Vec<u8> = cryptography::encrypt_ksef_token_with_rsa_using_public_key(
82            &ksef_token_cert,
83            &token_bytes,
84        )
85        .unwrap();
86
87        let encrypted_token_b64 = STANDARD.encode(&encrypted);
88
89        let request = models::AuthenticationKsefTokenRequest {
90            challenge: challenge.challenge,
91            context_identifier: models::AuthenticationTokenContextIdentifier {
92                auth_type: models::AuthenticationTokenContextIdentifierType::Nip,
93                value: Some(company_info.nip.clone()),
94            },
95            encrypted_token: encrypted_token_b64,
96        };
97
98        let signature = match self.submit_ksef_token_auth_request(&request).await {
99            Ok(signature) => signature,
100            Err(_) => {
101                return Err("signature_error");
102            }
103        };
104
105        let poll_timeout = Duration::from_secs(2 * 60); // 2 minuty
106
107        let total_millis = poll_timeout.as_millis();
108
109        let status_attempts = std::cmp::max(1, (total_millis / self.sleep_time as u128) as i32);
110
111        for attempt in 1..=status_attempts {
112            match self
113                .get_auth_status(
114                    &signature.reference_number,
115                    &signature.authentication_token.token,
116                )
117                .await
118            {
119                Ok(auth_status) => {
120                    if auth_status.status.code == 200 {
121                        break;
122                    }
123                }
124                Err(_) => {
125                    return Err("auth_status_error");
126                }
127            }
128            if attempt == status_attempts {
129                return Err("Maximum number of attempts exceeded");
130            }
131
132            sleep(Duration::from_millis(self.sleep_time)).await;
133        }
134
135        let tokens = match self
136            .get_access_token_by_authentication_token(&signature.authentication_token.token)
137            .await
138        {
139            Ok(tokens) => tokens,
140            Err(_) => {
141                return Err("token_error");
142            }
143        };
144
145        Ok(tokens)
146    }
147
148    pub async fn refresh_access_token(
149        &self,
150        refresh_token: &String,
151    ) -> Result<models::TokenInfo, &str> {
152        let url = "/v2/auth/token/refresh";
153
154        let reqwest_client = reqwest::Client::new();
155        let resp = reqwest_client
156            .post(self.join_url(url))
157            .bearer_auth(&refresh_token)
158            .send()
159            .await
160            .map_err(|_| "network error")?;
161
162        if resp.status().is_success() {
163            let result = resp
164                .json::<models::RefreshTokenResponse>()
165                .await
166                .map_err(|_| "invalid success response")?;
167            return Ok(result.access_token);
168        }
169
170        Err("server returned error status")
171    }
172
173    pub async fn query_invoice_metadata(
174        &self,
175        request: &invoice::InvoiceQueryFilters,
176        access_token: &String,
177        page_offset: i32,
178        page_size: i32,
179        sort_order: invoice::SortOrder,
180    ) -> Result<invoice::PagedInvoiceResponse, models::ErrorResponse> {
181        let mut url = format!("/v2/invoices/query/metadata?sortOrder={}", sort_order);
182
183        if page_offset > 0 {
184            url = format!("{}&pageOffset={}", url, page_offset);
185        }
186
187        if page_size > 0 {
188            url = format!("{}&pageSize={}", url, page_size);
189        }
190
191
192        let reqwest_client = reqwest::Client::new();
193        let resp = reqwest_client
194            .post(self.join_url(url.as_str()))
195            .bearer_auth(access_token)
196            .header("Content-Type", "application/json")
197            .json(request)
198            .send()
199            .await
200            .map_err(|_| models::ErrorResponse {
201                code: "network_error".into(),
202                message: "Failed to send request".into(),
203            })?;
204
205        let status = resp.status();
206
207        if status.is_success() {
208            let ok = resp
209                .json::<invoice::PagedInvoiceResponse>()
210                .await
211                .map_err(|_| models::ErrorResponse {
212                    code: "invalid_response".into(),
213                    message: "Failed to parse success response".into(),
214                })?;
215            return Ok(ok);
216        }
217
218        let err = resp
219            .json::<models::ErrorResponse>()
220            .await
221            .unwrap_or_else(|_| models::ErrorResponse {
222                code: "unknown_error".into(),
223                message: format!("Server returned HTTP {}", status),
224            });
225
226        Err(err)
227    }
228
229    async fn get_auth_challenge(
230        &self,
231    ) -> Result<models::AuthenticationChallengeResponse, reqwest::Error> {
232        let url = "/v2/auth/challenge";
233
234        let reqwest_client = reqwest::Client::new();
235        let result = reqwest_client
236            .post(self.join_url(url))
237            .send()
238            .await?
239            .json::<models::AuthenticationChallengeResponse>()
240            .await?;
241        Ok(result)
242    }
243
244    async fn submit_ksef_token_auth_request(
245        &self,
246        request: &models::AuthenticationKsefTokenRequest,
247    ) -> Result<models::SignatureResponse, reqwest::Error> {
248        let url = "/v2/auth/ksef-token";
249
250        let reqwest_client = reqwest::Client::new();
251        let result = reqwest_client
252            .post(self.join_url(url))
253            .json(&request)
254            .send()
255            .await?
256            .json::<models::SignatureResponse>()
257            .await?;
258        Ok(result)
259    }
260
261    async fn get_auth_status(
262        &self,
263        auth_operation_reference_number: &String,
264        authentication_token: &String,
265    ) -> Result<models::AuthStatus, reqwest::Error> {
266        let escaped = encode(auth_operation_reference_number);
267        let url = format!("/v2/auth/{}", escaped);
268
269        let reqwest_client = reqwest::Client::new();
270        let result = reqwest_client
271            .get(self.join_url(url.as_str()))
272            .bearer_auth(&authentication_token)
273            .send()
274            .await?
275            .json::<models::AuthStatus>()
276            .await?;
277        Ok(result)
278    }
279
280    async fn get_access_token_by_authentication_token(
281        &self,
282        authentication_token: &String,
283    ) -> Result<models::TokenPair, reqwest::Error> {
284        let url = "/v2/auth/token/redeem";
285
286        let reqwest_client = reqwest::Client::new();
287        let result = reqwest_client
288            .post(self.join_url(url))
289            .bearer_auth(&authentication_token)
290            .send()
291            .await?
292            .json::<models::TokenPair>()
293            .await?;
294        Ok(result)
295    }
296
297    async fn start_invoices_export(
298        &self,
299        request: &invoice::InvoiceExportRequest,
300        access_token: &String,
301    ) -> Result<invoice::OperationResponse, reqwest::Error> {
302        let url = "/v2/invoices/exports";
303
304        let reqwest_client = reqwest::Client::new();
305        let result = reqwest_client
306            .post(self.join_url(url))
307            .json(&request)
308            .bearer_auth(&access_token)
309            .send()
310            .await?
311            .json::<invoice::OperationResponse>()
312            .await?;
313        Ok(result)
314    }
315
316    async fn get_invoice_export_status_try(
317        &self,
318        reference_number: &String,
319        access_token: &String,
320    ) -> Result<invoice::InvoiceExportStatusResponse, reqwest::Error> {
321        let url = format!("/v2/invoices/exports/{}", encode(reference_number));
322
323        let reqwest_client = reqwest::Client::new();
324        let result = reqwest_client
325            .get(self.join_url(url.as_str()))
326            .bearer_auth(&access_token)
327            .send()
328            .await?
329            .json::<invoice::InvoiceExportStatusResponse>()
330            .await?;
331        Ok(result)
332    }
333
334    async fn get_invoice_export_status(
335        &self,
336        reference_number: &String,
337        access_token: &String,
338    ) -> Result<invoice::InvoiceExportStatusResponse, &'static str> {
339        let poll_timeout = Duration::from_secs(2 * 60);
340
341        let total_millis = poll_timeout.as_millis();
342        let status_attempts = std::cmp::max(1, (total_millis / self.sleep_time as u128) as i32);
343
344        for attempt in 1..=status_attempts {
345            match self
346                .get_invoice_export_status_try(&reference_number, &access_token)
347                .await
348            {
349                Ok(try_status) => {
350                    if try_status.status.code == 200 {
351                        return Ok(try_status);
352                    }
353                }
354                Err(_) => {
355                    return Err("try_status_error");
356                }
357            }
358            if attempt == status_attempts {
359                return Err("Maximum number of attempts exceeded");
360            }
361
362            sleep(Duration::from_millis(self.sleep_time)).await;
363        }
364
365        Err("export_error")
366    }
367
368    pub async fn get_invoice_export(
369        &self,
370        filters: &invoice::InvoiceQueryFilters,
371        access_token: &String,
372    ) -> Result<invoice::InvoiceExportResult, models::ErrorResponse> {
373        let encryption = match cryptography::get_encryption_data(&self).await {
374            Ok(encryption) => encryption,
375            Err(e) => return Err(models::ErrorResponse {
376                code: "encryption_error".into(),
377                message: e.into(),
378            }),
379        };
380
381        let invoice_export_request = invoice::InvoiceExportRequest {
382            encryption: encryption.encryption_info.clone(),
383            filters: (*filters).clone(),
384        };
385
386        let start_invoices_export = match self
387            .start_invoices_export(&invoice_export_request, &access_token)
388            .await
389        {
390            Ok(start_invoices_export) => start_invoices_export,
391            Err(e) => return Err(models::ErrorResponse {
392                code: "start_invoices_export_error".into(),
393                message: format!("Status: {}", e),
394            }),
395        };
396
397        let invoice_export_status = match self
398            .get_invoice_export_status(&start_invoices_export.reference_number, &access_token)
399            .await
400        {
401            Ok(export_status) => export_status,
402            Err(e) => return Err(models::ErrorResponse {
403                code: "invoice_export_status_error".into(),
404                message: e.into(),
405            }),
406        };
407
408        let mut metadata_summaries: Vec<invoice::InvoiceSummary> = Vec::new();
409        let mut xml_files: HashMap<String, String> = HashMap::new();
410
411
412        if !invoice_export_status.package.parts.is_empty() {
413            let decrypted_archive_stream = match self
414                .download_package_parts(&invoice_export_status.package.parts, &encryption)
415                .await
416            {
417                Ok(decrypted_archive_stream) => decrypted_archive_stream,
418                Err(e) => return Err(models::ErrorResponse {
419                    code: "download_package_parts_error".into(),
420                    message: e.into(),
421                }),
422            };
423
424            let unzipped_files = utils::unzip(decrypted_archive_stream);
425
426            for (file_name, content) in unzipped_files {
427                if file_name.eq_ignore_ascii_case(METADATA_ENTRY_NAME) {
428                    if let Ok(metadata) =
429                        serde_json::from_str::<invoice::InvoicePackageMetadata>(&content)
430                    {
431                        if let Some(invoices) = metadata.invoices {
432                            metadata_summaries.extend(invoices);
433                        }
434                    }
435                } else if file_name.to_lowercase().ends_with(XML_FILE_EXTENSION) {
436                    xml_files.insert(file_name.to_lowercase(), content);
437                }
438            }
439
440        }
441
442        let result = invoice::InvoiceExportResult{
443            metadata_summaries: metadata_summaries,
444            xml_files: xml_files,
445            is_truncated: invoice_export_status.package.is_truncated,
446            last_permanent_storage_date: invoice_export_status.package.last_permanent_storage_date,
447            permanent_storage_hwm_date: invoice_export_status.package.permanent_storage_hwm_date,
448        };
449        
450        Ok(result)
451    }
452
453    async fn download_package_parts(
454        &self,
455        parts: &Vec<invoice::InvoiceExportPackagePart>,
456        encryption: &models::EncryptionData,
457    ) -> Result<Cursor<Vec<u8>>, &str> {
458        let mut buffer = Cursor::new(Vec::new());
459
460        let mut parts_sorted: Vec<_> = parts.iter().collect();
461        parts_sorted.sort_by_key(|p| p.ordinal_number);
462
463        for part in parts_sorted {
464            let encrypted_bytes = match self.download_package_part(&part).await {
465                Ok(encrypted_bytes) => encrypted_bytes,
466                Err(e) => return Err(e),
467            };
468
469            let decrypted_bytes = match cryptography::decrypt_bytes_with_aes256(
470                &encrypted_bytes,
471                &encryption.cipher_key,
472                &encryption.cipher_iv,
473            ) {
474                Ok(decrypted_bytes) => decrypted_bytes,
475                Err(_) => return Err("decrypted_bytes_error"),
476            };
477
478            buffer.write_all(&decrypted_bytes).unwrap();
479        }
480
481        buffer.set_position(0);
482        Ok(buffer)
483    }
484
485    async fn download_package_part(
486        &self,
487        part: &invoice::InvoiceExportPackagePart,
488    ) -> Result<Vec<u8>, &str> {
489        let method_str = if part.method.is_empty() {
490            "GET"
491        } else {
492            part.method.as_str()
493        };
494
495        let method = method_str
496            .parse::<reqwest::Method>()
497            .map_err(|e| format!("Invalid HTTP method: {}", e))
498            .unwrap();
499
500        let reqwest_client = reqwest::Client::new();
501        let request = reqwest_client.request(method, &part.url);
502
503        let response = request
504            .send()
505            .await
506            .map_err(|e| format!("Response error: {}", e))
507            .unwrap();
508        let response = response
509            .error_for_status()
510            .map_err(|e| format!("EnsureSuccessStatusCode error: {}", e))
511            .unwrap();
512
513        let bytes = response
514            .bytes()
515            .await
516            .map_err(|e| format!("Get bytes error: {}", e))
517            .unwrap();
518        Ok(bytes.to_vec())
519    }
520}