use briefcase_core::{BudgetAlert, BudgetStatus, CostCalculator, CostEstimate, ModelPricing};
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct WasmCostCalculator {
inner: CostCalculator,
}
impl Default for WasmCostCalculator {
fn default() -> Self {
Self::new()
}
}
#[wasm_bindgen]
impl WasmCostCalculator {
#[wasm_bindgen(constructor)]
pub fn new() -> WasmCostCalculator {
WasmCostCalculator {
inner: CostCalculator::new(),
}
}
#[wasm_bindgen(js_name = estimateCost)]
pub fn estimate_cost(
&self,
model_name: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<WasmCostEstimate, JsValue> {
let estimate = self
.inner
.estimate_cost(model_name, input_tokens as usize, output_tokens as usize)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
Ok(WasmCostEstimate { inner: estimate })
}
#[wasm_bindgen(js_name = estimateCostFromText)]
pub fn estimate_cost_from_text(
&self,
model_name: &str,
input_text: &str,
estimated_output_tokens: u32,
) -> Result<WasmCostEstimate, JsValue> {
let estimate = self
.inner
.estimate_cost_from_text(model_name, input_text, estimated_output_tokens as usize)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
Ok(WasmCostEstimate { inner: estimate })
}
#[wasm_bindgen(js_name = checkBudget)]
pub fn check_budget(&self, spent: f64, budget: f64) -> WasmBudgetStatus {
let status = self.inner.check_budget(spent, budget);
WasmBudgetStatus { inner: status }
}
#[wasm_bindgen(js_name = getCheapestModel)]
pub fn get_cheapest_model(&self, min_context_window: u32) -> Option<WasmModelPricing> {
self.inner
.get_cheapest_model(min_context_window as usize)
.map(|pricing| WasmModelPricing {
inner: pricing.clone(),
})
}
#[wasm_bindgen(js_name = getModelsUnderCost)]
pub fn get_models_under_cost(&self, max_cost_per_1k: f64) -> Result<JsValue, JsValue> {
let models = self.inner.get_models_under_cost(max_cost_per_1k);
let model_objects: Vec<serde_json::Value> = models
.iter()
.map(|m| {
serde_json::json!({
"model_name": m.model_name,
"provider": m.provider,
"input_cost_per_1k_tokens": m.input_cost_per_1k_tokens,
"output_cost_per_1k_tokens": m.output_cost_per_1k_tokens,
"context_window": m.context_window,
"max_output_tokens": m.max_output_tokens,
})
})
.collect();
serde_wasm_bindgen::to_value(&model_objects)
.map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
}
#[wasm_bindgen(js_name = getModelsByProvider)]
pub fn get_models_by_provider(&self, provider: &str) -> Result<JsValue, JsValue> {
let models = self.inner.get_models_by_provider(provider);
let model_objects: Vec<serde_json::Value> = models
.iter()
.map(|m| {
serde_json::json!({
"model_name": m.model_name,
"provider": m.provider,
"input_cost_per_1k_tokens": m.input_cost_per_1k_tokens,
"output_cost_per_1k_tokens": m.output_cost_per_1k_tokens,
"context_window": m.context_window,
"max_output_tokens": m.max_output_tokens,
})
})
.collect();
serde_wasm_bindgen::to_value(&model_objects)
.map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
}
#[wasm_bindgen(js_name = compareModels)]
pub fn compare_models(
&self,
model_a: &str,
model_b: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<WasmModelComparison, JsValue> {
let comparison = self
.inner
.compare_models(
model_a,
model_b,
input_tokens as usize,
output_tokens as usize,
)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
Ok(WasmModelComparison { inner: comparison })
}
#[wasm_bindgen(js_name = addModel)]
pub fn add_model(&mut self, model_pricing: &WasmModelPricing) {
self.inner.add_model(model_pricing.inner.clone());
}
#[wasm_bindgen(js_name = removeModel)]
pub fn remove_model(&mut self, model_name: &str) -> Option<WasmModelPricing> {
self.inner
.remove_model(model_name)
.map(|pricing| WasmModelPricing { inner: pricing })
}
#[wasm_bindgen(js_name = getAllModels)]
pub fn get_all_models(&self) -> Result<JsValue, JsValue> {
let models = self.inner.get_all_models();
let model_objects: Vec<serde_json::Value> = models
.iter()
.map(|m| {
serde_json::json!({
"model_name": m.model_name,
"provider": m.provider,
"input_cost_per_1k_tokens": m.input_cost_per_1k_tokens,
"output_cost_per_1k_tokens": m.output_cost_per_1k_tokens,
"context_window": m.context_window,
"max_output_tokens": m.max_output_tokens,
})
})
.collect();
serde_wasm_bindgen::to_value(&model_objects)
.map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
}
#[wasm_bindgen(js_name = projectMonthlyCost)]
pub fn project_monthly_cost(
&self,
model_name: &str,
daily_input_tokens: u32,
daily_output_tokens: u32,
days_per_month: f64,
) -> Result<WasmCostProjection, JsValue> {
let projection = self
.inner
.project_monthly_cost(
model_name,
daily_input_tokens as usize,
daily_output_tokens as usize,
days_per_month,
)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
Ok(WasmCostProjection { inner: projection })
}
}
#[wasm_bindgen]
pub struct WasmCostEstimate {
inner: CostEstimate,
}
#[wasm_bindgen]
impl WasmCostEstimate {
#[wasm_bindgen(getter)]
pub fn model_name(&self) -> String {
self.inner.model_name.clone()
}
#[wasm_bindgen(getter)]
pub fn input_tokens(&self) -> u32 {
self.inner.input_tokens as u32
}
#[wasm_bindgen(getter)]
pub fn output_tokens(&self) -> u32 {
self.inner.output_tokens as u32
}
#[wasm_bindgen(getter)]
pub fn input_cost(&self) -> f64 {
self.inner.input_cost
}
#[wasm_bindgen(getter)]
pub fn output_cost(&self) -> f64 {
self.inner.output_cost
}
#[wasm_bindgen(getter)]
pub fn total_cost(&self) -> f64 {
self.inner.total_cost
}
#[wasm_bindgen(getter)]
pub fn currency(&self) -> String {
self.inner.currency.clone()
}
#[wasm_bindgen(js_name = toObject)]
pub fn to_object(&self) -> Result<JsValue, JsValue> {
serde_wasm_bindgen::to_value(&self.inner)
.map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
}
}
#[wasm_bindgen]
pub struct WasmBudgetStatus {
inner: BudgetStatus,
}
#[wasm_bindgen]
impl WasmBudgetStatus {
#[wasm_bindgen(getter)]
pub fn budget_usd(&self) -> f64 {
self.inner.budget_usd
}
#[wasm_bindgen(getter)]
pub fn spent_usd(&self) -> f64 {
self.inner.spent_usd
}
#[wasm_bindgen(getter)]
pub fn remaining_usd(&self) -> f64 {
self.inner.remaining_usd
}
#[wasm_bindgen(getter)]
pub fn percent_used(&self) -> f64 {
self.inner.percent_used
}
#[wasm_bindgen(getter)]
pub fn status(&self) -> String {
match self.inner.status {
BudgetAlert::Ok => "ok".to_string(),
BudgetAlert::Warning => "warning".to_string(),
BudgetAlert::Critical => "critical".to_string(),
BudgetAlert::Exceeded => "exceeded".to_string(),
}
}
#[wasm_bindgen(js_name = toObject)]
pub fn to_object(&self) -> Result<JsValue, JsValue> {
let obj = serde_json::json!({
"budget_usd": self.inner.budget_usd,
"spent_usd": self.inner.spent_usd,
"remaining_usd": self.inner.remaining_usd,
"percent_used": self.inner.percent_used,
"status": self.status(),
});
serde_wasm_bindgen::to_value(&obj)
.map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
}
}
#[wasm_bindgen]
pub struct WasmModelPricing {
inner: ModelPricing,
}
#[wasm_bindgen]
impl WasmModelPricing {
#[wasm_bindgen(constructor)]
pub fn new(
model_name: &str,
provider: &str,
input_cost: f64,
output_cost: f64,
context_window: u32,
max_output_tokens: Option<u32>,
) -> WasmModelPricing {
WasmModelPricing {
inner: ModelPricing {
model_name: model_name.to_string(),
provider: provider.to_string(),
input_cost_per_1k_tokens: input_cost,
output_cost_per_1k_tokens: output_cost,
context_window: context_window as usize,
max_output_tokens: max_output_tokens.map(|t| t as usize),
},
}
}
#[wasm_bindgen(getter)]
pub fn model_name(&self) -> String {
self.inner.model_name.clone()
}
#[wasm_bindgen(getter)]
pub fn provider(&self) -> String {
self.inner.provider.clone()
}
#[wasm_bindgen(getter)]
pub fn input_cost_per_1k_tokens(&self) -> f64 {
self.inner.input_cost_per_1k_tokens
}
#[wasm_bindgen(getter)]
pub fn output_cost_per_1k_tokens(&self) -> f64 {
self.inner.output_cost_per_1k_tokens
}
#[wasm_bindgen(getter)]
pub fn context_window(&self) -> u32 {
self.inner.context_window as u32
}
#[wasm_bindgen(getter)]
pub fn max_output_tokens(&self) -> Option<u32> {
self.inner.max_output_tokens.map(|t| t as u32)
}
#[wasm_bindgen(js_name = toObject)]
pub fn to_object(&self) -> Result<JsValue, JsValue> {
serde_wasm_bindgen::to_value(&self.inner)
.map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
}
}
#[wasm_bindgen]
pub struct WasmModelComparison {
inner: briefcase_core::ModelComparison,
}
#[wasm_bindgen]
impl WasmModelComparison {
#[wasm_bindgen(getter)]
pub fn model_a(&self) -> WasmCostEstimate {
WasmCostEstimate {
inner: self.inner.model_a.clone(),
}
}
#[wasm_bindgen(getter)]
pub fn model_b(&self) -> WasmCostEstimate {
WasmCostEstimate {
inner: self.inner.model_b.clone(),
}
}
#[wasm_bindgen(getter)]
pub fn cheaper_model(&self) -> String {
self.inner.cheaper_model.clone()
}
#[wasm_bindgen(getter)]
pub fn savings(&self) -> f64 {
self.inner.savings
}
#[wasm_bindgen(getter)]
pub fn percent_difference(&self) -> f64 {
self.inner.percent_difference
}
#[wasm_bindgen(js_name = toObject)]
pub fn to_object(&self) -> Result<JsValue, JsValue> {
let obj = serde_json::json!({
"model_a": self.inner.model_a,
"model_b": self.inner.model_b,
"cheaper_model": self.inner.cheaper_model,
"savings": self.inner.savings,
"percent_difference": self.inner.percent_difference,
});
serde_wasm_bindgen::to_value(&obj)
.map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
}
}
#[wasm_bindgen]
pub struct WasmCostProjection {
inner: briefcase_core::CostProjection,
}
#[wasm_bindgen]
impl WasmCostProjection {
#[wasm_bindgen(getter)]
pub fn model_name(&self) -> String {
self.inner.model_name.clone()
}
#[wasm_bindgen(getter)]
pub fn daily_cost(&self) -> f64 {
self.inner.daily_cost
}
#[wasm_bindgen(getter)]
pub fn monthly_cost(&self) -> f64 {
self.inner.monthly_cost
}
#[wasm_bindgen(getter)]
pub fn annual_cost(&self) -> f64 {
self.inner.annual_cost
}
#[wasm_bindgen(getter)]
pub fn currency(&self) -> String {
self.inner.currency.clone()
}
#[wasm_bindgen(js_name = toObject)]
pub fn to_object(&self) -> Result<JsValue, JsValue> {
serde_wasm_bindgen::to_value(&self.inner)
.map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
}
}
#[cfg(test)]
mod tests {
use briefcase_core::{BudgetAlert, CostCalculator, ModelPricing};
#[test]
fn test_cost_estimate_known_model() {
let calc = CostCalculator::new();
let estimate = calc.estimate_cost("gpt-4", 1000, 500).unwrap();
assert_eq!(estimate.model_name, "gpt-4");
assert_eq!(estimate.input_tokens, 1000);
assert_eq!(estimate.output_tokens, 500);
assert!(estimate.total_cost > 0.0);
assert_eq!(estimate.currency, "USD");
}
#[test]
fn test_cost_estimate_unknown_model_fails() {
let calc = CostCalculator::new();
let result = calc.estimate_cost("nonexistent-model", 100, 100);
assert!(result.is_err());
}
#[test]
fn test_cost_estimate_zero_tokens() {
let calc = CostCalculator::new();
let result = calc.estimate_cost("gpt-4", 0, 0);
assert!(result.is_err());
}
#[test]
fn test_cost_estimate_total_equals_sum() {
let calc = CostCalculator::new();
let estimate = calc.estimate_cost("gpt-4", 1000, 500).unwrap();
let expected = estimate.input_cost + estimate.output_cost;
assert!((estimate.total_cost - expected).abs() < 1e-10);
}
#[test]
fn test_budget_ok() {
let calc = CostCalculator::new();
let status = calc.check_budget(10.0, 100.0);
assert_eq!(status.budget_usd, 100.0);
assert_eq!(status.spent_usd, 10.0);
assert_eq!(status.remaining_usd, 90.0);
assert!((status.percent_used - 10.0).abs() < 0.1);
assert!(matches!(status.status, BudgetAlert::Ok));
}
#[test]
fn test_budget_warning() {
let calc = CostCalculator::new();
let status = calc.check_budget(80.0, 100.0);
assert!(matches!(status.status, BudgetAlert::Warning));
}
#[test]
fn test_budget_critical() {
let calc = CostCalculator::new();
let status = calc.check_budget(95.0, 100.0);
assert!(matches!(status.status, BudgetAlert::Critical));
}
#[test]
fn test_budget_exceeded() {
let calc = CostCalculator::new();
let status = calc.check_budget(110.0, 100.0);
assert!(matches!(status.status, BudgetAlert::Exceeded));
}
#[test]
fn test_compare_models() {
let calc = CostCalculator::new();
let comparison = calc
.compare_models("gpt-4", "gpt-3.5-turbo", 1000, 500)
.unwrap();
assert!(!comparison.cheaper_model.is_empty());
assert!(comparison.savings >= 0.0);
assert!(comparison.percent_difference >= 0.0);
}
#[test]
fn test_add_and_remove_custom_model() {
let mut calc = CostCalculator::new();
let custom = ModelPricing {
model_name: "custom-model".to_string(),
provider: "custom".to_string(),
input_cost_per_1k_tokens: 0.001,
output_cost_per_1k_tokens: 0.002,
context_window: 8192,
max_output_tokens: Some(4096),
};
calc.add_model(custom);
let estimate = calc.estimate_cost("custom-model", 1000, 500).unwrap();
assert_eq!(estimate.model_name, "custom-model");
let removed = calc.remove_model("custom-model");
assert!(removed.is_some());
assert!(calc.estimate_cost("custom-model", 100, 100).is_err());
}
#[test]
fn test_get_cheapest_model() {
let calc = CostCalculator::new();
let cheapest = calc.get_cheapest_model(0);
assert!(cheapest.is_some());
}
#[test]
fn test_get_models_by_provider() {
let calc = CostCalculator::new();
let openai_models = calc.get_models_by_provider("openai");
assert!(!openai_models.is_empty());
for model in &openai_models {
assert_eq!(model.provider, "openai");
}
}
#[test]
fn test_get_all_models_nonempty() {
let calc = CostCalculator::new();
let models = calc.get_all_models();
assert!(!models.is_empty());
}
#[test]
fn test_project_monthly_cost() {
let calc = CostCalculator::new();
let projection = calc
.project_monthly_cost("gpt-4", 4000, 2000, 30.0)
.unwrap();
assert_eq!(projection.model_name, "gpt-4");
assert!(projection.daily_cost > 0.0);
assert!(projection.monthly_cost > 0.0);
assert!(projection.annual_cost > 0.0);
assert!((projection.monthly_cost - projection.daily_cost * 30.0).abs() < 0.01);
}
#[test]
fn test_estimate_cost_from_text() {
let calc = CostCalculator::new();
let estimate = calc
.estimate_cost_from_text("gpt-4", "Hello world, this is a test", 100)
.unwrap();
assert!(estimate.input_tokens > 0);
assert_eq!(estimate.output_tokens, 100);
assert!(estimate.total_cost > 0.0);
}
}