Skip to main content

invoice_parser/
models.rs

1use chrono::NaiveDate;
2use serde::{Deserialize, Serialize};
3
4mod decimal_format {
5    use serde::{self, Deserialize, Deserializer, Serializer};
6
7    pub fn serialize<S>(value: &f64, serializer: S) -> Result<S::Ok, S::Error>
8    where
9        S: Serializer,
10    {
11        let abs = value.abs();
12        if abs == 0.0 {
13            serializer.serialize_f64(0.0)
14        } else if abs < 0.0001 {
15            serializer.serialize_str(
16                format!("{:.10}", value)
17                    .trim_end_matches('0')
18                    .trim_end_matches('.'),
19            )
20        } else if abs < 1.0 {
21            serializer.serialize_str(
22                format!("{:.8}", value)
23                    .trim_end_matches('0')
24                    .trim_end_matches('.'),
25            )
26        } else {
27            serializer.serialize_f64(*value)
28        }
29    }
30
31    pub fn deserialize<'de, D>(deserializer: D) -> Result<f64, D::Error>
32    where
33        D: Deserializer<'de>,
34    {
35        #[derive(Deserialize)]
36        #[serde(untagged)]
37        enum StringOrFloat {
38            String(String),
39            Float(f64),
40        }
41
42        match StringOrFloat::deserialize(deserializer)? {
43            StringOrFloat::String(s) => s.parse().map_err(serde::de::Error::custom),
44            StringOrFloat::Float(f) => Ok(f),
45        }
46    }
47}
48
49mod decimal_format_option {
50    use serde::{self, Deserialize, Deserializer, Serializer};
51
52    pub fn serialize<S>(value: &Option<f64>, serializer: S) -> Result<S::Ok, S::Error>
53    where
54        S: Serializer,
55    {
56        match value {
57            Some(v) => {
58                let abs = v.abs();
59                if abs == 0.0 {
60                    serializer.serialize_f64(0.0)
61                } else if abs < 0.0001 {
62                    serializer.serialize_str(
63                        format!("{:.10}", v)
64                            .trim_end_matches('0')
65                            .trim_end_matches('.'),
66                    )
67                } else if abs < 1.0 {
68                    serializer.serialize_str(
69                        format!("{:.8}", v)
70                            .trim_end_matches('0')
71                            .trim_end_matches('.'),
72                    )
73                } else {
74                    serializer.serialize_f64(*v)
75                }
76            }
77            None => serializer.serialize_none(),
78        }
79    }
80
81    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<f64>, D::Error>
82    where
83        D: Deserializer<'de>,
84    {
85        #[derive(Deserialize)]
86        #[serde(untagged)]
87        enum StringOrFloat {
88            String(String),
89            Float(f64),
90        }
91
92        let opt: Option<StringOrFloat> = Option::deserialize(deserializer)?;
93        match opt {
94            Some(StringOrFloat::String(s)) => s.parse().map(Some).map_err(serde::de::Error::custom),
95            Some(StringOrFloat::Float(f)) => Ok(Some(f)),
96            None => Ok(None),
97        }
98    }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
102pub enum DocumentFormat {
103    #[default]
104    Unknown,
105    AwsDirect,
106    ECloudValleyAws,
107    MicrofusionAliyun,
108    AliyunDirect,
109    UCloud,
110    GoogleCloud,
111    Azure,
112    Lokalise,
113    Sentry,
114    Mux,
115    MlyticsConsolidated,
116    AzureCsp,
117    AliyunUsageDetail,
118    MicrofusionGcpUsage,
119}
120
121impl std::fmt::Display for DocumentFormat {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        match self {
124            DocumentFormat::Unknown => write!(f, "Unknown"),
125            DocumentFormat::AwsDirect => write!(f, "AWS Direct"),
126            DocumentFormat::ECloudValleyAws => write!(f, "eCloudValley AWS"),
127            DocumentFormat::MicrofusionAliyun => write!(f, "Microfusion Aliyun"),
128            DocumentFormat::AliyunDirect => write!(f, "Alibaba Cloud Direct"),
129            DocumentFormat::UCloud => write!(f, "UCloud"),
130            DocumentFormat::GoogleCloud => write!(f, "Google Cloud"),
131            DocumentFormat::Azure => write!(f, "Microsoft Azure"),
132            DocumentFormat::Lokalise => write!(f, "Lokalise"),
133            DocumentFormat::Sentry => write!(f, "Sentry"),
134            DocumentFormat::Mux => write!(f, "Mux"),
135            DocumentFormat::MlyticsConsolidated => write!(f, "Mlytics Consolidated"),
136            DocumentFormat::AzureCsp => write!(f, "Azure CSP"),
137            DocumentFormat::AliyunUsageDetail => write!(f, "Aliyun Usage Detail"),
138            DocumentFormat::MicrofusionGcpUsage => write!(f, "Microfusion GCP Usage"),
139        }
140    }
141}
142
143impl DocumentFormat {
144    pub fn vendor_name(&self) -> &str {
145        match self {
146            DocumentFormat::AwsDirect => "AWS",
147            DocumentFormat::ECloudValleyAws => "AWS",
148            DocumentFormat::MicrofusionAliyun => "阿里云",
149            DocumentFormat::AliyunDirect => "阿里云",
150            DocumentFormat::UCloud => "UCloud",
151            DocumentFormat::GoogleCloud => "Google Cloud",
152            DocumentFormat::Azure => "Microsoft Azure",
153            DocumentFormat::Lokalise => "Lokalise",
154            DocumentFormat::Sentry => "Sentry",
155            DocumentFormat::Mux => "Mux",
156            DocumentFormat::MlyticsConsolidated => "Mlytics",
157            DocumentFormat::AzureCsp => "Microsoft Azure",
158            DocumentFormat::AliyunUsageDetail => "阿里云",
159            DocumentFormat::MicrofusionGcpUsage => "Google Cloud",
160            DocumentFormat::Unknown => "",
161        }
162    }
163
164    pub fn format_type(&self) -> &str {
165        match self {
166            DocumentFormat::UCloud => "XLSX",
167            DocumentFormat::AliyunUsageDetail => "XLSX",
168            DocumentFormat::MicrofusionGcpUsage => "XLSX",
169            DocumentFormat::Unknown => "其他",
170            _ => "PDF",
171        }
172    }
173}
174
175/// 货币类型枚举
176#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
177pub enum Currency {
178    /// 美元
179    #[default]
180    USD,
181    /// 欧元
182    EUR,
183    /// 英镑
184    GBP,
185    /// 日元
186    JPY,
187    /// 人民币
188    CNY,
189    /// 港币
190    HKD,
191    /// 新加坡元
192    SGD,
193    /// 澳元
194    AUD,
195    /// 加元
196    CAD,
197    /// 瑞士法郎
198    CHF,
199    /// 其他货币
200    Other(String),
201}
202
203impl From<&str> for Currency {
204    fn from(s: &str) -> Self {
205        match s.to_uppercase().as_str() {
206            "USD" | "$" | "US$" => Currency::USD,
207            "EUR" | "€" => Currency::EUR,
208            "GBP" | "£" => Currency::GBP,
209            "JPY" | "¥" | "YEN" => Currency::JPY,
210            "CNY" | "RMB" | "元" => Currency::CNY,
211            "HKD" | "HK$" => Currency::HKD,
212            "SGD" | "S$" => Currency::SGD,
213            "AUD" | "A$" => Currency::AUD,
214            "CAD" | "C$" => Currency::CAD,
215            "CHF" => Currency::CHF,
216            other => Currency::Other(other.to_string()),
217        }
218    }
219}
220
221/// 发票类型枚举
222#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
223pub enum InvoiceType {
224    /// 标准发票
225    #[default]
226    Standard,
227    /// 贷项通知单(退款/折让凭证)
228    CreditNote,
229    /// 借项通知单(补收款凭证)
230    DebitNote,
231    /// 形式发票(报价/预开发票)
232    ProformaInvoice,
233    /// 商业发票(国际贸易)
234    CommercialInvoice,
235    /// 收据
236    Receipt,
237    /// 账单
238    Bill,
239    /// 对账单(如 AWS 账单)
240    Statement,
241    /// 未知类型
242    Unknown,
243}
244
245/// 地址信息结构体
246#[derive(Debug, Clone, Serialize, Deserialize, Default)]
247pub struct Address {
248    /// 地址第一行(街道地址)
249    pub line1: Option<String>,
250    /// 地址第二行(门牌号、楼层等)
251    pub line2: Option<String>,
252    /// 城市
253    pub city: Option<String>,
254    /// 州/省
255    pub state: Option<String>,
256    /// 邮政编码
257    pub postal_code: Option<String>,
258    /// 国家
259    pub country: Option<String>,
260}
261
262impl Address {
263    /// 返回完整地址字符串,各部分用逗号分隔
264    pub fn full_address(&self) -> String {
265        [
266            self.line1.as_deref(),
267            self.line2.as_deref(),
268            self.city.as_deref(),
269            self.state.as_deref(),
270            self.postal_code.as_deref(),
271            self.country.as_deref(),
272        ]
273        .iter()
274        .filter_map(|&s| s)
275        .collect::<Vec<_>>()
276        .join(", ")
277    }
278}
279
280/// 交易方信息结构体(供应商或客户)
281#[derive(Debug, Clone, Serialize, Deserialize, Default)]
282pub struct Party {
283    /// 公司/个人名称
284    pub name: Option<String>,
285    /// 税务识别号(如统一编号、VAT号等)
286    pub tax_id: Option<String>,
287    /// 地址信息
288    pub address: Option<Address>,
289    /// 电子邮件
290    pub email: Option<String>,
291    /// 电话号码
292    pub phone: Option<String>,
293}
294
295/// 发票行项目结构体(单个商品/服务明细)
296#[derive(Debug, Clone, Serialize, Deserialize, Default)]
297pub struct LineItem {
298    /// 行号
299    pub line_number: Option<u32>,
300    /// 服务/项目名称(如:Amazon CloudFront)
301    pub service_name: Option<String>,
302    /// 使用类型/描述(如:US-Requests-Tier2-HTTPS)
303    pub description: String,
304    /// 数量
305    #[serde(with = "decimal_format_option")]
306    pub quantity: Option<f64>,
307    /// 单位(如:个、件、小时等)
308    pub unit: Option<String>,
309    /// 单价
310    #[serde(with = "decimal_format_option")]
311    pub unit_price: Option<f64>,
312    /// 折扣金额
313    #[serde(with = "decimal_format_option")]
314    pub discount: Option<f64>,
315    /// 税率(百分比)
316    #[serde(with = "decimal_format_option")]
317    pub tax_rate: Option<f64>,
318    /// 税额
319    #[serde(with = "decimal_format_option")]
320    pub tax_amount: Option<f64>,
321    /// 金额(数量 × 单价 - 折扣 + 税额)
322    #[serde(with = "decimal_format")]
323    pub amount: f64,
324}
325
326impl LineItem {
327    /// 计算行项目金额:数量 × 单价 - 折扣 + 税额
328    pub fn calculate_amount(&self) -> f64 {
329        let qty = self.quantity.unwrap_or(1.0);
330        let price = self.unit_price.unwrap_or(self.amount);
331        let discount = self.discount.unwrap_or(0.0);
332        let tax = self.tax_amount.unwrap_or(0.0);
333
334        (qty * price) - discount + tax
335    }
336
337    /// 验证单价×数量是否等于金额
338    /// 返回 (是否有效, 计算值, 差异值)
339    pub fn validate_amount(&self) -> LineItemValidation {
340        let (qty, price) = match (self.quantity, self.unit_price) {
341            (Some(q), Some(p)) => (q, p),
342            _ => {
343                return LineItemValidation {
344                    is_valid: true,
345                    can_validate: false,
346                    calculated_amount: None,
347                    difference: None,
348                    difference_percent: None,
349                }
350            }
351        };
352
353        let calculated = qty * price;
354        let diff = (self.amount - calculated).abs();
355        let diff_percent = if calculated.abs() > 0.0001 {
356            (diff / calculated.abs()) * 100.0
357        } else {
358            0.0
359        };
360
361        LineItemValidation {
362            is_valid: diff < 0.01 || diff_percent < 1.0,
363            can_validate: true,
364            calculated_amount: Some(calculated),
365            difference: Some(self.amount - calculated),
366            difference_percent: Some(diff_percent),
367        }
368    }
369}
370
371/// 行项目验证结果
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct LineItemValidation {
374    /// 是否有效(差异在允许范围内)
375    pub is_valid: bool,
376    /// 是否可以验证(有单价和数量)
377    pub can_validate: bool,
378    /// 计算金额(单价×数量)
379    pub calculated_amount: Option<f64>,
380    /// 差异值(实际金额 - 计算金额)
381    pub difference: Option<f64>,
382    /// 差异百分比
383    pub difference_percent: Option<f64>,
384}
385
386/// 税务汇总结构体
387#[derive(Debug, Clone, Serialize, Deserialize, Default)]
388pub struct TaxSummary {
389    /// 税种(如:增值税、营业税、GST等)
390    pub tax_type: Option<String>,
391    /// 税率(百分比)
392    pub tax_rate: Option<f64>,
393    /// 应税金额(税基)
394    pub taxable_amount: Option<f64>,
395    /// 税额
396    pub tax_amount: f64,
397}
398
399/// 付款信息结构体
400#[derive(Debug, Clone, Serialize, Deserialize, Default)]
401pub struct PaymentInfo {
402    /// 付款方式(如:银行转账、信用卡、支票等)
403    pub method: Option<String>,
404    /// 银行名称
405    pub bank_name: Option<String>,
406    /// 银行账号
407    pub account_number: Option<String>,
408    /// 银行路由号(美国银行系统)
409    pub routing_number: Option<String>,
410    /// 国际银行账号(IBAN)
411    pub iban: Option<String>,
412    /// SWIFT/BIC 代码(国际汇款)
413    pub swift_code: Option<String>,
414    /// 付款参考号/备注
415    pub reference: Option<String>,
416}
417
418/// 发票主结构体
419///
420/// 包含发票的所有核心信息,支持多种发票格式的解析结果存储
421#[derive(Debug, Clone, Serialize, Deserialize, Default)]
422pub struct Invoice {
423    pub document_format: DocumentFormat,
424    pub invoice_type: InvoiceType,
425    /// 发票号码
426    pub invoice_number: Option<String>,
427    /// 账户名称/项目别名(如 AWS Account No 后括号内的名称)
428    pub account_name: Option<String>,
429    /// 客户ID
430    pub customer_id: Option<String>,
431    /// 账单年月(格式:YYYY-MM,如 2025-01)
432    pub billing_period: Option<String>,
433    /// 发票日期
434    pub invoice_date: Option<NaiveDate>,
435    /// 付款截止日期
436    pub due_date: Option<NaiveDate>,
437    /// 货币类型
438    pub currency: Currency,
439    /// 供应商/卖方信息
440    pub vendor: Party,
441    /// 客户/买方信息
442    pub customer: Party,
443    /// 行项目列表(商品/服务明细)
444    pub line_items: Vec<LineItem>,
445    /// 小计金额(税前)
446    pub subtotal: Option<f64>,
447    /// 折扣金额(负数表示折扣)
448    pub discount_amount: Option<f64>,
449    /// 折扣比例(百分比,如 5.0 表示 5%)
450    pub discount_rate: Option<f64>,
451    /// 税务汇总列表
452    pub tax_summaries: Vec<TaxSummary>,
453    /// 总税额
454    pub total_tax: Option<f64>,
455    /// 总金额(含税)
456    pub total_amount: f64,
457    /// 已付金额
458    pub amount_paid: Option<f64>,
459    /// 应付金额(未付余额)
460    pub amount_due: Option<f64>,
461    /// 付款信息
462    pub payment_info: Option<PaymentInfo>,
463    /// 备注/附注
464    pub notes: Option<String>,
465    /// 原始文本(PDF提取的原文)
466    pub raw_text: Option<String>,
467    /// 元数据(扩展字段,如账单周期等)
468    pub metadata: std::collections::HashMap<String, String>,
469}
470
471impl Invoice {
472    /// 创建新的空发票实例
473    pub fn new() -> Self {
474        Self::default()
475    }
476
477    /// 序列化为格式化的 JSON 字符串
478    pub fn to_json(&self) -> Result<String, serde_json::Error> {
479        serde_json::to_string_pretty(self)
480    }
481
482    /// 序列化为紧凑的 JSON 字符串
483    pub fn to_json_compact(&self) -> Result<String, serde_json::Error> {
484        serde_json::to_string(self)
485    }
486
487    /// 从 JSON 字符串反序列化
488    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
489        serde_json::from_str(json)
490    }
491
492    /// 根据行项目计算小计金额
493    pub fn calculate_subtotal(&self) -> f64 {
494        self.line_items.iter().map(|item| item.amount).sum()
495    }
496
497    /// 根据税务汇总计算总税额
498    pub fn calculate_total_tax(&self) -> f64 {
499        self.tax_summaries.iter().map(|t| t.tax_amount).sum()
500    }
501
502    /// 判断发票是否已付清
503    pub fn is_paid(&self) -> bool {
504        match (self.amount_paid, self.amount_due) {
505            (Some(paid), _) if paid >= self.total_amount => true,
506            (_, Some(due)) if due <= 0.0 => true,
507            _ => false,
508        }
509    }
510
511    pub fn validate_line_items(&self) -> InvoiceValidation {
512        let mut validations = Vec::new();
513        let mut invalid_count = 0;
514        let mut validatable_count = 0;
515
516        for (idx, item) in self.line_items.iter().enumerate() {
517            let validation = item.validate_amount();
518            if validation.can_validate {
519                validatable_count += 1;
520                if !validation.is_valid {
521                    invalid_count += 1;
522                }
523            }
524            validations.push((idx, item.description.clone(), validation));
525        }
526
527        let line_items_sum = self.calculate_subtotal();
528        let subtotal_diff = self.subtotal.map(|s| {
529            let diff = s - line_items_sum;
530            let pct = if s.abs() > 0.0001 {
531                (diff.abs() / s.abs()) * 100.0
532            } else {
533                0.0
534            };
535            (diff, pct)
536        });
537
538        InvoiceValidation {
539            all_valid: invalid_count == 0 && Self::is_subtotal_valid(subtotal_diff),
540            subtotal_valid: Self::is_subtotal_valid(subtotal_diff),
541            total_items: self.line_items.len(),
542            validatable_items: validatable_count,
543            invalid_items: invalid_count,
544            line_items_sum,
545            subtotal: self.subtotal,
546            subtotal_difference: subtotal_diff.map(|(d, _)| d),
547            subtotal_difference_percent: subtotal_diff.map(|(_, p)| p),
548            item_validations: validations,
549        }
550    }
551
552    fn is_subtotal_valid(subtotal_diff: Option<(f64, f64)>) -> bool {
553        const MAX_ALLOWED_DIFFERENCE_PERCENT: f64 = 1.0;
554        match subtotal_diff {
555            Some((_, pct)) => pct <= MAX_ALLOWED_DIFFERENCE_PERCENT,
556            None => true,
557        }
558    }
559}
560
561#[derive(Debug, Clone)]
562pub struct InvoiceValidation {
563    pub all_valid: bool,
564    pub subtotal_valid: bool,
565    pub total_items: usize,
566    pub validatable_items: usize,
567    pub invalid_items: usize,
568    pub line_items_sum: f64,
569    pub subtotal: Option<f64>,
570    pub subtotal_difference: Option<f64>,
571    pub subtotal_difference_percent: Option<f64>,
572    pub item_validations: Vec<(usize, String, LineItemValidation)>,
573}
574
575impl InvoiceValidation {
576    pub fn print_report(&self) {
577        println!("=== Invoice Validation Report ===");
578        println!("Total items: {}", self.total_items);
579        println!("Validatable items: {}", self.validatable_items);
580        println!("Invalid items: {}", self.invalid_items);
581        println!("Line items sum: {:.2}", self.line_items_sum);
582
583        if let Some(subtotal) = self.subtotal {
584            println!("Invoice subtotal: {:.2}", subtotal);
585            if let (Some(diff), Some(pct)) =
586                (self.subtotal_difference, self.subtotal_difference_percent)
587            {
588                println!("Difference: {:.2} ({:.2}%)", diff, pct);
589                if !self.subtotal_valid {
590                    println!("ERROR: Line items sum does not match subtotal (>1% difference)");
591                }
592            }
593        }
594
595        let invalid_items: Vec<_> = self
596            .item_validations
597            .iter()
598            .filter(|(_, _, v)| v.can_validate && !v.is_valid)
599            .collect();
600
601        if !invalid_items.is_empty() {
602            println!("\nInvalid line items:");
603            for (idx, desc, v) in invalid_items {
604                println!(
605                    "  Line {}: {} | calculated: {:.4} | diff: {:.4}",
606                    idx + 1,
607                    desc,
608                    v.calculated_amount.unwrap_or(0.0),
609                    v.difference.unwrap_or(0.0)
610                );
611            }
612        }
613    }
614}
615
616/// 解析结果结构体
617///
618/// 封装解析操作的输出,支持单个或多个发票,并包含警告信息
619#[derive(Debug, Clone, Serialize, Deserialize)]
620pub struct ParseResult {
621    /// 解析出的发票列表
622    pub invoices: Vec<Invoice>,
623    /// 源文件路径
624    pub source_file: Option<String>,
625    /// 解析过程中的警告信息
626    pub parse_warnings: Vec<String>,
627}
628
629impl ParseResult {
630    /// 创建包含单个发票的解析结果
631    pub fn single(invoice: Invoice) -> Self {
632        Self {
633            invoices: vec![invoice],
634            source_file: None,
635            parse_warnings: Vec::new(),
636        }
637    }
638
639    /// 创建包含多个发票的解析结果
640    pub fn multiple(invoices: Vec<Invoice>) -> Self {
641        Self {
642            invoices,
643            source_file: None,
644            parse_warnings: Vec::new(),
645        }
646    }
647
648    /// 设置源文件路径(链式调用)
649    pub fn with_source(mut self, source: impl Into<String>) -> Self {
650        self.source_file = Some(source.into());
651        self
652    }
653
654    /// 添加警告信息(链式调用)
655    pub fn with_warning(mut self, warning: impl Into<String>) -> Self {
656        self.parse_warnings.push(warning.into());
657        self
658    }
659}