Skip to main content

invoice_parser/
models.rs

1use chrono::NaiveDate;
2use serde::{Deserialize, Serialize};
3
4mod decimal_format {
5    use serde::{self, Deserialize, Deserializer, Serializer};
6
7    pub fn serialize<S>(value: &f64, serializer: S) -> Result<S::Ok, S::Error>
8    where
9        S: Serializer,
10    {
11        let abs = value.abs();
12        if abs == 0.0 {
13            serializer.serialize_f64(0.0)
14        } else if abs < 0.0001 {
15            serializer.serialize_str(
16                format!("{:.10}", value)
17                    .trim_end_matches('0')
18                    .trim_end_matches('.'),
19            )
20        } else if abs < 1.0 {
21            serializer.serialize_str(
22                format!("{:.8}", value)
23                    .trim_end_matches('0')
24                    .trim_end_matches('.'),
25            )
26        } else {
27            serializer.serialize_f64(*value)
28        }
29    }
30
31    pub fn deserialize<'de, D>(deserializer: D) -> Result<f64, D::Error>
32    where
33        D: Deserializer<'de>,
34    {
35        #[derive(Deserialize)]
36        #[serde(untagged)]
37        enum StringOrFloat {
38            String(String),
39            Float(f64),
40        }
41
42        match StringOrFloat::deserialize(deserializer)? {
43            StringOrFloat::String(s) => s.parse().map_err(serde::de::Error::custom),
44            StringOrFloat::Float(f) => Ok(f),
45        }
46    }
47}
48
49mod decimal_format_option {
50    use serde::{self, Deserialize, Deserializer, Serializer};
51
52    pub fn serialize<S>(value: &Option<f64>, serializer: S) -> Result<S::Ok, S::Error>
53    where
54        S: Serializer,
55    {
56        match value {
57            Some(v) => {
58                let abs = v.abs();
59                if abs == 0.0 {
60                    serializer.serialize_f64(0.0)
61                } else if abs < 0.0001 {
62                    serializer.serialize_str(
63                        format!("{:.10}", v)
64                            .trim_end_matches('0')
65                            .trim_end_matches('.'),
66                    )
67                } else if abs < 1.0 {
68                    serializer.serialize_str(
69                        format!("{:.8}", v)
70                            .trim_end_matches('0')
71                            .trim_end_matches('.'),
72                    )
73                } else {
74                    serializer.serialize_f64(*v)
75                }
76            }
77            None => serializer.serialize_none(),
78        }
79    }
80
81    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<f64>, D::Error>
82    where
83        D: Deserializer<'de>,
84    {
85        #[derive(Deserialize)]
86        #[serde(untagged)]
87        enum StringOrFloat {
88            String(String),
89            Float(f64),
90        }
91
92        let opt: Option<StringOrFloat> = Option::deserialize(deserializer)?;
93        match opt {
94            Some(StringOrFloat::String(s)) => s.parse().map(Some).map_err(serde::de::Error::custom),
95            Some(StringOrFloat::Float(f)) => Ok(Some(f)),
96            None => Ok(None),
97        }
98    }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
102pub enum DocumentFormat {
103    #[default]
104    Unknown,
105    AwsDirect,
106    ECloudValleyAws,
107    GoogleCloud,
108    Azure,
109}
110
111impl std::fmt::Display for DocumentFormat {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        match self {
114            DocumentFormat::Unknown => write!(f, "Unknown"),
115            DocumentFormat::AwsDirect => write!(f, "AWS Direct"),
116            DocumentFormat::ECloudValleyAws => write!(f, "eCloudValley AWS"),
117            DocumentFormat::GoogleCloud => write!(f, "Google Cloud"),
118            DocumentFormat::Azure => write!(f, "Microsoft Azure"),
119        }
120    }
121}
122
123/// 货币类型枚举
124#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
125pub enum Currency {
126    /// 美元
127    #[default]
128    USD,
129    /// 欧元
130    EUR,
131    /// 英镑
132    GBP,
133    /// 日元
134    JPY,
135    /// 人民币
136    CNY,
137    /// 港币
138    HKD,
139    /// 新加坡元
140    SGD,
141    /// 澳元
142    AUD,
143    /// 加元
144    CAD,
145    /// 瑞士法郎
146    CHF,
147    /// 其他货币
148    Other(String),
149}
150
151impl From<&str> for Currency {
152    fn from(s: &str) -> Self {
153        match s.to_uppercase().as_str() {
154            "USD" | "$" | "US$" => Currency::USD,
155            "EUR" | "€" => Currency::EUR,
156            "GBP" | "£" => Currency::GBP,
157            "JPY" | "¥" | "YEN" => Currency::JPY,
158            "CNY" | "RMB" | "元" => Currency::CNY,
159            "HKD" | "HK$" => Currency::HKD,
160            "SGD" | "S$" => Currency::SGD,
161            "AUD" | "A$" => Currency::AUD,
162            "CAD" | "C$" => Currency::CAD,
163            "CHF" => Currency::CHF,
164            other => Currency::Other(other.to_string()),
165        }
166    }
167}
168
169/// 发票类型枚举
170#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
171pub enum InvoiceType {
172    /// 标准发票
173    #[default]
174    Standard,
175    /// 贷项通知单(退款/折让凭证)
176    CreditNote,
177    /// 借项通知单(补收款凭证)
178    DebitNote,
179    /// 形式发票(报价/预开发票)
180    ProformaInvoice,
181    /// 商业发票(国际贸易)
182    CommercialInvoice,
183    /// 收据
184    Receipt,
185    /// 账单
186    Bill,
187    /// 对账单(如 AWS 账单)
188    Statement,
189    /// 未知类型
190    Unknown,
191}
192
193/// 地址信息结构体
194#[derive(Debug, Clone, Serialize, Deserialize, Default)]
195pub struct Address {
196    /// 地址第一行(街道地址)
197    pub line1: Option<String>,
198    /// 地址第二行(门牌号、楼层等)
199    pub line2: Option<String>,
200    /// 城市
201    pub city: Option<String>,
202    /// 州/省
203    pub state: Option<String>,
204    /// 邮政编码
205    pub postal_code: Option<String>,
206    /// 国家
207    pub country: Option<String>,
208}
209
210impl Address {
211    /// 返回完整地址字符串,各部分用逗号分隔
212    pub fn full_address(&self) -> String {
213        [
214            self.line1.as_deref(),
215            self.line2.as_deref(),
216            self.city.as_deref(),
217            self.state.as_deref(),
218            self.postal_code.as_deref(),
219            self.country.as_deref(),
220        ]
221        .iter()
222        .filter_map(|&s| s)
223        .collect::<Vec<_>>()
224        .join(", ")
225    }
226}
227
228/// 交易方信息结构体(供应商或客户)
229#[derive(Debug, Clone, Serialize, Deserialize, Default)]
230pub struct Party {
231    /// 公司/个人名称
232    pub name: Option<String>,
233    /// 税务识别号(如统一编号、VAT号等)
234    pub tax_id: Option<String>,
235    /// 地址信息
236    pub address: Option<Address>,
237    /// 电子邮件
238    pub email: Option<String>,
239    /// 电话号码
240    pub phone: Option<String>,
241}
242
243/// 发票行项目结构体(单个商品/服务明细)
244#[derive(Debug, Clone, Serialize, Deserialize, Default)]
245pub struct LineItem {
246    /// 行号
247    pub line_number: Option<u32>,
248    /// 服务/项目名称(如:Amazon CloudFront)
249    pub service_name: Option<String>,
250    /// 使用类型/描述(如:US-Requests-Tier2-HTTPS)
251    pub description: String,
252    /// 数量
253    #[serde(with = "decimal_format_option")]
254    pub quantity: Option<f64>,
255    /// 单位(如:个、件、小时等)
256    pub unit: Option<String>,
257    /// 单价
258    #[serde(with = "decimal_format_option")]
259    pub unit_price: Option<f64>,
260    /// 折扣金额
261    #[serde(with = "decimal_format_option")]
262    pub discount: Option<f64>,
263    /// 税率(百分比)
264    #[serde(with = "decimal_format_option")]
265    pub tax_rate: Option<f64>,
266    /// 税额
267    #[serde(with = "decimal_format_option")]
268    pub tax_amount: Option<f64>,
269    /// 金额(数量 × 单价 - 折扣 + 税额)
270    #[serde(with = "decimal_format")]
271    pub amount: f64,
272}
273
274impl LineItem {
275    /// 计算行项目金额:数量 × 单价 - 折扣 + 税额
276    pub fn calculate_amount(&self) -> f64 {
277        let qty = self.quantity.unwrap_or(1.0);
278        let price = self.unit_price.unwrap_or(self.amount);
279        let discount = self.discount.unwrap_or(0.0);
280        let tax = self.tax_amount.unwrap_or(0.0);
281
282        (qty * price) - discount + tax
283    }
284
285    /// 验证单价×数量是否等于金额
286    /// 返回 (是否有效, 计算值, 差异值)
287    pub fn validate_amount(&self) -> LineItemValidation {
288        let (qty, price) = match (self.quantity, self.unit_price) {
289            (Some(q), Some(p)) => (q, p),
290            _ => {
291                return LineItemValidation {
292                    is_valid: true,
293                    can_validate: false,
294                    calculated_amount: None,
295                    difference: None,
296                    difference_percent: None,
297                }
298            }
299        };
300
301        let calculated = qty * price;
302        let diff = (self.amount - calculated).abs();
303        let diff_percent = if calculated.abs() > 0.0001 {
304            (diff / calculated.abs()) * 100.0
305        } else {
306            0.0
307        };
308
309        LineItemValidation {
310            is_valid: diff < 0.01 || diff_percent < 1.0,
311            can_validate: true,
312            calculated_amount: Some(calculated),
313            difference: Some(self.amount - calculated),
314            difference_percent: Some(diff_percent),
315        }
316    }
317}
318
319/// 行项目验证结果
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct LineItemValidation {
322    /// 是否有效(差异在允许范围内)
323    pub is_valid: bool,
324    /// 是否可以验证(有单价和数量)
325    pub can_validate: bool,
326    /// 计算金额(单价×数量)
327    pub calculated_amount: Option<f64>,
328    /// 差异值(实际金额 - 计算金额)
329    pub difference: Option<f64>,
330    /// 差异百分比
331    pub difference_percent: Option<f64>,
332}
333
334/// 税务汇总结构体
335#[derive(Debug, Clone, Serialize, Deserialize, Default)]
336pub struct TaxSummary {
337    /// 税种(如:增值税、营业税、GST等)
338    pub tax_type: Option<String>,
339    /// 税率(百分比)
340    pub tax_rate: Option<f64>,
341    /// 应税金额(税基)
342    pub taxable_amount: Option<f64>,
343    /// 税额
344    pub tax_amount: f64,
345}
346
347/// 付款信息结构体
348#[derive(Debug, Clone, Serialize, Deserialize, Default)]
349pub struct PaymentInfo {
350    /// 付款方式(如:银行转账、信用卡、支票等)
351    pub method: Option<String>,
352    /// 银行名称
353    pub bank_name: Option<String>,
354    /// 银行账号
355    pub account_number: Option<String>,
356    /// 银行路由号(美国银行系统)
357    pub routing_number: Option<String>,
358    /// 国际银行账号(IBAN)
359    pub iban: Option<String>,
360    /// SWIFT/BIC 代码(国际汇款)
361    pub swift_code: Option<String>,
362    /// 付款参考号/备注
363    pub reference: Option<String>,
364}
365
366/// 发票主结构体
367///
368/// 包含发票的所有核心信息,支持多种发票格式的解析结果存储
369#[derive(Debug, Clone, Serialize, Deserialize, Default)]
370pub struct Invoice {
371    pub document_format: DocumentFormat,
372    pub invoice_type: InvoiceType,
373    /// 发票号码
374    pub invoice_number: Option<String>,
375    /// 账户名称/项目别名(如 AWS Account No 后括号内的名称)
376    pub account_name: Option<String>,
377    /// 发票日期
378    pub invoice_date: Option<NaiveDate>,
379    /// 付款截止日期
380    pub due_date: Option<NaiveDate>,
381    /// 货币类型
382    pub currency: Currency,
383    /// 供应商/卖方信息
384    pub vendor: Party,
385    /// 客户/买方信息
386    pub customer: Party,
387    /// 行项目列表(商品/服务明细)
388    pub line_items: Vec<LineItem>,
389    /// 小计金额(税前)
390    pub subtotal: Option<f64>,
391    /// 折扣金额(负数表示折扣)
392    pub discount_amount: Option<f64>,
393    /// 折扣比例(百分比,如 5.0 表示 5%)
394    pub discount_rate: Option<f64>,
395    /// 税务汇总列表
396    pub tax_summaries: Vec<TaxSummary>,
397    /// 总税额
398    pub total_tax: Option<f64>,
399    /// 总金额(含税)
400    pub total_amount: f64,
401    /// 已付金额
402    pub amount_paid: Option<f64>,
403    /// 应付金额(未付余额)
404    pub amount_due: Option<f64>,
405    /// 付款信息
406    pub payment_info: Option<PaymentInfo>,
407    /// 备注/附注
408    pub notes: Option<String>,
409    /// 原始文本(PDF提取的原文)
410    pub raw_text: Option<String>,
411    /// 元数据(扩展字段,如账单周期等)
412    pub metadata: std::collections::HashMap<String, String>,
413}
414
415impl Invoice {
416    /// 创建新的空发票实例
417    pub fn new() -> Self {
418        Self::default()
419    }
420
421    /// 序列化为格式化的 JSON 字符串
422    pub fn to_json(&self) -> Result<String, serde_json::Error> {
423        serde_json::to_string_pretty(self)
424    }
425
426    /// 序列化为紧凑的 JSON 字符串
427    pub fn to_json_compact(&self) -> Result<String, serde_json::Error> {
428        serde_json::to_string(self)
429    }
430
431    /// 从 JSON 字符串反序列化
432    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
433        serde_json::from_str(json)
434    }
435
436    /// 根据行项目计算小计金额
437    pub fn calculate_subtotal(&self) -> f64 {
438        self.line_items.iter().map(|item| item.amount).sum()
439    }
440
441    /// 根据税务汇总计算总税额
442    pub fn calculate_total_tax(&self) -> f64 {
443        self.tax_summaries.iter().map(|t| t.tax_amount).sum()
444    }
445
446    /// 判断发票是否已付清
447    pub fn is_paid(&self) -> bool {
448        match (self.amount_paid, self.amount_due) {
449            (Some(paid), _) if paid >= self.total_amount => true,
450            (_, Some(due)) if due <= 0.0 => true,
451            _ => false,
452        }
453    }
454
455    pub fn validate_line_items(&self) -> InvoiceValidation {
456        let mut validations = Vec::new();
457        let mut invalid_count = 0;
458        let mut validatable_count = 0;
459
460        for (idx, item) in self.line_items.iter().enumerate() {
461            let validation = item.validate_amount();
462            if validation.can_validate {
463                validatable_count += 1;
464                if !validation.is_valid {
465                    invalid_count += 1;
466                }
467            }
468            validations.push((idx, item.description.clone(), validation));
469        }
470
471        let line_items_sum = self.calculate_subtotal();
472        let subtotal_diff = self.subtotal.map(|s| {
473            let diff = s - line_items_sum;
474            let pct = if s.abs() > 0.0001 {
475                (diff.abs() / s.abs()) * 100.0
476            } else {
477                0.0
478            };
479            (diff, pct)
480        });
481
482        InvoiceValidation {
483            all_valid: invalid_count == 0,
484            total_items: self.line_items.len(),
485            validatable_items: validatable_count,
486            invalid_items: invalid_count,
487            line_items_sum,
488            subtotal: self.subtotal,
489            subtotal_difference: subtotal_diff.map(|(d, _)| d),
490            subtotal_difference_percent: subtotal_diff.map(|(_, p)| p),
491            item_validations: validations,
492        }
493    }
494}
495
496#[derive(Debug, Clone)]
497pub struct InvoiceValidation {
498    pub all_valid: bool,
499    pub total_items: usize,
500    pub validatable_items: usize,
501    pub invalid_items: usize,
502    pub line_items_sum: f64,
503    pub subtotal: Option<f64>,
504    pub subtotal_difference: Option<f64>,
505    pub subtotal_difference_percent: Option<f64>,
506    pub item_validations: Vec<(usize, String, LineItemValidation)>,
507}
508
509impl InvoiceValidation {
510    pub fn print_report(&self) {
511        println!("=== Invoice Validation Report ===");
512        println!("Total items: {}", self.total_items);
513        println!("Validatable items: {}", self.validatable_items);
514        println!("Invalid items: {}", self.invalid_items);
515        println!("Line items sum: {:.2}", self.line_items_sum);
516
517        if let Some(subtotal) = self.subtotal {
518            println!("Invoice subtotal: {:.2}", subtotal);
519            if let (Some(diff), Some(pct)) =
520                (self.subtotal_difference, self.subtotal_difference_percent)
521            {
522                println!("Difference: {:.2} ({:.2}%)", diff, pct);
523            }
524        }
525
526        let invalid_items: Vec<_> = self
527            .item_validations
528            .iter()
529            .filter(|(_, _, v)| v.can_validate && !v.is_valid)
530            .collect();
531
532        if !invalid_items.is_empty() {
533            println!("\nInvalid line items:");
534            for (idx, desc, v) in invalid_items {
535                println!(
536                    "  Line {}: {} | calculated: {:.4} | diff: {:.4}",
537                    idx + 1,
538                    desc,
539                    v.calculated_amount.unwrap_or(0.0),
540                    v.difference.unwrap_or(0.0)
541                );
542            }
543        }
544    }
545}
546
547/// 解析结果结构体
548///
549/// 封装解析操作的输出,支持单个或多个发票,并包含警告信息
550#[derive(Debug, Clone, Serialize, Deserialize)]
551pub struct ParseResult {
552    /// 解析出的发票列表
553    pub invoices: Vec<Invoice>,
554    /// 源文件路径
555    pub source_file: Option<String>,
556    /// 解析过程中的警告信息
557    pub parse_warnings: Vec<String>,
558}
559
560impl ParseResult {
561    /// 创建包含单个发票的解析结果
562    pub fn single(invoice: Invoice) -> Self {
563        Self {
564            invoices: vec![invoice],
565            source_file: None,
566            parse_warnings: Vec::new(),
567        }
568    }
569
570    /// 创建包含多个发票的解析结果
571    pub fn multiple(invoices: Vec<Invoice>) -> Self {
572        Self {
573            invoices,
574            source_file: None,
575            parse_warnings: Vec::new(),
576        }
577    }
578
579    /// 设置源文件路径(链式调用)
580    pub fn with_source(mut self, source: impl Into<String>) -> Self {
581        self.source_file = Some(source.into());
582        self
583    }
584
585    /// 添加警告信息(链式调用)
586    pub fn with_warning(mut self, warning: impl Into<String>) -> Self {
587        self.parse_warnings.push(warning.into());
588        self
589    }
590}