Skip to main content

invoice_parser/
models.rs

1use chrono::NaiveDate;
2use serde::{Deserialize, Serialize};
3
4mod decimal_format {
5    use serde::{self, Deserialize, Deserializer, Serializer};
6
7    pub fn serialize<S>(value: &f64, serializer: S) -> Result<S::Ok, S::Error>
8    where
9        S: Serializer,
10    {
11        let abs = value.abs();
12        if abs == 0.0 {
13            serializer.serialize_f64(0.0)
14        } else if abs < 0.0001 {
15            serializer.serialize_str(
16                format!("{:.10}", value)
17                    .trim_end_matches('0')
18                    .trim_end_matches('.'),
19            )
20        } else if abs < 1.0 {
21            serializer.serialize_str(
22                format!("{:.8}", value)
23                    .trim_end_matches('0')
24                    .trim_end_matches('.'),
25            )
26        } else {
27            serializer.serialize_f64(*value)
28        }
29    }
30
31    pub fn deserialize<'de, D>(deserializer: D) -> Result<f64, D::Error>
32    where
33        D: Deserializer<'de>,
34    {
35        #[derive(Deserialize)]
36        #[serde(untagged)]
37        enum StringOrFloat {
38            String(String),
39            Float(f64),
40        }
41
42        match StringOrFloat::deserialize(deserializer)? {
43            StringOrFloat::String(s) => s.parse().map_err(serde::de::Error::custom),
44            StringOrFloat::Float(f) => Ok(f),
45        }
46    }
47}
48
49mod decimal_format_option {
50    use serde::{self, Deserialize, Deserializer, Serializer};
51
52    pub fn serialize<S>(value: &Option<f64>, serializer: S) -> Result<S::Ok, S::Error>
53    where
54        S: Serializer,
55    {
56        match value {
57            Some(v) => {
58                let abs = v.abs();
59                if abs == 0.0 {
60                    serializer.serialize_f64(0.0)
61                } else if abs < 0.0001 {
62                    serializer.serialize_str(
63                        format!("{:.10}", v)
64                            .trim_end_matches('0')
65                            .trim_end_matches('.'),
66                    )
67                } else if abs < 1.0 {
68                    serializer.serialize_str(
69                        format!("{:.8}", v)
70                            .trim_end_matches('0')
71                            .trim_end_matches('.'),
72                    )
73                } else {
74                    serializer.serialize_f64(*v)
75                }
76            }
77            None => serializer.serialize_none(),
78        }
79    }
80
81    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<f64>, D::Error>
82    where
83        D: Deserializer<'de>,
84    {
85        #[derive(Deserialize)]
86        #[serde(untagged)]
87        enum StringOrFloat {
88            String(String),
89            Float(f64),
90        }
91
92        let opt: Option<StringOrFloat> = Option::deserialize(deserializer)?;
93        match opt {
94            Some(StringOrFloat::String(s)) => s.parse().map(Some).map_err(serde::de::Error::custom),
95            Some(StringOrFloat::Float(f)) => Ok(Some(f)),
96            None => Ok(None),
97        }
98    }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
102pub enum DocumentFormat {
103    #[default]
104    Unknown,
105    AwsDirect,
106    ECloudValleyAws,
107    MicrofusionAliyun,
108    AliyunDirect,
109    UCloud,
110    GoogleCloud,
111    Azure,
112}
113
114impl std::fmt::Display for DocumentFormat {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        match self {
117            DocumentFormat::Unknown => write!(f, "Unknown"),
118            DocumentFormat::AwsDirect => write!(f, "AWS Direct"),
119            DocumentFormat::ECloudValleyAws => write!(f, "eCloudValley AWS"),
120            DocumentFormat::MicrofusionAliyun => write!(f, "Microfusion Aliyun"),
121            DocumentFormat::AliyunDirect => write!(f, "Alibaba Cloud Direct"),
122            DocumentFormat::UCloud => write!(f, "UCloud"),
123            DocumentFormat::GoogleCloud => write!(f, "Google Cloud"),
124            DocumentFormat::Azure => write!(f, "Microsoft Azure"),
125        }
126    }
127}
128
129/// 货币类型枚举
130#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
131pub enum Currency {
132    /// 美元
133    #[default]
134    USD,
135    /// 欧元
136    EUR,
137    /// 英镑
138    GBP,
139    /// 日元
140    JPY,
141    /// 人民币
142    CNY,
143    /// 港币
144    HKD,
145    /// 新加坡元
146    SGD,
147    /// 澳元
148    AUD,
149    /// 加元
150    CAD,
151    /// 瑞士法郎
152    CHF,
153    /// 其他货币
154    Other(String),
155}
156
157impl From<&str> for Currency {
158    fn from(s: &str) -> Self {
159        match s.to_uppercase().as_str() {
160            "USD" | "$" | "US$" => Currency::USD,
161            "EUR" | "€" => Currency::EUR,
162            "GBP" | "£" => Currency::GBP,
163            "JPY" | "¥" | "YEN" => Currency::JPY,
164            "CNY" | "RMB" | "元" => Currency::CNY,
165            "HKD" | "HK$" => Currency::HKD,
166            "SGD" | "S$" => Currency::SGD,
167            "AUD" | "A$" => Currency::AUD,
168            "CAD" | "C$" => Currency::CAD,
169            "CHF" => Currency::CHF,
170            other => Currency::Other(other.to_string()),
171        }
172    }
173}
174
175/// 发票类型枚举
176#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
177pub enum InvoiceType {
178    /// 标准发票
179    #[default]
180    Standard,
181    /// 贷项通知单(退款/折让凭证)
182    CreditNote,
183    /// 借项通知单(补收款凭证)
184    DebitNote,
185    /// 形式发票(报价/预开发票)
186    ProformaInvoice,
187    /// 商业发票(国际贸易)
188    CommercialInvoice,
189    /// 收据
190    Receipt,
191    /// 账单
192    Bill,
193    /// 对账单(如 AWS 账单)
194    Statement,
195    /// 未知类型
196    Unknown,
197}
198
199/// 地址信息结构体
200#[derive(Debug, Clone, Serialize, Deserialize, Default)]
201pub struct Address {
202    /// 地址第一行(街道地址)
203    pub line1: Option<String>,
204    /// 地址第二行(门牌号、楼层等)
205    pub line2: Option<String>,
206    /// 城市
207    pub city: Option<String>,
208    /// 州/省
209    pub state: Option<String>,
210    /// 邮政编码
211    pub postal_code: Option<String>,
212    /// 国家
213    pub country: Option<String>,
214}
215
216impl Address {
217    /// 返回完整地址字符串,各部分用逗号分隔
218    pub fn full_address(&self) -> String {
219        [
220            self.line1.as_deref(),
221            self.line2.as_deref(),
222            self.city.as_deref(),
223            self.state.as_deref(),
224            self.postal_code.as_deref(),
225            self.country.as_deref(),
226        ]
227        .iter()
228        .filter_map(|&s| s)
229        .collect::<Vec<_>>()
230        .join(", ")
231    }
232}
233
234/// 交易方信息结构体(供应商或客户)
235#[derive(Debug, Clone, Serialize, Deserialize, Default)]
236pub struct Party {
237    /// 公司/个人名称
238    pub name: Option<String>,
239    /// 税务识别号(如统一编号、VAT号等)
240    pub tax_id: Option<String>,
241    /// 地址信息
242    pub address: Option<Address>,
243    /// 电子邮件
244    pub email: Option<String>,
245    /// 电话号码
246    pub phone: Option<String>,
247}
248
249/// 发票行项目结构体(单个商品/服务明细)
250#[derive(Debug, Clone, Serialize, Deserialize, Default)]
251pub struct LineItem {
252    /// 行号
253    pub line_number: Option<u32>,
254    /// 服务/项目名称(如:Amazon CloudFront)
255    pub service_name: Option<String>,
256    /// 使用类型/描述(如:US-Requests-Tier2-HTTPS)
257    pub description: String,
258    /// 数量
259    #[serde(with = "decimal_format_option")]
260    pub quantity: Option<f64>,
261    /// 单位(如:个、件、小时等)
262    pub unit: Option<String>,
263    /// 单价
264    #[serde(with = "decimal_format_option")]
265    pub unit_price: Option<f64>,
266    /// 折扣金额
267    #[serde(with = "decimal_format_option")]
268    pub discount: Option<f64>,
269    /// 税率(百分比)
270    #[serde(with = "decimal_format_option")]
271    pub tax_rate: Option<f64>,
272    /// 税额
273    #[serde(with = "decimal_format_option")]
274    pub tax_amount: Option<f64>,
275    /// 金额(数量 × 单价 - 折扣 + 税额)
276    #[serde(with = "decimal_format")]
277    pub amount: f64,
278}
279
280impl LineItem {
281    /// 计算行项目金额:数量 × 单价 - 折扣 + 税额
282    pub fn calculate_amount(&self) -> f64 {
283        let qty = self.quantity.unwrap_or(1.0);
284        let price = self.unit_price.unwrap_or(self.amount);
285        let discount = self.discount.unwrap_or(0.0);
286        let tax = self.tax_amount.unwrap_or(0.0);
287
288        (qty * price) - discount + tax
289    }
290
291    /// 验证单价×数量是否等于金额
292    /// 返回 (是否有效, 计算值, 差异值)
293    pub fn validate_amount(&self) -> LineItemValidation {
294        let (qty, price) = match (self.quantity, self.unit_price) {
295            (Some(q), Some(p)) => (q, p),
296            _ => {
297                return LineItemValidation {
298                    is_valid: true,
299                    can_validate: false,
300                    calculated_amount: None,
301                    difference: None,
302                    difference_percent: None,
303                }
304            }
305        };
306
307        let calculated = qty * price;
308        let diff = (self.amount - calculated).abs();
309        let diff_percent = if calculated.abs() > 0.0001 {
310            (diff / calculated.abs()) * 100.0
311        } else {
312            0.0
313        };
314
315        LineItemValidation {
316            is_valid: diff < 0.01 || diff_percent < 1.0,
317            can_validate: true,
318            calculated_amount: Some(calculated),
319            difference: Some(self.amount - calculated),
320            difference_percent: Some(diff_percent),
321        }
322    }
323}
324
325/// 行项目验证结果
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct LineItemValidation {
328    /// 是否有效(差异在允许范围内)
329    pub is_valid: bool,
330    /// 是否可以验证(有单价和数量)
331    pub can_validate: bool,
332    /// 计算金额(单价×数量)
333    pub calculated_amount: Option<f64>,
334    /// 差异值(实际金额 - 计算金额)
335    pub difference: Option<f64>,
336    /// 差异百分比
337    pub difference_percent: Option<f64>,
338}
339
340/// 税务汇总结构体
341#[derive(Debug, Clone, Serialize, Deserialize, Default)]
342pub struct TaxSummary {
343    /// 税种(如:增值税、营业税、GST等)
344    pub tax_type: Option<String>,
345    /// 税率(百分比)
346    pub tax_rate: Option<f64>,
347    /// 应税金额(税基)
348    pub taxable_amount: Option<f64>,
349    /// 税额
350    pub tax_amount: f64,
351}
352
353/// 付款信息结构体
354#[derive(Debug, Clone, Serialize, Deserialize, Default)]
355pub struct PaymentInfo {
356    /// 付款方式(如:银行转账、信用卡、支票等)
357    pub method: Option<String>,
358    /// 银行名称
359    pub bank_name: Option<String>,
360    /// 银行账号
361    pub account_number: Option<String>,
362    /// 银行路由号(美国银行系统)
363    pub routing_number: Option<String>,
364    /// 国际银行账号(IBAN)
365    pub iban: Option<String>,
366    /// SWIFT/BIC 代码(国际汇款)
367    pub swift_code: Option<String>,
368    /// 付款参考号/备注
369    pub reference: Option<String>,
370}
371
372/// 发票主结构体
373///
374/// 包含发票的所有核心信息,支持多种发票格式的解析结果存储
375#[derive(Debug, Clone, Serialize, Deserialize, Default)]
376pub struct Invoice {
377    pub document_format: DocumentFormat,
378    pub invoice_type: InvoiceType,
379    /// 发票号码
380    pub invoice_number: Option<String>,
381    /// 账户名称/项目别名(如 AWS Account No 后括号内的名称)
382    pub account_name: Option<String>,
383    /// 客户ID
384    pub customer_id: Option<String>,
385    /// 账单年月(格式:YYYY-MM,如 2025-01)
386    pub billing_period: Option<String>,
387    /// 发票日期
388    pub invoice_date: Option<NaiveDate>,
389    /// 付款截止日期
390    pub due_date: Option<NaiveDate>,
391    /// 货币类型
392    pub currency: Currency,
393    /// 供应商/卖方信息
394    pub vendor: Party,
395    /// 客户/买方信息
396    pub customer: Party,
397    /// 行项目列表(商品/服务明细)
398    pub line_items: Vec<LineItem>,
399    /// 小计金额(税前)
400    pub subtotal: Option<f64>,
401    /// 折扣金额(负数表示折扣)
402    pub discount_amount: Option<f64>,
403    /// 折扣比例(百分比,如 5.0 表示 5%)
404    pub discount_rate: Option<f64>,
405    /// 税务汇总列表
406    pub tax_summaries: Vec<TaxSummary>,
407    /// 总税额
408    pub total_tax: Option<f64>,
409    /// 总金额(含税)
410    pub total_amount: f64,
411    /// 已付金额
412    pub amount_paid: Option<f64>,
413    /// 应付金额(未付余额)
414    pub amount_due: Option<f64>,
415    /// 付款信息
416    pub payment_info: Option<PaymentInfo>,
417    /// 备注/附注
418    pub notes: Option<String>,
419    /// 原始文本(PDF提取的原文)
420    pub raw_text: Option<String>,
421    /// 元数据(扩展字段,如账单周期等)
422    pub metadata: std::collections::HashMap<String, String>,
423}
424
425impl Invoice {
426    /// 创建新的空发票实例
427    pub fn new() -> Self {
428        Self::default()
429    }
430
431    /// 序列化为格式化的 JSON 字符串
432    pub fn to_json(&self) -> Result<String, serde_json::Error> {
433        serde_json::to_string_pretty(self)
434    }
435
436    /// 序列化为紧凑的 JSON 字符串
437    pub fn to_json_compact(&self) -> Result<String, serde_json::Error> {
438        serde_json::to_string(self)
439    }
440
441    /// 从 JSON 字符串反序列化
442    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
443        serde_json::from_str(json)
444    }
445
446    /// 根据行项目计算小计金额
447    pub fn calculate_subtotal(&self) -> f64 {
448        self.line_items.iter().map(|item| item.amount).sum()
449    }
450
451    /// 根据税务汇总计算总税额
452    pub fn calculate_total_tax(&self) -> f64 {
453        self.tax_summaries.iter().map(|t| t.tax_amount).sum()
454    }
455
456    /// 判断发票是否已付清
457    pub fn is_paid(&self) -> bool {
458        match (self.amount_paid, self.amount_due) {
459            (Some(paid), _) if paid >= self.total_amount => true,
460            (_, Some(due)) if due <= 0.0 => true,
461            _ => false,
462        }
463    }
464
465    pub fn validate_line_items(&self) -> InvoiceValidation {
466        let mut validations = Vec::new();
467        let mut invalid_count = 0;
468        let mut validatable_count = 0;
469
470        for (idx, item) in self.line_items.iter().enumerate() {
471            let validation = item.validate_amount();
472            if validation.can_validate {
473                validatable_count += 1;
474                if !validation.is_valid {
475                    invalid_count += 1;
476                }
477            }
478            validations.push((idx, item.description.clone(), validation));
479        }
480
481        let line_items_sum = self.calculate_subtotal();
482        let subtotal_diff = self.subtotal.map(|s| {
483            let diff = s - line_items_sum;
484            let pct = if s.abs() > 0.0001 {
485                (diff.abs() / s.abs()) * 100.0
486            } else {
487                0.0
488            };
489            (diff, pct)
490        });
491
492        InvoiceValidation {
493            all_valid: invalid_count == 0 && Self::is_subtotal_valid(subtotal_diff),
494            subtotal_valid: Self::is_subtotal_valid(subtotal_diff),
495            total_items: self.line_items.len(),
496            validatable_items: validatable_count,
497            invalid_items: invalid_count,
498            line_items_sum,
499            subtotal: self.subtotal,
500            subtotal_difference: subtotal_diff.map(|(d, _)| d),
501            subtotal_difference_percent: subtotal_diff.map(|(_, p)| p),
502            item_validations: validations,
503        }
504    }
505
506    fn is_subtotal_valid(subtotal_diff: Option<(f64, f64)>) -> bool {
507        const MAX_ALLOWED_DIFFERENCE_PERCENT: f64 = 1.0;
508        match subtotal_diff {
509            Some((_, pct)) => pct <= MAX_ALLOWED_DIFFERENCE_PERCENT,
510            None => true,
511        }
512    }
513}
514
515#[derive(Debug, Clone)]
516pub struct InvoiceValidation {
517    pub all_valid: bool,
518    pub subtotal_valid: bool,
519    pub total_items: usize,
520    pub validatable_items: usize,
521    pub invalid_items: usize,
522    pub line_items_sum: f64,
523    pub subtotal: Option<f64>,
524    pub subtotal_difference: Option<f64>,
525    pub subtotal_difference_percent: Option<f64>,
526    pub item_validations: Vec<(usize, String, LineItemValidation)>,
527}
528
529impl InvoiceValidation {
530    pub fn print_report(&self) {
531        println!("=== Invoice Validation Report ===");
532        println!("Total items: {}", self.total_items);
533        println!("Validatable items: {}", self.validatable_items);
534        println!("Invalid items: {}", self.invalid_items);
535        println!("Line items sum: {:.2}", self.line_items_sum);
536
537        if let Some(subtotal) = self.subtotal {
538            println!("Invoice subtotal: {:.2}", subtotal);
539            if let (Some(diff), Some(pct)) =
540                (self.subtotal_difference, self.subtotal_difference_percent)
541            {
542                println!("Difference: {:.2} ({:.2}%)", diff, pct);
543                if !self.subtotal_valid {
544                    println!("ERROR: Line items sum does not match subtotal (>1% difference)");
545                }
546            }
547        }
548
549        let invalid_items: Vec<_> = self
550            .item_validations
551            .iter()
552            .filter(|(_, _, v)| v.can_validate && !v.is_valid)
553            .collect();
554
555        if !invalid_items.is_empty() {
556            println!("\nInvalid line items:");
557            for (idx, desc, v) in invalid_items {
558                println!(
559                    "  Line {}: {} | calculated: {:.4} | diff: {:.4}",
560                    idx + 1,
561                    desc,
562                    v.calculated_amount.unwrap_or(0.0),
563                    v.difference.unwrap_or(0.0)
564                );
565            }
566        }
567    }
568}
569
570/// 解析结果结构体
571///
572/// 封装解析操作的输出,支持单个或多个发票,并包含警告信息
573#[derive(Debug, Clone, Serialize, Deserialize)]
574pub struct ParseResult {
575    /// 解析出的发票列表
576    pub invoices: Vec<Invoice>,
577    /// 源文件路径
578    pub source_file: Option<String>,
579    /// 解析过程中的警告信息
580    pub parse_warnings: Vec<String>,
581}
582
583impl ParseResult {
584    /// 创建包含单个发票的解析结果
585    pub fn single(invoice: Invoice) -> Self {
586        Self {
587            invoices: vec![invoice],
588            source_file: None,
589            parse_warnings: Vec::new(),
590        }
591    }
592
593    /// 创建包含多个发票的解析结果
594    pub fn multiple(invoices: Vec<Invoice>) -> Self {
595        Self {
596            invoices,
597            source_file: None,
598            parse_warnings: Vec::new(),
599        }
600    }
601
602    /// 设置源文件路径(链式调用)
603    pub fn with_source(mut self, source: impl Into<String>) -> Self {
604        self.source_file = Some(source.into());
605        self
606    }
607
608    /// 添加警告信息(链式调用)
609    pub fn with_warning(mut self, warning: impl Into<String>) -> Self {
610        self.parse_warnings.push(warning.into());
611        self
612    }
613}