Skip to main content

invoice_parser/
models.rs

1use chrono::NaiveDate;
2use serde::{Deserialize, Serialize};
3
4mod decimal_format {
5    use serde::{self, Deserialize, Deserializer, Serializer};
6
7    pub fn serialize<S>(value: &f64, serializer: S) -> Result<S::Ok, S::Error>
8    where
9        S: Serializer,
10    {
11        let abs = value.abs();
12        if abs == 0.0 {
13            serializer.serialize_f64(0.0)
14        } else if abs < 0.0001 {
15            serializer.serialize_str(
16                format!("{:.10}", value)
17                    .trim_end_matches('0')
18                    .trim_end_matches('.'),
19            )
20        } else if abs < 1.0 {
21            serializer.serialize_str(
22                format!("{:.8}", value)
23                    .trim_end_matches('0')
24                    .trim_end_matches('.'),
25            )
26        } else {
27            serializer.serialize_f64(*value)
28        }
29    }
30
31    pub fn deserialize<'de, D>(deserializer: D) -> Result<f64, D::Error>
32    where
33        D: Deserializer<'de>,
34    {
35        #[derive(Deserialize)]
36        #[serde(untagged)]
37        enum StringOrFloat {
38            String(String),
39            Float(f64),
40        }
41
42        match StringOrFloat::deserialize(deserializer)? {
43            StringOrFloat::String(s) => s.parse().map_err(serde::de::Error::custom),
44            StringOrFloat::Float(f) => Ok(f),
45        }
46    }
47}
48
49mod decimal_format_option {
50    use serde::{self, Deserialize, Deserializer, Serializer};
51
52    pub fn serialize<S>(value: &Option<f64>, serializer: S) -> Result<S::Ok, S::Error>
53    where
54        S: Serializer,
55    {
56        match value {
57            Some(v) => {
58                let abs = v.abs();
59                if abs == 0.0 {
60                    serializer.serialize_f64(0.0)
61                } else if abs < 0.0001 {
62                    serializer.serialize_str(
63                        format!("{:.10}", v)
64                            .trim_end_matches('0')
65                            .trim_end_matches('.'),
66                    )
67                } else if abs < 1.0 {
68                    serializer.serialize_str(
69                        format!("{:.8}", v)
70                            .trim_end_matches('0')
71                            .trim_end_matches('.'),
72                    )
73                } else {
74                    serializer.serialize_f64(*v)
75                }
76            }
77            None => serializer.serialize_none(),
78        }
79    }
80
81    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<f64>, D::Error>
82    where
83        D: Deserializer<'de>,
84    {
85        #[derive(Deserialize)]
86        #[serde(untagged)]
87        enum StringOrFloat {
88            String(String),
89            Float(f64),
90        }
91
92        let opt: Option<StringOrFloat> = Option::deserialize(deserializer)?;
93        match opt {
94            Some(StringOrFloat::String(s)) => s.parse().map(Some).map_err(serde::de::Error::custom),
95            Some(StringOrFloat::Float(f)) => Ok(Some(f)),
96            None => Ok(None),
97        }
98    }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
102pub enum DocumentFormat {
103    #[default]
104    Unknown,
105    AwsDirect,
106    ECloudValleyAws,
107    MicrofusionAliyun,
108    AliyunDirect,
109    UCloud,
110    GoogleCloud,
111    Azure,
112    Lokalise,
113    Sentry,
114    Mux,
115}
116
117impl std::fmt::Display for DocumentFormat {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        match self {
120            DocumentFormat::Unknown => write!(f, "Unknown"),
121            DocumentFormat::AwsDirect => write!(f, "AWS Direct"),
122            DocumentFormat::ECloudValleyAws => write!(f, "eCloudValley AWS"),
123            DocumentFormat::MicrofusionAliyun => write!(f, "Microfusion Aliyun"),
124            DocumentFormat::AliyunDirect => write!(f, "Alibaba Cloud Direct"),
125            DocumentFormat::UCloud => write!(f, "UCloud"),
126            DocumentFormat::GoogleCloud => write!(f, "Google Cloud"),
127            DocumentFormat::Azure => write!(f, "Microsoft Azure"),
128            DocumentFormat::Lokalise => write!(f, "Lokalise"),
129            DocumentFormat::Sentry => write!(f, "Sentry"),
130            DocumentFormat::Mux => write!(f, "Mux"),
131        }
132    }
133}
134
135/// 货币类型枚举
136#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
137pub enum Currency {
138    /// 美元
139    #[default]
140    USD,
141    /// 欧元
142    EUR,
143    /// 英镑
144    GBP,
145    /// 日元
146    JPY,
147    /// 人民币
148    CNY,
149    /// 港币
150    HKD,
151    /// 新加坡元
152    SGD,
153    /// 澳元
154    AUD,
155    /// 加元
156    CAD,
157    /// 瑞士法郎
158    CHF,
159    /// 其他货币
160    Other(String),
161}
162
163impl From<&str> for Currency {
164    fn from(s: &str) -> Self {
165        match s.to_uppercase().as_str() {
166            "USD" | "$" | "US$" => Currency::USD,
167            "EUR" | "€" => Currency::EUR,
168            "GBP" | "£" => Currency::GBP,
169            "JPY" | "¥" | "YEN" => Currency::JPY,
170            "CNY" | "RMB" | "元" => Currency::CNY,
171            "HKD" | "HK$" => Currency::HKD,
172            "SGD" | "S$" => Currency::SGD,
173            "AUD" | "A$" => Currency::AUD,
174            "CAD" | "C$" => Currency::CAD,
175            "CHF" => Currency::CHF,
176            other => Currency::Other(other.to_string()),
177        }
178    }
179}
180
181/// 发票类型枚举
182#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
183pub enum InvoiceType {
184    /// 标准发票
185    #[default]
186    Standard,
187    /// 贷项通知单(退款/折让凭证)
188    CreditNote,
189    /// 借项通知单(补收款凭证)
190    DebitNote,
191    /// 形式发票(报价/预开发票)
192    ProformaInvoice,
193    /// 商业发票(国际贸易)
194    CommercialInvoice,
195    /// 收据
196    Receipt,
197    /// 账单
198    Bill,
199    /// 对账单(如 AWS 账单)
200    Statement,
201    /// 未知类型
202    Unknown,
203}
204
205/// 地址信息结构体
206#[derive(Debug, Clone, Serialize, Deserialize, Default)]
207pub struct Address {
208    /// 地址第一行(街道地址)
209    pub line1: Option<String>,
210    /// 地址第二行(门牌号、楼层等)
211    pub line2: Option<String>,
212    /// 城市
213    pub city: Option<String>,
214    /// 州/省
215    pub state: Option<String>,
216    /// 邮政编码
217    pub postal_code: Option<String>,
218    /// 国家
219    pub country: Option<String>,
220}
221
222impl Address {
223    /// 返回完整地址字符串,各部分用逗号分隔
224    pub fn full_address(&self) -> String {
225        [
226            self.line1.as_deref(),
227            self.line2.as_deref(),
228            self.city.as_deref(),
229            self.state.as_deref(),
230            self.postal_code.as_deref(),
231            self.country.as_deref(),
232        ]
233        .iter()
234        .filter_map(|&s| s)
235        .collect::<Vec<_>>()
236        .join(", ")
237    }
238}
239
240/// 交易方信息结构体(供应商或客户)
241#[derive(Debug, Clone, Serialize, Deserialize, Default)]
242pub struct Party {
243    /// 公司/个人名称
244    pub name: Option<String>,
245    /// 税务识别号(如统一编号、VAT号等)
246    pub tax_id: Option<String>,
247    /// 地址信息
248    pub address: Option<Address>,
249    /// 电子邮件
250    pub email: Option<String>,
251    /// 电话号码
252    pub phone: Option<String>,
253}
254
255/// 发票行项目结构体(单个商品/服务明细)
256#[derive(Debug, Clone, Serialize, Deserialize, Default)]
257pub struct LineItem {
258    /// 行号
259    pub line_number: Option<u32>,
260    /// 服务/项目名称(如:Amazon CloudFront)
261    pub service_name: Option<String>,
262    /// 使用类型/描述(如:US-Requests-Tier2-HTTPS)
263    pub description: String,
264    /// 数量
265    #[serde(with = "decimal_format_option")]
266    pub quantity: Option<f64>,
267    /// 单位(如:个、件、小时等)
268    pub unit: Option<String>,
269    /// 单价
270    #[serde(with = "decimal_format_option")]
271    pub unit_price: Option<f64>,
272    /// 折扣金额
273    #[serde(with = "decimal_format_option")]
274    pub discount: Option<f64>,
275    /// 税率(百分比)
276    #[serde(with = "decimal_format_option")]
277    pub tax_rate: Option<f64>,
278    /// 税额
279    #[serde(with = "decimal_format_option")]
280    pub tax_amount: Option<f64>,
281    /// 金额(数量 × 单价 - 折扣 + 税额)
282    #[serde(with = "decimal_format")]
283    pub amount: f64,
284}
285
286impl LineItem {
287    /// 计算行项目金额:数量 × 单价 - 折扣 + 税额
288    pub fn calculate_amount(&self) -> f64 {
289        let qty = self.quantity.unwrap_or(1.0);
290        let price = self.unit_price.unwrap_or(self.amount);
291        let discount = self.discount.unwrap_or(0.0);
292        let tax = self.tax_amount.unwrap_or(0.0);
293
294        (qty * price) - discount + tax
295    }
296
297    /// 验证单价×数量是否等于金额
298    /// 返回 (是否有效, 计算值, 差异值)
299    pub fn validate_amount(&self) -> LineItemValidation {
300        let (qty, price) = match (self.quantity, self.unit_price) {
301            (Some(q), Some(p)) => (q, p),
302            _ => {
303                return LineItemValidation {
304                    is_valid: true,
305                    can_validate: false,
306                    calculated_amount: None,
307                    difference: None,
308                    difference_percent: None,
309                }
310            }
311        };
312
313        let calculated = qty * price;
314        let diff = (self.amount - calculated).abs();
315        let diff_percent = if calculated.abs() > 0.0001 {
316            (diff / calculated.abs()) * 100.0
317        } else {
318            0.0
319        };
320
321        LineItemValidation {
322            is_valid: diff < 0.01 || diff_percent < 1.0,
323            can_validate: true,
324            calculated_amount: Some(calculated),
325            difference: Some(self.amount - calculated),
326            difference_percent: Some(diff_percent),
327        }
328    }
329}
330
331/// 行项目验证结果
332#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct LineItemValidation {
334    /// 是否有效(差异在允许范围内)
335    pub is_valid: bool,
336    /// 是否可以验证(有单价和数量)
337    pub can_validate: bool,
338    /// 计算金额(单价×数量)
339    pub calculated_amount: Option<f64>,
340    /// 差异值(实际金额 - 计算金额)
341    pub difference: Option<f64>,
342    /// 差异百分比
343    pub difference_percent: Option<f64>,
344}
345
346/// 税务汇总结构体
347#[derive(Debug, Clone, Serialize, Deserialize, Default)]
348pub struct TaxSummary {
349    /// 税种(如:增值税、营业税、GST等)
350    pub tax_type: Option<String>,
351    /// 税率(百分比)
352    pub tax_rate: Option<f64>,
353    /// 应税金额(税基)
354    pub taxable_amount: Option<f64>,
355    /// 税额
356    pub tax_amount: f64,
357}
358
359/// 付款信息结构体
360#[derive(Debug, Clone, Serialize, Deserialize, Default)]
361pub struct PaymentInfo {
362    /// 付款方式(如:银行转账、信用卡、支票等)
363    pub method: Option<String>,
364    /// 银行名称
365    pub bank_name: Option<String>,
366    /// 银行账号
367    pub account_number: Option<String>,
368    /// 银行路由号(美国银行系统)
369    pub routing_number: Option<String>,
370    /// 国际银行账号(IBAN)
371    pub iban: Option<String>,
372    /// SWIFT/BIC 代码(国际汇款)
373    pub swift_code: Option<String>,
374    /// 付款参考号/备注
375    pub reference: Option<String>,
376}
377
378/// 发票主结构体
379///
380/// 包含发票的所有核心信息,支持多种发票格式的解析结果存储
381#[derive(Debug, Clone, Serialize, Deserialize, Default)]
382pub struct Invoice {
383    pub document_format: DocumentFormat,
384    pub invoice_type: InvoiceType,
385    /// 发票号码
386    pub invoice_number: Option<String>,
387    /// 账户名称/项目别名(如 AWS Account No 后括号内的名称)
388    pub account_name: Option<String>,
389    /// 客户ID
390    pub customer_id: Option<String>,
391    /// 账单年月(格式:YYYY-MM,如 2025-01)
392    pub billing_period: Option<String>,
393    /// 发票日期
394    pub invoice_date: Option<NaiveDate>,
395    /// 付款截止日期
396    pub due_date: Option<NaiveDate>,
397    /// 货币类型
398    pub currency: Currency,
399    /// 供应商/卖方信息
400    pub vendor: Party,
401    /// 客户/买方信息
402    pub customer: Party,
403    /// 行项目列表(商品/服务明细)
404    pub line_items: Vec<LineItem>,
405    /// 小计金额(税前)
406    pub subtotal: Option<f64>,
407    /// 折扣金额(负数表示折扣)
408    pub discount_amount: Option<f64>,
409    /// 折扣比例(百分比,如 5.0 表示 5%)
410    pub discount_rate: Option<f64>,
411    /// 税务汇总列表
412    pub tax_summaries: Vec<TaxSummary>,
413    /// 总税额
414    pub total_tax: Option<f64>,
415    /// 总金额(含税)
416    pub total_amount: f64,
417    /// 已付金额
418    pub amount_paid: Option<f64>,
419    /// 应付金额(未付余额)
420    pub amount_due: Option<f64>,
421    /// 付款信息
422    pub payment_info: Option<PaymentInfo>,
423    /// 备注/附注
424    pub notes: Option<String>,
425    /// 原始文本(PDF提取的原文)
426    pub raw_text: Option<String>,
427    /// 元数据(扩展字段,如账单周期等)
428    pub metadata: std::collections::HashMap<String, String>,
429}
430
431impl Invoice {
432    /// 创建新的空发票实例
433    pub fn new() -> Self {
434        Self::default()
435    }
436
437    /// 序列化为格式化的 JSON 字符串
438    pub fn to_json(&self) -> Result<String, serde_json::Error> {
439        serde_json::to_string_pretty(self)
440    }
441
442    /// 序列化为紧凑的 JSON 字符串
443    pub fn to_json_compact(&self) -> Result<String, serde_json::Error> {
444        serde_json::to_string(self)
445    }
446
447    /// 从 JSON 字符串反序列化
448    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
449        serde_json::from_str(json)
450    }
451
452    /// 根据行项目计算小计金额
453    pub fn calculate_subtotal(&self) -> f64 {
454        self.line_items.iter().map(|item| item.amount).sum()
455    }
456
457    /// 根据税务汇总计算总税额
458    pub fn calculate_total_tax(&self) -> f64 {
459        self.tax_summaries.iter().map(|t| t.tax_amount).sum()
460    }
461
462    /// 判断发票是否已付清
463    pub fn is_paid(&self) -> bool {
464        match (self.amount_paid, self.amount_due) {
465            (Some(paid), _) if paid >= self.total_amount => true,
466            (_, Some(due)) if due <= 0.0 => true,
467            _ => false,
468        }
469    }
470
471    pub fn validate_line_items(&self) -> InvoiceValidation {
472        let mut validations = Vec::new();
473        let mut invalid_count = 0;
474        let mut validatable_count = 0;
475
476        for (idx, item) in self.line_items.iter().enumerate() {
477            let validation = item.validate_amount();
478            if validation.can_validate {
479                validatable_count += 1;
480                if !validation.is_valid {
481                    invalid_count += 1;
482                }
483            }
484            validations.push((idx, item.description.clone(), validation));
485        }
486
487        let line_items_sum = self.calculate_subtotal();
488        let subtotal_diff = self.subtotal.map(|s| {
489            let diff = s - line_items_sum;
490            let pct = if s.abs() > 0.0001 {
491                (diff.abs() / s.abs()) * 100.0
492            } else {
493                0.0
494            };
495            (diff, pct)
496        });
497
498        InvoiceValidation {
499            all_valid: invalid_count == 0 && Self::is_subtotal_valid(subtotal_diff),
500            subtotal_valid: Self::is_subtotal_valid(subtotal_diff),
501            total_items: self.line_items.len(),
502            validatable_items: validatable_count,
503            invalid_items: invalid_count,
504            line_items_sum,
505            subtotal: self.subtotal,
506            subtotal_difference: subtotal_diff.map(|(d, _)| d),
507            subtotal_difference_percent: subtotal_diff.map(|(_, p)| p),
508            item_validations: validations,
509        }
510    }
511
512    fn is_subtotal_valid(subtotal_diff: Option<(f64, f64)>) -> bool {
513        const MAX_ALLOWED_DIFFERENCE_PERCENT: f64 = 1.0;
514        match subtotal_diff {
515            Some((_, pct)) => pct <= MAX_ALLOWED_DIFFERENCE_PERCENT,
516            None => true,
517        }
518    }
519}
520
521#[derive(Debug, Clone)]
522pub struct InvoiceValidation {
523    pub all_valid: bool,
524    pub subtotal_valid: bool,
525    pub total_items: usize,
526    pub validatable_items: usize,
527    pub invalid_items: usize,
528    pub line_items_sum: f64,
529    pub subtotal: Option<f64>,
530    pub subtotal_difference: Option<f64>,
531    pub subtotal_difference_percent: Option<f64>,
532    pub item_validations: Vec<(usize, String, LineItemValidation)>,
533}
534
535impl InvoiceValidation {
536    pub fn print_report(&self) {
537        println!("=== Invoice Validation Report ===");
538        println!("Total items: {}", self.total_items);
539        println!("Validatable items: {}", self.validatable_items);
540        println!("Invalid items: {}", self.invalid_items);
541        println!("Line items sum: {:.2}", self.line_items_sum);
542
543        if let Some(subtotal) = self.subtotal {
544            println!("Invoice subtotal: {:.2}", subtotal);
545            if let (Some(diff), Some(pct)) =
546                (self.subtotal_difference, self.subtotal_difference_percent)
547            {
548                println!("Difference: {:.2} ({:.2}%)", diff, pct);
549                if !self.subtotal_valid {
550                    println!("ERROR: Line items sum does not match subtotal (>1% difference)");
551                }
552            }
553        }
554
555        let invalid_items: Vec<_> = self
556            .item_validations
557            .iter()
558            .filter(|(_, _, v)| v.can_validate && !v.is_valid)
559            .collect();
560
561        if !invalid_items.is_empty() {
562            println!("\nInvalid line items:");
563            for (idx, desc, v) in invalid_items {
564                println!(
565                    "  Line {}: {} | calculated: {:.4} | diff: {:.4}",
566                    idx + 1,
567                    desc,
568                    v.calculated_amount.unwrap_or(0.0),
569                    v.difference.unwrap_or(0.0)
570                );
571            }
572        }
573    }
574}
575
576/// 解析结果结构体
577///
578/// 封装解析操作的输出,支持单个或多个发票,并包含警告信息
579#[derive(Debug, Clone, Serialize, Deserialize)]
580pub struct ParseResult {
581    /// 解析出的发票列表
582    pub invoices: Vec<Invoice>,
583    /// 源文件路径
584    pub source_file: Option<String>,
585    /// 解析过程中的警告信息
586    pub parse_warnings: Vec<String>,
587}
588
589impl ParseResult {
590    /// 创建包含单个发票的解析结果
591    pub fn single(invoice: Invoice) -> Self {
592        Self {
593            invoices: vec![invoice],
594            source_file: None,
595            parse_warnings: Vec::new(),
596        }
597    }
598
599    /// 创建包含多个发票的解析结果
600    pub fn multiple(invoices: Vec<Invoice>) -> Self {
601        Self {
602            invoices,
603            source_file: None,
604            parse_warnings: Vec::new(),
605        }
606    }
607
608    /// 设置源文件路径(链式调用)
609    pub fn with_source(mut self, source: impl Into<String>) -> Self {
610        self.source_file = Some(source.into());
611        self
612    }
613
614    /// 添加警告信息(链式调用)
615    pub fn with_warning(mut self, warning: impl Into<String>) -> Self {
616        self.parse_warnings.push(warning.into());
617        self
618    }
619}