use modkit_macros::domain_model;
use crate::domain::model::quota::SettlementMethod;
use crate::infra::db::entity::chat_turn::TurnState;
#[domain_model]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BillingOutcome {
Completed,
Failed,
Aborted,
}
impl BillingOutcome {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Completed => "completed",
Self::Failed => "failed",
Self::Aborted => "aborted",
}
}
}
#[domain_model]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BillingDerivation {
pub outcome: BillingOutcome,
pub settlement_method: SettlementMethod,
pub unknown_error_code: bool,
}
#[domain_model]
#[derive(Debug, Clone)]
pub struct BillingDerivationInput {
pub terminal_state: TurnState,
pub error_code: Option<String>,
pub has_usage: bool,
}
#[must_use]
pub fn derive_billing_outcome(input: &BillingDerivationInput) -> BillingDerivation {
match input.terminal_state {
TurnState::Completed => BillingDerivation {
outcome: BillingOutcome::Completed,
settlement_method: SettlementMethod::Actual,
unknown_error_code: false,
},
TurnState::Cancelled => {
if input.has_usage {
BillingDerivation {
outcome: BillingOutcome::Aborted,
settlement_method: SettlementMethod::Actual,
unknown_error_code: false,
}
} else {
BillingDerivation {
outcome: BillingOutcome::Aborted,
settlement_method: SettlementMethod::Estimated,
unknown_error_code: false,
}
}
}
TurnState::Failed => match input.error_code.as_deref() {
Some("orphan_timeout") => BillingDerivation {
outcome: BillingOutcome::Aborted,
settlement_method: SettlementMethod::Estimated,
unknown_error_code: false,
},
Some("context_length_exceeded" | "validation_error") => BillingDerivation {
outcome: BillingOutcome::Failed,
settlement_method: SettlementMethod::Released,
unknown_error_code: false,
},
Some(
"provider_error"
| "provider_timeout"
| "rate_limited"
| "web_search_calls_exceeded",
) => {
if input.has_usage {
BillingDerivation {
outcome: BillingOutcome::Failed,
settlement_method: SettlementMethod::Actual,
unknown_error_code: false,
}
} else {
BillingDerivation {
outcome: BillingOutcome::Failed,
settlement_method: SettlementMethod::Estimated,
unknown_error_code: false,
}
}
}
_ => BillingDerivation {
outcome: BillingOutcome::Failed,
settlement_method: SettlementMethod::Estimated,
unknown_error_code: true,
},
},
TurnState::Running => unreachable!("finalization called with Running state"),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn input(
state: TurnState,
error_code: Option<&str>,
has_usage: bool,
) -> BillingDerivationInput {
BillingDerivationInput {
terminal_state: state,
error_code: error_code.map(String::from),
has_usage,
}
}
#[test]
fn completed_derives_completed_actual() {
let r = derive_billing_outcome(&input(TurnState::Completed, None, true));
assert_eq!(r.outcome, BillingOutcome::Completed);
assert_eq!(r.settlement_method, SettlementMethod::Actual);
assert!(!r.unknown_error_code);
}
#[test]
fn cancelled_with_usage_derives_aborted_actual() {
let r = derive_billing_outcome(&input(TurnState::Cancelled, None, true));
assert_eq!(r.outcome, BillingOutcome::Aborted);
assert_eq!(r.settlement_method, SettlementMethod::Actual);
}
#[test]
fn cancelled_without_usage_derives_aborted_estimated() {
let r = derive_billing_outcome(&input(TurnState::Cancelled, None, false));
assert_eq!(r.outcome, BillingOutcome::Aborted);
assert_eq!(r.settlement_method, SettlementMethod::Estimated);
}
#[test]
fn orphan_timeout_derives_aborted_estimated() {
let r = derive_billing_outcome(&input(TurnState::Failed, Some("orphan_timeout"), true));
assert_eq!(r.outcome, BillingOutcome::Aborted);
assert_eq!(r.settlement_method, SettlementMethod::Estimated);
}
#[test]
fn context_length_exceeded_derives_failed_released() {
let r = derive_billing_outcome(&input(
TurnState::Failed,
Some("context_length_exceeded"),
false,
));
assert_eq!(r.outcome, BillingOutcome::Failed);
assert_eq!(r.settlement_method, SettlementMethod::Released);
}
#[test]
fn validation_error_derives_failed_released() {
let r = derive_billing_outcome(&input(TurnState::Failed, Some("validation_error"), false));
assert_eq!(r.outcome, BillingOutcome::Failed);
assert_eq!(r.settlement_method, SettlementMethod::Released);
}
#[test]
fn provider_error_with_usage_derives_failed_actual() {
let r = derive_billing_outcome(&input(TurnState::Failed, Some("provider_error"), true));
assert_eq!(r.outcome, BillingOutcome::Failed);
assert_eq!(r.settlement_method, SettlementMethod::Actual);
}
#[test]
fn provider_error_without_usage_derives_failed_estimated() {
let r = derive_billing_outcome(&input(TurnState::Failed, Some("provider_error"), false));
assert_eq!(r.outcome, BillingOutcome::Failed);
assert_eq!(r.settlement_method, SettlementMethod::Estimated);
}
#[test]
fn rate_limited_with_usage_derives_failed_actual() {
let r = derive_billing_outcome(&input(TurnState::Failed, Some("rate_limited"), true));
assert_eq!(r.outcome, BillingOutcome::Failed);
assert_eq!(r.settlement_method, SettlementMethod::Actual);
}
#[test]
fn unknown_error_code_derives_failed_estimated_with_flag() {
let r = derive_billing_outcome(&input(TurnState::Failed, Some("some_new_code"), true));
assert_eq!(r.outcome, BillingOutcome::Failed);
assert_eq!(r.settlement_method, SettlementMethod::Estimated);
assert!(r.unknown_error_code);
}
#[test]
fn failed_with_no_error_code_derives_failed_estimated_with_flag() {
let r = derive_billing_outcome(&input(TurnState::Failed, None, false));
assert_eq!(r.outcome, BillingOutcome::Failed);
assert_eq!(r.settlement_method, SettlementMethod::Estimated);
assert!(r.unknown_error_code);
}
#[test]
#[should_panic(expected = "finalization called with Running state")]
#[allow(clippy::let_underscore_must_use, dropping_copy_types)]
fn running_state_panics() {
drop(derive_billing_outcome(&input(
TurnState::Running,
None,
false,
)));
}
}